From 79b040f5f9ca8bc5e6dcc95660974f9171918e99 Mon Sep 17 00:00:00 2001 From: Xander Bil Date: Mon, 25 Mar 2024 23:06:04 +0100 Subject: [PATCH] Handle parser errors --- src/resolver.rs | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/src/resolver.rs b/src/resolver.rs index bfaa4da..823535d 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -5,8 +5,9 @@ use std::sync::Arc; use tokio::net::UdpSocket; use crate::db::models::{get_from_database, insert_into_database}; +use crate::errors::ParseError; use crate::parser::FromBytes; -use crate::structs::{Class, Message, Type, RCODE, RR, Opcode}; +use crate::structs::{Class, Header, Message, Opcode, Type, RCODE}; use crate::utils::vec_equal; const MAX_DATAGRAM_SIZE: usize = 4096; @@ -15,7 +16,6 @@ fn set_response_flags(flags: u16, rcode: RCODE) -> u16 { (flags | 0b1_0000_1_0_0_0_000_0000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111 } - fn get_opcode(flags: &u16) -> Result { Opcode::try_from((flags & 0b0111100000000000) >> 11) } @@ -55,10 +55,7 @@ async fn handle_update(message: Message) -> Message { // Check Zone authority let zone = &message.question[0]; let zlen = zone.qname.len(); - if !(zlen >= 2 - && zone.qname[zlen - 1] == "gent" - && zone.qname[zlen - 2] == "zeus") - { + if !(zlen >= 2 && zone.qname[zlen - 1] == "gent" && zone.qname[zlen - 2] == "zeus") { response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); return response; } @@ -89,7 +86,6 @@ async fn handle_update(message: Message) -> Message { response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); return response; } - } for rr in message.authority { @@ -108,6 +104,32 @@ async fn handle_update(message: Message) -> Message { response } +fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { + eprintln!("{}", err); + let mut header = Header::from_bytes(bytes, &mut 0).unwrap_or(Header { + id: 0, + flags: 0, + qdcount: 0, + ancount: 0, + nscount: 0, + arcount: 0, + }); + + header.qdcount = 0; + header.ancount = 0; + header.nscount = 0; + header.arcount = 0; + header.flags = set_response_flags(header.flags, RCODE::FORMERR); + + Message { + header, + question: vec![], + answer: vec![], + authority: vec![], + additional: vec![], + } +} + async fn get_response(bytes: &[u8]) -> Message { let mut i: usize = 0; match Message::from_bytes(bytes, &mut i) { @@ -118,10 +140,7 @@ async fn get_response(bytes: &[u8]) -> Message { }, Err(_) => todo!(), }, - Err(err) => { - println!("{}", err); - unimplemented!() //TODO: implement this - } + Err(err) => handle_parse_error(bytes, err), } }