10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-10-29 21:14:27 +01:00

respond refused if query is not in authoritative zone

This commit is contained in:
Xander Bil 2024-08-23 22:17:15 +02:00
parent 8e46045d4c
commit 4261d06248
No known key found for this signature in database
GPG key ID: EC9706B54A278598
10 changed files with 67 additions and 30 deletions

View file

@ -177,7 +177,7 @@ mod tests {
fn test() { fn test() {
let mut connection = get_test_connection(); let mut connection = get_test_connection();
let rr = get_rr(); let rr = get_rr(None);
let f = |connection: &mut PgConnection| { let f = |connection: &mut PgConnection| {
get_from_database( get_from_database(

View file

@ -5,7 +5,7 @@ use zns::{
structs::{Message, Question, RR}, structs::{Message, Question, RR},
}; };
use crate::db::models::get_from_database; use crate::{config::Config, db::models::get_from_database};
use super::ResponseHandler; use super::ResponseHandler;
@ -20,6 +20,8 @@ impl ResponseHandler for QueryHandler {
) -> Result<Message, ZNSError> { ) -> Result<Message, ZNSError> {
let mut response = message.clone(); let mut response = message.clone();
message.check_authoritative(&Config::get().authoritative_zone)?;
for question in &message.question { for question in &message.question {
let answers = get_from_database( let answers = get_from_database(
&question.qname, &question.qname,
@ -90,8 +92,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_handle_query() { async fn test_handle_query() {
let mut connection = get_test_connection(); let mut connection = get_test_connection();
let rr = get_rr(); let rr = get_rr(Some(Config::get().authoritative_zone.clone()));
let mut message = get_message(); let mut message = get_message(Some(Config::get().authoritative_zone.clone()));
message.header.ancount = 0; message.header.ancount = 0;
message.answer = vec![]; message.answer = vec![];

View file

@ -3,11 +3,10 @@ use diesel::PgConnection;
use crate::{ use crate::{
config::Config, config::Config,
db::models::{delete_from_database, insert_into_database}, db::models::{delete_from_database, insert_into_database},
utils::vec_equal,
}; };
use zns::errors::ZNSError;
use zns::structs::{Class, Message, RRClass, RRType, Type}; use zns::structs::{Class, Message, RRClass, RRType, Type};
use zns::{errors::ZNSError, utils::vec_equal};
use self::sig::Sig; use self::sig::Sig;
@ -37,18 +36,13 @@ impl ResponseHandler for UpdateHandler {
} }
// Check Zone authority // Check Zone authority
let zone = &message.question[0]; message.check_authoritative(&Config::get().authoritative_zone)?;
let zlen = zone.qname.len();
let auth_zone = &Config::get().authoritative_zone;
if !(zlen >= auth_zone.len() && vec_equal(&zone.qname[zlen - auth_zone.len()..], auth_zone))
{
return Err(ZNSError::Formerr {
message: "Invalid zone".to_string(),
});
}
// Check Prerequisite TODO: implement this // Check Prerequisite TODO: implement this
let zone = &message.question[0];
let zlen = zone.qname.len();
//TODO: this code is ugly //TODO: this code is ugly
let last = message.additional.last(); let last = message.additional.last();
if last.is_some() && last.unwrap()._type == Type::Type(RRType::SIG) { if last.is_some() && last.unwrap()._type == Type::Type(RRType::SIG) {

View file

@ -4,7 +4,6 @@ mod config;
mod db; mod db;
mod handlers; mod handlers;
mod resolver; mod resolver;
mod utils;
use config::Config; use config::Config;

View file

@ -104,6 +104,8 @@ pub async fn tcp_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
mod tests { mod tests {
use zns::structs::{Class, Question, RRClass, RRType, Type}; use zns::structs::{Class, Question, RRClass, RRType, Type};
use crate::config::Config;
use super::*; use super::*;
#[tokio::test] #[tokio::test]
@ -118,7 +120,7 @@ mod tests {
arcount: 0, arcount: 0,
}, },
question: vec![Question { question: vec![Question {
qname: vec![String::from("example"), String::from("org")], qname: Config::get().authoritative_zone.clone(),
qtype: Type::Type(RRType::A), qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN), qclass: Class::Class(RRClass::IN),
}], }],

View file

@ -5,3 +5,4 @@ pub mod reader;
pub mod structs; pub mod structs;
pub mod test_utils; pub mod test_utils;
pub mod utils;

View file

@ -1,4 +1,8 @@
use crate::structs::{Message, Opcode, RCODE}; use crate::{
errors::ZNSError,
structs::{LabelString, Message, Opcode, RCODE},
utils::vec_equal,
};
impl Message { impl Message {
pub fn set_response(&mut self, rcode: RCODE) { pub fn set_response(&mut self, rcode: RCODE) {
@ -10,16 +14,32 @@ impl Message {
Opcode::try_from((self.header.flags & 0b0111100000000000) >> 11) Opcode::try_from((self.header.flags & 0b0111100000000000) >> 11)
} }
#[allow(dead_code)] // Used with tests #[cfg(feature = "test-utils")]
pub fn get_rcode(&self) -> Result<RCODE, u16> { pub fn get_rcode(&self) -> Result<RCODE, u16> {
RCODE::try_from(self.header.flags & (!0 >> 12)) RCODE::try_from(self.header.flags & (!0 >> 12))
} }
pub fn check_authoritative(&self, auth_zone: &LabelString) -> Result<(), ZNSError> {
let authoritative = self.question.iter().all(|question| {
let zlen = question.qname.len();
zlen >= auth_zone.len()
&& vec_equal(&question.qname[zlen - auth_zone.len()..], auth_zone)
});
if !authoritative {
return Err(ZNSError::Refused {
message: "Not authoritative".to_string(),
});
}
Ok(())
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::structs::Header; use crate::{structs::Header, test_utils::get_message};
use super::*; use super::*;
@ -48,4 +68,19 @@ mod tests {
assert_eq!(message.get_rcode().unwrap(), RCODE::NOTIMP); assert_eq!(message.get_rcode().unwrap(), RCODE::NOTIMP);
} }
#[test]
fn test_not_authoritative() {
let message = get_message(Some(vec![
String::from("not"),
String::from("good"),
String::from("zone"),
]));
let zone = vec![String::from("good")];
assert!(message
.check_authoritative(&zone)
.is_err_and(|x| x.rcode() == RCODE::REFUSED));
}
} }

View file

@ -328,7 +328,7 @@ pub mod tests {
#[test] #[test]
fn test_parse_rr() { fn test_parse_rr() {
let rr = get_rr(); let rr = get_rr(None);
let bytes = RR::to_bytes(rr.clone()); let bytes = RR::to_bytes(rr.clone());
let parsed = RR::from_bytes(&mut Reader::new(&bytes)); let parsed = RR::from_bytes(&mut Reader::new(&bytes));
@ -392,7 +392,7 @@ pub mod tests {
#[test] #[test]
fn test_parse_message() { fn test_parse_message() {
let message = get_message(); let message = get_message(None);
let bytes = Message::to_bytes(message.clone()); let bytes = Message::to_bytes(message.clone());
let parsed = Message::from_bytes(&mut Reader::new(&bytes)); let parsed = Message::from_bytes(&mut Reader::new(&bytes));
assert!(parsed.is_ok()); assert!(parsed.is_ok());

View file

@ -2,9 +2,9 @@
use crate::structs::*; use crate::structs::*;
#[cfg(feature = "test-utils")] #[cfg(feature = "test-utils")]
pub fn get_rr() -> RR { pub fn get_rr(name: Option<LabelString>) -> RR {
RR { RR {
name: vec![String::from("example"), String::from("org")], name: name.unwrap_or(vec![String::from("example"), String::from("org")]),
_type: Type::Type(RRType::A), _type: Type::Type(RRType::A),
class: Class::Class(RRClass::IN), class: Class::Class(RRClass::IN),
ttl: 10, ttl: 10,
@ -13,7 +13,7 @@ pub fn get_rr() -> RR {
} }
} }
pub fn get_message() -> Message { pub fn get_message(name: Option<LabelString>) -> Message {
Message { Message {
header: Header { header: Header {
id: 1, id: 1,
@ -25,18 +25,22 @@ pub fn get_message() -> Message {
}, },
question: vec![ question: vec![
Question { Question {
qname: vec![String::from("example"), String::from("org")], qname: name
.clone()
.unwrap_or(vec![String::from("example"), String::from("org")]),
qtype: Type::Type(RRType::A), qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN), qclass: Class::Class(RRClass::IN),
}, },
Question { Question {
qname: vec![String::from("example"), String::from("org")], qname: name
.clone()
.unwrap_or(vec![String::from("example"), String::from("org")]),
qtype: Type::Type(RRType::A), qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN), qclass: Class::Class(RRClass::IN),
}, },
], ],
answer: vec![get_rr()], answer: vec![get_rr(name.clone())],
authority: vec![get_rr()], authority: vec![get_rr(name.clone())],
additional: vec![get_rr()], additional: vec![get_rr(name)],
} }
} }