diff --git a/src/parser.rs b/src/parser.rs index f5dc4fa..6fc708f 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2,7 +2,7 @@ use std::{mem::size_of, vec}; use crate::{ errors::ParseError, - structs::{Class, Header, LabelString, Message, Question, Type, RR}, + structs::{Class, Header, LabelString, Message, OptRR, Question, Type, RR}, }; type Result = std::result::Result; @@ -30,7 +30,7 @@ impl TryFrom for Class { } pub trait FromBytes { - fn from_bytes(bytes: &[u8]) -> Result + fn from_bytes(bytes: &[u8], i: &mut usize) -> Result where Self: Sized; fn to_bytes(s: Self) -> Vec @@ -38,6 +38,27 @@ pub trait FromBytes { Self: Sized; } +pub fn parse_opt_type(bytes: &[u8]) -> Result> { + let mut pairs: Vec = vec![]; + let mut i: usize = 0; + while i + 4 <= bytes.len() { + let length = u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap()); + pairs.push(OptRR { + code: u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap()), + length, + rdata: bytes[i + 4..i + 4 + length as usize] + .try_into() + .map_err(|_| ParseError { + object: String::from("Type::OPT"), + message: String::from("Invalid OPT DATA"), + })?, + }); + i += 4 + length as usize; + } + + Ok(pairs) +} + impl Type { pub fn to_data(&self, text: &String) -> Result> { match self { @@ -55,6 +76,7 @@ impl Type { }) } } + Type::OPT => todo!(), } } pub fn from_data(&self, bytes: &[u8]) -> Result { @@ -70,18 +92,20 @@ impl Type { }) } } + Type::OPT => unimplemented!() } } } impl FromBytes for Header { - fn from_bytes(bytes: &[u8]) -> Result { - if bytes.len() != size_of::
() { + fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { + if bytes.len() < size_of::
() { Err(ParseError { object: String::from("Header"), message: String::from("Size of Header does not match"), }) } else { + *i += size_of::
(); Ok(Header { id: u16::from_be_bytes(bytes[0..2].try_into().unwrap()), flags: u16::from_be_bytes(bytes[2..4].try_into().unwrap()), @@ -91,6 +115,7 @@ impl FromBytes for Header { arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()), }) } + } fn to_bytes(header: Self) -> Vec { @@ -105,25 +130,25 @@ impl FromBytes for Header { result.to_vec() } + } impl FromBytes for LabelString { - fn from_bytes(bytes: &[u8]) -> Result { - let mut i = 0; + fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { let mut qname = vec![]; // Parse qname labels - while bytes[i] != 0 && bytes[i] as usize + i < bytes.len() { + while bytes[*i] != 0 && bytes[*i] as usize + *i < bytes.len() { qname - .push(String::from_utf8(bytes[i + 1..bytes[i] as usize + 1 + i].to_vec()).unwrap()); - i += bytes[i] as usize + 1; + .push(String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap()); + *i += bytes[*i] as usize + 1; } - i += 1; - Ok((qname, i)) + *i += 1; + Ok(qname) } - fn to_bytes((name, _): Self) -> Vec { + fn to_bytes(name: Self) -> Vec { let mut result: Vec = vec![]; for label in name { result.push(label.len() as u8); @@ -135,7 +160,7 @@ impl FromBytes for LabelString { } impl FromBytes for Question { - fn from_bytes(bytes: &[u8]) -> Result { + fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { // 16 for length octet + zero length octet if bytes.len() < 2 + size_of::() + size_of::() { Err(ParseError { @@ -143,16 +168,16 @@ impl FromBytes for Question { message: String::from("len of bytes smaller then minimum size"), }) } else { - let (qname, i) = LabelString::from_bytes(bytes)?; + let qname = LabelString::from_bytes(bytes, i)?; - if bytes.len() - i < size_of::() + size_of::() { + if bytes.len() - *i < size_of::() + size_of::() { Err(ParseError { object: String::from("Question"), message: String::from("len of rest bytes smaller then minimum size"), }) } else { //Try Parse qtype - let qtype = Type::try_from(u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap())) + let qtype = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) .map_err(|_| ParseError { object: String::from("Type"), message: String::from("invalid"), @@ -160,12 +185,14 @@ impl FromBytes for Question { //Try Parse qclass let qclass = - Class::try_from(u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap())) + Class::try_from(u16::from_be_bytes(bytes[*i + 2..*i + 4].try_into().unwrap())) .map_err(|_| ParseError { object: String::from("Class"), message: String::from("invalid"), })?; + *i += 4; // For qtype and qclass => 4 bytes + Ok(Question { qname, qtype, @@ -176,7 +203,7 @@ impl FromBytes for Question { } fn to_bytes(question: Self) -> Vec { - let mut result = LabelString::to_bytes((question.qname, 0)); + let mut result = LabelString::to_bytes(question.qname); result.extend(u16::to_be_bytes(question.qtype.to_owned() as u16)); result.extend(u16::to_be_bytes(question.qclass.to_owned() as u16)); result @@ -184,50 +211,51 @@ impl FromBytes for Question { } impl FromBytes for RR { - fn from_bytes(bytes: &[u8]) -> Result { - let (name, i) = LabelString::from_bytes(bytes)?; - if bytes.len() - i < size_of::() + size_of::() + 6 { + fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { + let name = LabelString::from_bytes(bytes, i)?; + if bytes.len() - *i < size_of::() + size_of::() + 6 { Err(ParseError { object: String::from("RR"), message: String::from("len of rest of bytes smaller then minimum size"), }) } else { - let _type = Type::try_from(u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap())) + let _type = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) .map_err(|_| ParseError { object: String::from("Type"), message: String::from("invalid"), })?; let class = - Class::try_from(u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap())) + Class::try_from(u16::from_be_bytes(bytes[*i + 2..*i + 4].try_into().unwrap())) .map_err(|_| ParseError { object: String::from("Class"), message: String::from("invalid"), })?; - let ttl = i32::from_be_bytes(bytes[i + 4..i + 8].try_into().unwrap()); - let rdlength = u16::from_be_bytes(bytes[i + 8..i + 10].try_into().unwrap()); + let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap()); + let rdlength = u16::from_be_bytes(bytes[*i + 8..*i + 10].try_into().unwrap()); - if bytes.len() - i - 10 != rdlength as usize { + if bytes.len() - *i - 10 != rdlength as usize { Err(ParseError { object: String::from("RR"), message: String::from("len of rest of bytes not equal to rdlength"), }) } else { + *i += 10 + rdlength as usize; Ok(RR { name, _type, class, ttl, rdlength, - rdata: bytes[i + 10..].to_vec(), + rdata: bytes[*i - rdlength as usize.. *i].to_vec(), }) } } } fn to_bytes(rr: Self) -> Vec { - let mut result = LabelString::to_bytes((rr.name, 0)); + let mut result = LabelString::to_bytes(rr.name); result.extend(u16::to_be_bytes(rr._type.to_owned() as u16)); result.extend(u16::to_be_bytes(rr.class.to_owned() as u16)); result.extend(i32::to_be_bytes(rr.ttl.to_owned())); @@ -238,9 +266,10 @@ impl FromBytes for RR { } impl FromBytes for Message { - fn from_bytes(bytes: &[u8]) -> Result { - let header = Header::from_bytes(&bytes[0..12])?; - let question = Question::from_bytes(&bytes[12..])?; + fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { + let header = Header::from_bytes(&bytes,i)?; + let question = Question::from_bytes(&bytes,i)?; + Ok(Message { header, question, diff --git a/src/resolver.rs b/src/resolver.rs index 3638a0f..43b889c 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -5,14 +5,16 @@ use std::sync::Arc; use tokio::net::UdpSocket; use crate::db::models::get_from_database; -use crate::parser::FromBytes; -use crate::structs::Message; +use crate::parser::{parse_opt_type, FromBytes}; +use crate::structs::{Message, Type, RR}; -const MAX_DATAGRAM_SIZE: usize = 40_96; +const MAX_DATAGRAM_SIZE: usize = 4096; +const OPTION_CODE: usize = 65001; -async fn create_query(message: Message) -> Message { +async fn handle_normal_question(message: Message) -> Message { let mut response = message.clone(); + println!("{:#?}",message.question); let answer = get_from_database(message.question).await; response.header.arcount = 0; @@ -31,16 +33,30 @@ async fn create_query(message: Message) -> Message { response } +async fn handle_opt_rr(rr: RR) { + let pairs = parse_opt_type(&rr.rdata); + println!("{:#?}", pairs) +} + +async fn get_response(message: Message) -> Message { + match message.question.qtype { + Type::OPT => handle_normal_question(message), + _ => handle_normal_question(message), + } + .await +} + pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box> { let socket_shared = Arc::new(UdpSocket::bind(addr).await?); loop { let mut data = vec![0u8; MAX_DATAGRAM_SIZE]; let (len, addr) = socket_shared.recv_from(&mut data).await?; - match Message::from_bytes(&data[..len]) { + let mut i: usize = 0; + match Message::from_bytes(&data[..len], &mut i) { Ok(message) => { let socket = socket_shared.clone(); tokio::spawn(async move { - let response = create_query(message).await; + let response = get_response(message).await; let _ = socket .send_to(Message::to_bytes(response).as_slice(), addr) .await; diff --git a/src/structs.rs b/src/structs.rs index de1852d..24cf428 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -4,6 +4,7 @@ use serde::Deserialize; #[derive(Debug, Clone, Deserialize)] pub enum Type { A = 1, + OPT = 41 } #[repr(u16)] @@ -48,7 +49,14 @@ pub struct RR { pub rdata: Vec, } -pub type LabelString = (Vec, usize); +#[derive(Debug, Clone)] +pub struct OptRR { + pub code: u16, + pub length: u16, + pub rdata: Vec +} + +pub type LabelString = Vec; #[derive(Debug)] pub struct Response {