diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..743ae6e --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,24 @@ +use core::fmt; + +#[derive(Debug)] +pub struct DNSError { + pub message: String, +} + +impl fmt::Display for DNSError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Error: {}", self.message) + } +} + +#[derive(Debug)] +pub struct ParseError { + pub object: String, + pub message: String, +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Parse Error for {}: {}", self.object, self.message) + } +} diff --git a/src/main.rs b/src/main.rs index cfbc47f..73a0e7a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,169 +1,16 @@ -use std::{error::Error, mem::size_of, net::SocketAddr}; +use std::{error::Error, net::SocketAddr}; +use parser::FromBytes; +use structs::Message; use tokio::net::UdpSocket; +mod errors; +mod parser; +mod structs; mod worker; -#[repr(u16)] -#[derive(Debug)] -enum Type { - A = 1, -} - -#[repr(u16)] -#[derive(Debug)] -enum Class { - IN = 1, -} - -#[derive(Debug)] -struct Question { - qname: Vec, // TODO: not padded - qtype: Type, // NOTE: should be QTYPE, right now not really needed - qclass: Class, -} - -#[derive(Debug)] -struct Header { - id: u16, - flags: u16, // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | ; 1 | 4 | 1 | 1 | 1 | 1 | 3 | 4 - qdcount: u16, - ancount: u16, - nscount: u16, - arcount: u16, -} - -#[derive(Debug)] -pub struct Message { - header: Option
, - question: Option, - answer: Option, - authority: Option, - additional: Option, -} - -#[derive(Debug)] -struct RR { - name: String, - t: u16, - class: u16, - ttl: u32, - rdlength: u16, - rdata: String, -} - -#[derive(Debug)] -struct Response { - field: Type, -} - const MAX_DATAGRAM_SIZE: usize = 40_96; -impl TryFrom for Type { - type Error = (); //TODO: user better error - - fn try_from(value: u16) -> Result { - match value { - x if x == Type::A as u16 => Ok(Type::A), - _ => Err(()), - } - } -} - -impl TryFrom for Class { - type Error = (); //TODO: user better error - - fn try_from(value: u16) -> Result { - match value { - x if x == Class::IN as u16 => Ok(Class::IN), - _ => Err(()), - } - } -} - -// TODO: use Error instead of Option -trait FromBytes { - fn from_bytes(bytes: &[u8]) -> Option - where - Self: Sized; -} - -impl FromBytes for Header { - fn from_bytes(bytes: &[u8]) -> Option { - if bytes.len() != size_of::
() { - return None; // Size of header should match - } - - Some(Header { - id: u16::from_be_bytes(bytes[0..2].try_into().unwrap()), - flags: u16::from_be_bytes(bytes[2..4].try_into().unwrap()), - qdcount: u16::from_be_bytes(bytes[4..6].try_into().unwrap()), - ancount: u16::from_be_bytes(bytes[6..8].try_into().unwrap()), - nscount: u16::from_be_bytes(bytes[8..10].try_into().unwrap()), - arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()), - }) - } -} - -//HACK: lots of unsafe unwrap -impl FromBytes for Question { - fn from_bytes(bytes: &[u8]) -> Option { - // 16 for length octet + zero length octet - if bytes.len() < 2 + size_of::() + size_of::() { - None - } else { - let mut qname = vec![]; - let mut i = 0; - - // Parse qname labels - while bytes[i] != 0 && bytes[i] as usize + i < bytes.len() { - qname.push( - String::from_utf8(bytes[i + 1..bytes[i] as usize + 1 + i].to_vec()).unwrap(), - ); - i += bytes[i] as usize + 1; - } - i += 1; - - if bytes.len() - i < size_of::() + size_of::() { - None - } else { - //Try Parse qtype - let qtype = Type::try_from(u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap())) - .unwrap(); - - //Try Parse qclass - let qclass = - Class::try_from(u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap())) - .unwrap(); - - Some(Question { - qname, - qtype, - qclass, - }) - } - } - } -} - -impl FromBytes for Message { - fn from_bytes(bytes: &[u8]) -> Option { - let header = Header::from_bytes(&bytes[0..12]); - let question = Question::from_bytes(&bytes[12..]); - if header.is_some() { - Some(Message { - header, - question, - answer: None, - authority: None, - additional: None, - }) - } else { - None - } - } -} - async fn create_query(message: Message) { println!("{:?}", message); } @@ -177,11 +24,11 @@ async fn main() -> Result<(), Box> { let mut data = vec![0u8; MAX_DATAGRAM_SIZE]; loop { let len = socket.recv(&mut data).await?; - let message = Message::from_bytes(&data[..len]); - if message.is_some() { - tokio::spawn(async move { - create_query(message.unwrap()).await; - }); - } + match Message::from_bytes(&data[..len]) { + Ok(message) => { + tokio::spawn(async move { create_query(message).await }); + } + Err(err) => println!("{}", err), + }; } } diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..ee35a14 --- /dev/null +++ b/src/parser.rs @@ -0,0 +1,124 @@ +use std::mem::size_of; + +use crate::{ + errors::ParseError, + structs::{Class, Header, Message, Question, Type}, +}; + +type Result = std::result::Result; + +impl TryFrom for Type { + type Error = (); //TODO: user better error + + fn try_from(value: u16) -> std::result::Result { + match value { + x if x == Type::A as u16 => Ok(Type::A), + _ => Err(()), + } + } +} + +impl TryFrom for Class { + type Error = (); //TODO: user better error + + fn try_from(value: u16) -> std::result::Result { + match value { + x if x == Class::IN as u16 => Ok(Class::IN), + _ => Err(()), + } + } +} + +// TODO: use Error instead of Option +pub trait FromBytes { + fn from_bytes(bytes: &[u8]) -> Result + where + Self: Sized; +} + +impl FromBytes for Header { + fn from_bytes(bytes: &[u8]) -> Result { + if bytes.len() != size_of::
() { + Err(ParseError { + object: String::from("Header"), + message: String::from("Size of Header does not match"), + }) + } else { + Ok(Header { + id: u16::from_be_bytes(bytes[0..2].try_into().unwrap()), + flags: u16::from_be_bytes(bytes[2..4].try_into().unwrap()), + qdcount: u16::from_be_bytes(bytes[4..6].try_into().unwrap()), + ancount: u16::from_be_bytes(bytes[6..8].try_into().unwrap()), + nscount: u16::from_be_bytes(bytes[8..10].try_into().unwrap()), + arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()), + }) + } + } +} + +//HACK: lots of unsafe unwrap +impl FromBytes for Question { + fn from_bytes(bytes: &[u8]) -> Result { + // 16 for length octet + zero length octet + if bytes.len() < 2 + size_of::() + size_of::() { + Err(ParseError { + object: String::from("Question"), + message: String::from("len of bytes smaller then minimum size"), + }) + } else { + let mut qname = vec![]; + let mut i = 0; + + // Parse qname labels + while bytes[i] != 0 && bytes[i] as usize + i < bytes.len() { + qname.push( + String::from_utf8(bytes[i + 1..bytes[i] as usize + 1 + i].to_vec()).unwrap(), + ); + i += bytes[i] as usize + 1; + } + i += 1; + + if bytes.len() - i < size_of::() + size_of::() { + Err(ParseError { + object: String::from("Question"), + message: String::from("len of rest bytes smaller then minimum size"), + }) + } else { + //Try Parse qtype + let qtype = Type::try_from(u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap())) + .map_err(|_| ParseError { + object: String::from("Type"), + message: String::from("invalid"), + })?; + + //Try Parse qclass + let qclass = + Class::try_from(u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap())) + .map_err(|_| ParseError { + object: String::from("Class"), + message: String::from("invalid"), + })?; + + Ok(Question { + qname, + qtype, + qclass, + }) + } + } + } +} + +impl FromBytes for Message { + fn from_bytes(bytes: &[u8]) -> Result { + let header = Header::from_bytes(&bytes[0..12])?; + let question = Question::from_bytes(&bytes[12..])?; + Ok(Message { + header, + question, + answer: None, + authority: None, + additional: None, + }) + } +} diff --git a/src/structs.rs b/src/structs.rs new file mode 100644 index 0000000..320bd86 --- /dev/null +++ b/src/structs.rs @@ -0,0 +1,53 @@ + +#[repr(u16)] +#[derive(Debug)] +pub enum Type { + A = 1, +} + +#[repr(u16)] +#[derive(Debug)] +pub enum Class { + IN = 1, +} + +#[derive(Debug)] +pub struct Question { + pub qname: Vec, // TODO: not padded + pub qtype: Type, // NOTE: should be QTYPE, right now not really needed + pub qclass: Class, +} + +#[derive(Debug)] +pub struct Header { + pub id: u16, + pub flags: u16, // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | ; 1 | 4 | 1 | 1 | 1 | 1 | 3 | 4 + pub qdcount: u16, + pub ancount: u16, + pub nscount: u16, + pub arcount: u16, +} + +#[derive(Debug)] +pub struct Message { + pub header: Header, + pub question: Question, + pub answer: Option, + pub authority: Option, + pub additional: Option, +} + +#[derive(Debug)] +pub struct RR { + name: String, + t: u16, + class: u16, + ttl: u32, + rdlength: u16, + rdata: String, +} + +#[derive(Debug)] +pub struct Response { + field: Type, +}