diff --git a/src/db/models.rs b/src/db/models.rs index 52ca82c..f71de9e 100644 --- a/src/db/models.rs +++ b/src/db/models.rs @@ -35,16 +35,18 @@ impl Record { pub fn get( db: &mut PgConnection, name: String, - _type: i32, + _type: Option, class: i32, ) -> Result, diesel::result::Error> { - records::table - .filter( - records::name - .eq(name) - .and(records::_type.eq(_type).and(records::class.eq(class))), - ) - .get_results(db) + let mut query = records::table.into_boxed(); + + query = query.filter(records::name.eq(name).and(records::class.eq(class))); + + if let Some(value) = _type { + query = query.filter(records::_type.eq(value)) + } + + query.get_results(db) } pub fn create( @@ -98,16 +100,19 @@ pub fn insert_into_database(rr: &RR, connection: &mut PgConnection) -> Result<() pub fn get_from_database( name: &Vec, - _type: Type, + _type: Option, class: Class, connection: &mut PgConnection, ) -> Result, ZNSError> { - let records = - Record::get(connection, name.join("."), _type.into(), class.into()).map_err(|e| { - ZNSError::Database { - message: e.to_string(), - } - })?; + let records = Record::get( + connection, + name.join("."), + _type.map(|t| t.into()), + class.into(), + ) + .map_err(|e| ZNSError::Database { + message: e.to_string(), + })?; Ok(records .into_iter() @@ -154,7 +159,12 @@ mod tests { let rr = get_rr(); let f = |connection: &mut PgConnection| { - get_from_database(&rr.name, rr._type.clone(), rr.class.clone(), connection) + get_from_database( + &rr.name, + Some(rr._type.clone()), + rr.class.clone(), + connection, + ) }; assert!(f(&mut connection).unwrap().is_empty()); diff --git a/src/handlers/query.rs b/src/handlers/query.rs index d653e99..593f41d 100644 --- a/src/handlers/query.rs +++ b/src/handlers/query.rs @@ -1,11 +1,16 @@ use diesel::PgConnection; -use crate::{db::models::get_from_database, errors::ZNSError, structs::Message}; +use crate::{ + db::models::get_from_database, + errors::ZNSError, + structs::{Message, Question, RR}, +}; use super::ResponseHandler; pub struct QueryHandler {} +//TODO: the clones in this file should and could be avoided impl ResponseHandler for QueryHandler { async fn handle( message: &Message, @@ -18,17 +23,20 @@ impl ResponseHandler for QueryHandler { for question in &message.question { let answers = get_from_database( &question.qname, - question.qtype.clone(), + Some(question.qtype.clone()), question.qclass.clone(), connection, ); match answers { - Ok(rrs) => { + Ok(mut rrs) => { if rrs.len() == 0 { - return Err(ZNSError::NXDomain { - domain: question.qname.join("."), - }); + rrs.extend(try_wildcard(question, connection)?); + if rrs.len() == 0 { + return Err(ZNSError::NXDomain { + domain: question.qname.join("."), + }); + } } response.header.ancount += rrs.len() as u16; response.answer.extend(rrs) @@ -45,6 +53,23 @@ impl ResponseHandler for QueryHandler { } } +fn try_wildcard(question: &Question, connection: &mut PgConnection) -> Result, ZNSError> { + let records = get_from_database(&question.qname, None, question.qclass.clone(), connection)?; + + if records.len() > 0 { + Ok(vec![]) + } else { + let mut qname = question.qname.clone(); + qname[0] = String::from("*"); + get_from_database( + &qname, + Some(question.qtype.clone()), + question.qclass.clone(), + connection, + ) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/handlers/update/authenticate.rs b/src/handlers/update/authenticate.rs index 9bfe8e3..0a9d0c1 100644 --- a/src/handlers/update/authenticate.rs +++ b/src/handlers/update/authenticate.rs @@ -58,7 +58,7 @@ async fn validate_dnskey( ) -> Result { Ok(get_from_database( zone, - Type::Type(RRType::DNSKEY), + Some(Type::Type(RRType::DNSKEY)), Class::Class(RRClass::IN), connection, )? diff --git a/src/reader.rs b/src/reader.rs index 749c82b..8f708ef 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -66,12 +66,12 @@ impl<'a> Reader<'a> { } pub fn seek(&self, position: usize) -> Result { - if position >= self.position { + if position >= self.position - 2 { Err(ZNSError::Reader { message: String::from("Seeking into the future is not allowed!!"), }) } else { - let mut reader = Reader::new(&self.buffer[0..self.position]); + let mut reader = Reader::new(&self.buffer[..self.position]); reader.position = position; Ok(reader) }