10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-11-21 21:41:10 +01:00

Respond with fixed SOA of zones

This commit is contained in:
Xander Bil 2024-09-26 00:00:51 +02:00
parent 4939d2b3e1
commit c9767db879
No known key found for this signature in database
GPG key ID: EC9706B54A278598
13 changed files with 209 additions and 95 deletions

View file

@ -1,13 +1,14 @@
use std::{env, net::IpAddr, sync::OnceLock}; use std::{env, net::IpAddr, sync::OnceLock};
use dotenvy::dotenv; use dotenvy::dotenv;
use zns::labelstring::LabelString;
static CONFIG: OnceLock<Config> = OnceLock::new(); static CONFIG: OnceLock<Config> = OnceLock::new();
pub struct Config { pub struct Config {
pub zauth_url: String, pub zauth_url: String,
pub db_uri: String, pub db_uri: String,
pub authoritative_zone: Vec<String>, pub authoritative_zone: LabelString,
pub port: u16, pub port: u16,
pub address: IpAddr, pub address: IpAddr,
} }
@ -25,11 +26,7 @@ impl Config {
Config { Config {
db_uri: env::var("DATABASE_URL").expect("DATABASE_URL must be set"), db_uri: env::var("DATABASE_URL").expect("DATABASE_URL must be set"),
zauth_url: env::var("ZAUTH_URL").expect("ZAUTH_URL must be set"), zauth_url: env::var("ZAUTH_URL").expect("ZAUTH_URL must be set"),
authoritative_zone: env::var("ZONE") authoritative_zone: LabelString::from(&env::var("ZONE").expect("ZONE must be set")),
.expect("ZONE must be set")
.split('.')
.map(str::to_string)
.collect(),
port: env::var("ZNS_PORT") port: env::var("ZNS_PORT")
.map(|v| v.parse::<u16>().expect("ZNS_PORT is invalid")) .map(|v| v.parse::<u16>().expect("ZNS_PORT is invalid"))
.unwrap_or(5333), .unwrap_or(5333),

View file

@ -2,6 +2,7 @@ use diesel::prelude::*;
use diesel::sql_types::Text; use diesel::sql_types::Text;
use zns::{ use zns::{
errors::ZNSError, errors::ZNSError,
labelstring::LabelString,
structs::{Class, Type, RR}, structs::{Class, Type, RR},
}; };
@ -103,7 +104,7 @@ pub fn insert_into_database(rr: &RR, connection: &mut PgConnection) -> Result<()
} }
let record = Record { let record = Record {
name: rr.name.join("."), name: rr.name.to_string(),
_type: rr._type.clone().into(), _type: rr._type.clone().into(),
class: rr.class.clone().into(), class: rr.class.clone().into(),
ttl: rr.ttl, ttl: rr.ttl,
@ -119,14 +120,14 @@ pub fn insert_into_database(rr: &RR, connection: &mut PgConnection) -> Result<()
} }
pub fn get_from_database( pub fn get_from_database(
name: &[String], name: &LabelString,
_type: Option<Type>, _type: Option<Type>,
class: Class, class: Class,
connection: &mut PgConnection, connection: &mut PgConnection,
) -> Result<Vec<RR>, ZNSError> { ) -> Result<Vec<RR>, ZNSError> {
let records = Record::get( let records = Record::get(
connection, connection,
name.join("."), name.to_string(),
_type.map(|t| t.into()), _type.map(|t| t.into()),
class.into(), class.into(),
) )
@ -137,7 +138,7 @@ pub fn get_from_database(
Ok(records Ok(records
.into_iter() .into_iter()
.map(|record| RR { .map(|record| RR {
name: record.name.split('.').map(str::to_string).collect(), name: LabelString::from(&record.name),
_type: Type::from(record._type as u16), _type: Type::from(record._type as u16),
class: Class::from(record.class as u16), class: Class::from(record.class as u16),
ttl: record.ttl, ttl: record.ttl,
@ -149,7 +150,7 @@ pub fn get_from_database(
//TODO: cleanup models //TODO: cleanup models
pub fn delete_from_database( pub fn delete_from_database(
name: &[String], name: &LabelString,
_type: Option<Type>, _type: Option<Type>,
class: Class, class: Class,
rdata: Option<Vec<u8>>, rdata: Option<Vec<u8>>,
@ -157,7 +158,7 @@ pub fn delete_from_database(
) { ) {
let _ = Record::delete( let _ = Record::delete(
connection, connection,
name.join("."), name.to_string(),
_type.map(|f| f.into()), _type.map(|f| f.into()),
class.into(), class.into(),
rdata, rdata,

View file

@ -2,7 +2,9 @@ use diesel::PgConnection;
use zns::{ use zns::{
errors::ZNSError, errors::ZNSError,
structs::{Message, Question, RR}, labelstring::LabelString,
parser::ToBytes,
structs::{Class, Message, Question, RRClass, RRType, SoaRData, Type, RR},
}; };
use crate::{config::Config, db::models::get_from_database}; use crate::{config::Config, db::models::get_from_database};
@ -35,10 +37,14 @@ impl ResponseHandler for QueryHandler {
if rrs.is_empty() { if rrs.is_empty() {
rrs.extend(try_wildcard(question, connection)?); rrs.extend(try_wildcard(question, connection)?);
if rrs.is_empty() { if rrs.is_empty() {
return Err(ZNSError::NXDomain { if question.qtype == Type::Type(RRType::SOA) {
domain: question.qname.join("."), rrs.extend([get_soa(&question.qname)?])
qtype: question.qtype.clone(), } else {
}); return Err(ZNSError::NXDomain {
domain: question.qname.to_string(),
qtype: question.qtype.clone(),
});
}
} }
} }
response.header.ancount += rrs.len() as u16; response.header.ancount += rrs.len() as u16;
@ -59,13 +65,13 @@ impl ResponseHandler for QueryHandler {
fn try_wildcard(question: &Question, connection: &mut PgConnection) -> Result<Vec<RR>, ZNSError> { fn try_wildcard(question: &Question, connection: &mut PgConnection) -> Result<Vec<RR>, ZNSError> {
let records = get_from_database(&question.qname, None, question.qclass.clone(), connection)?; let records = get_from_database(&question.qname, None, question.qclass.clone(), connection)?;
if !records.is_empty() || question.qname.is_empty() { if !records.is_empty() || question.qname.as_slice().is_empty() {
Ok(vec![]) Ok(vec![])
} else { } else {
let mut qname = question.qname.clone(); let qname = question.qname.clone().to_vec();
qname[0] = String::from("*"); qname.to_vec()[0] = String::from("*");
Ok(get_from_database( Ok(get_from_database(
&qname, &qname.into(),
Some(question.qtype.clone()), Some(question.qtype.clone()),
question.qclass.clone(), question.qclass.clone(),
connection, connection,
@ -79,6 +85,47 @@ fn try_wildcard(question: &Question, connection: &mut PgConnection) -> Result<Ve
} }
} }
fn get_soa(name: &LabelString) -> Result<RR, ZNSError> {
let auth_zone = Config::get().authoritative_zone.clone();
let rdata = if &Config::get().authoritative_zone == name {
// Recommended values taken from wikipedia: https://en.wikipedia.org/wiki/SOA_record
Ok(SoaRData {
mname: auth_zone,
rname: LabelString::from("admin.zeus.ugent.be"),
serial: 1,
refresh: 86400,
retry: 7200,
expire: 3600000,
minimum: 172800,
})
} else if name.len() > auth_zone.len() {
let zone: LabelString = name.as_slice()[name.len() - auth_zone.len() - 1..].into();
Ok(SoaRData {
mname: zone.clone(),
rname: LabelString::from(&format!("{}.zeus.ugent.be", zone.as_slice()[0])),
serial: 1,
refresh: 86400,
retry: 7200,
expire: 3600000,
minimum: 172800,
})
} else {
Err(ZNSError::NXDomain {
domain: name.to_string(),
qtype: Type::Type(RRType::SOA),
})
}?;
Ok(RR {
name: name.to_owned(),
_type: Type::Type(RRType::SOA),
class: Class::Class(RRClass::IN),
ttl: 11200,
rdlength: 0,
rdata: SoaRData::to_bytes(rdata),
})
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -5,9 +5,10 @@ use crate::{config::Config, db::models::get_from_database};
use zns::{ use zns::{
errors::ZNSError, errors::ZNSError,
labelstring::LabelString,
parser::FromBytes, parser::FromBytes,
reader::Reader, reader::Reader,
structs::{Class, LabelString, RRClass, RRType, Type}, structs::{Class, RRClass, RRType, Type},
}; };
use super::{dnskey::DNSKeyRData, sig::Sig}; use super::{dnskey::DNSKeyRData, sig::Sig};
@ -17,8 +18,9 @@ pub async fn authenticate(
zone: &LabelString, zone: &LabelString,
connection: &mut PgConnection, connection: &mut PgConnection,
) -> Result<bool, ZNSError> { ) -> Result<bool, ZNSError> {
if zone.len() > Config::get().authoritative_zone.len() { if zone.as_slice().len() > Config::get().authoritative_zone.as_slice().len() {
let username = &zone[zone.len() - Config::get().authoritative_zone.len() - 1]; let username = &zone.as_slice()
[zone.as_slice().len() - Config::get().authoritative_zone.as_slice().len() - 1];
let ssh_verified = validate_ssh(&username.to_lowercase(), sig) let ssh_verified = validate_ssh(&username.to_lowercase(), sig)
.await .await
@ -62,7 +64,7 @@ async fn validate_ssh(username: &String, sig: &Sig) -> Result<bool, reqwest::Err
} }
async fn validate_dnskey( async fn validate_dnskey(
zone: &[String], zone: &LabelString,
sig: &Sig, sig: &Sig,
connection: &mut PgConnection, connection: &mut PgConnection,
) -> Result<bool, ZNSError> { ) -> Result<bool, ZNSError> {

View file

@ -5,8 +5,8 @@ use crate::{
db::models::{delete_from_database, insert_into_database}, db::models::{delete_from_database, insert_into_database},
}; };
use zns::errors::ZNSError;
use zns::structs::{Class, Message, RRClass, RRType, Type}; use zns::structs::{Class, Message, RRClass, RRType, Type};
use zns::{errors::ZNSError, utils::labels_equal};
use self::sig::Sig; use self::sig::Sig;
@ -41,7 +41,7 @@ impl ResponseHandler for UpdateHandler {
// Check Prerequisite TODO: implement this // Check Prerequisite TODO: implement this
let zone = &message.question[0]; let zone = &message.question[0];
let zlen = zone.qname.len(); let zlen = zone.qname.as_slice().len();
//TODO: this code is ugly //TODO: this code is ugly
let last = message.additional.last(); let last = message.additional.last();
@ -61,10 +61,10 @@ impl ResponseHandler for UpdateHandler {
// Update Section Prescan // Update Section Prescan
for rr in &message.authority { for rr in &message.authority {
let rlen = rr.name.len(); let rlen = rr.name.as_slice().len();
// Check if rr has same zone // Check if rr has same zone
if rlen < zlen || !(labels_equal(&zone.qname, &rr.name[rlen - zlen..].into())) { if rlen < zlen || !(&zone.qname == &rr.name.as_slice()[rlen - zlen..].into()) {
return Err(ZNSError::Refused { return Err(ZNSError::Refused {
message: "RR has different zone from Question".to_string(), message: "RR has different zone from Question".to_string(),
}); });

View file

@ -4,10 +4,7 @@ use base64::prelude::*;
use int_enum::IntEnum; use int_enum::IntEnum;
use zns::{ use zns::{
errors::ZNSError, errors::ZNSError, labelstring::LabelString, parser::FromBytes, reader::Reader, structs::RR,
parser::FromBytes,
reader::Reader,
structs::{LabelString, RR},
}; };
use super::{ use super::{

85
zns/src/labelstring.rs Normal file
View file

@ -0,0 +1,85 @@
use std::fmt::Display;
#[derive(Debug, Clone)]
pub struct LabelString(Vec<String>);
pub fn labels_equal(vec1: &LabelString, vec2: &LabelString) -> bool {
if vec1.as_slice().len() != vec2.as_slice().len() {
return false;
}
for (elem1, elem2) in vec1.as_slice().iter().zip(vec2.as_slice().iter()) {
if elem1.to_lowercase() != elem2.to_lowercase() {
return false;
}
}
true
}
impl LabelString {
pub fn from(string: &str) -> Self {
LabelString(string.split('.').map(str::to_string).collect())
}
pub fn as_slice(&self) -> &[String] {
self.0.as_slice()
}
pub fn to_vec(self) -> Vec<String> {
self.0
}
pub fn len(&self) -> usize {
self.0.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl PartialEq for LabelString {
fn eq(&self, other: &Self) -> bool {
labels_equal(self, other)
}
}
impl From<&[String]> for LabelString {
fn from(value: &[String]) -> Self {
LabelString(value.to_vec())
}
}
impl From<Vec<String>> for LabelString {
fn from(value: Vec<String>) -> Self {
LabelString(value)
}
}
impl Display for LabelString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.join("."))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_labels_equal() {
assert!(labels_equal(
&LabelString::from("one.two"),
&LabelString::from("oNE.two")
));
assert!(!labels_equal(
&LabelString::from("onne.two"),
&LabelString::from("oNEe.two")
));
}
}

View file

@ -1,8 +1,8 @@
pub mod errors; pub mod errors;
pub mod labelstring;
pub mod message; pub mod message;
pub mod parser; pub mod parser;
pub mod reader; pub mod reader;
pub mod structs; pub mod structs;
pub mod test_utils; pub mod test_utils;
pub mod utils;

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
errors::ZNSError, errors::ZNSError,
structs::{LabelString, Message, Opcode, RCODE}, labelstring::LabelString,
utils::labels_equal, structs::{Message, Opcode, RCODE},
}; };
impl Message { impl Message {
@ -23,10 +23,12 @@ impl Message {
for question in &self.question { for question in &self.question {
let zlen = question.qname.len(); let zlen = question.qname.len();
if !(zlen >= auth_zone.len() if !(zlen >= auth_zone.len()
&& labels_equal(&question.qname[zlen - auth_zone.len()..].into(), auth_zone)) && &Into::<LabelString>::into(
question.qname.as_slice()[zlen - auth_zone.len()..].to_vec(),
) == auth_zone)
{ {
return Err(ZNSError::Refused { return Err(ZNSError::Refused {
message: format!("Not authoritative for: {}", question.qname.join(".")), message: format!("Not authoritative for: {}", question.qname),
}); });
} }
} }
@ -69,20 +71,16 @@ mod tests {
#[test] #[test]
fn test_authoritative() { fn test_authoritative() {
let name = vec![ let name = LabelString::from("not.good.zone");
String::from("not"),
String::from("good"),
String::from("zone"),
];
let message = get_message(Some(name)); let message = get_message(Some(name));
assert!(message assert!(message
.check_authoritative(&vec![String::from("good")]) .check_authoritative(&LabelString::from("good"))
.is_err_and(|x| x.rcode() == RCODE::REFUSED)); .is_err_and(|x| x.rcode() == RCODE::REFUSED));
assert!(message assert!(message
.check_authoritative(&vec![String::from("Zone")]) .check_authoritative(&LabelString::from("Zone"))
.is_ok()) .is_ok())
} }
} }

View file

@ -2,8 +2,9 @@ use std::mem::size_of;
use crate::{ use crate::{
errors::ZNSError, errors::ZNSError,
labelstring::LabelString,
reader::Reader, reader::Reader,
structs::{Class, Header, LabelString, Message, Opcode, Question, RRClass, RRType, Type, RR}, structs::{Class, Header, Message, Opcode, Question, RRClass, RRType, SoaRData, Type, RR},
}; };
type Result<T> = std::result::Result<T, ZNSError>; type Result<T> = std::result::Result<T, ZNSError>;
@ -143,17 +144,17 @@ impl FromBytes for LabelString {
if code & 0b11000000 != 0 { if code & 0b11000000 != 0 {
let offset = (((code & 0b00111111) as u16) << 8) | reader.read_u8()? as u16; let offset = (((code & 0b00111111) as u16) << 8) | reader.read_u8()? as u16;
let mut reader_past = reader.seek(offset as usize)?; let mut reader_past = reader.seek(offset as usize)?;
out.extend(LabelString::from_bytes(&mut reader_past)?); out.extend(LabelString::from_bytes(&mut reader_past)?.to_vec());
} }
Ok(out) Ok(out.into())
} }
} }
impl ToBytes for LabelString { impl ToBytes for LabelString {
fn to_bytes(name: Self) -> Vec<u8> { fn to_bytes(name: Self) -> Vec<u8> {
let mut result: Vec<u8> = vec![]; let mut result: Vec<u8> = vec![];
for label in name { for label in name.as_slice() {
result.push(label.len() as u8); result.push(label.len() as u8);
result.extend(label.as_bytes()); result.extend(label.as_bytes());
} }
@ -289,6 +290,19 @@ impl ToBytes for Message {
} }
} }
impl ToBytes for SoaRData {
fn to_bytes(rdata: Self) -> Vec<u8> {
let mut result = LabelString::to_bytes(rdata.mname);
result.extend(LabelString::to_bytes(rdata.rname));
result.extend(u32::to_be_bytes(rdata.serial));
result.extend(i32::to_be_bytes(rdata.refresh));
result.extend(i32::to_be_bytes(rdata.retry));
result.extend(i32::to_be_bytes(rdata.expire));
result.extend(u32::to_be_bytes(rdata.minimum));
result
}
}
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use crate::test_utils::{get_message, get_rr}; use crate::test_utils::{get_message, get_rr};
@ -315,7 +329,7 @@ pub mod tests {
#[test] #[test]
fn test_parse_question() { fn test_parse_question() {
let question = Question { let question = Question {
qname: vec![String::from("example"), String::from("org")], qname: LabelString::from("example.org"),
qtype: Type::Type(RRType::A), qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN), qclass: Class::Class(RRClass::IN),
}; };
@ -338,7 +352,7 @@ pub mod tests {
#[test] #[test]
fn test_labelstring() { fn test_labelstring() {
let labelstring = vec![String::from("example"), String::from("org")]; let labelstring: LabelString = vec![String::from("example"), String::from("org")].into();
let bytes = LabelString::to_bytes(labelstring.clone()); let bytes = LabelString::to_bytes(labelstring.clone());
let parsed = LabelString::from_bytes(&mut Reader::new(&bytes)); let parsed = LabelString::from_bytes(&mut Reader::new(&bytes));
@ -348,7 +362,7 @@ pub mod tests {
#[test] #[test]
fn test_labelstring_ptr() { fn test_labelstring_ptr() {
let labelstring = vec![String::from("example"), String::from("org")]; let labelstring: LabelString = vec![String::from("example"), String::from("org")].into();
let mut bytes = LabelString::to_bytes(labelstring.clone()); let mut bytes = LabelString::to_bytes(labelstring.clone());
@ -370,7 +384,7 @@ pub mod tests {
#[test] #[test]
fn test_labelstring_invalid_ptr() { fn test_labelstring_invalid_ptr() {
let labelstring = vec![String::from("example"), String::from("org")]; let labelstring: LabelString = vec![String::from("example"), String::from("org")].into();
let mut bytes = LabelString::to_bytes(labelstring.clone()); let mut bytes = LabelString::to_bytes(labelstring.clone());

View file

@ -1,5 +1,7 @@
use int_enum::IntEnum; use int_enum::IntEnum;
use crate::labelstring::LabelString;
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum Type { pub enum Type {
Type(RRType), Type(RRType),
@ -89,4 +91,12 @@ pub struct RR {
pub rdata: Vec<u8>, pub rdata: Vec<u8>,
} }
pub type LabelString = Vec<String>; pub struct SoaRData {
pub mname: LabelString,
pub rname: LabelString,
pub serial: u32,
pub refresh: i32,
pub retry: i32,
pub expire: i32,
pub minimum: u32,
}

View file

@ -2,9 +2,10 @@
use crate::structs::*; use crate::structs::*;
#[cfg(feature = "test-utils")] #[cfg(feature = "test-utils")]
use crate::labelstring::LabelString;
pub fn get_rr(name: Option<LabelString>) -> RR { pub fn get_rr(name: Option<LabelString>) -> RR {
RR { RR {
name: name.unwrap_or(vec![String::from("example"), String::from("org")]), name: name.unwrap_or(LabelString::from("example.org")),
_type: Type::Type(RRType::A), _type: Type::Type(RRType::A),
class: Class::Class(RRClass::IN), class: Class::Class(RRClass::IN),
ttl: 10, ttl: 10,
@ -25,16 +26,12 @@ pub fn get_message(name: Option<LabelString>) -> Message {
}, },
question: vec![ question: vec![
Question { Question {
qname: name qname: name.clone().unwrap_or(LabelString::from("example.org")),
.clone()
.unwrap_or(vec![String::from("example"), String::from("org")]),
qtype: Type::Type(RRType::A), qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN), qclass: Class::Class(RRClass::IN),
}, },
Question { Question {
qname: name qname: name.clone().unwrap_or(LabelString::from("example.org")),
.clone()
.unwrap_or(vec![String::from("example"), String::from("org")]),
qtype: Type::Type(RRType::A), qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN), qclass: Class::Class(RRClass::IN),
}, },

View file

@ -1,34 +0,0 @@
use crate::structs::LabelString;
pub fn labels_equal(vec1: &LabelString, vec2: &LabelString) -> bool {
if vec1.len() != vec2.len() {
return false;
}
for (elem1, elem2) in vec1.iter().zip(vec2.iter()) {
if elem1.to_lowercase() != elem2.to_lowercase() {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_labels_equal() {
assert!(labels_equal(
&vec![String::from("one"), String::from("two")],
&vec![String::from("oNE"), String::from("two")]
));
assert!(!labels_equal(
&vec![String::from("one"), String::from("two")],
&vec![String::from("oNEe"), String::from("two")]
));
}
}