10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-11-22 05:41:11 +01:00

Refactor and now with error handling

This commit is contained in:
Xander Bil 2024-02-22 21:46:53 +01:00
parent e06ab152de
commit a0fb2fad7b
No known key found for this signature in database
GPG key ID: EC9706B54A278598
4 changed files with 213 additions and 165 deletions

24
src/errors.rs Normal file
View file

@ -0,0 +1,24 @@
use core::fmt;
#[derive(Debug)]
pub struct DNSError {
pub message: String,
}
impl fmt::Display for DNSError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Error: {}", self.message)
}
}
#[derive(Debug)]
pub struct ParseError {
pub object: String,
pub message: String,
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Parse Error for {}: {}", self.object, self.message)
}
}

View file

@ -1,169 +1,16 @@
use std::{error::Error, mem::size_of, net::SocketAddr}; use std::{error::Error, net::SocketAddr};
use parser::FromBytes;
use structs::Message;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
mod errors;
mod parser;
mod structs;
mod worker; mod worker;
#[repr(u16)]
#[derive(Debug)]
enum Type {
A = 1,
}
#[repr(u16)]
#[derive(Debug)]
enum Class {
IN = 1,
}
#[derive(Debug)]
struct Question {
qname: Vec<String>, // TODO: not padded
qtype: Type, // NOTE: should be QTYPE, right now not really needed
qclass: Class,
}
#[derive(Debug)]
struct Header {
id: u16,
flags: u16, // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | ; 1 | 4 | 1 | 1 | 1 | 1 | 3 | 4
qdcount: u16,
ancount: u16,
nscount: u16,
arcount: u16,
}
#[derive(Debug)]
pub struct Message {
header: Option<Header>,
question: Option<Question>,
answer: Option<RR>,
authority: Option<RR>,
additional: Option<RR>,
}
#[derive(Debug)]
struct RR {
name: String,
t: u16,
class: u16,
ttl: u32,
rdlength: u16,
rdata: String,
}
#[derive(Debug)]
struct Response {
field: Type,
}
const MAX_DATAGRAM_SIZE: usize = 40_96; const MAX_DATAGRAM_SIZE: usize = 40_96;
impl TryFrom<u16> for Type {
type Error = (); //TODO: user better error
fn try_from(value: u16) -> Result<Self, Self::Error> {
match value {
x if x == Type::A as u16 => Ok(Type::A),
_ => Err(()),
}
}
}
impl TryFrom<u16> for Class {
type Error = (); //TODO: user better error
fn try_from(value: u16) -> Result<Self, Self::Error> {
match value {
x if x == Class::IN as u16 => Ok(Class::IN),
_ => Err(()),
}
}
}
// TODO: use Error instead of Option
trait FromBytes {
fn from_bytes(bytes: &[u8]) -> Option<Self>
where
Self: Sized;
}
impl FromBytes for Header {
fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() != size_of::<Header>() {
return None; // Size of header should match
}
Some(Header {
id: u16::from_be_bytes(bytes[0..2].try_into().unwrap()),
flags: u16::from_be_bytes(bytes[2..4].try_into().unwrap()),
qdcount: u16::from_be_bytes(bytes[4..6].try_into().unwrap()),
ancount: u16::from_be_bytes(bytes[6..8].try_into().unwrap()),
nscount: u16::from_be_bytes(bytes[8..10].try_into().unwrap()),
arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()),
})
}
}
//HACK: lots of unsafe unwrap
impl FromBytes for Question {
fn from_bytes(bytes: &[u8]) -> Option<Self> {
// 16 for length octet + zero length octet
if bytes.len() < 2 + size_of::<Class>() + size_of::<Type>() {
None
} else {
let mut qname = vec![];
let mut i = 0;
// Parse qname labels
while bytes[i] != 0 && bytes[i] as usize + i < bytes.len() {
qname.push(
String::from_utf8(bytes[i + 1..bytes[i] as usize + 1 + i].to_vec()).unwrap(),
);
i += bytes[i] as usize + 1;
}
i += 1;
if bytes.len() - i < size_of::<Class>() + size_of::<Type>() {
None
} else {
//Try Parse qtype
let qtype = Type::try_from(u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap()))
.unwrap();
//Try Parse qclass
let qclass =
Class::try_from(u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap()))
.unwrap();
Some(Question {
qname,
qtype,
qclass,
})
}
}
}
}
impl FromBytes for Message {
fn from_bytes(bytes: &[u8]) -> Option<Self> {
let header = Header::from_bytes(&bytes[0..12]);
let question = Question::from_bytes(&bytes[12..]);
if header.is_some() {
Some(Message {
header,
question,
answer: None,
authority: None,
additional: None,
})
} else {
None
}
}
}
async fn create_query(message: Message) { async fn create_query(message: Message) {
println!("{:?}", message); println!("{:?}", message);
} }
@ -177,11 +24,11 @@ async fn main() -> Result<(), Box<dyn Error>> {
let mut data = vec![0u8; MAX_DATAGRAM_SIZE]; let mut data = vec![0u8; MAX_DATAGRAM_SIZE];
loop { loop {
let len = socket.recv(&mut data).await?; let len = socket.recv(&mut data).await?;
let message = Message::from_bytes(&data[..len]); match Message::from_bytes(&data[..len]) {
if message.is_some() { Ok(message) => {
tokio::spawn(async move { tokio::spawn(async move { create_query(message).await });
create_query(message.unwrap()).await;
});
} }
Err(err) => println!("{}", err),
};
} }
} }

