diff --git a/src/parser.rs b/src/parser.rs index 89111aa..ef73d46 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2,7 +2,7 @@ use std::{mem::size_of, vec}; use crate::{ errors::ParseError, - structs::{Class, Header, LabelString, Message, OptRR, Question, Type, RR}, + structs::{Class, Header, LabelString, Message, Opcode, OptRR, Question, Type, RR}, }; type Result = std::result::Result; @@ -13,6 +13,8 @@ impl TryFrom for Type { fn try_from(value: u16) -> std::result::Result { match value { x if x == Type::A as u16 => Ok(Type::A), + x if x == Type::OPT as u16 => Ok(Type::OPT), + x if x == Type::SOA as u16 => Ok(Type::SOA), _ => Err(format!("Invalid Type value: {}", value)), } } @@ -29,6 +31,18 @@ impl TryFrom for Class { } } +impl TryFrom for Opcode { + type Error = String; + + fn try_from(value: u16) -> std::result::Result { + match value { + x if x == Opcode::QUERY as u16 => Ok(Opcode::QUERY), + x if x == Opcode::UPDATE as u16 => Ok(Opcode::UPDATE), + _ => Err(format!("Invalid Opcode value: {}", value)), + } + } +} + pub trait FromBytes { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result where @@ -77,6 +91,7 @@ impl Type { } } Type::SOA => todo!(), + Type::OPT => todo!(), } } pub fn from_data(&self, bytes: &[u8]) -> Result { @@ -93,6 +108,7 @@ impl Type { } } Type::SOA => todo!(), + Type::OPT => todo!(), } } } @@ -136,13 +152,29 @@ impl FromBytes for LabelString { let mut qname = vec![]; // Parse qname labels - while bytes[*i] != 0 && bytes[*i] as usize + *i < bytes.len() { + while bytes[*i] != 0 + && (bytes[*i] & 0b11000000 == 0) + && bytes[*i] as usize + *i < bytes.len() + { qname.push( String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap(), ); *i += bytes[*i] as usize + 1; } + if bytes[*i] & 0b11000000 != 0 { + let offset = u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()) & 0b00111111; + if *i <= offset as usize { + return Err(ParseError { + object: String::from("Label"), + message: String::from("Invalid PTR"), + }); + } else { + qname.extend(LabelString::from_bytes(bytes, &mut (offset as usize))?); + *i += 1; + } + } + *i += 1; Ok(qname) } @@ -178,18 +210,18 @@ impl FromBytes for Question { //Try Parse qtype let qtype = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) - .map_err(|_| ParseError { + .map_err(|e| ParseError { object: String::from("Type"), - message: String::from("invalid"), + message: e, })?; //Try Parse qclass let qclass = Class::try_from(u16::from_be_bytes( bytes[*i + 2..*i + 4].try_into().unwrap(), )) - .map_err(|_| ParseError { + .map_err(|e| ParseError { object: String::from("Class"), - message: String::from("invalid"), + message: e, })?; *i += 4; // For qtype and qclass => 4 bytes @@ -214,6 +246,7 @@ impl FromBytes for Question { impl FromBytes for RR { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { let name = LabelString::from_bytes(bytes, i)?; + println!("{:#?}", name); if bytes.len() - *i < size_of::() + size_of::() + 6 { Err(ParseError { object: String::from("RR"), @@ -221,17 +254,17 @@ impl FromBytes for RR { }) } else { let _type = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) - .map_err(|_| ParseError { + .map_err(|e| ParseError { object: String::from("Type"), - message: String::from("invalid"), + message: e, })?; let class = Class::try_from(u16::from_be_bytes( bytes[*i + 2..*i + 4].try_into().unwrap(), )) - .map_err(|_| ParseError { + .map_err(|e| ParseError { object: String::from("Class"), - message: String::from("invalid"), + message: e, })?; let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap()); @@ -287,7 +320,7 @@ impl FromBytes for Message { } let mut additional = vec![]; - for _ in 0..header.nscount { + for _ in 0..header.arcount { additional.push(RR::from_bytes(&bytes, i)?); } diff --git a/src/resolver.rs b/src/resolver.rs index 9e2d676..bfaa4da 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -4,19 +4,25 @@ use std::sync::Arc; use tokio::net::UdpSocket; -use crate::db::models::get_from_database; +use crate::db::models::{get_from_database, insert_into_database}; use crate::parser::FromBytes; -use crate::structs::{Class, Message, Type, RCODE}; +use crate::structs::{Class, Message, Type, RCODE, RR, Opcode}; use crate::utils::vec_equal; const MAX_DATAGRAM_SIZE: usize = 4096; fn set_response_flags(flags: u16, rcode: RCODE) -> u16 { - (flags | 0b1000010000000000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111 + (flags | 0b1_0000_1_0_0_0_000_0000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111 +} + + +fn get_opcode(flags: &u16) -> Result { + Opcode::try_from((flags & 0b0111100000000000) >> 11) } async fn handle_query(message: Message) -> Message { let mut response = message.clone(); + response.header.arcount = 0; //TODO: fix this, handle unknown class values for question in message.question { let answer = get_from_database(&question).await; @@ -28,7 +34,6 @@ async fn handle_query(message: Message) -> Message { response.answer = vec![rr] } Err(e) => { - response.header.flags |= 0b1000010110000011; response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN); eprintln!("{}", e); } @@ -48,10 +53,11 @@ async fn handle_update(message: Message) -> Message { } // Check Zone authority - let zlen = message.question[0].qname.len(); + let zone = &message.question[0]; + let zlen = zone.qname.len(); if !(zlen >= 2 - && message.question[0].qname[zlen - 1] == "gent" - && message.question[0].qname[zlen - 2] == "zeus") + && zone.qname[zlen - 1] == "gent" + && zone.qname[zlen - 2] == "zeus") { response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); return response; @@ -67,18 +73,18 @@ async fn handle_update(message: Message) -> Message { // TODO: implement this, use rfc2931 // Update Section Prescan - for rr in message.authority { + for rr in &message.authority { let rlen = rr.name.len(); // Check if rr has same zone - if rlen < zlen || !(vec_equal(&message.question[0].qname, &rr.name[rlen - zlen..])) { + if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) { response.header.flags = set_response_flags(response.header.flags, RCODE::NOTZONE); return response; } if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0)) || (rr.class == Class::NONE && rr.ttl != 0) - || rr.class != message.question[0].qclass + || rr.class != zone.qclass { response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); return response; @@ -86,13 +92,32 @@ async fn handle_update(message: Message) -> Message { } + for rr in message.authority { + if rr.class == zone.qclass { + insert_into_database(rr).await; + } else if rr.class == Class::ANY { + response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP); + return response; + } else if rr.class == Class::ANY { + response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP); + return response; + } + } + + response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR); response } async fn get_response(bytes: &[u8]) -> Message { let mut i: usize = 0; match Message::from_bytes(bytes, &mut i) { - Ok(message) => handle_query(message).await, + Ok(message) => match get_opcode(&message.header.flags) { + Ok(opcode) => match opcode { + Opcode::QUERY => handle_query(message).await, + Opcode::UPDATE => handle_update(message).await, + }, + Err(_) => todo!(), + }, Err(err) => { println!("{}", err); unimplemented!() //TODO: implement this diff --git a/src/structs.rs b/src/structs.rs index f5ebf0d..3e85e15 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -4,7 +4,8 @@ use serde::Deserialize; #[derive(Debug, Clone, Deserialize)] pub enum Type { A = 1, - SOA = 6 + SOA = 6, + OPT = 41 } #[repr(u16)] @@ -30,6 +31,11 @@ pub enum RCODE { NOTZONE = 10 } +pub enum Opcode { + QUERY = 0, + UPDATE = 5 +} + #[derive(Debug, Clone)] pub struct Question { pub qname: Vec,