10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-11-23 22:11:10 +01:00

refactor dns parser to use a reader

This commit is contained in:
Xander Bil 2024-06-06 23:21:53 +02:00
parent 640aa93be2
commit 925370314a
No known key found for this signature in database
GPG key ID: EC9706B54A278598
6 changed files with 178 additions and 123 deletions

View file

@ -44,3 +44,26 @@ impl fmt::Display for AuthenticationError {
write!(f, "Authentication Error: {}", self.message) write!(f, "Authentication Error: {}", self.message)
} }
} }
#[derive(Debug)]
pub struct ReaderError {
pub message: String,
}
impl fmt::Display for ReaderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Reader Error: {}", self.message)
}
}
impl<E> From<E> for ParseError
where
E: Into<ReaderError>,
{
fn from(value: E) -> Self {
ParseError {
object: String::from("Reader"),
message: value.into().to_string(),
}
}
}

View file

@ -4,22 +4,21 @@ use dotenvy::dotenv;
use crate::resolver::resolver_listener_loop; use crate::resolver::resolver_listener_loop;
mod authenticate;
mod db; mod db;
mod errors; mod errors;
mod parser; mod parser;
mod reader;
mod resolver; mod resolver;
mod sig;
mod structs; mod structs;
mod utils; mod utils;
mod sig;
mod authenticate;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> { async fn main() -> Result<(), Box<dyn Error>> {
dotenv().ok(); dotenv().ok();
let resolver_add = SocketAddr::from(([127, 0, 0, 1], 8080)); let resolver_add = SocketAddr::from(([127, 0, 0, 1], 8080));
let _ = tokio::join!( let _ = tokio::join!(resolver_listener_loop(resolver_add),);
resolver_listener_loop(resolver_add),
);
Ok(()) Ok(())
} }

View file