124
src/parser.rs Normal file
View file

@ -0,0 +1,124 @@
use std::mem::size_of;
use crate::{
errors::ParseError,
structs::{Class, Header, Message, Question, Type},
};
type Result<T> = std::result::Result<T, ParseError>;
impl TryFrom<u16> for Type {
type Error = (); //TODO: user better error
fn try_from(value: u16) -> std::result::Result<Self, ()> {
match value {
x if x == Type::A as u16 => Ok(Type::A),
_ => Err(()),
}
}
}
impl TryFrom<u16> for Class {
type Error = (); //TODO: user better error
fn try_from(value: u16) -> std::result::Result<Self, ()> {
match value {
x if x == Class::IN as u16 => Ok(Class::IN),
_ => Err(()),
}
}
}
// TODO: use Error instead of Option
pub trait FromBytes {
fn from_bytes(bytes: &[u8]) -> Result<Self>
where
Self: Sized;
}
impl FromBytes for Header {
fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != size_of::<Header>() {
Err(ParseError {
object: String::from("Header"),
message: String::from("Size of Header does not match"),
})
} else {
Ok(Header {
id: u16::from_be_bytes(bytes[0..2].try_into().unwrap()),
flags: u16::from_be_bytes(bytes[2..4].try_into().unwrap()),
qdcount: u16::from_be_bytes(bytes[4..6].try_into().unwrap()),
ancount: u16::from_be_bytes(bytes[6..8].try_into().unwrap()),
nscount: u16::from_be_bytes(bytes[8..10].try_into().unwrap()),
arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()),
})
}
}
}
//HACK: lots of unsafe unwrap
impl FromBytes for Question {
fn from_bytes(bytes: &[u8]) -> Result<Self> {
// 16 for length octet + zero length octet
if bytes.len() < 2 + size_of::<Class>() + size_of::<Type>() {
Err(ParseError {
object: String::from("Question"),
message: String::from("len of bytes smaller then minimum size"),
})
} else {
let mut qname = vec![];
let mut i = 0;
// Parse qname labels
while bytes[i] != 0 && bytes[i] as usize + i < bytes.len() {
qname.push(
String::from_utf8(bytes[i + 1..bytes[i] as usize + 1 + i].to_vec()).unwrap(),
);
i += bytes[i] as usize + 1;
}
i += 1;
if bytes.len() - i < size_of::<Class>() + size_of::<Type>() {
Err(ParseError {
object: String::from("Question"),
message: String::from("len of rest bytes smaller then minimum size"),
})
} else {
//Try Parse qtype
let qtype = Type::try_from(u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap()))
.map_err(|_| ParseError {
object: String::from("Type"),
message: String::from("invalid"),
})?;
//Try Parse qclass
let qclass =
Class::try_from(u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap()))
.map_err(|_| ParseError {
object: String::from("Class"),
message: String::from("invalid"),
})?;
Ok(Question {
qname,
qtype,
qclass,
})
}
}
}
}
impl FromBytes for Message {
fn from_bytes(bytes: &[u8]) -> Result<Self> {
let header = Header::from_bytes(&bytes[0..12])?;
let question = Question::from_bytes(&bytes[12..])?;
Ok(Message {
header,
question,
answer: None,
authority: None,
additional: None,
})
}
}

53
src/structs.rs Normal file
View file

@ -0,0 +1,53 @@
#[repr(u16)]
#[derive(Debug)]
pub enum Type {
A = 1,
}
#[repr(u16)]
#[derive(Debug)]
pub enum Class {
IN = 1,
}
#[derive(Debug)]
pub struct Question {
pub qname: Vec<String>, // TODO: not padded
pub qtype: Type, // NOTE: should be QTYPE, right now not really needed
pub qclass: Class,
}
#[derive(Debug)]
pub struct Header {
pub id: u16,
pub flags: u16, // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | ; 1 | 4 | 1 | 1 | 1 | 1 | 3 | 4
pub qdcount: u16,
pub ancount: u16,
pub nscount: u16,
pub arcount: u16,
}
#[derive(Debug)]
pub struct Message {
pub header: Header,
pub question: Question,
pub answer: Option<RR>,
pub authority: Option<RR>,
pub additional: Option<RR>,
}
#[derive(Debug)]
pub struct RR {
name: String,
t: u16,
class: u16,
ttl: u32,
rdlength: u16,
rdata: String,
}
#[derive(Debug)]
pub struct Response {
field: Type,
}