10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-10-30 05:24:26 +01:00

pass all bytes to parsers

This commit is contained in:
Xander Bil 2024-03-17 23:20:20 +01:00
parent 5cd801a6d0
commit 6dd3f23815
No known key found for this signature in database
GPG key ID: EC9706B54A278598
3 changed files with 91 additions and 38 deletions

View file

@ -2,7 +2,7 @@ use std::{mem::size_of, vec};
use crate::{ use crate::{
errors::ParseError, errors::ParseError,
structs::{Class, Header, LabelString, Message, Question, Type, RR}, structs::{Class, Header, LabelString, Message, OptRR, Question, Type, RR},
}; };
type Result<T> = std::result::Result<T, ParseError>; type Result<T> = std::result::Result<T, ParseError>;
@ -30,7 +30,7 @@ impl TryFrom<u16> for Class {
} }
pub trait FromBytes { pub trait FromBytes {
fn from_bytes(bytes: &[u8]) -> Result<Self> fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self>
where where
Self: Sized; Self: Sized;
fn to_bytes(s: Self) -> Vec<u8> fn to_bytes(s: Self) -> Vec<u8>
@ -38,6 +38,27 @@ pub trait FromBytes {
Self: Sized; Self: Sized;
} }
pub fn parse_opt_type(bytes: &[u8]) -> Result<Vec<OptRR>> {
let mut pairs: Vec<OptRR> = vec![];
let mut i: usize = 0;
while i + 4 <= bytes.len() {
let length = u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap());
pairs.push(OptRR {
code: u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap()),
length,
rdata: bytes[i + 4..i + 4 + length as usize]
.try_into()
.map_err(|_| ParseError {
object: String::from("Type::OPT"),
message: String::from("Invalid OPT DATA"),
})?,
});
i += 4 + length as usize;
}
Ok(pairs)
}
impl Type { impl Type {
pub fn to_data(&self, text: &String) -> Result<Vec<u8>> { pub fn to_data(&self, text: &String) -> Result<Vec<u8>> {
match self { match self {
@ -55,6 +76,7 @@ impl Type {
}) })
} }
} }
Type::OPT => todo!(),
} }
} }
pub fn from_data(&self, bytes: &[u8]) -> Result<String> { pub fn from_data(&self, bytes: &[u8]) -> Result<String> {
@ -70,18 +92,20 @@ impl Type {
}) })
} }
} }
Type::OPT => unimplemented!()
} }
} }
} }
impl FromBytes for Header { impl FromBytes for Header {
fn from_bytes(bytes: &[u8]) -> Result<Self> { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
if bytes.len() != size_of::<Header>() { if bytes.len() < size_of::<Header>() {
Err(ParseError { Err(ParseError {
object: String::from("Header"), object: String::from("Header"),
message: String::from("Size of Header does not match"), message: String::from("Size of Header does not match"),
}) })
} else { } else {
*i += size_of::<Header>();
Ok(Header { Ok(Header {
id: u16::from_be_bytes(bytes[0..2].try_into().unwrap()), id: u16::from_be_bytes(bytes[0..2].try_into().unwrap()),
flags: u16::from_be_bytes(bytes[2..4].try_into().unwrap()), flags: u16::from_be_bytes(bytes[2..4].try_into().unwrap()),
@ -91,6 +115,7 @@ impl FromBytes for Header {
arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()), arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()),
}) })
} }
} }
fn to_bytes(header: Self) -> Vec<u8> { fn to_bytes(header: Self) -> Vec<u8> {
@ -105,25 +130,25 @@ impl FromBytes for Header {
result.to_vec() result.to_vec()
} }
} }
impl FromBytes for LabelString { impl FromBytes for LabelString {
fn from_bytes(bytes: &[u8]) -> Result<Self> { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
let mut i = 0;
let mut qname = vec![]; let mut qname = vec![];
// Parse qname labels // Parse qname labels
while bytes[i] != 0 && bytes[i] as usize + i < bytes.len() { while bytes[*i] != 0 && bytes[*i] as usize + *i < bytes.len() {
qname qname
.push(String::from_utf8(bytes[i + 1..bytes[i] as usize + 1 + i].to_vec()).unwrap()); .push(String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap());
i += bytes[i] as usize + 1; *i += bytes[*i] as usize + 1;
} }
i += 1; *i += 1;
Ok((qname, i)) Ok(qname)
} }
fn to_bytes((name, _): Self) -> Vec<u8> { fn to_bytes(name: Self) -> Vec<u8> {
let mut result: Vec<u8> = vec![]; let mut result: Vec<u8> = vec![];
for label in name { for label in name {
result.push(label.len() as u8); result.push(label.len() as u8);
@ -135,7 +160,7 @@ impl FromBytes for LabelString {
} }
impl FromBytes for Question { impl FromBytes for Question {
fn from_bytes(bytes: &[u8]) -> Result<Self> { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
// 16 for length octet + zero length octet // 16 for length octet + zero length octet
if bytes.len() < 2 + size_of::<Class>() + size_of::<Type>() { if bytes.len() < 2 + size_of::<Class>() + size_of::<Type>() {
Err(ParseError { Err(ParseError {
@ -143,16 +168,16 @@ impl FromBytes for Question {
message: String::from("len of bytes smaller then minimum size"), message: String::from("len of bytes smaller then minimum size"),
}) })
} else { } else {
let (qname, i) = LabelString::from_bytes(bytes)?; let qname = LabelString::from_bytes(bytes, i)?;
if bytes.len() - i < size_of::<Class>() + size_of::<Type>() { if bytes.len() - *i < size_of::<Class>() + size_of::<Type>() {
Err(ParseError { Err(ParseError {
object: String::from("Question"), object: String::from("Question"),
message: String::from("len of rest bytes smaller then minimum size"), message: String::from("len of rest bytes smaller then minimum size"),
}) })
} else { } else {
//Try Parse qtype //Try Parse qtype
let qtype = Type::try_from(u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap())) let qtype = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
.map_err(|_| ParseError { .map_err(|_| ParseError {
object: String::from("Type"), object: String::from("Type"),
message: String::from("invalid"), message: String::from("invalid"),
@ -160,12 +185,14 @@ impl FromBytes for Question {
//Try Parse qclass //Try Parse qclass
let qclass = let qclass =
Class::try_from(u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap())) Class::try_from(u16::from_be_bytes(bytes[*i + 2..*i + 4].try_into().unwrap()))
.map_err(|_| ParseError { .map_err(|_| ParseError {
object: String::from("Class"), object: String::from("Class"),
message: String::from("invalid"), message: String::from("invalid"),
})?; })?;
*i += 4; // For qtype and qclass => 4 bytes
Ok(Question { Ok(Question {
qname, qname,
qtype, qtype,
@ -176,7 +203,7 @@ impl FromBytes for Question {
} }
fn to_bytes(question: Self) -> Vec<u8> { fn to_bytes(question: Self) -> Vec<u8> {
let mut result = LabelString::to_bytes((question.qname, 0)); let mut result = LabelString::to_bytes(question.qname);
result.extend(u16::to_be_bytes(question.qtype.to_owned() as u16)); result.extend(u16::to_be_bytes(question.qtype.to_owned() as u16));
result.extend(u16::to_be_bytes(question.qclass.to_owned() as u16)); result.extend(u16::to_be_bytes(question.qclass.to_owned() as u16));
result result
@ -184,50 +211,51 @@ impl FromBytes for Question {
} }
impl FromBytes for RR { impl FromBytes for RR {
fn from_bytes(bytes: &[u8]) -> Result<Self> { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
let (name, i) = LabelString::from_bytes(bytes)?; let name = LabelString::from_bytes(bytes, i)?;
if bytes.len() - i < size_of::<Type>() + size_of::<Class>() + 6 { if bytes.len() - *i < size_of::<Type>() + size_of::<Class>() + 6 {
Err(ParseError { Err(ParseError {
object: String::from("RR"), object: String::from("RR"),
message: String::from("len of rest of bytes smaller then minimum size"), message: String::from("len of rest of bytes smaller then minimum size"),
}) })
} else { } else {
let _type = Type::try_from(u16::from_be_bytes(bytes[i..i + 2].try_into().unwrap())) let _type = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
.map_err(|_| ParseError { .map_err(|_| ParseError {
object: String::from("Type"), object: String::from("Type"),
message: String::from("invalid"), message: String::from("invalid"),
})?; })?;
let class = let class =
Class::try_from(u16::from_be_bytes(bytes[i + 2..i + 4].try_into().unwrap())) Class::try_from(u16::from_be_bytes(bytes[*i + 2..*i + 4].try_into().unwrap()))
.map_err(|_| ParseError { .map_err(|_| ParseError {
object: String::from("Class"), object: String::from("Class"),
message: String::from("invalid"), message: String::from("invalid"),
})?; })?;
let ttl = i32::from_be_bytes(bytes[i + 4..i + 8].try_into().unwrap()); let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap());
let rdlength = u16::from_be_bytes(bytes[i + 8..i + 10].try_into().unwrap()); let rdlength = u16::from_be_bytes(bytes[*i + 8..*i + 10].try_into().unwrap());
if bytes.len() - i - 10 != rdlength as usize { if bytes.len() - *i - 10 != rdlength as usize {
Err(ParseError { Err(ParseError {
object: String::from("RR"), object: String::from("RR"),
message: String::from("len of rest of bytes not equal to rdlength"), message: String::from("len of rest of bytes not equal to rdlength"),
}) })
} else { } else {
*i += 10 + rdlength as usize;
Ok(RR { Ok(RR {
name, name,
_type, _type,
class, class,
ttl, ttl,
rdlength, rdlength,
rdata: bytes[i + 10..].to_vec(), rdata: bytes[*i - rdlength as usize.. *i].to_vec(),
}) })
} }
} }
} }
fn to_bytes(rr: Self) -> Vec<u8> { fn to_bytes(rr: Self) -> Vec<u8> {
let mut result = LabelString::to_bytes((rr.name, 0)); let mut result = LabelString::to_bytes(rr.name);
result.extend(u16::to_be_bytes(rr._type.to_owned() as u16)); result.extend(u16::to_be_bytes(rr._type.to_owned() as u16));
result.extend(u16::to_be_bytes(rr.class.to_owned() as u16)); result.extend(u16::to_be_bytes(rr.class.to_owned() as u16));
result.extend(i32::to_be_bytes(rr.ttl.to_owned())); result.extend(i32::to_be_bytes(rr.ttl.to_owned()));
@ -238,9 +266,10 @@ impl FromBytes for RR {
} }
impl FromBytes for Message { impl FromBytes for Message {
fn from_bytes(bytes: &[u8]) -> Result<Self> { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
let header = Header::from_bytes(&bytes[0..12])?; let header = Header::from_bytes(&bytes,i)?;
let question = Question::from_bytes(&bytes[12..])?; let question = Question::from_bytes(&bytes,i)?;
Ok(Message { Ok(Message {
header, header,
question, question,

View file

@ -5,14 +5,16 @@ use std::sync::Arc;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use crate::db::models::get_from_database; use crate::db::models::get_from_database;
use crate::parser::FromBytes; use crate::parser::{parse_opt_type, FromBytes};
use crate::structs::Message; use crate::structs::{Message, Type, RR};
const MAX_DATAGRAM_SIZE: usize = 40_96; const MAX_DATAGRAM_SIZE: usize = 4096;
const OPTION_CODE: usize = 65001;
async fn create_query(message: Message) -> Message { async fn handle_normal_question(message: Message) -> Message {
let mut response = message.clone(); let mut response = message.clone();
println!("{:#?}",message.question);
let answer = get_from_database(message.question).await; let answer = get_from_database(message.question).await;
response.header.arcount = 0; response.header.arcount = 0;
@ -31,16 +33,30 @@ async fn create_query(message: Message) -> Message {
response response
} }
async fn handle_opt_rr(rr: RR) {
let pairs = parse_opt_type(&rr.rdata);
println!("{:#?}", pairs)
}
async fn get_response(message: Message) -> Message {
match message.question.qtype {
Type::OPT => handle_normal_question(message),
_ => handle_normal_question(message),
}
.await
}
pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> { pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
let socket_shared = Arc::new(UdpSocket::bind(addr).await?); let socket_shared = Arc::new(UdpSocket::bind(addr).await?);
loop { loop {
let mut data = vec![0u8; MAX_DATAGRAM_SIZE]; let mut data = vec![0u8; MAX_DATAGRAM_SIZE];
let (len, addr) = socket_shared.recv_from(&mut data).await?; let (len, addr) = socket_shared.recv_from(&mut data).await?;
match Message::from_bytes(&data[..len]) { let mut i: usize = 0;
match Message::from_bytes(&data[..len], &mut i) {
Ok(message) => { Ok(message) => {
let socket = socket_shared.clone(); let socket = socket_shared.clone();
tokio::spawn(async move { tokio::spawn(async move {
let response = create_query(message).await; let response = get_response(message).await;
let _ = socket let _ = socket
.send_to(Message::to_bytes(response).as_slice(), addr) .send_to(Message::to_bytes(response).as_slice(), addr)
.await; .await;

View file

@ -4,6 +4,7 @@ use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub enum Type { pub enum Type {
A = 1, A = 1,
OPT = 41
} }
#[repr(u16)] #[repr(u16)]
@ -48,7 +49,14 @@ pub struct RR {
pub rdata: Vec<u8>, pub rdata: Vec<u8>,
} }
pub type LabelString = (Vec<String>, usize); #[derive(Debug, Clone)]
pub struct OptRR {
pub code: u16,
pub length: u16,
pub rdata: Vec<u8>
}
pub type LabelString = Vec<String>;
#[derive(Debug)] #[derive(Debug)]
pub struct Response { pub struct Response {