10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-11-24 22:11:10 +01:00

Implemented dns Update checks

This commit is contained in:
Xander Bil 2024-03-20 23:55:53 +01:00
parent 6dd3f23815
commit 68ce89c5a2
No known key found for this signature in database
GPG key ID: EC9706B54A278598
6 changed files with 188 additions and 84 deletions

View file

@ -71,13 +71,13 @@ pub async fn insert_into_database(rr: RR) -> Result<(), DatabaseError> {
Ok(())
}
pub async fn get_from_database(question: Question) -> Result<RR, DatabaseError> {
pub async fn get_from_database(question: &Question) -> Result<RR, DatabaseError> {
let db_connection = &mut establish_connection();
let record = Record::get(
db_connection,
question.qname.join("."),
question.qtype as i32,
question.qclass as i32,
question.qtype.clone() as i32,
question.qclass.clone() as i32,
)
.map_err(|e| DatabaseError {
message: e.to_string(),

View file

@ -8,6 +8,7 @@ mod errors;
mod parser;
mod resolver;
mod structs;
mod utils;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {

View file

@ -76,7 +76,7 @@ impl Type {
})
}
}
Type::OPT => todo!(),
Type::SOA => todo!(),
}
}
pub fn from_data(&self, bytes: &[u8]) -> Result<String> {
@ -92,7 +92,7 @@ impl Type {
})
}
}
Type::OPT => unimplemented!()
Type::SOA => todo!(),
}
}
}
@ -115,7 +115,6 @@ impl FromBytes for Header {
arcount: u16::from_be_bytes(bytes[10..12].try_into().unwrap()),
})
}
}
fn to_bytes(header: Self) -> Vec<u8> {
@ -130,7 +129,6 @@ impl FromBytes for Header {
result.to_vec()
}
}
impl FromBytes for LabelString {
@ -139,8 +137,9 @@ impl FromBytes for LabelString {
// Parse qname labels
while bytes[*i] != 0 && bytes[*i] as usize + *i < bytes.len() {
qname
.push(String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap());
qname.push(
String::from_utf8(bytes[*i + 1..bytes[*i] as usize + 1 + *i].to_vec()).unwrap(),
);
*i += bytes[*i] as usize + 1;
}
@ -177,15 +176,17 @@ impl FromBytes for Question {
})
} else {
//Try Parse qtype
let qtype = Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
let qtype =
Type::try_from(u16::from_be_bytes(bytes[*i..*i + 2].try_into().unwrap()))
.map_err(|_| ParseError {
object: String::from("Type"),
message: String::from("invalid"),
})?;
//Try Parse qclass
let qclass =
Class::try_from(u16::from_be_bytes(bytes[*i + 2..*i + 4].try_into().unwrap()))
let qclass = Class::try_from(u16::from_be_bytes(
bytes[*i + 2..*i + 4].try_into().unwrap(),
))
.map_err(|_| ParseError {
object: String::from("Class"),
message: String::from("invalid"),
@ -225,8 +226,9 @@ impl FromBytes for RR {
message: String::from("invalid"),
})?;
let class =
Class::try_from(u16::from_be_bytes(bytes[*i + 2..*i + 4].try_into().unwrap()))
let class = Class::try_from(u16::from_be_bytes(
bytes[*i + 2..*i + 4].try_into().unwrap(),
))
.map_err(|_| ParseError {
object: String::from("Class"),
message: String::from("invalid"),
@ -268,29 +270,51 @@ impl FromBytes for RR {
impl FromBytes for Message {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
let header = Header::from_bytes(&bytes, i)?;
let question = Question::from_bytes(&bytes,i)?;
let mut question = vec![];
for _ in 0..header.qdcount {
question.push(Question::from_bytes(&bytes, i)?);
}
let mut answer = vec![];
for _ in 0..header.ancount {
answer.push(RR::from_bytes(&bytes, i)?);
}
let mut authority = vec![];
for _ in 0..header.nscount {
authority.push(RR::from_bytes(&bytes, i)?);
}
let mut additional = vec![];
for _ in 0..header.nscount {
additional.push(RR::from_bytes(&bytes, i)?);
}
Ok(Message {
header,
question,
answer: None,
authority: None,
additional: None,
answer,
authority,
additional,
})
}
fn to_bytes(message: Self) -> Vec<u8> {
let mut result = vec![];
result.extend(Header::to_bytes(message.header));
result.extend(Question::to_bytes(message.question));
if message.answer.is_some() {
result.extend(RR::to_bytes(message.answer.unwrap()));
for question in message.question {
result.extend(Question::to_bytes(question));
}
if message.authority.is_some() {
result.extend(RR::to_bytes(message.authority.unwrap()));
for answer in message.answer {
result.extend(RR::to_bytes(answer));
}
if message.additional.is_some() {
result.extend(RR::to_bytes(message.additional.unwrap()));
for auth in message.authority {
result.extend(RR::to_bytes(auth));
}
for additional in message.additional {
result.extend(RR::to_bytes(additional));
}
result
}

View file

@ -5,45 +5,99 @@ use std::sync::Arc;
use tokio::net::UdpSocket;
use crate::db::models::get_from_database;
use crate::parser::{parse_opt_type, FromBytes};
use crate::structs::{Message, Type, RR};
use crate::parser::FromBytes;
use crate::structs::{Class, Message, Type, RCODE};
use crate::utils::vec_equal;
const MAX_DATAGRAM_SIZE: usize = 4096;
const OPTION_CODE: usize = 65001;
async fn handle_normal_question(message: Message) -> Message {
fn set_response_flags(flags: u16, rcode: RCODE) -> u16 {
(flags | 0b1000010000000000 | rcode as u16) & 0b1_1111_1_0_1_0_111_1111
}
async fn handle_query(message: Message) -> Message {
let mut response = message.clone();
println!("{:#?}",message.question);
let answer = get_from_database(message.question).await;
response.header.arcount = 0;
for question in message.question {
let answer = get_from_database(&question).await;
match answer {
Ok(rr) => {
response.header.flags |= 0b1000010110000000;
response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR);
response.header.ancount = 1;
response.answer = Some(rr)
response.answer = vec![rr]
}
Err(e) => {
response.header.flags |= 0b1000010110000011;
response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN);
eprintln!("{}", e);
}
}
}
response
}
async fn handle_opt_rr(rr: RR) {
let pairs = parse_opt_type(&rr.rdata);
println!("{:#?}", pairs)
async fn handle_update(message: Message) -> Message {
let mut response = message.clone();
// Zone section (question) processing
if (message.header.qdcount != 1) || !matches!(message.question[0].qtype, Type::SOA) {
response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR);
return response;
}
async fn get_response(message: Message) -> Message {
match message.question.qtype {
Type::OPT => handle_normal_question(message),
_ => handle_normal_question(message),
// Check Zone authority
let zlen = message.question[0].qname.len();
if !(zlen >= 2
&& message.question[0].qname[zlen - 1] == "gent"
&& message.question[0].qname[zlen - 2] == "zeus")
{
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTAUTH);
return response;
}
// Check Prerequisite TODO: implement this
if message.header.ancount > 0 {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP);
return response;
}
// Check Requestor Permission
// TODO: implement this, use rfc2931
// Update Section Prescan
for rr in message.authority {
let rlen = rr.name.len();
// Check if rr has same zone
if rlen < zlen || !(vec_equal(&message.question[0].qname, &rr.name[rlen - zlen..])) {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTZONE);
return response;
}
if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0))
|| (rr.class == Class::NONE && rr.ttl != 0)
|| rr.class != message.question[0].qclass
{
response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR);
return response;
}
}
response
}
async fn get_response(bytes: &[u8]) -> Message {
let mut i: usize = 0;
match Message::from_bytes(bytes, &mut i) {
Ok(message) => handle_query(message).await,
Err(err) => {
println!("{}", err);
unimplemented!() //TODO: implement this
}
}
.await
}
pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
@ -51,18 +105,12 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Erro
loop {
let mut data = vec![0u8; MAX_DATAGRAM_SIZE];
let (len, addr) = socket_shared.recv_from(&mut data).await?;
let mut i: usize = 0;
match Message::from_bytes(&data[..len], &mut i) {
Ok(message) => {
let socket = socket_shared.clone();
tokio::spawn(async move {
let response = get_response(message).await;
let response = get_response(&data[..len]).await;
let _ = socket
.send_to(Message::to_bytes(response).as_slice(), addr)
.await;
});
}
Err(err) => println!("{}", err),
};
}
}

View file

@ -4,13 +4,30 @@ use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub enum Type {
A = 1,
OPT = 41
SOA = 6
}
#[repr(u16)]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub enum Class {
IN = 1,
NONE = 254,
ANY = 255
}
#[repr(u16)]
pub enum RCODE {
NOERROR = 0,
FORMERR = 1,
SERVFAIL = 2,
NXDOMAIN = 3,
NOTIMP = 4,
REFUSED = 5,
YXDOMAIN = 6,
YXRRSET = 7,
NXRRSET = 8,
NOTAUTH = 9,
NOTZONE = 10
}
#[derive(Debug, Clone)]
@ -33,10 +50,10 @@ pub struct Header {
#[derive(Debug, Clone)]
pub struct Message {
pub header: Header,
pub question: Question,
pub answer: Option<RR>,
pub authority: Option<RR>,
pub additional: Option<RR>,
pub question: Vec<Question>,
pub answer: Vec<RR>,
pub authority: Vec<RR>,
pub additional: Vec<RR>,
}
#[derive(Debug, Clone)]

14
src/utils.rs Normal file
View file

@ -0,0 +1,14 @@
pub fn vec_equal<T: PartialEq>(vec1: &[T], vec2: &[T]) -> bool {
if vec1.len() != vec2.len() {
return false;
}
for (elem1, elem2) in vec1.iter().zip(vec2.iter()) {
if elem1 != elem2 {
return false;
}
}
true
}