mirror of
https://github.com/ZeusWPI/ZNS.git
synced 2024-10-30 05:24:26 +01:00
implement message compression and update works
This commit is contained in:
parent
68ce89c5a2
commit
9f8ecac4f0
3 changed files with 87 additions and 23 deletions
|
@ -2,7 +2,7 @@ use std::{mem::size_of, vec};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
errors::ParseError,
|
errors::ParseError,
|
||||||
structs::{Class, Header, LabelString, Message, OptRR, Question, Type, RR},
|
structs::{Class, Header, LabelString, Message, Opcode, OptRR, Question, Type, RR},
|
||||||
};
|
};
|
||||||
|
|
||||||
type Result<T> = std::result::Result<T, ParseError>;
|
type Result<T> = std::result::Result<T, ParseError>;
|
||||||
|
@ -13,6 +13,8 @@ impl TryFrom<u16> for Type {
|
||||||
fn try_from(value: u16) -> std::result::Result<Self, String> {
|
fn try_from(value: u16) -> std::result::Result<Self, String> {
|
||||||
match value {
|
match value {
|
||||||
x if x == Type::A as u16 => Ok(Type::A),
|
x if x == Type::A as u16 => Ok(Type::A),
|
||||||
|
x if x == Type::OPT as u16 => Ok(Type::OPT),
|
||||||
|
x if x == Type::SOA as u16 => Ok(Type::SOA),
|
||||||
_ => Err(format!("Invalid Type value: {}", value)),
|
_ => Err(format!("Invalid Type value: {}", value)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,6 +31,18 @@ impl TryFrom<u16> for Class {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TryFrom<u16> for Opcode {
|
||||||
|
type Error = String;
|
||||||
|
|
||||||
|
fn try_from(value: u16) -> std::result::Result<Self, String> {
|
||||||
|
match value {
|
||||||
|
x if x == Opcode::QUERY as u16 => Ok(Opcode::QUERY),
|
||||||
|
x if x == Opcode::UPDATE as u16 => Ok(Opcode::UPDATE),
|
||||||
|
_ => Err(format!("Invalid Opcode value: {}", value)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait FromBytes {
|
pub trait FromBytes {
|
||||||
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self>
|
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self>
|
||||||
where
|
where
|
||||||
|
@ -77,6 +91,7 @@ impl Type {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Type::SOA => todo!(),
|
Type::SOA => todo!(),
|
||||||
|
Type::OPT => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn from_data(&self, bytes: &[u8]) -> Result<String> {
|
pub fn from_data(&self, bytes: &[u8]) -> Result<String> {
|
||||||
|
@ -93,6 +108,7 @@ impl Type {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Type::SOA => todo!(),
|
Type::SOA => todo!(),
|
||||||
|
Type::OPT => todo!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -136,13 +152,29 @@ impl FromBytes for LabelString {
|
||||||
let mut qname = vec![];
|
let mut qname = vec![];
|
||||||
|
|
||||||
// Parse qname labels
|
// Parse qname labels
|
||||||
while bytes[*i] != 0 && bytes[*i] as usize + *i < bytes.len() {
|
while bytes[*i] != 0
|
||||||
|
&& (bytes[*i] & 0b11000000 == 0)
|
||||||
|
&& bytes[*i] as usize + *i < bytes.len()
|
||||||
|
{
|
||||||
qname.push(
|
qname.push(
|
||||||
String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap(),
|
String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap(),
|
||||||
);
|
);
|
||||||
*i += bytes[*i] as usize + 1;
|
*i += bytes[*i] as usize + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if bytes[*i] & 0b11000000 != 0 {
|
||||||
|
let offset = u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()) & 0b00111111;
|
||||||
|
if *i <= offset as usize {
|
||||||
|
return Err(ParseError {
|
||||||
|
object: String::from("Label"),
|
||||||
|
message: String::from("Invalid PTR"),
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
qname.extend(LabelString::from_bytes(bytes, &mut (offset as usize))?);
|
||||||
|
*i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
*i += 1;
|
*i += 1;
|
||||||
Ok(qname)
|
Ok(qname)
|
||||||
}
|
}
|
||||||
|
@ -178,18 +210,18 @@ impl FromBytes for Question {
|
||||||
//Try Parse qtype
|
//Try Parse qtype
|
||||||
let qtype =
|
let qtype =
|
||||||
Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
|
Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
|
||||||
.map_err(|_| ParseError {
|
.map_err(|e| ParseError {
|
||||||
object: String::from("Type"),
|
object: String::from("Type"),
|
||||||
message: String::from("invalid"),
|
message: e,
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
//Try Parse qclass
|
//Try Parse qclass
|
||||||
let qclass = Class::try_from(u16::from_be_bytes(
|
let qclass = Class::try_from(u16::from_be_bytes(
|
||||||
bytes[*i + 2..*i + 4].try_into().unwrap(),
|
bytes[*i + 2..*i + 4].try_into().unwrap(),
|
||||||
))
|
))
|
||||||
.map_err(|_| ParseError {
|
.map_err(|e| ParseError {
|
||||||
object: String::from("Class"),
|
object: String::from("Class"),
|
||||||
message: String::from("invalid"),
|
message: e,
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
*i += 4; // For qtype and qclass => 4 bytes
|
*i += 4; // For qtype and qclass => 4 bytes
|
||||||
|
@ -214,6 +246,7 @@ impl FromBytes for Question {
|
||||||
impl FromBytes for RR {
|
impl FromBytes for RR {
|
||||||
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
|
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
|
||||||
let name = LabelString::from_bytes(bytes, i)?;
|
let name = LabelString::from_bytes(bytes, i)?;
|
||||||
|
println!("{:#?}", name);
|
||||||
if bytes.len() - *i < size_of::<Type>() + size_of::<Class>() + 6 {
|
if bytes.len() - *i < size_of::<Type>() + size_of::<Class>() + 6 {
|
||||||
Err(ParseError {
|
Err(ParseError {
|
||||||
object: String::from("RR"),
|
object: String::from("RR"),
|
||||||
|
@ -221,17 +254,17 @@ impl FromBytes for RR {
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let _type = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
|
let _type = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
|
||||||
.map_err(|_| ParseError {
|
.map_err(|e| ParseError {
|
||||||
object: String::from("Type"),
|
object: String::from("Type"),
|
||||||
message: String::from("invalid"),
|
message: e,
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let class = Class::try_from(u16::from_be_bytes(
|
let class = Class::try_from(u16::from_be_bytes(
|
||||||
bytes[*i + 2..*i + 4].try_into().unwrap(),
|
bytes[*i + 2..*i + 4].try_into().unwrap(),
|
||||||
))
|
))
|
||||||
.map_err(|_| ParseError {
|
.map_err(|e| ParseError {
|
||||||
object: String::from("Class"),
|
object: String::from("Class"),
|
||||||
message: String::from("invalid"),
|
message: e,
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap());
|
let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap());
|
||||||
|
@ -287,7 +320,7 @@ impl FromBytes for Message {
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut additional = vec![];
|
let mut additional = vec![];
|
||||||
for _ in 0..header.nscount {
|
for _ in 0..header.arcount {
|
||||||
additional.push(RR::from_bytes(&bytes, i)?);
|
additional.push(RR::from_bytes(&bytes, i)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,19 +4,25 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
|
|
||||||
use crate::db::models::get_from_database;
|
use crate::db::models::{get_from_database, insert_into_database};
|
||||||
use crate::parser::FromBytes;
|
use crate::parser::FromBytes;
|
||||||
use crate::structs::{Class, Message, Type, RCODE};
|
use crate::structs::{Class, Message, Type, RCODE, RR, Opcode};
|
||||||
use crate::utils::vec_equal;
|
use crate::utils::vec_equal;
|
||||||
|
|
||||||
const MAX_DATAGRAM_SIZE: usize = 4096;
|
const MAX_DATAGRAM_SIZE: usize = 4096;
|
||||||
|
|
||||||
fn set_response_flags(flags: u16, rcode: RCODE) -> u16 {
|
fn set_response_flags(flags: u16, rcode: RCODE) -> u16 {
|
||||||
(flags | 0b1000010000000000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111
|
(flags | 0b1_0000_1_0_0_0_000_0000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn get_opcode(flags: &u16) -> Result<Opcode, String> {
|
||||||
|
Opcode::try_from((flags & 0b0111100000000000) >> 11)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_query(message: Message) -> Message {
|
async fn handle_query(message: Message) -> Message {
|
||||||
let mut response = message.clone();
|
let mut response = message.clone();
|
||||||
|
response.header.arcount = 0; //TODO: fix this, handle unknown class values
|
||||||
|
|
||||||
for question in message.question {
|
for question in message.question {
|
||||||
let answer = get_from_database(&question).await;
|
let answer = get_from_database(&question).await;
|
||||||
|
@ -28,7 +34,6 @@ async fn handle_query(message: Message) -> Message {
|
||||||
response.answer = vec![rr]
|
response.answer = vec![rr]
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
response.header.flags |= 0b1000010110000011;
|
|
||||||
response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN);
|
response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN);
|
||||||
eprintln!("{}", e);
|
eprintln!("{}", e);
|
||||||
}
|
}
|
||||||
|
@ -48,10 +53,11 @@ async fn handle_update(message: Message) -> Message {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check Zone authority
|
// Check Zone authority
|
||||||
let zlen = message.question[0].qname.len();
|
let zone = &message.question[0];
|
||||||
|
let zlen = zone.qname.len();
|
||||||
if !(zlen >= 2
|
if !(zlen >= 2
|
||||||
&& message.question[0].qname[zlen - 1] == "gent"
|
&& zone.qname[zlen - 1] == "gent"
|
||||||
&& message.question[0].qname[zlen - 2] == "zeus")
|
&& zone.qname[zlen - 2] == "zeus")
|
||||||
{
|
{
|
||||||
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH);
|
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH);
|
||||||
return response;
|
return response;
|
||||||
|
@ -67,18 +73,18 @@ async fn handle_update(message: Message) -> Message {
|
||||||
// TODO: implement this, use rfc2931
|
// TODO: implement this, use rfc2931
|
||||||
|
|
||||||
// Update Section Prescan
|
// Update Section Prescan
|
||||||
for rr in message.authority {
|
for rr in &message.authority {
|
||||||
let rlen = rr.name.len();
|
let rlen = rr.name.len();
|
||||||
|
|
||||||
// Check if rr has same zone
|
// Check if rr has same zone
|
||||||
if rlen < zlen || !(vec_equal(&message.question[0].qname, &rr.name[rlen - zlen..])) {
|
if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) {
|
||||||
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTZONE);
|
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTZONE);
|
||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0))
|
if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0))
|
||||||
|| (rr.class == Class::NONE && rr.ttl != 0)
|
|| (rr.class == Class::NONE && rr.ttl != 0)
|
||||||
|| rr.class != message.question[0].qclass
|
|| rr.class != zone.qclass
|
||||||
{
|
{
|
||||||
response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR);
|
response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR);
|
||||||
return response;
|
return response;
|
||||||
|
@ -86,13 +92,32 @@ async fn handle_update(message: Message) -> Message {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for rr in message.authority {
|
||||||
|
if rr.class == zone.qclass {
|
||||||
|
insert_into_database(rr).await;
|
||||||
|
} else if rr.class == Class::ANY {
|
||||||
|
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP);
|
||||||
|
return response;
|
||||||
|
} else if rr.class == Class::ANY {
|
||||||
|
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP);
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR);
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_response(bytes: &[u8]) -> Message {
|
async fn get_response(bytes: &[u8]) -> Message {
|
||||||
let mut i: usize = 0;
|
let mut i: usize = 0;
|
||||||
match Message::from_bytes(bytes, &mut i) {
|
match Message::from_bytes(bytes, &mut i) {
|
||||||
Ok(message) => handle_query(message).await,
|
Ok(message) => match get_opcode(&message.header.flags) {
|
||||||
|
Ok(opcode) => match opcode {
|
||||||
|
Opcode::QUERY => handle_query(message).await,
|
||||||
|
Opcode::UPDATE => handle_update(message).await,
|
||||||
|
},
|
||||||
|
Err(_) => todo!(),
|
||||||
|
},
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
println!("{}", err);
|
println!("{}", err);
|
||||||
unimplemented!() //TODO: implement this
|
unimplemented!() //TODO: implement this
|
||||||
|
|
|
@ -4,7 +4,8 @@ use serde::Deserialize;
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub enum Type {
|
pub enum Type {
|
||||||
A = 1,
|
A = 1,
|
||||||
SOA = 6
|
SOA = 6,
|
||||||
|
OPT = 41
|
||||||
}
|
}
|
||||||
|
|
||||||
#[repr(u16)]
|
#[repr(u16)]
|
||||||
|
@ -30,6 +31,11 @@ pub enum RCODE {
|
||||||
NOTZONE = 10
|
NOTZONE = 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum Opcode {
|
||||||
|
QUERY = 0,
|
||||||
|
UPDATE = 5
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Question {
|
pub struct Question {
|
||||||
pub qname: Vec<String>,
|
pub qname: Vec<String>,
|
||||||
|
|
Loading…
Reference in a new issue