10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-10-30 05:24:26 +01:00

Improve error handling

This commit is contained in:
Xander Bil 2024-06-07 00:26:44 +02:00
parent 925370314a
commit 83fedebed6
No known key found for this signature in database
GPG key ID: EC9706B54A278598
4 changed files with 109 additions and 56 deletions

View file

@ -81,15 +81,15 @@ impl Record {
} }
} }
pub async fn insert_into_database(rr: RR) -> Result<(), DatabaseError> { pub async fn insert_into_database(rr: &RR) -> Result<(), DatabaseError> {
let db_connection = &mut establish_connection(); let db_connection = &mut establish_connection();
let record = Record { let record = Record {
name: rr.name.join("."), name: rr.name.join("."),
_type: rr._type.into(), _type: rr._type.clone().into(),
class: rr.class.into(), class: rr.class.clone().into(),
ttl: rr.ttl, ttl: rr.ttl,
rdlength: rr.rdlength as i32, rdlength: rr.rdlength as i32,
rdata: rr.rdata, rdata: rr.rdata.clone(),
}; };
Record::create(db_connection, record).map_err(|e| DatabaseError { Record::create(db_connection, record).map_err(|e| DatabaseError {
@ -128,11 +128,17 @@ pub async fn get_from_database(question: &Question) -> Result<Vec<RR>, DatabaseE
//TODO: cleanup models //TODO: cleanup models
pub async fn delete_from_database( pub async fn delete_from_database(
name: Vec<String>, name: &Vec<String>,
_type: Option<Type>, _type: Option<Type>,
class: Class, class: Class,
rdata: Option<Vec<u8>>, rdata: Option<Vec<u8>>,
) { ) {
let db_connection = &mut establish_connection(); let db_connection = &mut establish_connection();
let _ = Record::delete(db_connection, name.join("."), _type.map(|f| f.into()), class.into(), rdata); let _ = Record::delete(
db_connection,
name.join("."),
_type.map(|f| f.into()),
class.into(),
rdata,
);
} }

View file

@ -1,8 +1,10 @@
use core::fmt; use core::fmt;
#[derive(Debug)] use crate::structs::RCODE;
pub struct DNSError { pub struct DNSError {
pub message: String, pub message: String,
pub rcode: RCODE
} }
impl fmt::Display for DNSError { impl fmt::Display for DNSError {
@ -25,7 +27,7 @@ impl fmt::Display for ParseError {
#[derive(Debug)] #[derive(Debug)]
pub struct DatabaseError { pub struct DatabaseError {
pub message: String, pub message: String
} }
impl fmt::Display for DatabaseError { impl fmt::Display for DatabaseError {
@ -67,3 +69,15 @@ where
} }
} }
} }
impl<E> From<E> for DNSError
where
E: Into<ParseError>,
{
fn from(value: E) -> Self {
DNSError {
message: value.into().to_string(),
rcode: RCODE::FORMERR
}
}
}

View file

@ -6,7 +6,7 @@ use tokio::net::UdpSocket;
use crate::authenticate::authenticate; use crate::authenticate::authenticate;
use crate::db::models::{delete_from_database, get_from_database, insert_into_database}; use crate::db::models::{delete_from_database, get_from_database, insert_into_database};
use crate::errors::ParseError; use crate::errors::{DNSError, ParseError};
use crate::parser::FromBytes; use crate::parser::FromBytes;
use crate::reader::Reader; use crate::reader::Reader;
use crate::sig::Sig; use crate::sig::Sig;
@ -15,7 +15,7 @@ use crate::utils::vec_equal;
const MAX_DATAGRAM_SIZE: usize = 4096; const MAX_DATAGRAM_SIZE: usize = 4096;
fn set_response_flags(flags: u16, rcode: RCODE) -> u16 { fn set_response_flags(flags: &u16, rcode: RCODE) -> u16 {
(flags | 0b1_0000_1_0_0_0_000_0000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111 (flags | 0b1_0000_1_0_0_0_000_0000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111
} }
@ -23,46 +23,50 @@ fn get_opcode(flags: &u16) -> Result<Opcode, String> {
Opcode::try_from((flags & 0b0111100000000000) >> 11) Opcode::try_from((flags & 0b0111100000000000) >> 11)
} }
async fn handle_query(message: Message) -> Message { async fn handle_query(message: &Message) -> Result<Message, DNSError> {
let mut response = message.clone(); let mut response = message.clone();
response.header.arcount = 0; //TODO: fix this, handle unknown class values response.header.arcount = 0; //TODO: fix this, handle unknown class values
for question in message.question { for question in &message.question {
let answers = get_from_database(&question).await; let answers = get_from_database(&question).await;
match answers { match answers {
Ok(rrs) => { Ok(rrs) => {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR);
response.header.ancount = rrs.len() as u16; response.header.ancount = rrs.len() as u16;
response.answer = rrs response.answer.extend(rrs)
} }
Err(e) => { Err(e) => {
response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN); return Err(DNSError {
eprintln!("{}", e); rcode: RCODE::NXDOMAIN,
message: e.to_string(),
})
} }
} }
} }
response Ok(response)
} }
async fn handle_update(message: Message, bytes: &[u8]) -> Message { async fn handle_update(message: &Message, bytes: &[u8]) -> Result<Message, DNSError> {
let mut response = message.clone(); let response = message.clone();
// Zone section (question) processing // Zone section (question) processing
if (message.header.qdcount != 1) if (message.header.qdcount != 1)
|| !matches!(message.question[0].qtype, Type::Type(RRType::SOA)) || !matches!(message.question[0].qtype, Type::Type(RRType::SOA))
{ {
response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); return Err(DNSError {
return response; message: "Qdcount not one".to_string(),
rcode: RCODE::FORMERR,
});
} }
// Check Zone authority // Check Zone authority
let zone = &message.question[0]; let zone = &message.question[0];
let zlen = zone.qname.len(); let zlen = zone.qname.len();
if !(zlen >= 2 && zone.qname[zlen - 1] == "gent" && zone.qname[zlen - 2] == "zeus") { if !(zlen >= 2 && zone.qname[zlen - 1] == "gent" && zone.qname[zlen - 2] == "zeus") {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); return Err(DNSError {
return response; message: "Invalid zone".to_string(),
rcode: RCODE::NOTAUTH,
});
} }
// Check Prerequisite TODO: implement this // Check Prerequisite TODO: implement this
@ -70,15 +74,19 @@ async fn handle_update(message: Message, bytes: &[u8]) -> Message {
//TODO: this code is ugly //TODO: this code is ugly
let last = message.additional.last(); let last = message.additional.last();
if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) { if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) {
let sig = Sig::new(last.unwrap(), bytes); let sig = Sig::new(last.unwrap(), bytes)?;
if !authenticate(&sig, &zone.qname).await.is_ok_and(|x| x) { if !authenticate(&sig, &zone.qname).await.is_ok_and(|x| x) {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); return Err(DNSError {
return response; message: "Unable to verify authentication".to_string(),
rcode: RCODE::NOTAUTH,
});
} }
} else { } else {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); return Err(DNSError {
return response; message: "No KEY record at the end of request found".to_string(),
rcode: RCODE::NOTAUTH,
});
} }
// Update Section Prescan // Update Section Prescan
@ -87,8 +95,10 @@ async fn handle_update(message: Message, bytes: &[u8]) -> Message {
// Check if rr has same zone // Check if rr has same zone
if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) { if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTZONE); return Err(DNSError {
return response; message: "RR has different zone from Question".to_string(),
rcode: RCODE::NOTZONE,
});
} }
if (rr.class == Class::Class(RRClass::ANY) && (rr.ttl != 0 || rr.rdlength != 0)) if (rr.class == Class::Class(RRClass::ANY) && (rr.ttl != 0 || rr.rdlength != 0))
@ -100,43 +110,50 @@ async fn handle_update(message: Message, bytes: &[u8]) -> Message {
] ]
.contains(&rr.class) .contains(&rr.class)
{ {
response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); return Err(DNSError {
return response; message: "RR has invalid rr,ttl or class".to_string(),
rcode: RCODE::FORMERR,
});
} }
} }
for rr in message.authority { for rr in &message.authority {
if rr.class == zone.qclass { if rr.class == zone.qclass {
let _ = insert_into_database(rr).await; let _ = insert_into_database(&rr).await;
} else if rr.class == Class::Class(RRClass::ANY) { } else if rr.class == Class::Class(RRClass::ANY) {
if rr._type == Type::Type(RRType::ANY) { if rr._type == Type::Type(RRType::ANY) {
if rr.name == zone.qname { if rr.name == zone.qname {
response.header.flags = return Err(DNSError {
set_response_flags(response.header.flags, RCODE::NOTIMP); message: "Not yet implemented".to_string(),
return response; rcode: RCODE::NOTIMP,
});
} else { } else {
delete_from_database(rr.name, None, Class::Class(RRClass::IN), None).await; delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await;
} }
} else { } else {
delete_from_database(rr.name, Some(rr._type), Class::Class(RRClass::IN), None) delete_from_database(
.await; &rr.name,
Some(rr._type.clone()),
Class::Class(RRClass::IN),
None,
)
.await;
} }
} else if rr.class == Class::Class(RRClass::NONE) { } else if rr.class == Class::Class(RRClass::NONE) {
if rr._type == Type::Type(RRType::SOA) { if rr._type == Type::Type(RRType::SOA) {
continue; continue;
} }
delete_from_database( delete_from_database(
rr.name, &rr.name,
Some(rr._type), Some(rr._type.clone()),
Class::Class(RRClass::IN), Class::Class(RRClass::IN),
Some(rr.rdata), Some(rr.rdata.clone()),
) )
.await; .await;
} }
} }
response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR); Ok(response)
response
} }
fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message {
@ -155,7 +172,7 @@ fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message {
header.ancount = 0; header.ancount = 0;
header.nscount = 0; header.nscount = 0;
header.arcount = 0; header.arcount = 0;
header.flags = set_response_flags(header.flags, RCODE::FORMERR); header.flags = set_response_flags(&header.flags, RCODE::FORMERR);
Message { Message {
header, header,
@ -169,11 +186,26 @@ fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message {
async fn get_response(bytes: &[u8]) -> Message { async fn get_response(bytes: &[u8]) -> Message {
let mut reader = Reader::new(bytes); let mut reader = Reader::new(bytes);
match Message::from_bytes(&mut reader) { match Message::from_bytes(&mut reader) {
Ok(message) => match get_opcode(&message.header.flags) { Ok(mut message) => match get_opcode(&message.header.flags) {
Ok(opcode) => match opcode { Ok(opcode) => {
Opcode::QUERY => handle_query(message).await, let result = match opcode {
Opcode::UPDATE => handle_update(message, bytes).await, Opcode::QUERY => handle_query(&message).await,
}, Opcode::UPDATE => handle_update(&message, bytes).await,
};
match result {
Ok(mut response) => {
response.header.flags =
set_response_flags(&response.header.flags, RCODE::NOERROR);
response
}
Err(e) => {
eprintln!("{}", e.to_string());
message.header.flags = set_response_flags(&message.header.flags, e.rcode);
message
}
}
}
Err(_) => todo!(), Err(_) => todo!(),
}, },
Err(err) => handle_parse_error(bytes, err), Err(err) => handle_parse_error(bytes, err),

View file

@ -1,6 +1,7 @@
use base64::prelude::*; use base64::prelude::*;
use crate::{ use crate::{
errors::ParseError,
parser::FromBytes, parser::FromBytes,
reader::Reader, reader::Reader,
structs::{KeyRData, RR}, structs::{KeyRData, RR},
@ -16,20 +17,20 @@ pub enum PublicKey {
} }
impl Sig { impl Sig {
pub fn new(rr: &RR, datagram: &[u8]) -> Sig { pub fn new(rr: &RR, datagram: &[u8]) -> Result<Sig, ParseError> {
let mut request = datagram[0..datagram.len() - 11 - rr.rdlength as usize].to_vec(); let mut request = datagram[0..datagram.len() - 11 - rr.rdlength as usize].to_vec();
request[11] -= 1; // Decrease arcount request[11] -= 1; // Decrease arcount
let mut reader = Reader::new(&rr.rdata); let mut reader = Reader::new(&rr.rdata);
let key_rdata = KeyRData::from_bytes(&mut reader).unwrap(); let key_rdata = KeyRData::from_bytes(&mut reader)?;
let mut raw_data = rr.rdata[0..rr.rdata.len() - key_rdata.signature.len()].to_vec(); let mut raw_data = rr.rdata[0..rr.rdata.len() - key_rdata.signature.len()].to_vec();
raw_data.extend(request); raw_data.extend(request);
Sig { Ok(Sig {
raw_data, raw_data,
key_rdata, key_rdata,
} })
} }
fn verify_ed25519(&self, key: String) -> bool { fn verify_ed25519(&self, key: String) -> bool {