diff --git a/src/db/models.rs b/src/db/models.rs index 1ebb22c..f9cf2c6 100644 --- a/src/db/models.rs +++ b/src/db/models.rs @@ -71,13 +71,13 @@ pub async fn insert_into_database(rr: RR) -> Result<(), DatabaseError> { Ok(()) } -pub async fn get_from_database(question: Question) -> Result { +pub async fn get_from_database(question: &Question) -> Result { let db_connection = &mut establish_connection(); let record = Record::get( db_connection, question.qname.join("."), - question.qtype as i32, - question.qclass as i32, + question.qtype.clone() as i32, + question.qclass.clone() as i32, ) .map_err(|e| DatabaseError { message: e.to_string(), diff --git a/src/main.rs b/src/main.rs index 4aff029..33e1517 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ mod errors; mod parser; mod resolver; mod structs; +mod utils; #[tokio::main] async fn main() -> Result<(), Box> { diff --git a/src/parser.rs b/src/parser.rs index 6fc708f..89111aa 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -76,7 +76,7 @@ impl Type { }) } } - Type::OPT => todo!(), + Type::SOA => todo!(), } } pub fn from_data(&self, bytes: &[u8]) -> Result { @@ -92,7 +92,7 @@ impl Type { }) } } - Type::OPT => unimplemented!() + Type::SOA => todo!(), } } } @@ -115,7 +115,6 @@ impl FromBytes for Header { arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()), }) } - } fn to_bytes(header: Self) -> Vec { @@ -130,7 +129,6 @@ impl FromBytes for Header { result.to_vec() } - } impl FromBytes for LabelString { @@ -139,8 +137,9 @@ impl FromBytes for LabelString { // Parse qname labels while bytes[*i] != 0 && bytes[*i] as usize + *i < bytes.len() { - qname - .push(String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap()); + qname.push( + String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap(), + ); *i += bytes[*i] as usize + 1; } @@ -177,20 +176,22 @@ impl FromBytes for Question { }) } else { //Try Parse qtype - let qtype = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) - .map_err(|_| ParseError { - object: String::from("Type"), - message: String::from("invalid"), - })?; - - //Try Parse qclass - let qclass = - Class::try_from(u16::from_be_bytes(bytes[*i + 2..*i + 4].try_into().unwrap())) + let qtype = + Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) .map_err(|_| ParseError { - object: String::from("Class"), + object: String::from("Type"), message: String::from("invalid"), })?; + //Try Parse qclass + let qclass = Class::try_from(u16::from_be_bytes( + bytes[*i + 2..*i + 4].try_into().unwrap(), + )) + .map_err(|_| ParseError { + object: String::from("Class"), + message: String::from("invalid"), + })?; + *i += 4; // For qtype and qclass => 4 bytes Ok(Question { @@ -221,16 +222,17 @@ impl FromBytes for RR { } else { let _type = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap())) .map_err(|_| ParseError { - object: String::from("Type"), - message: String::from("invalid"), - })?; + object: String::from("Type"), + message: String::from("invalid"), + })?; - let class = - Class::try_from(u16::from_be_bytes(bytes[*i + 2..*i + 4].try_into().unwrap())) - .map_err(|_| ParseError { - object: String::from("Class"), - message: String::from("invalid"), - })?; + let class = Class::try_from(u16::from_be_bytes( + bytes[*i + 2..*i + 4].try_into().unwrap(), + )) + .map_err(|_| ParseError { + object: String::from("Class"), + message: String::from("invalid"), + })?; let ttl = i32::from_be_bytes(bytes[*i + 4..*i + 8].try_into().unwrap()); let rdlength = u16::from_be_bytes(bytes[*i + 8..*i + 10].try_into().unwrap()); @@ -248,7 +250,7 @@ impl FromBytes for RR { class, ttl, rdlength, - rdata: bytes[*i - rdlength as usize.. *i].to_vec(), + rdata: bytes[*i - rdlength as usize..*i].to_vec(), }) } } @@ -267,30 +269,52 @@ impl FromBytes for RR { impl FromBytes for Message { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result { - let header = Header::from_bytes(&bytes,i)?; - let question = Question::from_bytes(&bytes,i)?; + let header = Header::from_bytes(&bytes, i)?; + + let mut question = vec![]; + for _ in 0..header.qdcount { + question.push(Question::from_bytes(&bytes, i)?); + } + + let mut answer = vec![]; + for _ in 0..header.ancount { + answer.push(RR::from_bytes(&bytes, i)?); + } + + let mut authority = vec![]; + for _ in 0..header.nscount { + authority.push(RR::from_bytes(&bytes, i)?); + } + + let mut additional = vec![]; + for _ in 0..header.nscount { + additional.push(RR::from_bytes(&bytes, i)?); + } Ok(Message { header, question, - answer: None, - authority: None, - additional: None, + answer, + authority, + additional, }) } fn to_bytes(message: Self) -> Vec { let mut result = vec![]; result.extend(Header::to_bytes(message.header)); - result.extend(Question::to_bytes(message.question)); - if message.answer.is_some() { - result.extend(RR::to_bytes(message.answer.unwrap())); + + for question in message.question { + result.extend(Question::to_bytes(question)); } - if message.authority.is_some() { - result.extend(RR::to_bytes(message.authority.unwrap())); + for answer in message.answer { + result.extend(RR::to_bytes(answer)); } - if message.additional.is_some() { - result.extend(RR::to_bytes(message.additional.unwrap())); + for auth in message.authority { + result.extend(RR::to_bytes(auth)); + } + for additional in message.additional { + result.extend(RR::to_bytes(additional)); } result } diff --git a/src/resolver.rs b/src/resolver.rs index 43b889c..9e2d676 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -5,45 +5,99 @@ use std::sync::Arc; use tokio::net::UdpSocket; use crate::db::models::get_from_database; -use crate::parser::{parse_opt_type, FromBytes}; -use crate::structs::{Message, Type, RR}; +use crate::parser::FromBytes; +use crate::structs::{Class, Message, Type, RCODE}; +use crate::utils::vec_equal; const MAX_DATAGRAM_SIZE: usize = 4096; -const OPTION_CODE: usize = 65001; -async fn handle_normal_question(message: Message) -> Message { +fn set_response_flags(flags: u16, rcode: RCODE) -> u16 { + (flags | 0b1000010000000000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111 +} + +async fn handle_query(message: Message) -> Message { let mut response = message.clone(); - println!("{:#?}",message.question); - let answer = get_from_database(message.question).await; - response.header.arcount = 0; + for question in message.question { + let answer = get_from_database(&question).await; - match answer { - Ok(rr) => { - response.header.flags |= 0b1000010110000000; - response.header.ancount = 1; - response.answer = Some(rr) - } - Err(e) => { - response.header.flags |= 0b1000010110000011; - eprintln!("{}", e); + match answer { + Ok(rr) => { + response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR); + response.header.ancount = 1; + response.answer = vec![rr] + } + Err(e) => { + response.header.flags |= 0b1000010110000011; + response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN); + eprintln!("{}", e); + } } } response } -async fn handle_opt_rr(rr: RR) { - let pairs = parse_opt_type(&rr.rdata); - println!("{:#?}", pairs) +async fn handle_update(message: Message) -> Message { + let mut response = message.clone(); + + // Zone section (question) processing + if (message.header.qdcount != 1) || !matches!(message.question[0].qtype, Type::SOA) { + response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); + return response; + } + + // Check Zone authority + let zlen = message.question[0].qname.len(); + if !(zlen >= 2 + && message.question[0].qname[zlen - 1] == "gent" + && message.question[0].qname[zlen - 2] == "zeus") + { + response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH); + return response; + } + + // Check Prerequisite TODO: implement this + if message.header.ancount > 0 { + response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP); + return response; + } + + // Check Requestor Permission + // TODO: implement this, use rfc2931 + + // Update Section Prescan + for rr in message.authority { + let rlen = rr.name.len(); + + // Check if rr has same zone + if rlen < zlen || !(vec_equal(&message.question[0].qname, &rr.name[rlen - zlen..])) { + response.header.flags = set_response_flags(response.header.flags, RCODE::NOTZONE); + return response; + } + + if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0)) + || (rr.class == Class::NONE && rr.ttl != 0) + || rr.class != message.question[0].qclass + { + response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); + return response; + } + + } + + response } -async fn get_response(message: Message) -> Message { - match message.question.qtype { - Type::OPT => handle_normal_question(message), - _ => handle_normal_question(message), +async fn get_response(bytes: &[u8]) -> Message { + let mut i: usize = 0; + match Message::from_bytes(bytes, &mut i) { + Ok(message) => handle_query(message).await, + Err(err) => { + println!("{}", err); + unimplemented!() //TODO: implement this + } } - .await } pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box> { @@ -51,18 +105,12 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box { - let socket = socket_shared.clone(); - tokio::spawn(async move { - let response = get_response(message).await; - let _ = socket - .send_to(Message::to_bytes(response).as_slice(), addr) - .await; - }); - } - Err(err) => println!("{}", err), - }; + let socket = socket_shared.clone(); + tokio::spawn(async move { + let response = get_response(&data[..len]).await; + let _ = socket + .send_to(Message::to_bytes(response).as_slice(), addr) + .await; + }); } } diff --git a/src/structs.rs b/src/structs.rs index 24cf428..f5ebf0d 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -4,13 +4,30 @@ use serde::Deserialize; #[derive(Debug, Clone, Deserialize)] pub enum Type { A = 1, - OPT = 41 + SOA = 6 } #[repr(u16)] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum Class { IN = 1, + NONE = 254, + ANY = 255 +} + +#[repr(u16)] +pub enum RCODE { + NOERROR = 0, + FORMERR = 1, + SERVFAIL = 2, + NXDOMAIN = 3, + NOTIMP = 4, + REFUSED = 5, + YXDOMAIN = 6, + YXRRSET = 7, + NXRRSET = 8, + NOTAUTH = 9, + NOTZONE = 10 } #[derive(Debug, Clone)] @@ -33,10 +50,10 @@ pub struct Header { #[derive(Debug, Clone)] pub struct Message { pub header: Header, - pub question: Question, - pub answer: Option, - pub authority: Option, - pub additional: Option, + pub question: Vec, + pub answer: Vec, + pub authority: Vec, + pub additional: Vec, } #[derive(Debug, Clone)] diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..785fb20 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,14 @@ + +pub fn vec_equal(vec1: &[T], vec2: &[T]) -> bool { + if vec1.len() != vec2.len() { + return false; + } + + for (elem1, elem2) in vec1.iter().zip(vec2.iter()) { + if elem1 != elem2 { + return false; + } + } + + true +}