diff --git a/src/errors.rs b/src/errors.rs index 4346d56..b7faa43 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -44,3 +44,26 @@ impl fmt::Display for AuthenticationError { write!(f, "Authentication Error: {}", self.message) } } + +#[derive(Debug)] +pub struct ReaderError { + pub message: String, +} + +impl fmt::Display for ReaderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Reader Error: {}", self.message) + } +} + +impl From for ParseError +where + E: Into, +{ + fn from(value: E) -> Self { + ParseError { + object: String::from("Reader"), + message: value.into().to_string(), + } + } +} diff --git a/src/main.rs b/src/main.rs index 59b2704..ed12551 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,22 +4,21 @@ use dotenvy::dotenv; use crate::resolver::resolver_listener_loop; +mod authenticate; mod db; mod errors; mod parser; +mod reader; mod resolver; +mod sig; mod structs; mod utils; -mod sig; -mod authenticate; #[tokio::main] async fn main() -> Result<(), Box> { dotenv().ok(); let resolver_add = SocketAddr::from(([127, 0, 0, 1], 8080)); - let _ = tokio::join!( - resolver_listener_loop(resolver_add), - ); + let _ = tokio::join!(resolver_listener_loop(resolver_add),); Ok(()) } diff --git a/src/parser.rs b/src/parser.rs index 7b83969..5455864 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2,9 +2,9 @@ use std::{mem::size_of, vec}; use crate::{ errors::ParseError, + reader::Reader, structs::{ - Class, Header, KeyRData, LabelString, Message, Opcode, Question, RRClass, RRType, - Type, RR, + Class, Header, KeyRData, LabelString, Message, Opcode, Question, RRClass, RRType, Type, RR, }, }; @@ -77,7 +77,7 @@ impl TryFrom for Opcode { } pub trait FromBytes { - fn from_bytes(bytes: &[u8], i: &mut usize) -> Result + fn from_bytes(reader: &mut Reader) -> Result where Self: Sized; fn to_bytes(s: Self) -> Vec @@ -85,60 +85,21 @@ pub trait FromBytes { Self: Sized; } -impl Type { - pub fn to_data(&self, text: &String) -> Result> { - match self { - Type::Type(RRType::A) => { - let arr: Vec = text - .split(".") - .filter_map(|s| s.parse::().ok()) - .collect(); - if arr.len() == 4 { - Ok(arr) - } else { - Err(ParseError { - object: String::from("Type::A"), - message: String::from("Invalid IPv4 address"), - }) - } - } - _ => todo!(), - } - } - pub fn from_data(&self, bytes: &[u8]) -> Result { - match self { - Type::Type(RRType::A) => { - if bytes.len() == 4 { - let arr: Vec = bytes.iter().map(|b| b.to_string()).collect(); - Ok(arr.join(".")) - } else { - Err(ParseError { - object: String::from("Type::A"), - message: String::from("Invalid Ipv4 address bytes"), - }) - } - } - _ => todo!(), - } - } -} - impl FromBytes for Header { - fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { - if bytes.len() < size_of::
() { + fn from_bytes(reader: &mut Reader) -> Result { + if reader.unread_bytes() < size_of::
() { Err(ParseError { object: String::from("Header"), message: String::from("Size of Header does not match"), }) } else { - *i += size_of::
(); Ok(Header { - id: u16::from_be_bytes(bytes[0..2].try_into().unwrap()), - flags: u16::from_be_bytes(bytes[2..4].try_into().unwrap()), - qdcount: u16::from_be_bytes(bytes[4..6].try_into().unwrap()), - ancount: u16::from_be_bytes(bytes[6..8].try_into().unwrap()), - nscount: u16::from_be_bytes(bytes[8..10].try_into().unwrap()), - arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()), + id: reader.read_u16()?, + flags: reader.read_u16()?, + qdcount: reader.read_u16()?, + ancount: reader.read_u16()?, + nscount: reader.read_u16()?, + arcount: reader.read_u16()?, }) } } @@ -158,34 +119,29 @@ impl FromBytes for Header { } impl FromBytes for LabelString { - fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { + fn from_bytes(reader: &mut Reader) -> Result { let mut out = vec![]; // Parse qname labels - while bytes[*i] != 0 - && (bytes[*i] & 0b11000000 == 0) - && bytes[*i] as usize + *i < bytes.len() - { + let mut code = reader.read_u8()?; + while code != 0 && (code & 0b11000000 == 0) && reader.unread_bytes() > code as usize { out.push( - String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap(), + String::from_utf8(reader.read(code as usize)?.to_vec()).map_err(|e| { + ParseError { + object: String::from("Label"), + message: e.to_string(), + } + })?, ); - *i += bytes[*i] as usize + 1; + code = reader.read_u8()?; } - if bytes[*i] & 0b11000000 != 0 { - let offset = u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()) & 0b0011111111111111; - if *i <= offset as usize { - return Err(ParseError { - object: String::from("Label"), - message: String::from("Invalid PTR"), - }); - } else { - out.extend(LabelString::from_bytes(bytes, &mut (offset as usize))?); - *i += 1; - } + if code & 0b11000000 != 0 { + let offset = (((code & 0b00111111) as u16) << 8) | reader.read_u8()? as u16; + let mut reader_past = reader.seek(offset as usize)?; + out.extend(LabelString::from_bytes(&mut reader_past)?); } - *i += 1; Ok(out) } @@ -201,31 +157,27 @@ impl FromBytes for LabelString { } impl FromBytes for Question { - fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { + fn from_bytes(reader: &mut Reader) -> Result { // 16 for length octet + zero length octet - if bytes.len() < 2 + size_of::() + size_of::() { + if reader.unread_bytes() < 2 + size_of::() + size_of::() { Err(ParseError { object: String::from("Question"), message: String::from("len of bytes smaller then minimum size"), }) } else { - let qname = LabelString::from_bytes(bytes, i)?; + let qname = LabelString::from_bytes(reader)?; - if bytes.len() - *i < size_of::() + size_of::() { + if reader.unread_bytes() < 4 { Err(ParseError { object: String::from("Question"), message: String::from("len of rest bytes smaller then minimum size"), }) } else { //Try Parse qtype - let qtype = Type::from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())); + let qtype = Type::from(reader.read_u16()?); //Try Parse qclass - let qclass = Class::from(u16::from_be_bytes( - bytes[*i + 2..*i + 4].try_into().unwrap(), - )); - - *i += 4; // For qtype and qclass => 4 bytes + let qclass = Class::from(reader.read_u16()?); Ok(Question { qname, @@ -245,37 +197,31 @@ impl FromBytes for Question { } impl FromBytes for RR { - fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { - let name = LabelString::from_bytes(bytes, i)?; - if bytes.len() - *i < size_of::() + size_of::() + 6 { + fn from_bytes(reader: &mut Reader) -> Result { + let name = LabelString::from_bytes(reader)?; + if reader.unread_bytes() < size_of::() + size_of::() + 6 { Err(ParseError { object: String::from("RR"), message: String::from("len of rest of bytes smaller then minimum size"), }) } else { - let _type = Type::from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())); - - let class = Class::from(u16::from_be_bytes( - bytes[*i + 2..*i + 4].try_into().unwrap(), - )); - - let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap()); - let rdlength = u16::from_be_bytes(bytes[*i + 8..*i + 10].try_into().unwrap()); - - if bytes.len() - *i - 10 < rdlength as usize { + let _type = Type::from(reader.read_u16()?); + let class = Class::from(reader.read_u16()?); + let ttl = reader.read_i32()?; + let rdlength = reader.read_u16()?; + if reader.unread_bytes() < rdlength as usize { Err(ParseError { object: String::from("RR"), message: String::from("len of rest of bytes not equal to rdlength"), }) } else { - *i += 10 + rdlength as usize; Ok(RR { name, _type, class, ttl, rdlength, - rdata: bytes[*i - rdlength as usize..*i].to_vec(), + rdata: reader.read(rdlength as usize)?, }) } } @@ -293,27 +239,27 @@ impl FromBytes for RR { } impl FromBytes for Message { - fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { - let header = Header::from_bytes(&bytes, i)?; + fn from_bytes(reader: &mut Reader) -> Result { + let header = Header::from_bytes(reader)?; let mut question = vec![]; for _ in 0..header.qdcount { - question.push(Question::from_bytes(&bytes, i)?); + question.push(Question::from_bytes(reader)?); } let mut answer = vec![]; for _ in 0..header.ancount { - answer.push(RR::from_bytes(&bytes, i)?); + answer.push(RR::from_bytes(reader)?); } let mut authority = vec![]; for _ in 0..header.nscount { - authority.push(RR::from_bytes(&bytes, i)?); + authority.push(RR::from_bytes(reader)?); } let mut additional = vec![]; for _ in 0..header.arcount { - additional.push(RR::from_bytes(&bytes, i)?); + additional.push(RR::from_bytes(reader)?); } Ok(Message { @@ -346,24 +292,23 @@ impl FromBytes for Message { } impl FromBytes for KeyRData { - fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { - if bytes.len() < 18 { + fn from_bytes(reader: &mut Reader) -> Result { + if reader.unread_bytes() < 18 { Err(ParseError { object: String::from("KeyRData"), message: String::from("invalid rdata"), }) } else { - *i = 18; Ok(KeyRData { - type_covered: u16::from_be_bytes(bytes[0..2].try_into().unwrap()), - algo: bytes[2], - labels: bytes[3], - original_ttl: u32::from_be_bytes(bytes[4..8].try_into().unwrap()), - signature_expiration: u32::from_be_bytes(bytes[8..12].try_into().unwrap()), - signature_inception: u32::from_be_bytes(bytes[12..16].try_into().unwrap()), - key_tag: u16::from_be_bytes(bytes[16..18].try_into().unwrap()), - signer: LabelString::from_bytes(bytes, i)?, - signature: bytes[*i..bytes.len()].to_vec(), + type_covered: reader.read_u16()?, + algo: reader.read_u8()?, + labels: reader.read_u8()?, + original_ttl: reader.read_u32()?, + signature_expiration: reader.read_u32()?, + signature_inception: reader.read_u32()?, + key_tag: reader.read_u16()?, + signer: LabelString::from_bytes(reader)?, + signature: reader.read(reader.unread_bytes())?, }) } } diff --git a/src/reader.rs b/src/reader.rs new file mode 100644 index 0000000..b349acf --- /dev/null +++ b/src/reader.rs @@ -0,0 +1,85 @@ +use std::array::TryFromSliceError; + +use crate::errors::ReaderError; + +pub struct Reader<'a> { + buffer: &'a [u8], + position: usize, +} + +type Result = std::result::Result; + +impl<'a> Reader<'a> { + pub fn new(buffer: &[u8]) -> Reader { + Reader { + buffer, + position: 0, + } + } + + pub fn unread_bytes(&self) -> usize { + self.buffer.len() - self.position + } + + pub fn read(&mut self, size: usize) -> Result> { + if size > self.unread_bytes() { + Err(ReaderError { + message: String::from("cannot read enough bytes"), + }) + } else { + self.position += size; + Ok(self.buffer[self.position - size..self.position].to_vec()) + } + } + + pub fn read_u8(&mut self) -> Result { + self.position += 1; + Ok(self.buffer[self.position - 1]) + } + + pub fn read_u16(&mut self) -> Result { + let result = u16::from_be_bytes( + self.buffer[self.position..self.position + 2] + .try_into() + .map_err(|e: TryFromSliceError| ReaderError { + message: e.to_string(), + })?, + ); + self.position += 2; + Ok(result) + } + + pub fn read_i32(&mut self) -> Result { + let result = i32::from_be_bytes( + self.buffer[self.position..self.position + 4] + .try_into() + .map_err(|e: TryFromSliceError| ReaderError { + message: e.to_string(), + })?, + ); + self.position += 4; + Ok(result) + } + + pub fn read_u32(&mut self) -> Result { + let result = u32::from_be_bytes( + self.buffer[self.position..self.position + 4] + .try_into() + .map_err(|e: TryFromSliceError| ReaderError { + message: e.to_string(), + })?, + ); + self.position += 4; + Ok(result) + } + + pub fn seek(&self, position: usize) -> Result { + if position >= self.position { + Err(ReaderError { + message: String::from("Seeking into the future is not allowed!!"), + }) + } else { + Ok(Reader::new(&self.buffer[position..self.position])) + } + } +} diff --git a/src/resolver.rs b/src/resolver.rs index 86acfd3..8349c2d 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -8,6 +8,7 @@ use crate::authenticate::authenticate; use crate::db::models::{delete_from_database, get_from_database, insert_into_database}; use crate::errors::ParseError; use crate::parser::FromBytes; +use crate::reader::Reader; use crate::sig::Sig; use crate::structs::{Class, Header, Message, Opcode, RRClass, RRType, Type, RCODE}; use crate::utils::vec_equal; @@ -140,7 +141,8 @@ async fn handle_update(message: Message, bytes: &[u8]) -> Message { fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { eprintln!("{}", err); - let mut header = Header::from_bytes(bytes, &mut 0).unwrap_or(Header { + let mut reader = Reader::new(bytes); + let mut header = Header::from_bytes(&mut reader).unwrap_or(Header { id: 0, flags: 0, qdcount: 0, @@ -165,8 +167,8 @@ fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { } async fn get_response(bytes: &[u8]) -> Message { - let mut i: usize = 0; - match Message::from_bytes(bytes, &mut i) { + let mut reader = Reader::new(bytes); + match Message::from_bytes(&mut reader) { Ok(message) => match get_opcode(&message.header.flags) { Ok(opcode) => match opcode { Opcode::QUERY => handle_query(message).await, diff --git a/src/sig.rs b/src/sig.rs index 7cebe4c..cd18a8f 100644 --- a/src/sig.rs +++ b/src/sig.rs @@ -2,6 +2,7 @@ use base64::prelude::*; use crate::{ parser::FromBytes, + reader::Reader, structs::{KeyRData, RR}, }; @@ -19,10 +20,10 @@ impl Sig { let mut request = datagram[0..datagram.len() - 11 - rr.rdlength as usize].to_vec(); request[11] -= 1; // Decrease arcount - let mut i = 0; - let key_rdata = KeyRData::from_bytes(&rr.rdata, &mut i).unwrap(); + let mut reader = Reader::new(&rr.rdata); + let key_rdata = KeyRData::from_bytes(&mut reader).unwrap(); - let mut raw_data = rr.rdata[0..i].to_vec(); + let mut raw_data = rr.rdata[0..rr.rdata.len() - key_rdata.signature.len()].to_vec(); raw_data.extend(request); Sig {