diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs new file mode 100644 index 0000000..548a3ae --- /dev/null +++ b/src/handlers/mod.rs @@ -0,0 +1,30 @@ +use crate::{ + errors::DNSError, + structs::{Message, Opcode, RCODE}, +}; + +use self::{query::QueryHandler, update::UpdateHandler}; + +mod query; +mod update; + +pub trait ResponseHandler { + async fn handle(message: &Message, raw: &[u8]) -> Result; +} + +pub struct Handler {} + +impl ResponseHandler for Handler { + async fn handle(message: &Message, raw: &[u8]) -> Result { + match message.get_opcode() { + Ok(opcode) => match opcode { + Opcode::QUERY => QueryHandler::handle(&message, raw).await, + Opcode::UPDATE => UpdateHandler::handle(&message, raw).await, + }, + Err(e) => Err(DNSError { + message: e.to_string(), + rcode: RCODE::FORMERR, + }), + } + } +} diff --git a/src/handlers/query.rs b/src/handlers/query.rs new file mode 100644 index 0000000..e4f8f70 --- /dev/null +++ b/src/handlers/query.rs @@ -0,0 +1,35 @@ +use crate::{ + db::models::get_from_database, + errors::DNSError, + structs::{Message, RCODE}, +}; + +use super::ResponseHandler; + +pub(super) struct QueryHandler {} + +impl ResponseHandler for QueryHandler { + async fn handle(message: &Message, _raw: &[u8]) -> Result { + let mut response = message.clone(); + response.header.arcount = 0; //TODO: fix this, handle unknown class values + + for question in &message.question { + let answers = get_from_database(&question).await; + + match answers { + Ok(rrs) => { + response.header.ancount = rrs.len() as u16; + response.answer.extend(rrs) + } + Err(e) => { + return Err(DNSError { + rcode: RCODE::NXDOMAIN, + message: e.to_string(), + }) + } + } + } + + Ok(response) + } +} diff --git a/src/authenticate.rs b/src/handlers/update/authenticate.rs similarity index 89% rename from src/authenticate.rs rename to src/handlers/update/authenticate.rs index 49da614..323a901 100644 --- a/src/authenticate.rs +++ b/src/handlers/update/authenticate.rs @@ -2,16 +2,15 @@ use std::env; use reqwest::Error; -use crate::{ - errors::AuthenticationError, - sig::{PublicKey, Sig}, -}; +use crate::errors::AuthenticationError; + +use super::sig::{PublicKey, Sig}; type SSHKeys = Vec; type Result = std::result::Result; -pub async fn authenticate(sig: &Sig, zone: &Vec) -> Result { +pub(super) async fn authenticate(sig: &Sig, zone: &Vec) -> Result { if zone.len() >= 4 { let username = &zone[zone.len() - 4]; // Should match: username.users.zeus.gent let public_keys = get_keys(username).await.map_err(|e| AuthenticationError { diff --git a/src/handlers/update/mod.rs b/src/handlers/update/mod.rs new file mode 100644 index 0000000..41f9707 --- /dev/null +++ b/src/handlers/update/mod.rs @@ -0,0 +1,132 @@ +use crate::{ + db::models::{delete_from_database, insert_into_database}, + errors::DNSError, + structs::{Class, Message, RRClass, RRType, Type, RCODE}, + utils::vec_equal, +}; + +use self::sig::Sig; + +use super::ResponseHandler; + +mod authenticate; +mod sig; + +pub(super) struct UpdateHandler {} + +impl ResponseHandler for UpdateHandler { + async fn handle(message: &Message, raw: &[u8]) -> Result { + let response = message.clone(); + // Zone section (question) processing + if (message.header.qdcount != 1) + || !matches!(message.question[0].qtype, Type::Type(RRType::SOA)) + { + return Err(DNSError { + message: "Qdcount not one".to_string(), + rcode: RCODE::FORMERR, + }); + } + + // Check Zone authority + let zone = &message.question[0]; + let zlen = zone.qname.len(); + if !(zlen >= 2 && zone.qname[zlen - 1] == "gent" && zone.qname[zlen - 2] == "zeus") { + return Err(DNSError { + message: "Invalid zone".to_string(), + rcode: RCODE::NOTAUTH, + }); + } + + // Check Prerequisite TODO: implement this + + //TODO: this code is ugly + let last = message.additional.last(); + if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) { + let sig = Sig::new(last.unwrap(), raw)?; + + if !authenticate::authenticate(&sig, &zone.qname) + .await + .is_ok_and(|x| x) + { + return Err(DNSError { + message: "Unable to verify authentication".to_string(), + rcode: RCODE::NOTAUTH, + }); + } + } else { + return Err(DNSError { + message: "No KEY record at the end of request found".to_string(), + rcode: RCODE::NOTAUTH, + }); + } + + // Update Section Prescan + for rr in &message.authority { + let rlen = rr.name.len(); + + // Check if rr has same zone + if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) { + return Err(DNSError { + message: "RR has different zone from Question".to_string(), + rcode: RCODE::NOTZONE, + }); + } + + match (rr.class == Class::Class(RRClass::ANY) && (rr.ttl != 0 || rr.rdlength != 0)) + || (rr.class == Class::Class(RRClass::NONE) && rr.ttl != 0) + || ![ + Class::Class(RRClass::NONE), + Class::Class(RRClass::ANY), + zone.qclass.clone(), + ] + .contains(&rr.class) + { + true => { + return Err(DNSError { + message: "RR has invalid rr,ttl or class".to_string(), + rcode: RCODE::FORMERR, + }); + } + false => (), + } + } + + for rr in &message.authority { + if rr.class == zone.qclass { + let _ = insert_into_database(&rr).await; + } else if rr.class == Class::Class(RRClass::ANY) { + if rr._type == Type::Type(RRType::ANY) { + if rr.name == zone.qname { + return Err(DNSError { + message: "Not yet implemented".to_string(), + rcode: RCODE::NOTIMP, + }); + } else { + delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await; + } + } else { + delete_from_database( + &rr.name, + Some(rr._type.clone()), + Class::Class(RRClass::IN), + None, + ) + .await; + } + } else if rr.class == Class::Class(RRClass::NONE) { + if rr._type == Type::Type(RRType::SOA) { + continue; + } + delete_from_database( + &rr.name, + Some(rr._type.clone()), + Class::Class(RRClass::IN), + Some(rr.rdata.clone()), + ) + .await; + } + } + + Ok(response) + } +} diff --git a/src/sig.rs b/src/handlers/update/sig.rs similarity index 96% rename from src/sig.rs rename to src/handlers/update/sig.rs index cb0e679..f5f19ad 100644 --- a/src/sig.rs +++ b/src/handlers/update/sig.rs @@ -7,12 +7,12 @@ use crate::{ structs::{KeyRData, RR}, }; -pub struct Sig { +pub(super) struct Sig { raw_data: Vec, key_rdata: KeyRData, } -pub enum PublicKey { +pub(super) enum PublicKey { ED25519(String), } diff --git a/src/main.rs b/src/main.rs index ed12551..baac33b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,13 +4,13 @@ use dotenvy::dotenv; use crate::resolver::resolver_listener_loop; -mod authenticate; mod db; mod errors; +mod handlers; +mod message; mod parser; mod reader; mod resolver; -mod sig; mod structs; mod utils; diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..eb22497 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,12 @@ +use crate::structs::{Message, Opcode, RCODE}; + +impl Message { + pub fn set_response(&mut self, rcode: RCODE) { + self.header.flags = (self.header.flags | 0b1_0000_1_0_0_0_000_0000 | rcode as u16) + & 0b1_1111_1_0_1_0_111_1111 + } + + pub fn get_opcode(&self) -> Result { + Opcode::try_from((self.header.flags & 0b0111100000000000) >> 11) + } +} diff --git a/src/resolver.rs b/src/resolver.rs index 0b212de..d776631 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -4,158 +4,14 @@ use std::sync::Arc; use tokio::net::UdpSocket; -use crate::authenticate::authenticate; -use crate::db::models::{delete_from_database, get_from_database, insert_into_database}; -use crate::errors::{DNSError, ParseError}; +use crate::errors::ParseError; +use crate::handlers::{Handler, ResponseHandler}; use crate::parser::FromBytes; use crate::reader::Reader; -use crate::sig::Sig; -use crate::structs::{Class, Header, Message, Opcode, RRClass, RRType, Type, RCODE}; -use crate::utils::vec_equal; +use crate::structs::{Header, Message, RCODE}; const MAX_DATAGRAM_SIZE: usize = 4096; -fn set_response_flags(flags: &u16, rcode: RCODE) -> u16 { - (flags | 0b1_0000_1_0_0_0_000_0000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111 -} - -fn get_opcode(flags: &u16) -> Result { - Opcode::try_from((flags & 0b0111100000000000) >> 11) -} - -async fn handle_query(message: &Message) -> Result { - let mut response = message.clone(); - response.header.arcount = 0; //TODO: fix this, handle unknown class values - - for question in &message.question { - let answers = get_from_database(&question).await; - - match answers { - Ok(rrs) => { - response.header.ancount = rrs.len() as u16; - response.answer.extend(rrs) - } - Err(e) => { - return Err(DNSError { - rcode: RCODE::NXDOMAIN, - message: e.to_string(), - }) - } - } - } - - Ok(response) -} - -async fn handle_update(message: &Message, bytes: &[u8]) -> Result { - let response = message.clone(); - // Zone section (question) processing - if (message.header.qdcount != 1) - || !matches!(message.question[0].qtype, Type::Type(RRType::SOA)) - { - return Err(DNSError { - message: "Qdcount not one".to_string(), - rcode: RCODE::FORMERR, - }); - } - - // Check Zone authority - let zone = &message.question[0]; - let zlen = zone.qname.len(); - if !(zlen >= 2 && zone.qname[zlen - 1] == "gent" && zone.qname[zlen - 2] == "zeus") { - return Err(DNSError { - message: "Invalid zone".to_string(), - rcode: RCODE::NOTAUTH, - }); - } - - // Check Prerequisite TODO: implement this - - //TODO: this code is ugly - let last = message.additional.last(); - if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) { - let sig = Sig::new(last.unwrap(), bytes)?; - - if !authenticate(&sig, &zone.qname).await.is_ok_and(|x| x) { - return Err(DNSError { - message: "Unable to verify authentication".to_string(), - rcode: RCODE::NOTAUTH, - }); - } - } else { - return Err(DNSError { - message: "No KEY record at the end of request found".to_string(), - rcode: RCODE::NOTAUTH, - }); - } - - // Update Section Prescan - for rr in &message.authority { - let rlen = rr.name.len(); - - // Check if rr has same zone - if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) { - return Err(DNSError { - message: "RR has different zone from Question".to_string(), - rcode: RCODE::NOTZONE, - }); - } - - if (rr.class == Class::Class(RRClass::ANY) && (rr.ttl != 0 || rr.rdlength != 0)) - || (rr.class == Class::Class(RRClass::NONE) && rr.ttl != 0) - || ![ - Class::Class(RRClass::NONE), - Class::Class(RRClass::ANY), - zone.qclass.clone(), - ] - .contains(&rr.class) - { - return Err(DNSError { - message: "RR has invalid rr,ttl or class".to_string(), - rcode: RCODE::FORMERR, - }); - } - } - - for rr in &message.authority { - if rr.class == zone.qclass { - let _ = insert_into_database(&rr).await; - } else if rr.class == Class::Class(RRClass::ANY) { - if rr._type == Type::Type(RRType::ANY) { - if rr.name == zone.qname { - return Err(DNSError { - message: "Not yet implemented".to_string(), - rcode: RCODE::NOTIMP, - }); - } else { - delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await; - } - } else { - delete_from_database( - &rr.name, - Some(rr._type.clone()), - Class::Class(RRClass::IN), - None, - ) - .await; - } - } else if rr.class == Class::Class(RRClass::NONE) { - if rr._type == Type::Type(RRType::SOA) { - continue; - } - delete_from_database( - &rr.name, - Some(rr._type.clone()), - Class::Class(RRClass::IN), - Some(rr.rdata.clone()), - ) - .await; - } - } - - Ok(response) -} - fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { eprintln!("{}", err); let mut reader = Reader::new(bytes); @@ -172,41 +28,31 @@ fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { header.ancount = 0; header.nscount = 0; header.arcount = 0; - header.flags = set_response_flags(&header.flags, RCODE::FORMERR); - Message { + let mut message = Message { header, question: vec![], answer: vec![], authority: vec![], additional: vec![], - } + }; + message.set_response(RCODE::FORMERR); + message } async fn get_response(bytes: &[u8]) -> Message { let mut reader = Reader::new(bytes); match Message::from_bytes(&mut reader) { - Ok(mut message) => match get_opcode(&message.header.flags) { - Ok(opcode) => { - let result = match opcode { - Opcode::QUERY => handle_query(&message).await, - Opcode::UPDATE => handle_update(&message, bytes).await, - }; - - match result { - Ok(mut response) => { - response.header.flags = - set_response_flags(&response.header.flags, RCODE::NOERROR); - response - } - Err(e) => { - eprintln!("{}", e.to_string()); - message.header.flags = set_response_flags(&message.header.flags, e.rcode); - message - } - } + Ok(mut message) => match Handler::handle(&message, bytes).await { + Ok(mut response) => { + response.set_response(RCODE::NOERROR); + response + } + Err(e) => { + eprintln!("{}", e.to_string()); + message.set_response(e.rcode); + message } - Err(_) => todo!(), }, Err(err) => handle_parse_error(bytes, err), }