@ -2,9 +2,9 @@ use std::{mem::size_of, vec};
use crate::{ use crate::{
errors::ParseError, errors::ParseError,
reader::Reader,
structs::{ structs::{
Class, Header, KeyRData, LabelString, Message, Opcode, Question, RRClass, RRType, Class, Header, KeyRData, LabelString, Message, Opcode, Question, RRClass, RRType, Type, RR,
Type, RR,
}, },
}; };
@ -77,7 +77,7 @@ impl TryFrom<u16> for Opcode {
} }
pub trait FromBytes { pub trait FromBytes {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> fn from_bytes(reader: &mut Reader) -> Result<Self>
where where
Self: Sized; Self: Sized;
fn to_bytes(s: Self) -> Vec<u8> fn to_bytes(s: Self) -> Vec<u8>
@ -85,60 +85,21 @@ pub trait FromBytes {
Self: Sized; Self: Sized;
} }
impl Type {
pub fn to_data(&self, text: &String) -> Result<Vec<u8>> {
match self {
Type::Type(RRType::A) => {
let arr: Vec<u8> = text
.split(".")
.filter_map(|s| s.parse::<u8>().ok())
.collect();
if arr.len() == 4 {
Ok(arr)
} else {
Err(ParseError {
object: String::from("Type::A"),
message: String::from("Invalid IPv4 address"),
})
}
}
_ => todo!(),
}
}
pub fn from_data(&self, bytes: &[u8]) -> Result<String> {
match self {
Type::Type(RRType::A) => {
if bytes.len() == 4 {
let arr: Vec<String> = bytes.iter().map(|b| b.to_string()).collect();
Ok(arr.join("."))
} else {
Err(ParseError {
object: String::from("Type::A"),
message: String::from("Invalid Ipv4 address bytes"),
})
}
}
_ => todo!(),
}
}
}
impl FromBytes for Header { impl FromBytes for Header {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> { fn from_bytes(reader: &mut Reader) -> Result<Self> {
if bytes.len() < size_of::<Header>() { if reader.unread_bytes() < size_of::<Header>() {
Err(ParseError { Err(ParseError {
object: String::from("Header"), object: String::from("Header"),
message: String::from("Size of Header does not match"), message: String::from("Size of Header does not match"),
}) })
} else { } else {
*i += size_of::<Header>();
Ok(Header { Ok(Header {
id: u16::from_be_bytes(bytes[0..2].try_into().unwrap()), id: reader.read_u16()?,
flags: u16::from_be_bytes(bytes[2..4].try_into().unwrap()), flags: reader.read_u16()?,
qdcount: u16::from_be_bytes(bytes[4..6].try_into().unwrap()), qdcount: reader.read_u16()?,
ancount: u16::from_be_bytes(bytes[6..8].try_into().unwrap()), ancount: reader.read_u16()?,
nscount: u16::from_be_bytes(bytes[8..10].try_into().unwrap()), nscount: reader.read_u16()?,
arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()), arcount: reader.read_u16()?,
}) })
} }
} }
@ -158,34 +119,29 @@ impl FromBytes for Header {
} }
impl FromBytes for LabelString { impl FromBytes for LabelString {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> { fn from_bytes(reader: &mut Reader) -> Result<Self> {
let mut out = vec![]; let mut out = vec![];
// Parse qname labels // Parse qname labels
while bytes[*i] != 0 let mut code = reader.read_u8()?;
&& (bytes[*i] & 0b11000000 == 0) while code != 0 && (code & 0b11000000 == 0) && reader.unread_bytes() > code as usize {
&& bytes[*i] as usize + *i < bytes.len()
{
out.push( out.push(
String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap(), String::from_utf8(reader.read(code as usize)?.to_vec()).map_err(|e| {
ParseError {
object: String::from("Label"),
message: e.to_string(),
}
})?,
); );
*i += bytes[*i] as usize + 1; code = reader.read_u8()?;
} }
if bytes[*i] & 0b11000000 != 0 { if code & 0b11000000 != 0 {
let offset = u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()) & 0b0011111111111111; let offset = (((code & 0b00111111) as u16) << 8) | reader.read_u8()? as u16;
if *i <= offset as usize { let mut reader_past = reader.seek(offset as usize)?;
return Err(ParseError { out.extend(LabelString::from_bytes(&mut reader_past)?);
object: String::from("Label"),
message: String::from("Invalid PTR"),
});
} else {
out.extend(LabelString::from_bytes(bytes, &mut (offset as usize))?);
*i += 1;
}
} }
*i += 1;
Ok(out) Ok(out)
} }
@ -201,31 +157,27 @@ impl FromBytes for LabelString {
} }
impl FromBytes for Question { impl FromBytes for Question {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> { fn from_bytes(reader: &mut Reader) -> Result<Self> {
// 16 for length octet + zero length octet // 16 for length octet + zero length octet
if bytes.len() < 2 + size_of::<Class>() + size_of::<Type>() { if reader.unread_bytes() < 2 + size_of::<Class>() + size_of::<Type>() {
Err(ParseError { Err(ParseError {
object: String::from("Question"), object: String::from("Question"),
message: String::from("len of bytes smaller then minimum size"), message: String::from("len of bytes smaller then minimum size"),
}) })
} else { } else {
let qname = LabelString::from_bytes(bytes, i)?; let qname = LabelString::from_bytes(reader)?;
if bytes.len() - *i < size_of::<Class>() + size_of::<Type>() { if reader.unread_bytes() < 4 {
Err(ParseError { Err(ParseError {
object: String::from("Question"), object: String::from("Question"),
message: String::from("len of rest bytes smaller then minimum size"), message: String::from("len of rest bytes smaller then minimum size"),
}) })
} else { } else {
//Try Parse qtype //Try Parse qtype
let qtype = Type::from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())); let qtype = Type::from(reader.read_u16()?);
//Try Parse qclass //Try Parse qclass
let qclass = Class::from(u16::from_be_bytes( let qclass = Class::from(reader.read_u16()?);
bytes[*i + 2..*i + 4].try_into().unwrap(),
));
*i += 4; // For qtype and qclass => 4 bytes
Ok(Question { Ok(Question {
qname, qname,
@ -245,37 +197,31 @@ impl FromBytes for Question {
} }
impl FromBytes for RR { impl FromBytes for RR {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> { fn from_bytes(reader: &mut Reader) -> Result<Self> {
let name = LabelString::from_bytes(bytes, i)?; let name = LabelString::from_bytes(reader)?;
if bytes.len() - *i < size_of::<Type>() + size_of::<Class>() + 6 { if reader.unread_bytes() < size_of::<Type>() + size_of::<Class>() + 6 {
Err(ParseError { Err(ParseError {
object: String::from("RR"), object: String::from("RR"),
message: String::from("len of rest of bytes smaller then minimum size"), message: String::from("len of rest of bytes smaller then minimum size"),
}) })
} else { } else {
let _type = Type::from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())); let _type = Type::from(reader.read_u16()?);
let class = Class::from(reader.read_u16()?);
let class = Class::from(u16::from_be_bytes( let ttl = reader.read_i32()?;
bytes[*i + 2..*i + 4].try_into().unwrap(), let rdlength = reader.read_u16()?;
)); if reader.unread_bytes() < rdlength as usize {
let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap());
let rdlength = u16::from_be_bytes(bytes[*i + 8..*i + 10].try_into().unwrap());
if bytes.len() - *i - 10 < rdlength as usize {
Err(ParseError { Err(ParseError {
object: String::from("RR"), object: String::from("RR"),
message: String::from("len of rest of bytes not equal to rdlength"), message: String::from("len of rest of bytes not equal to rdlength"),
}) })
} else { } else {
*i += 10 + rdlength as usize;
Ok(RR { Ok(RR {
name, name,
_type, _type,
class, class,
ttl, ttl,
rdlength, rdlength,
rdata: bytes[*i - rdlength as usize..*i].to_vec(), rdata: reader.read(rdlength as usize)?,
}) })
} }
} }
@ -293,27 +239,27 @@ impl FromBytes for RR {
} }
impl FromBytes for Message { impl FromBytes for Message {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> { fn from_bytes(reader: &mut Reader) -> Result<Self> {
let header = Header::from_bytes(&bytes, i)?; let header = Header::from_bytes(reader)?;
let mut question = vec![]; let mut question = vec![];
for _ in 0..header.qdcount { for _ in 0..header.qdcount {
question.push(Question::from_bytes(&bytes, i)?); question.push(Question::from_bytes(reader)?);
} }
let mut answer = vec![]; let mut answer = vec![];
for _ in 0..header.ancount { for _ in 0..header.ancount {
answer.push(RR::from_bytes(&bytes, i)?); answer.push(RR::from_bytes(reader)?);
} }
let mut authority = vec![]; let mut authority = vec![];
for _ in 0..header.nscount { for _ in 0..header.nscount {
authority.push(RR::from_bytes(&bytes, i)?); authority.push(RR::from_bytes(reader)?);
} }
let mut additional = vec![]; let mut additional = vec![];
for _ in 0..header.arcount { for _ in 0..header.arcount {
additional.push(RR::from_bytes(&bytes, i)?); additional.push(RR::from_bytes(reader)?);
} }
Ok(Message { Ok(Message {
@ -346,24 +292,23 @@ impl FromBytes for Message {
} }
impl FromBytes for KeyRData { impl FromBytes for KeyRData {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> { fn from_bytes(reader: &mut Reader) -> Result<Self> {
if bytes.len() < 18 { if reader.unread_bytes() < 18 {
Err(ParseError { Err(ParseError {
object: String::from("KeyRData"), object: String::from("KeyRData"),
message: String::from("invalid rdata"), message: String::from("invalid rdata"),
}) })
} else { } else {
*i = 18;
Ok(KeyRData { Ok(KeyRData {
type_covered: u16::from_be_bytes(bytes[0..2].try_into().unwrap()), type_covered: reader.read_u16()?,
algo: bytes[2], algo: reader.read_u8()?,
labels: bytes[3], labels: reader.read_u8()?,
original_ttl: u32::from_be_bytes(bytes[4..8].try_into().unwrap()), original_ttl: reader.read_u32()?,
signature_expiration: u32::from_be_bytes(bytes[8..12].try_into().unwrap()), signature_expiration: reader.read_u32()?,
signature_inception: u32::from_be_bytes(bytes[12..16].try_into().unwrap()), signature_inception: reader.read_u32()?,
key_tag: u16::from_be_bytes(bytes[16..18].try_into().unwrap()), key_tag: reader.read_u16()?,
signer: LabelString::from_bytes(bytes, i)?, signer: LabelString::from_bytes(reader)?,
signature: bytes[*i..bytes.len()].to_vec(), signature: reader.read(reader.unread_bytes())?,
}) })
} }
} }

85
src/reader.rs Normal file
View file

@ -0,0 +1,85 @@
use std::array::TryFromSliceError;
use crate::errors::ReaderError;
pub struct Reader<'a> {
buffer: &'a [u8],
position: usize,
}
type Result<T> = std::result::Result<T, ReaderError>;
impl<'a> Reader<'a> {
pub fn new(buffer: &[u8]) -> Reader {
Reader {
buffer,
position: 0,
}
}
pub fn unread_bytes(&self) -> usize {
self.buffer.len() - self.position
}
pub fn read(&mut self, size: usize) -> Result<Vec<u8>> {
if size > self.unread_bytes() {
Err(ReaderError {
message: String::from("cannot read enough bytes"),
})
} else {
self.position += size;
Ok(self.buffer[self.position - size..self.position].to_vec())
}
}
pub fn read_u8(&mut self) -> Result<u8> {
self.position += 1;
Ok(self.buffer[self.position - 1])
}
pub fn read_u16(&mut self) -> Result<u16> {
let result = u16::from_be_bytes(
self.buffer[self.position..self.position + 2]
.try_into()
.map_err(|e: TryFromSliceError| ReaderError {
message: e.to_string(),
})?,
);
self.position += 2;
Ok(result)
}
pub fn read_i32(&mut self) -> Result<i32> {
let result = i32::from_be_bytes(
self.buffer[self.position..self.position + 4]
.try_into()
.map_err(|e: TryFromSliceError| ReaderError {
message: e.to_string(),
})?,
);
self.position += 4;
Ok(result)
}
pub fn read_u32(&mut self) -> Result<u32> {
let result = u32::from_be_bytes(
self.buffer[self.position..self.position + 4]
.try_into()
.map_err(|e: TryFromSliceError| ReaderError {
message: e.to_string(),
})?,
);
self.position += 4;
Ok(result)
}
pub fn seek(&self, position: usize) -> Result<Self> {
if position >= self.position {
Err(ReaderError {
message: String::from("Seeking into the future is not allowed!!"),
})
} else {
Ok(Reader::new(&self.buffer[position..self.position]))
}
}
}

View file

@ -8,6 +8,7 @@ use crate::authenticate::authenticate;
use crate::db::models::{delete_from_database, get_from_database, insert_into_database}; use crate::db::models::{delete_from_database, get_from_database, insert_into_database};
use crate::errors::ParseError; use crate::errors::ParseError;
use crate::parser::FromBytes; use crate::parser::FromBytes;
use crate::reader::Reader;
use crate::sig::Sig; use crate::sig::Sig;
use crate::structs::{Class, Header, Message, Opcode, RRClass, RRType, Type, RCODE}; use crate::structs::{Class, Header, Message, Opcode, RRClass, RRType, Type, RCODE};
use crate::utils::vec_equal; use crate::utils::vec_equal;
@ -140,7 +141,8 @@ async fn handle_update(message: Message, bytes: &[u8]) -> Message {
fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message {
eprintln!("{}", err); eprintln!("{}", err);
let mut header = Header::from_bytes(bytes, &mut 0).unwrap_or(Header { let mut reader = Reader::new(bytes);
let mut header = Header::from_bytes(&mut reader).unwrap_or(Header {
id: 0, id: 0,
flags: 0, flags: 0,
qdcount: 0, qdcount: 0,
@ -165,8 +167,8 @@ fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message {
} }
async fn get_response(bytes: &[u8]) -> Message { async fn get_response(bytes: &[u8]) -> Message {
let mut i: usize = 0; let mut reader = Reader::new(bytes);
match Message::from_bytes(bytes, &mut i) { match Message::from_bytes(&mut reader) {
Ok(message) => match get_opcode(&message.header.flags) { Ok(message) => match get_opcode(&message.header.flags) {
Ok(opcode) => match opcode { Ok(opcode) => match opcode {
Opcode::QUERY => handle_query(message).await, Opcode::QUERY => handle_query(message).await,

View file

@ -2,6 +2,7 @@ use base64::prelude::*;
use crate::{ use crate::{
parser::FromBytes, parser::FromBytes,
reader::Reader,
structs::{KeyRData, RR}, structs::{KeyRData, RR},
}; };
@ -19,10 +20,10 @@ impl Sig {
let mut request = datagram[0..datagram.len() - 11 - rr.rdlength as usize].to_vec(); let mut request = datagram[0..datagram.len() - 11 - rr.rdlength as usize].to_vec();
request[11] -= 1; // Decrease arcount request[11] -= 1; // Decrease arcount
let mut i = 0; let mut reader = Reader::new(&rr.rdata);
let key_rdata = KeyRData::from_bytes(&rr.rdata, &mut i).unwrap(); let key_rdata = KeyRData::from_bytes(&mut reader).unwrap();
let mut raw_data = rr.rdata[0..i].to_vec(); let mut raw_data = rr.rdata[0..rr.rdata.len() - key_rdata.signature.len()].to_vec();
raw_data.extend(request); raw_data.extend(request);
Sig { Sig {