diff --git a/src/db/lib.rs b/src/db/lib.rs index ce68bc1..6239c2a 100644 --- a/src/db/lib.rs +++ b/src/db/lib.rs @@ -2,8 +2,19 @@ use diesel::prelude::*; use crate::config::Config; -pub fn establish_connection() -> PgConnection { +pub fn get_connection() -> PgConnection { let database_url = Config::get().db_uri.clone(); PgConnection::establish(&database_url) .unwrap_or_else(|_| panic!("Error connecting to {}", Config::get().db_uri)) } + +#[cfg(test)] +pub mod tests { + use super::*; + + pub fn get_test_connection() -> PgConnection { + let mut connection = get_connection(); + assert!(connection.begin_test_transaction().is_ok()); + connection + } +} diff --git a/src/db/models.rs b/src/db/models.rs index 523ea10..52ca82c 100644 --- a/src/db/models.rs +++ b/src/db/models.rs @@ -6,8 +6,6 @@ use diesel::prelude::*; use self::schema::records::{self}; -use super::lib::establish_connection; - mod schema { diesel::table! { records (name, _type, class, rdlength, rdata) { @@ -81,8 +79,7 @@ impl Record { } } -pub async fn insert_into_database(rr: &RR) -> Result<(), ZNSError> { - let db_connection = &mut establish_connection(); +pub fn insert_into_database(rr: &RR, connection: &mut PgConnection) -> Result<(), ZNSError> { let record = Record { name: rr.name.join("."), _type: rr._type.clone().into(), @@ -92,21 +89,21 @@ pub async fn insert_into_database(rr: &RR) -> Result<(), ZNSError> { rdata: rr.rdata.clone(), }; - Record::create(db_connection, record).map_err(|e| ZNSError::Database { + Record::create(connection, record).map_err(|e| ZNSError::Database { message: e.to_string(), })?; Ok(()) } -pub async fn get_from_database( +pub fn get_from_database( name: &Vec, _type: Type, class: Class, + connection: &mut PgConnection, ) -> Result, ZNSError> { - let db_connection = &mut establish_connection(); let records = - Record::get(db_connection, name.join("."), _type.into(), class.into()).map_err(|e| { + Record::get(connection, name.join("."), _type.into(), class.into()).map_err(|e| { ZNSError::Database { message: e.to_string(), } @@ -128,18 +125,58 @@ pub async fn get_from_database( } //TODO: cleanup models -pub async fn delete_from_database( +pub fn delete_from_database( name: &Vec, _type: Option, class: Class, rdata: Option>, + connection: &mut PgConnection, ) { - let db_connection = &mut establish_connection(); let _ = Record::delete( - db_connection, + connection, name.join("."), _type.map(|f| f.into()), class.into(), rdata, ); } + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{db::lib::tests::get_test_connection, parser::tests::get_rr}; + + #[test] + fn test() { + let mut connection = get_test_connection(); + + let rr = get_rr(); + + let f = |connection: &mut PgConnection| { + get_from_database(&rr.name, rr._type.clone(), rr.class.clone(), connection) + }; + + assert!(f(&mut connection).unwrap().is_empty()); + + assert!(insert_into_database(&rr, &mut connection).is_ok()); + + let result = f(&mut connection); + assert!(result.is_ok()); + assert_eq!(result.as_ref().unwrap().len(), 1); + assert_eq!(result.unwrap()[0], rr); + + delete_from_database( + &rr.name, + Some(rr._type.clone()), + rr.class.clone(), + Some(rr.rdata.clone()), + &mut connection, + ); + + assert!(f(&mut connection).unwrap().is_empty()); + + assert!(insert_into_database(&rr, &mut connection).is_ok()); + assert!(insert_into_database(&rr, &mut connection).is_err()); + } +} diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 6de0873..9cc1243 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -1,3 +1,5 @@ +use diesel::PgConnection; + use crate::{ errors::ZNSError, structs::{Message, Opcode}, @@ -9,17 +11,26 @@ mod query; mod update; pub trait ResponseHandler { - async fn handle(message: &Message, raw: &[u8]) -> Result; + async fn handle( + message: &Message, + raw: &[u8], + connection: &mut PgConnection, + ) -> Result; } pub struct Handler {} impl ResponseHandler for Handler { - async fn handle(message: &Message, raw: &[u8]) -> Result { + async fn handle( + message: &Message, + raw: &[u8], + connection: &mut PgConnection, + ) -> Result { match message.get_opcode() { + //TODO: implement this in Opcode Ok(opcode) => match opcode { - Opcode::QUERY => QueryHandler::handle(&message, raw).await, - Opcode::UPDATE => UpdateHandler::handle(&message, raw).await, + Opcode::QUERY => QueryHandler::handle(&message, raw, connection).await, + Opcode::UPDATE => UpdateHandler::handle(&message, raw, connection).await, }, Err(e) => Err(ZNSError::Formerr { message: e.to_string(), diff --git a/src/handlers/query.rs b/src/handlers/query.rs index 680df0a..d653e99 100644 --- a/src/handlers/query.rs +++ b/src/handlers/query.rs @@ -1,3 +1,5 @@ +use diesel::PgConnection; + use crate::{db::models::get_from_database, errors::ZNSError, structs::Message}; use super::ResponseHandler; @@ -5,7 +7,11 @@ use super::ResponseHandler; pub struct QueryHandler {} impl ResponseHandler for QueryHandler { - async fn handle(message: &Message, _raw: &[u8]) -> Result { + async fn handle( + message: &Message, + _raw: &[u8], + connection: &mut PgConnection, + ) -> Result { let mut response = message.clone(); response.header.arcount = 0; //TODO: fix this, handle unknown class values @@ -14,8 +20,8 @@ impl ResponseHandler for QueryHandler { &question.qname, question.qtype.clone(), question.qclass.clone(), - ) - .await; + connection, + ); match answers { Ok(rrs) => { @@ -24,7 +30,7 @@ impl ResponseHandler for QueryHandler { domain: question.qname.join("."), }); } - response.header.ancount = rrs.len() as u16; + response.header.ancount += rrs.len() as u16; response.answer.extend(rrs) } Err(e) => { @@ -38,3 +44,38 @@ impl ResponseHandler for QueryHandler { Ok(response) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::tests::get_message; + use crate::structs::*; + + use crate::{ + db::{lib::tests::get_test_connection, models::insert_into_database}, + parser::{tests::get_rr, ToBytes}, + }; + + #[tokio::test] + async fn test_handle_query() { + let mut connection = get_test_connection(); + let rr = get_rr(); + let mut message = get_message(); + message.header.ancount = 0; + message.answer = vec![]; + + assert!(insert_into_database(&rr, &mut connection).is_ok()); + + let result = QueryHandler::handle( + &message, + &Message::to_bytes(message.clone()), + &mut connection, + ) + .await + .unwrap(); + assert_eq!(result.header.ancount, 2); + assert_eq!(result.answer.len(), 2); + assert_eq!(result.answer[0], rr); + assert_eq!(result.answer[1], rr); + } +} diff --git a/src/handlers/update/authenticate.rs b/src/handlers/update/authenticate.rs index 6c0242d..8715570 100644 --- a/src/handlers/update/authenticate.rs +++ b/src/handlers/update/authenticate.rs @@ -1,3 +1,5 @@ +use diesel::PgConnection; + use crate::{ config::Config, db::models::get_from_database, @@ -9,7 +11,11 @@ use crate::{ use super::{dnskey::DNSKeyRData, sig::Sig}; -pub async fn authenticate(sig: &Sig, zone: &Vec) -> Result { +pub async fn authenticate( + sig: &Sig, + zone: &Vec, + connection: &mut PgConnection, +) -> Result { if zone.len() >= 4 { let username = &zone[zone.len() - 4]; // Should match: username.users.zeus.gent @@ -18,7 +24,7 @@ pub async fn authenticate(sig: &Sig, zone: &Vec) -> Result Result, sig: &Sig) -> Result { - Ok( - get_from_database(zone, Type::Type(RRType::DNSKEY), Class::Class(RRClass::IN)) - .await? - .iter() - .any(|rr| { - let mut reader = Reader::new(&rr.rdata); - DNSKeyRData::from_bytes(&mut reader) - .is_ok_and(|dnskey| sig.verify_dnskey(dnskey).is_ok_and(|b| b)) - }), - ) +async fn validate_dnskey( + zone: &Vec, + sig: &Sig, + connection: &mut PgConnection, +) -> Result { + Ok(get_from_database( + zone, + Type::Type(RRType::DNSKEY), + Class::Class(RRClass::IN), + connection, + )? + .iter() + .any(|rr| { + let mut reader = Reader::new(&rr.rdata); + DNSKeyRData::from_bytes(&mut reader) + .is_ok_and(|dnskey| sig.verify_dnskey(dnskey).is_ok_and(|b| b)) + })) } diff --git a/src/handlers/update/mod.rs b/src/handlers/update/mod.rs index c1cac91..26c26f7 100644 --- a/src/handlers/update/mod.rs +++ b/src/handlers/update/mod.rs @@ -1,3 +1,5 @@ +use diesel::PgConnection; + use crate::{ db::models::{delete_from_database, insert_into_database}, errors::ZNSError, @@ -17,7 +19,11 @@ mod sig; pub struct UpdateHandler {} impl ResponseHandler for UpdateHandler { - async fn handle(message: &Message, raw: &[u8]) -> Result { + async fn handle( + message: &Message, + raw: &[u8], + connection: &mut PgConnection, + ) -> Result { let response = message.clone(); // Zone section (question) processing if (message.header.qdcount != 1) @@ -44,7 +50,7 @@ impl ResponseHandler for UpdateHandler { if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) { let sig = Sig::new(last.unwrap(), raw)?; - if !authenticate::authenticate(&sig, &zone.qname) + if !authenticate::authenticate(&sig, &zone.qname, connection) .await .is_ok_and(|x| x) { @@ -89,7 +95,7 @@ impl ResponseHandler for UpdateHandler { for rr in &message.authority { if rr.class == zone.qclass { - let _ = insert_into_database(&rr).await; + let _ = insert_into_database(&rr, connection); } else if rr.class == Class::Class(RRClass::ANY) { if rr._type == Type::Type(RRType::ANY) { if rr.name == zone.qname { @@ -98,7 +104,13 @@ impl ResponseHandler for UpdateHandler { message: "rr.name == zone.qname".to_string(), }); } else { - delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await; + delete_from_database( + &rr.name, + None, + Class::Class(RRClass::IN), + None, + connection, + ) } } else { delete_from_database( @@ -106,8 +118,8 @@ impl ResponseHandler for UpdateHandler { Some(rr._type.clone()), Class::Class(RRClass::IN), None, + connection, ) - .await; } } else if rr.class == Class::Class(RRClass::NONE) { if rr._type == Type::Type(RRType::SOA) { @@ -118,8 +130,8 @@ impl ResponseHandler for UpdateHandler { Some(rr._type.clone()), Class::Class(RRClass::IN), Some(rr.rdata.clone()), + connection, ) - .await; } } diff --git a/src/handlers/update/pubkeys/rsa.rs b/src/handlers/update/pubkeys/rsa.rs index 5ec5809..06f7966 100644 --- a/src/handlers/update/pubkeys/rsa.rs +++ b/src/handlers/update/pubkeys/rsa.rs @@ -29,6 +29,18 @@ impl PublicKey for RsaPublicKey { Ok(RsaPublicKey { e, n }) } + fn from_dnskey(key: &[u8]) -> Result + where + Self: Sized, + { + let mut reader = Reader::new(key); + let e_len = reader.read_u8()?; + let e = reader.read(e_len as usize)?; + let mut n = reader.read(reader.unread_bytes())?; + n.insert(0, 0); + Ok(RsaPublicKey { e, n }) + } + fn verify( &self, data: &[u8], @@ -47,7 +59,7 @@ impl PublicKey for RsaPublicKey { Algorithm::RSASHA512 => Ok(&signature::RSA_PKCS1_2048_8192_SHA512), Algorithm::RSASHA256 => Ok(&signature::RSA_PKCS1_2048_8192_SHA256), _ => Err(ZNSError::PublicKey { - message: format!("RsaPublicKey: invalid verify algorithm",), + message: String::from("RsaPublicKey: invalid verify algorithm"), }), }?; @@ -55,16 +67,4 @@ impl PublicKey for RsaPublicKey { Ok(pkey.verify(data, signature).is_ok()) } - - fn from_dnskey(key: &[u8]) -> Result - where - Self: Sized, - { - let mut reader = Reader::new(key); - let e_len = reader.read_u8()?; - let e = reader.read(e_len as usize)?; - let mut n = reader.read(reader.unread_bytes())?; - n.insert(0, 0); - Ok(RsaPublicKey { e, n }) - } } diff --git a/src/message.rs b/src/message.rs index eb22497..ca952dc 100644 --- a/src/message.rs +++ b/src/message.rs @@ -9,4 +9,43 @@ impl Message { pub fn get_opcode(&self) -> Result { Opcode::try_from((self.header.flags & 0b0111100000000000) >> 11) } + + #[allow(dead_code)] // Used with tests + pub fn get_rcode(&self) -> Result { + RCODE::try_from(self.header.flags & (!0 >> 12)) + } +} + +#[cfg(test)] +mod tests { + + use crate::structs::Header; + + use super::*; + + #[test] + fn test() { + let mut message = Message { + header: Header { + id: 1, + flags: 288, + qdcount: 0, + ancount: 0, + nscount: 0, + arcount: 0, + }, + question: vec![], + answer: vec![], + authority: vec![], + additional: vec![], + }; + + assert_eq!(message.get_opcode().unwrap() as u8, Opcode::QUERY as u8); + + message.set_response(RCODE::NOTIMP); + + assert!((message.header.flags & (1 << 15)) > 0); + + assert_eq!(message.get_rcode().unwrap(), RCODE::NOTIMP); + } } diff --git a/src/parser.rs b/src/parser.rs index 11f8e8a..cd58b80 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -295,3 +295,151 @@ impl ToBytes for Message { result } } + +#[cfg(test)] +pub mod tests { + use super::*; + + pub fn get_rr() -> RR { + RR { + name: vec![String::from("example"), String::from("org")], + _type: Type::Type(RRType::A), + class: Class::Class(RRClass::IN), + ttl: 10, + rdlength: 4, + rdata: vec![1, 2, 3, 4], + } + } + + pub fn get_message() -> Message { + Message { + header: Header { + id: 1, + flags: 288, + qdcount: 2, + ancount: 1, + nscount: 1, + arcount: 1, + }, + question: vec![ + Question { + qname: vec![String::from("example"), String::from("org")], + qtype: Type::Type(RRType::A), + qclass: Class::Class(RRClass::IN), + }, + Question { + qname: vec![String::from("example"), String::from("org")], + qtype: Type::Type(RRType::A), + qclass: Class::Class(RRClass::IN), + }, + ], + answer: vec![get_rr()], + authority: vec![get_rr()], + additional: vec![get_rr()], + } + } + + #[test] + fn test_parse_header() { + let header = Header { + id: 1, + flags: 288, + qdcount: 1, + ancount: 0, + nscount: 0, + arcount: 0, + }; + + let bytes = Header::to_bytes(header.clone()); + let parsed = Header::from_bytes(&mut Reader::new(&bytes)); + assert!(parsed.is_ok()); + assert_eq!(parsed.unwrap(), header); + } + + #[test] + fn test_parse_question() { + let question = Question { + qname: vec![String::from("example"), String::from("org")], + qtype: Type::Type(RRType::A), + qclass: Class::Class(RRClass::IN), + }; + + let bytes = Question::to_bytes(question.clone()); + let parsed = Question::from_bytes(&mut Reader::new(&bytes)); + assert!(parsed.is_ok()); + assert_eq!(parsed.unwrap(), question); + } + + #[test] + fn test_parse_rr() { + let rr = get_rr(); + + let bytes = RR::to_bytes(rr.clone()); + let parsed = RR::from_bytes(&mut Reader::new(&bytes)); + assert!(parsed.is_ok()); + assert_eq!(parsed.unwrap(), rr); + } + + #[test] + fn test_labelstring() { + let labelstring = vec![String::from("example"), String::from("org")]; + + let bytes = LabelString::to_bytes(labelstring.clone()); + let parsed = LabelString::from_bytes(&mut Reader::new(&bytes)); + assert!(parsed.is_ok()); + assert_eq!(parsed.unwrap(), labelstring); + } + + #[test] + fn test_labelstring_ptr() { + let labelstring = vec![String::from("example"), String::from("org")]; + + let mut bytes = LabelString::to_bytes(labelstring.clone()); + + bytes.insert(0, 0); + bytes.insert(0, 0); + + let to_read = bytes.len(); + + bytes.push(0b11000000); + bytes.push(0b00000010); + + let mut reader = Reader::new(&bytes); + let _ = reader.read(to_read); + + let parsed = LabelString::from_bytes(&mut reader); + assert!(parsed.is_ok()); + assert_eq!(parsed.unwrap(), labelstring); + } + + #[test] + fn test_labelstring_invalid_ptr() { + let labelstring = vec![String::from("example"), String::from("org")]; + + let mut bytes = LabelString::to_bytes(labelstring.clone()); + + bytes.insert(0, 0); + bytes.insert(0, 0); + + let to_read = bytes.len(); + + bytes.push(0b11000000); + // Not allowed to point to itself or in the future + bytes.push(to_read as u8); + + let mut reader = Reader::new(&bytes); + let _ = reader.read(to_read); + + let parsed = LabelString::from_bytes(&mut reader); + assert!(parsed.is_err()); + } + + #[test] + fn test_parse_message() { + let message = get_message(); + let bytes = Message::to_bytes(message.clone()); + let parsed = Message::from_bytes(&mut Reader::new(&bytes)); + assert!(parsed.is_ok()); + assert_eq!(parsed.unwrap(), message); + } +} diff --git a/src/reader.rs b/src/reader.rs index 82efbca..fb831e0 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,5 +1,3 @@ -use std::array::TryFromSliceError; - use crate::errors::ZNSError; pub struct Reader<'a> { @@ -33,43 +31,37 @@ impl<'a> Reader<'a> { } pub fn read_u8(&mut self) -> Result { - self.position += 1; - Ok(self.buffer[self.position - 1]) + if self.unread_bytes() == 0 { + Err(ZNSError::Reader { + message: String::from("cannot read u8"), + }) + } else { + self.position += 1; + Ok(self.buffer[self.position - 1]) + } } pub fn read_u16(&mut self) -> Result { - let result = u16::from_be_bytes( - self.buffer[self.position..self.position + 2] - .try_into() - .map_err(|e: TryFromSliceError| ZNSError::Reader { - message: e.to_string(), - })?, - ); - self.position += 2; + let result = + u16::from_be_bytes(self.read(2)?.try_into().map_err(|_| ZNSError::Reader { + message: String::from("invalid read_u16"), + })?); Ok(result) } pub fn read_i32(&mut self) -> Result { - let result = i32::from_be_bytes( - self.buffer[self.position..self.position + 4] - .try_into() - .map_err(|e: TryFromSliceError| ZNSError::Reader { - message: e.to_string(), - })?, - ); - self.position += 4; + let result = + i32::from_be_bytes(self.read(4)?.try_into().map_err(|_| ZNSError::Reader { + message: String::from("invalid read_u32"), + })?); Ok(result) } pub fn read_u32(&mut self) -> Result { - let result = u32::from_be_bytes( - self.buffer[self.position..self.position + 4] - .try_into() - .map_err(|e: TryFromSliceError| ZNSError::Reader { - message: e.to_string(), - })?, - ); - self.position += 4; + let result = + u32::from_be_bytes(self.read(4)?.try_into().map_err(|_| ZNSError::Reader { + message: String::from("invalid read_u32"), + })?); Ok(result) } @@ -83,3 +75,58 @@ impl<'a> Reader<'a> { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() { + let fake_bytes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let mut reader = Reader::new(&fake_bytes); + + assert_eq!(reader.unread_bytes(), 11); + + let u16 = reader.read_u16(); + assert!(u16.is_ok()); + assert_eq!(u16.unwrap(), 1); + + assert_eq!(reader.unread_bytes(), 9); + + let u8 = reader.read_u8(); + assert!(u8.is_ok()); + assert_eq!(u8.unwrap(), 2); + assert_eq!(reader.unread_bytes(), 8); + + let u32 = reader.read_u32(); + assert!(u32.is_ok()); + assert_eq!( + u32.unwrap(), + u32::from_be_bytes(fake_bytes[3..7].try_into().unwrap()) + ); + assert_eq!(reader.unread_bytes(), 4); + + let read = reader.read(3); + assert!(read.is_ok()); + assert_eq!(read.unwrap(), fake_bytes[7..10]); + assert_eq!(reader.unread_bytes(), 1); + + let too_much = reader.read(2); + assert!(too_much.is_err()); + assert_eq!(reader.unread_bytes(), 1); + + assert!(reader.read_u8().is_ok()); + + assert!(reader.read_u8().is_err()); + assert!(reader.read_u16().is_err()); + assert!(reader.read_u32().is_err()); + assert!(reader.read_i32().is_err()); + + let new_reader = reader.seek(1); + assert!(new_reader.is_ok()); + assert_eq!(new_reader.unwrap().unread_bytes(), 10); + + let new_reader = reader.seek(100); + assert!(new_reader.is_err()); + } +} diff --git a/src/resolver.rs b/src/resolver.rs index 6447e7a..fe26d75 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use tokio::net::UdpSocket; +use crate::db::lib::get_connection; use crate::errors::ZNSError; use crate::handlers::{Handler, ResponseHandler}; use crate::parser::{FromBytes, ToBytes}; @@ -43,12 +44,13 @@ fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message { async fn get_response(bytes: &[u8]) -> Message { let mut reader = Reader::new(bytes); match Message::from_bytes(&mut reader) { - Ok(mut message) => match Handler::handle(&message, bytes).await { + Ok(mut message) => match Handler::handle(&message, bytes, &mut get_connection()).await { Ok(mut response) => { response.set_response(RCODE::NOERROR); response } Err(e) => { + println!("{:#?}", message); eprintln!("{}", e.to_string()); message.set_response(e.rcode()); message @@ -72,3 +74,36 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box, @@ -78,7 +79,7 @@ pub struct Message { pub additional: Vec, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct RR { pub name: LabelString, pub _type: Type, @@ -88,11 +89,4 @@ pub struct RR { pub rdata: Vec, } -#[derive(Debug, Clone)] -pub struct OptRR { - pub code: u16, - pub length: u16, - pub rdata: Vec, -} - pub type LabelString = Vec;