10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-11-22 05:41:11 +01:00

implement message compression and update works

This commit is contained in:
Xander Bil 2024-03-25 22:40:39 +01:00
parent 68ce89c5a2
commit 9f8ecac4f0
No known key found for this signature in database
GPG key ID: EC9706B54A278598
3 changed files with 87 additions and 23 deletions

View file

@ -2,7 +2,7 @@ use std::{mem::size_of, vec};
use crate::{ use crate::{
errors::ParseError, errors::ParseError,
structs::{Class, Header, LabelString, Message, OptRR, Question, Type, RR}, structs::{Class, Header, LabelString, Message, Opcode, OptRR, Question, Type, RR},
}; };
type Result<T> = std::result::Result<T, ParseError>; type Result<T> = std::result::Result<T, ParseError>;
@ -13,6 +13,8 @@ impl TryFrom<u16> for Type {
fn try_from(value: u16) -> std::result::Result<Self, String> { fn try_from(value: u16) -> std::result::Result<Self, String> {
match value { match value {
x if x == Type::A as u16 => Ok(Type::A), x if x == Type::A as u16 => Ok(Type::A),
x if x == Type::OPT as u16 => Ok(Type::OPT),
x if x == Type::SOA as u16 => Ok(Type::SOA),
_ => Err(format!("Invalid Type value: {}", value)), _ => Err(format!("Invalid Type value: {}", value)),
} }
} }
@ -29,6 +31,18 @@ impl TryFrom<u16> for Class {
} }
} }
impl TryFrom<u16> for Opcode {
type Error = String;
fn try_from(value: u16) -> std::result::Result<Self, String> {
match value {
x if x == Opcode::QUERY as u16 => Ok(Opcode::QUERY),
x if x == Opcode::UPDATE as u16 => Ok(Opcode::UPDATE),
_ => Err(format!("Invalid Opcode value: {}", value)),
}
}
}
pub trait FromBytes { pub trait FromBytes {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self>
where where
@ -77,6 +91,7 @@ impl Type {
} }
} }
Type::SOA => todo!(), Type::SOA => todo!(),
Type::OPT => todo!(),
} }
} }
pub fn from_data(&self, bytes: &[u8]) -> Result<String> { pub fn from_data(&self, bytes: &[u8]) -> Result<String> {
@ -93,6 +108,7 @@ impl Type {
} }
} }
Type::SOA => todo!(), Type::SOA => todo!(),
Type::OPT => todo!(),
} }
} }
} }
@ -136,13 +152,29 @@ impl FromBytes for LabelString {
let mut qname = vec![]; let mut qname = vec![];
// Parse qname labels // Parse qname labels
while bytes[*i] != 0 && bytes[*i] as usize + *i < bytes.len() { while bytes[*i] != 0
&& (bytes[*i] & 0b11000000 == 0)
&& bytes[*i] as usize + *i < bytes.len()
{
qname.push( qname.push(
String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap(), String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap(),
); );
*i += bytes[*i] as usize + 1; *i += bytes[*i] as usize + 1;
} }
if bytes[*i] & 0b11000000 != 0 {
let offset = u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()) & 0b00111111;
if *i <= offset as usize {
return Err(ParseError {
object: String::from("Label"),
message: String::from("Invalid PTR"),
});
} else {
qname.extend(LabelString::from_bytes(bytes, &mut (offset as usize))?);
*i += 1;
}
}
*i += 1; *i += 1;
Ok(qname) Ok(qname)
} }
@ -178,18 +210,18 @@ impl FromBytes for Question {
//Try Parse qtype //Try Parse qtype
let qtype = let qtype =
Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
.map_err(|_| ParseError { .map_err(|e| ParseError {
object: String::from("Type"), object: String::from("Type"),
message: String::from("invalid"), message: e,
})?; })?;
//Try Parse qclass //Try Parse qclass
let qclass = Class::try_from(u16::from_be_bytes( let qclass = Class::try_from(u16::from_be_bytes(
bytes[*i + 2..*i + 4].try_into().unwrap(), bytes[*i + 2..*i + 4].try_into().unwrap(),
)) ))
.map_err(|_| ParseError { .map_err(|e| ParseError {
object: String::from("Class"), object: String::from("Class"),
message: String::from("invalid"), message: e,
})?; })?;
*i += 4; // For qtype and qclass => 4 bytes *i += 4; // For qtype and qclass => 4 bytes
@ -214,6 +246,7 @@ impl FromBytes for Question {
impl FromBytes for RR { impl FromBytes for RR {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
let name = LabelString::from_bytes(bytes, i)?; let name = LabelString::from_bytes(bytes, i)?;
println!("{:#?}", name);
if bytes.len() - *i < size_of::<Type>() + size_of::<Class>() + 6 { if bytes.len() - *i < size_of::<Type>() + size_of::<Class>() + 6 {
Err(ParseError { Err(ParseError {
object: String::from("RR"), object: String::from("RR"),
@ -221,17 +254,17 @@ impl FromBytes for RR {
}) })
} else { } else {
let _type = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) let _type = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
.map_err(|_| ParseError { .map_err(|e| ParseError {
object: String::from("Type"), object: String::from("Type"),
message: String::from("invalid"), message: e,
})?; })?;
let class = Class::try_from(u16::from_be_bytes( let class = Class::try_from(u16::from_be_bytes(
bytes[*i + 2..*i + 4].try_into().unwrap(), bytes[*i + 2..*i + 4].try_into().unwrap(),
)) ))
.map_err(|_| ParseError { .map_err(|e| ParseError {
object: String::from("Class"), object: String::from("Class"),
message: String::from("invalid"), message: e,
})?; })?;
let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap()); let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap());
@ -287,7 +320,7 @@ impl FromBytes for Message {
} }
let mut additional = vec![]; let mut additional = vec![];
for _ in 0..header.nscount { for _ in 0..header.arcount {
additional.push(RR::from_bytes(&bytes, i)?); additional.push(RR::from_bytes(&bytes, i)?);
} }

View file

@ -4,19 +4,25 @@ use std::sync::Arc;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use crate::db::models::get_from_database; use crate::db::models::{get_from_database, insert_into_database};
use crate::parser::FromBytes; use crate::parser::FromBytes;
use crate::structs::{Class, Message, Type, RCODE}; use crate::structs::{Class, Message, Type, RCODE, RR, Opcode};
use crate::utils::vec_equal; use crate::utils::vec_equal;
const MAX_DATAGRAM_SIZE: usize = 4096; const MAX_DATAGRAM_SIZE: usize = 4096;
fn set_response_flags(flags: u16, rcode: RCODE) -> u16 { fn set_response_flags(flags: u16, rcode: RCODE) -> u16 {
(flags | 0b1000010000000000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111 (flags | 0b1_0000_1_0_0_0_000_0000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111
}
fn get_opcode(flags: &u16) -> Result<Opcode, String> {
Opcode::try_from((flags & 0b0111100000000000) >> 11)
} }
async fn handle_query(message: Message) -> Message { async fn handle_query(message: Message) -> Message {
let mut response = message.clone(); let mut response = message.clone();
response.header.arcount = 0; //TODO: fix this, handle unknown class values
for question in message.question { for question in message.question {
let answer = get_from_database(&question).await; let answer = get_from_database(&question).await;
@ -28,7 +34,6 @@ async fn handle_query(message: Message) -> Message {
response.answer = vec![rr] response.answer = vec![rr]
} }
Err(e) => { Err(e) => {
response.header.flags |= 0b1000010110000011;
response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN); response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN);
eprintln!("{}", e); eprintln!("{}", e);
} }
@ -48,10 +53,11 @@ async fn handle_update(message: Message) -> Message {
} }
// Check Zone authority // Check Zone authority
let zlen = message.question[0].qname.len(); let zone = &message.question[0];
let zlen = zone.qname.len();
if !(zlen >= 2 if !(zlen >= 2
&& message.question[0].qname[zlen - 1] == "gent" && zone.qname[zlen - 1] == "gent"
&& message.question[0].qname[zlen - 2] == "zeus") && zone.qname[zlen - 2] == "zeus")
{ {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH);
return response; return response;
@ -67,18 +73,18 @@ async fn handle_update(message: Message) -> Message {
// TODO: implement this, use rfc2931 // TODO: implement this, use rfc2931
// Update Section Prescan // Update Section Prescan
for rr in message.authority { for rr in &message.authority {
let rlen = rr.name.len(); let rlen = rr.name.len();
// Check if rr has same zone // Check if rr has same zone
if rlen < zlen || !(vec_equal(&message.question[0].qname, &rr.name[rlen - zlen..])) { if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTZONE); response.header.flags = set_response_flags(response.header.flags, RCODE::NOTZONE);
return response; return response;
} }
if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0)) if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0))
|| (rr.class == Class::NONE && rr.ttl != 0) || (rr.class == Class::NONE && rr.ttl != 0)
|| rr.class != message.question[0].qclass || rr.class != zone.qclass
{ {
response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR);
return response; return response;
@ -86,13 +92,32 @@ async fn handle_update(message: Message) -> Message {
} }
for rr in message.authority {
if rr.class == zone.qclass {
insert_into_database(rr).await;
} else if rr.class == Class::ANY {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP);
return response;
} else if rr.class == Class::ANY {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP);
return response;
}
}
response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR);
response response
} }
async fn get_response(bytes: &[u8]) -> Message { async fn get_response(bytes: &[u8]) -> Message {
let mut i: usize = 0; let mut i: usize = 0;
match Message::from_bytes(bytes, &mut i) { match Message::from_bytes(bytes, &mut i) {
Ok(message) => handle_query(message).await, Ok(message) => match get_opcode(&message.header.flags) {
Ok(opcode) => match opcode {
Opcode::QUERY => handle_query(message).await,
Opcode::UPDATE => handle_update(message).await,
},
Err(_) => todo!(),
},
Err(err) => { Err(err) => {
println!("{}", err); println!("{}", err);
unimplemented!() //TODO: implement this unimplemented!() //TODO: implement this

View file

@ -4,7 +4,8 @@ use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub enum Type { pub enum Type {
A = 1, A = 1,
SOA = 6 SOA = 6,
OPT = 41
} }
#[repr(u16)] #[repr(u16)]
@ -30,6 +31,11 @@ pub enum RCODE {
NOTZONE = 10 NOTZONE = 10
} }
pub enum Opcode {
QUERY = 0,
UPDATE = 5
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Question { pub struct Question {
pub qname: Vec<String>, pub qname: Vec<String>,