diff --git a/src/db/models.rs b/src/db/models.rs index ad0e3ec..b034775 100644 --- a/src/db/models.rs +++ b/src/db/models.rs @@ -81,15 +81,15 @@ impl Record { } } -pub async fn insert_into_database(rr: RR) -> Result<(), DatabaseError> { +pub async fn insert_into_database(rr: &RR) -> Result<(), DatabaseError> { let db_connection = &mut establish_connection(); let record = Record { name: rr.name.join("."), - _type: rr._type.into(), - class: rr.class.into(), + _type: rr._type.clone().into(), + class: rr.class.clone().into(), ttl: rr.ttl, rdlength: rr.rdlength as i32, - rdata: rr.rdata, + rdata: rr.rdata.clone(), }; Record::create(db_connection, record).map_err(|e| DatabaseError { @@ -128,11 +128,17 @@ pub async fn get_from_database(question: &Question) -> Result, DatabaseE //TODO: cleanup models pub async fn delete_from_database( - name: Vec, + name: &Vec, _type: Option, class: Class, rdata: Option>, ) { let db_connection = &mut establish_connection(); - let _ = Record::delete(db_connection, name.join("."), _type.map(|f| f.into()), class.into(), rdata); + let _ = Record::delete( + db_connection, + name.join("."), + _type.map(|f| f.into()), + class.into(), + rdata, + ); } diff --git a/src/errors.rs b/src/errors.rs index b7faa43..ac01e6b 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,8 +1,10 @@ use core::fmt; -#[derive(Debug)] +use crate::structs::RCODE; + pub struct DNSError { pub message: String, + pub rcode: RCODE } impl fmt::Display for DNSError { @@ -25,7 +27,7 @@ impl fmt::Display for ParseError { #[derive(Debug)] pub struct DatabaseError { - pub message: String, + pub message: String } impl fmt::Display for DatabaseError { @@ -67,3 +69,15 @@ where } } } + +impl From for DNSError +where + E: Into, +{ + fn from(value: E) -> Self { + DNSError { + message: value.into().to_string(), + rcode: RCODE::FORMERR + } + } +} diff --git a/src/resolver.rs b/src/resolver.rs index 8349c2d..0b212de 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -6,7 +6,7 @@ use tokio::net::UdpSocket; use crate::authenticate::authenticate; use crate::db::models::{delete_from_database, get_from_database, insert_into_database}; -use crate::errors::ParseError; +use crate::errors::{DNSError, ParseError}; use crate::parser::FromBytes; use crate::reader::Reader; use crate::sig::Sig; @@ -15,7 +15,7 @@ use crate::utils::vec_equal; const MAX_DATAGRAM_SIZE: usize = 4096; -fn set_response_flags(flags: u16, rcode: RCODE) -> u16 { +fn set_response_flags(flags: &u16, rcode: RCODE) -> u16 { (flags | 0b1_0000_1_0_0_0_000_0000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111 } @@ -23,46 +23,50 @@ fn get_opcode(flags: &u16) -> Result { Opcode::try_from((flags & 0b0111100000000000) >> 11) } -async fn handle_query(message: Message) -> Message { +async fn handle_query(message: &Message) -> Result { let mut response = message.clone(); response.header.arcount = 0; //TODO: fix this, handle unknown class values - for question in message.question { + for question in &message.question { let answers = get_from_database(&question).await; match answers { Ok(rrs) => { - response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR); response.header.ancount = rrs.len() as u16; - response.answer = rrs + response.answer.extend(rrs) } Err(e) => { - response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN); - eprintln!("{}", e); + return Err(DNSError { + rcode: RCODE::NXDOMAIN, + message: e.to_string(), + }) } } } - response + Ok(response) } -async fn handle_update(message: Message, bytes: &[u8]) -> Message { - let mut response = message.clone(); - +async fn handle_update(message: &Message, bytes: &[u8]) -> Result { + let response = message.clone(); // Zone section (question) processing if (message.header.qdcount != 1) || !matches!(message.question[0].qtype, Type::Type(RRType::SOA)) { - response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); - return response; + return Err(DNSError { + message: "Qdcount not one".to_string(), + rcode: RCODE::FORMERR, + }); } // Check Zone authority let zone = &message.question[0]; let zlen = zone.qname.len(); if !(zlen >= 2 && zone.qname[zlen - 1] == "gent" && zone.qname[zlen - 2] == "zeus") { - response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); - return response; + return Err(DNSError { + message: "Invalid zone".to_string(), + rcode: RCODE::NOTAUTH, + }); } // Check Prerequisite TODO: implement this @@ -70,15 +74,19 @@ async fn handle_update(message: Message, bytes: &[u8]) -> Message { //TODO: this code is ugly let last = message.additional.last(); if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) { - let sig = Sig::new(last.unwrap(), bytes); + let sig = Sig::new(last.unwrap(), bytes)?; if !authenticate(&sig, &zone.qname).await.is_ok_and(|x| x) { - response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); - return response; + return Err(DNSError { + message: "Unable to verify authentication".to_string(), + rcode: RCODE::NOTAUTH, + }); } } else { - response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); - return response; + return Err(DNSError { + message: "No KEY record at the end of request found".to_string(), + rcode: RCODE::NOTAUTH, + }); } // Update Section Prescan @@ -87,8 +95,10 @@ async fn handle_update(message: Message, bytes: &[u8]) -> Message { // Check if rr has same zone if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) { - response.header.flags = set_response_flags(response.header.flags, RCODE::NOTZONE); - return response; + return Err(DNSError { + message: "RR has different zone from Question".to_string(), + rcode: RCODE::NOTZONE, + }); } if (rr.class == Class::Class(RRClass::ANY) && (rr.ttl != 0 || rr.rdlength != 0)) @@ -100,43 +110,50 @@ async fn handle_update(message: Message, bytes: &[u8]) -> Message { ] .contains(&rr.class) { - response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); - return response; + return Err(DNSError { + message: "RR has invalid rr,ttl or class".to_string(), + rcode: RCODE::FORMERR, + }); } } - for rr in message.authority { + for rr in &message.authority { if rr.class == zone.qclass { - let _ = insert_into_database(rr).await; + let _ = insert_into_database(&rr).await; } else if rr.class == Class::Class(RRClass::ANY) { if rr._type == Type::Type(RRType::ANY) { if rr.name == zone.qname { - response.header.flags = - set_response_flags(response.header.flags, RCODE::NOTIMP); - return response; + return Err(DNSError { + message: "Not yet implemented".to_string(), + rcode: RCODE::NOTIMP, + }); } else { - delete_from_database(rr.name, None, Class::Class(RRClass::IN), None).await; + delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await; } } else { - delete_from_database(rr.name, Some(rr._type), Class::Class(RRClass::IN), None) - .await; + delete_from_database( + &rr.name, + Some(rr._type.clone()), + Class::Class(RRClass::IN), + None, + ) + .await; } } else if rr.class == Class::Class(RRClass::NONE) { if rr._type == Type::Type(RRType::SOA) { continue; } delete_from_database( - rr.name, - Some(rr._type), + &rr.name, + Some(rr._type.clone()), Class::Class(RRClass::IN), - Some(rr.rdata), + Some(rr.rdata.clone()), ) .await; } } - response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR); - response + Ok(response) } fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { @@ -155,7 +172,7 @@ fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { header.ancount = 0; header.nscount = 0; header.arcount = 0; - header.flags = set_response_flags(header.flags, RCODE::FORMERR); + header.flags = set_response_flags(&header.flags, RCODE::FORMERR); Message { header, @@ -169,11 +186,26 @@ fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { async fn get_response(bytes: &[u8]) -> Message { let mut reader = Reader::new(bytes); match Message::from_bytes(&mut reader) { - Ok(message) => match get_opcode(&message.header.flags) { - Ok(opcode) => match opcode { - Opcode::QUERY => handle_query(message).await, - Opcode::UPDATE => handle_update(message, bytes).await, - }, + Ok(mut message) => match get_opcode(&message.header.flags) { + Ok(opcode) => { + let result = match opcode { + Opcode::QUERY => handle_query(&message).await, + Opcode::UPDATE => handle_update(&message, bytes).await, + }; + + match result { + Ok(mut response) => { + response.header.flags = + set_response_flags(&response.header.flags, RCODE::NOERROR); + response + } + Err(e) => { + eprintln!("{}", e.to_string()); + message.header.flags = set_response_flags(&message.header.flags, e.rcode); + message + } + } + } Err(_) => todo!(), }, Err(err) => handle_parse_error(bytes, err), diff --git a/src/sig.rs b/src/sig.rs index cd18a8f..cb0e679 100644 --- a/src/sig.rs +++ b/src/sig.rs @@ -1,6 +1,7 @@ use base64::prelude::*; use crate::{ + errors::ParseError, parser::FromBytes, reader::Reader, structs::{KeyRData, RR}, @@ -16,20 +17,20 @@ pub enum PublicKey { } impl Sig { - pub fn new(rr: &RR, datagram: &[u8]) -> Sig { + pub fn new(rr: &RR, datagram: &[u8]) -> Result { let mut request = datagram[0..datagram.len() - 11 - rr.rdlength as usize].to_vec(); request[11] -= 1; // Decrease arcount let mut reader = Reader::new(&rr.rdata); - let key_rdata = KeyRData::from_bytes(&mut reader).unwrap(); + let key_rdata = KeyRData::from_bytes(&mut reader)?; let mut raw_data = rr.rdata[0..rr.rdata.len() - key_rdata.signature.len()].to_vec(); raw_data.extend(request); - Sig { + Ok(Sig { raw_data, key_rdata, - } + }) } fn verify_ed25519(&self, key: String) -> bool {