diff --git a/src/api.rs b/src/api.rs index 6120ba7..b6218ba 100644 --- a/src/api.rs +++ b/src/api.rs @@ -11,7 +11,7 @@ use serde::Deserialize; use tokio::net::TcpListener; use crate::db::models::insert_into_database; -use crate::structs::{Class, Type, RR}; +use crate::structs::{Class, RRClass, Type, RR}; type GenericError = Box; type Result = std::result::Result; @@ -41,7 +41,7 @@ async fn create_record(req: Request) -> Result Result<(), DatabaseError> { let db_connection = &mut establish_connection(); let record = Record { name: rr.name.join("."), - _type: rr._type as i32, - class: rr.class as i32, + _type: rr._type.into(), + class: rr.class.into(), ttl: rr.ttl, rdlength: rr.rdlength as i32, rdata: rr.rdata, @@ -104,8 +104,8 @@ pub async fn get_from_database(question: &Question) -> Result, DatabaseE let records = Record::get( db_connection, question.qname.join("."), - question.qtype.clone() as i32, - question.qclass.clone() as i32, + question.qtype.clone().into(), + question.qclass.clone().into(), ) .map_err(|e| DatabaseError { message: e.to_string(), @@ -116,12 +116,8 @@ pub async fn get_from_database(question: &Question) -> Result, DatabaseE .filter_map(|record| { Some(RR { name: record.name.split(".").map(str::to_string).collect(), - _type: Type::try_from(record._type as u16) - .map_err(|e| DatabaseError { message: e }) - .ok()?, - class: Class::try_from(record.class as u16) - .map_err(|e| DatabaseError { message: e }) - .ok()?, + _type: Type::from(record._type as u16), + class: Class::from(record.class as u16), ttl: record.ttl, rdlength: record.rdlength as u16, rdata: record.rdata, @@ -138,5 +134,5 @@ pub async fn delete_from_database( rdata: Option>, ) { let db_connection = &mut establish_connection(); - let _ = Record::delete(db_connection, name.join("."), _type.map(|f| f as i32), class as i32, rdata); + let _ = Record::delete(db_connection, name.join("."), _type.map(|f| f.into()), class.into(), rdata); } diff --git a/src/parser.rs b/src/parser.rs index 1c7f7cd..3cd1483 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2,35 +2,64 @@ use std::{mem::size_of, vec}; use crate::{ errors::ParseError, - structs::{Class, Header, LabelString, Message, Opcode, OptRR, Question, Type, RR}, + structs::{ + Class, Header, KeyRData, LabelString, Message, Opcode, OptRR, Question, RRClass, RRType, + Type, RR, + }, }; type Result = std::result::Result; -impl TryFrom for Type { - type Error = String; - - fn try_from(value: u16) -> std::result::Result { +impl From for u16 { + fn from(value: Type) -> Self { match value { - //TODO: clean this up - x if x == Type::A as u16 => Ok(Type::A), - x if x == Type::OPT as u16 => Ok(Type::OPT), - x if x == Type::SOA as u16 => Ok(Type::SOA), - x if x == Type::ANY as u16 => Ok(Type::ANY), - _ => Err(format!("Invalid Type value: {}", value)), + Type::Type(t) => t as u16, + Type::Other(x) => x, } } } -impl TryFrom for Class { - type Error = String; +impl From for i32 { + fn from(value: Type) -> Self { + Into::::into(value) as i32 + } +} - fn try_from(value: u16) -> std::result::Result { +impl From for i32 { + fn from(value: Class) -> Self { + Into::::into(value) as i32 + } +} + +impl From for u16 { + fn from(value: Class) -> Self { match value { - x if x == Class::IN as u16 => Ok(Class::IN), - x if x == Class::ANY as u16 => Ok(Class::ANY), - x if x == Class::NONE as u16 => Ok(Class::NONE), - _ => Err(format!("Invalid Class value: {}", value)), + Class::Class(t) => t as u16, + Class::Other(x) => x, + } + } +} + +impl From for Type { + fn from(value: u16) -> Self { + match value { + x if x == RRType::A as u16 => Type::Type(RRType::A), + x if x == RRType::OPT as u16 => Type::Type(RRType::OPT), + x if x == RRType::SOA as u16 => Type::Type(RRType::SOA), + x if x == RRType::ANY as u16 => Type::Type(RRType::SOA), + x if x == RRType::KEY as u16 => Type::Type(RRType::KEY), + x => Type::Other(x), + } + } +} + +impl From for Class { + fn from(value: u16) -> Self { + match value { + x if x == RRClass::IN as u16 => Class::Class(RRClass::IN), + x if x == RRClass::ANY as u16 => Class::Class(RRClass::ANY), + x if x == RRClass::NONE as u16 => Class::Class(RRClass::NONE), + x => Class::Other(x), } } } @@ -80,7 +109,7 @@ pub fn parse_opt_type(bytes: &[u8]) -> Result> { impl Type { pub fn to_data(&self, text: &String) -> Result> { match self { - Type::A => { + Type::Type(RRType::A) => { let arr: Vec = text .split(".") .filter_map(|s| s.parse::().ok()) @@ -94,12 +123,12 @@ impl Type { }) } } - _ => todo!() + _ => todo!(), } } pub fn from_data(&self, bytes: &[u8]) -> Result { match self { - Type::A => { + Type::Type(RRType::A) => { if bytes.len() == 4 { let arr: Vec = bytes.iter().map(|b| b.to_string()).collect(); Ok(arr.join(".")) @@ -110,7 +139,7 @@ impl Type { }) } } - _ => todo!() + _ => todo!(), } } } @@ -210,21 +239,12 @@ impl FromBytes for Question { }) } else { //Try Parse qtype - let qtype = - Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) - .map_err(|e| ParseError { - object: String::from("Type"), - message: e, - })?; + let qtype = Type::from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())); //Try Parse qclass - let qclass = Class::try_from(u16::from_be_bytes( + let qclass = Class::from(u16::from_be_bytes( bytes[*i + 2..*i + 4].try_into().unwrap(), - )) - .map_err(|e| ParseError { - object: String::from("Class"), - message: e, - })?; + )); *i += 4; // For qtype and qclass => 4 bytes @@ -239,8 +259,8 @@ impl FromBytes for Question { fn to_bytes(question: Self) -> Vec { let mut result = LabelString::to_bytes(question.qname); - result.extend(u16::to_be_bytes(question.qtype.to_owned() as u16)); - result.extend(u16::to_be_bytes(question.qclass.to_owned() as u16)); + result.extend(u16::to_be_bytes(question.qtype.into())); + result.extend(u16::to_be_bytes(question.qclass.into())); result } } @@ -254,24 +274,16 @@ impl FromBytes for RR { message: String::from("len of rest of bytes smaller then minimum size"), }) } else { - let _type = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) - .map_err(|e| ParseError { - object: String::from("Type"), - message: e, - })?; + let _type = Type::from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())); - let class = Class::try_from(u16::from_be_bytes( + let class = Class::from(u16::from_be_bytes( bytes[*i + 2..*i + 4].try_into().unwrap(), - )) - .map_err(|e| ParseError { - object: String::from("Class"), - message: e, - })?; + )); let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap()); let rdlength = u16::from_be_bytes(bytes[*i + 8..*i + 10].try_into().unwrap()); - if bytes.len() - *i - 10 != rdlength as usize { + if bytes.len() - *i - 10 < rdlength as usize { Err(ParseError { object: String::from("RR"), message: String::from("len of rest of bytes not equal to rdlength"), @@ -292,8 +304,8 @@ impl FromBytes for RR { fn to_bytes(rr: Self) -> Vec { let mut result = LabelString::to_bytes(rr.name); - result.extend(u16::to_be_bytes(rr._type.to_owned() as u16)); - result.extend(u16::to_be_bytes(rr.class.to_owned() as u16)); + result.extend(u16::to_be_bytes(rr._type.into())); + result.extend(u16::to_be_bytes(rr.class.into())); result.extend(i32::to_be_bytes(rr.ttl.to_owned())); result.extend(u16::to_be_bytes(4 as u16)); result.extend(rr.rdata); @@ -309,6 +321,8 @@ impl FromBytes for Message { for _ in 0..header.qdcount { question.push(Question::from_bytes(&bytes, i)?); } + println!("{:#?}", question); + println!("{:#?}", header); let mut answer = vec![]; for _ in 0..header.ancount { @@ -319,6 +333,7 @@ impl FromBytes for Message { for _ in 0..header.nscount { authority.push(RR::from_bytes(&bytes, i)?); } + println!("{:#?}", authority); let mut additional = vec![]; for _ in 0..header.arcount { @@ -353,3 +368,34 @@ impl FromBytes for Message { result } } + +impl FromBytes for KeyRData { + fn from_bytes(bytes: &[u8], _: &mut usize) -> Result { + if bytes.len() < 18 { + Err(ParseError { + object: String::from("KeyRData"), + message: String::from("invalid rdata"), + }) + } else { + let mut i = 18; + Ok(KeyRData { + type_covered: u16::from_be_bytes(bytes[0..2].try_into().unwrap()), + algo: bytes[2], + labels: bytes[3], + original_ttl: u32::from_be_bytes(bytes[4..8].try_into().unwrap()), + signature_expiration: u32::from_be_bytes(bytes[8..12].try_into().unwrap()), + signature_inception: u32::from_be_bytes(bytes[12..16].try_into().unwrap()), + key_tag: u16::from_be_bytes(bytes[16..18].try_into().unwrap()), + signer: LabelString::from_bytes(bytes, &mut i)?, + signature: bytes[i..bytes.len()].to_vec(), + }) + } + } + + fn to_bytes(s: Self) -> Vec + where + Self: Sized, + { + todo!() + } +} diff --git a/src/resolver.rs b/src/resolver.rs index 89137a2..65e6057 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -7,7 +7,7 @@ use tokio::net::UdpSocket; use crate::db::models::{delete_from_database, get_from_database, insert_into_database}; use crate::errors::ParseError; use crate::parser::FromBytes; -use crate::structs::{Class, Header, Message, Opcode, Type, RCODE}; +use crate::structs::{Class, Header, KeyRData, Message, Opcode, RRClass, RRType, Type, RCODE}; use crate::utils::vec_equal; const MAX_DATAGRAM_SIZE: usize = 4096; @@ -30,7 +30,7 @@ async fn handle_query(message: Message) -> Message { match answers { Ok(rrs) => { response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR); - response.header.ancount = 1; + response.header.ancount = rrs.len() as u16; response.answer = rrs } Err(e) => { @@ -47,7 +47,7 @@ async fn handle_update(message: Message) -> Message { let mut response = message.clone(); // Zone section (question) processing - if (message.header.qdcount != 1) || !matches!(message.question[0].qtype, Type::SOA) { + if (message.header.qdcount != 1) || !matches!(message.question[0].qtype, Type::Type(RRType::SOA)) { response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); return response; } @@ -61,13 +61,18 @@ async fn handle_update(message: Message) -> Message { } // Check Prerequisite TODO: implement this - if message.header.ancount > 0 { - response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP); - return response; - } + // if message.header.ancount > 0 { + // response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP); + // return response; + // } // Check Requestor Permission - // TODO: implement this, use rfc2931 + for rr in &message.additional { + if rr._type == Type::Type(RRType::KEY) { + let key = KeyRData::from_bytes(&rr.rdata, &mut 0).unwrap(); + println!("{:#?}",key); + } + } // Update Section Prescan for rr in &message.authority { @@ -79,9 +84,9 @@ async fn handle_update(message: Message) -> Message { return response; } - if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0)) - || (rr.class == Class::NONE && rr.ttl != 0) - || ![Class::NONE, Class::ANY, zone.qclass.clone()].contains(&rr.class) + if (rr.class == Class::Class(RRClass::ANY) && (rr.ttl != 0 || rr.rdlength != 0)) + || (rr.class == Class::Class(RRClass::NONE) && rr.ttl != 0) + || ![Class::Class(RRClass::NONE), Class::Class(RRClass::ANY), zone.qclass.clone()].contains(&rr.class) { response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); return response; @@ -92,23 +97,23 @@ async fn handle_update(message: Message) -> Message { for rr in message.authority { if rr.class == zone.qclass { let _ = insert_into_database(rr).await; - } else if rr.class == Class::ANY { - if rr._type == Type::ANY { + } else if rr.class == Class::Class(RRClass::ANY) { + if rr._type == Type::Type(RRType::ANY) { if rr.name == zone.qname { response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP); return response; } else { - delete_from_database(rr.name, None, Class::IN, None).await; + delete_from_database(rr.name, None, Class::Class(RRClass::IN), None).await; } } else { - delete_from_database(rr.name, Some(rr._type), Class::IN, None).await; + delete_from_database(rr.name, Some(rr._type), Class::Class(RRClass::IN), None).await; } - } else if rr.class == Class::NONE { - if rr._type == Type::SOA { + } else if rr.class == Class::Class(RRClass::NONE) { + if rr._type == Type::Type(RRType::SOA) { continue; } - delete_from_database(rr.name, Some(rr._type),Class::IN, Some(rr.rdata)).await; + delete_from_database(rr.name, Some(rr._type), Class::Class(RRClass::IN), Some(rr.rdata)).await; } } diff --git a/src/structs.rs b/src/structs.rs index 569e183..e0c9b32 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -1,23 +1,39 @@ use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub enum Type { + Type(RRType), + Other(u16) +} + #[repr(u16)] #[derive(Debug, Clone, Deserialize, PartialEq)] -pub enum Type { +pub enum RRType { A = 1, SOA = 6, + KEY = 24, OPT = 41, - ANY = 255 + ANY = 255, +} + + +#[derive(Debug, Clone, PartialEq)] +pub enum Class { + Class(RRClass), + Other(u16) } #[repr(u16)] #[derive(Debug, Clone, PartialEq)] -pub enum Class { +pub enum RRClass { IN = 1, NONE = 254, - ANY = 255 + ANY = 255, } #[repr(u16)] +#[allow(dead_code)] pub enum RCODE { NOERROR = 0, FORMERR = 1, @@ -83,6 +99,14 @@ pub struct OptRR { pub type LabelString = Vec; #[derive(Debug)] -pub struct Response { - field: Type, +pub struct KeyRData { + pub type_covered: u16, + pub algo: u8, + pub labels: u8, + pub original_ttl: u32, + pub signature_expiration: u32, + pub signature_inception: u32, + pub key_tag: u16, + pub signer: LabelString, + pub signature: Vec }