diff --git a/src/db/models.rs b/src/db/models.rs index ec3ddd3..c96545a 100644 --- a/src/db/models.rs +++ b/src/db/models.rs @@ -1,6 +1,6 @@ use crate::{ errors::DatabaseError, - structs::{Class, Question, Type, RR}, + structs::{Class, Type, RR}, }; use diesel::prelude::*; @@ -99,17 +99,18 @@ pub async fn insert_into_database(rr: &RR) -> Result<(), DatabaseError> { Ok(()) } -pub async fn get_from_database(question: &Question) -> Result, DatabaseError> { +pub async fn get_from_database( + name: &Vec, + _type: Type, + class: Class, +) -> Result, DatabaseError> { let db_connection = &mut establish_connection(); - let records = Record::get( - db_connection, - question.qname.join("."), - question.qtype.clone().into(), - question.qclass.clone().into(), - ) - .map_err(|e| DatabaseError { - message: e.to_string(), - })?; + let records = + Record::get(db_connection, name.join("."), _type.into(), class.into()).map_err(|e| { + DatabaseError { + message: e.to_string(), + } + })?; Ok(records .into_iter() diff --git a/src/errors.rs b/src/errors.rs index ac01e6b..42b450f 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -4,7 +4,7 @@ use crate::structs::RCODE; pub struct DNSError { pub message: String, - pub rcode: RCODE + pub rcode: RCODE, } impl fmt::Display for DNSError { @@ -27,7 +27,7 @@ impl fmt::Display for ParseError { #[derive(Debug)] pub struct DatabaseError { - pub message: String + pub message: String, } impl fmt::Display for DatabaseError { @@ -77,7 +77,24 @@ where fn from(value: E) -> Self { DNSError { message: value.into().to_string(), - rcode: RCODE::FORMERR + rcode: RCODE::FORMERR, + } + } +} + +trait Supported {} + +impl Supported for reqwest::Error {} +impl Supported for DatabaseError {} + +impl From for AuthenticationError +where + E: Supported, + E: std::fmt::Display, +{ + fn from(value: E) -> Self { + AuthenticationError { + message: value.to_string(), } } } diff --git a/src/handlers/query.rs b/src/handlers/query.rs index e4f8f70..03c3ef0 100644 --- a/src/handlers/query.rs +++ b/src/handlers/query.rs @@ -14,7 +14,12 @@ impl ResponseHandler for QueryHandler { response.header.arcount = 0; //TODO: fix this, handle unknown class values for question in &message.question { - let answers = get_from_database(&question).await; + let answers = get_from_database( + &question.qname, + question.qtype.clone(), + question.qclass.clone(), + ) + .await; match answers { Ok(rrs) => { diff --git a/src/handlers/update/authenticate.rs b/src/handlers/update/authenticate.rs index ed68bb1..7b3cd5d 100644 --- a/src/handlers/update/authenticate.rs +++ b/src/handlers/update/authenticate.rs @@ -1,33 +1,30 @@ use reqwest::Error; -use crate::{config::Config, errors::AuthenticationError}; +use crate::{ + config::Config, + db::models::get_from_database, + errors::{AuthenticationError, DatabaseError}, + parser::FromBytes, + reader::Reader, + structs::{Class, RRClass, RRType, Type}, +}; -use super::sig::{PublicKey, Sig}; +use super::{dnskey::DNSKeyRData, sig::Sig}; -type SSHKeys = Vec; - -type Result = std::result::Result; - -pub(super) async fn authenticate(sig: &Sig, zone: &Vec) -> Result { +pub(super) async fn authenticate( + sig: &Sig, + zone: &Vec, +) -> Result { if zone.len() >= 4 { let username = &zone[zone.len() - 4]; // Should match: username.users.zeus.gent - let public_keys = get_keys(username).await.map_err(|e| AuthenticationError { - message: e.to_string(), - })?; - Ok(public_keys.iter().any(|key| { - let key_split: Vec<&str> = key.split_ascii_whitespace().collect(); - match key_split.len() { - 3 => { - let key_encoded = key_split[1].to_string(); - match key_split[0] { - "ssh-ed25519" => sig.verify(PublicKey::ED25519(key_encoded)), - _ => false, - } - } - _ => false, - } - })) + let ssh_verified = validate_ssh(username, sig).await?; + + if ssh_verified { + Ok(true) + } else { + Ok(validate_dnskey(zone, sig).await?) + } } else { Err(AuthenticationError { message: String::from("Invalid zone"), @@ -35,13 +32,38 @@ pub(super) async fn authenticate(sig: &Sig, zone: &Vec) -> Result } } -async fn get_keys(username: &String) -> std::result::Result { +async fn validate_ssh(username: &String, sig: &Sig) -> Result { Ok(reqwest::get(format!( "{}/users/keys/{}", Config::get().zauth_url, username )) .await? - .json::() - .await?) + .json::>() + .await? + .iter() + .any(|key| { + let key_split: Vec<&str> = key.split_ascii_whitespace().collect(); + match key_split.len() { + 3 => match key_split[0] { + "ssh-ed25519" => sig.verify_ssh_ed25519(key_split[1]), + _ => false, + }, + _ => false, + } + })) +} + +async fn validate_dnskey(zone: &Vec, sig: &Sig) -> Result { + Ok( + get_from_database(zone, Type::Type(RRType::DNSKEY), Class::Class(RRClass::IN)) + .await? + .iter() + .any(|rr| { + let mut reader = Reader::new(&rr.rdata); + DNSKeyRData::from_bytes(&mut reader) + .map(|key| key.verify(sig)) + .is_ok_and(|b| b) + }), + ) } diff --git a/src/handlers/update/dnskey.rs b/src/handlers/update/dnskey.rs new file mode 100644 index 0000000..01dc1ad --- /dev/null +++ b/src/handlers/update/dnskey.rs @@ -0,0 +1,36 @@ +use base64::prelude::*; + +use crate::{errors::ParseError, parser::FromBytes, reader::Reader}; + +use super::sig::Sig; + +/// https://datatracker.ietf.org/doc/html/rfc4034#section-2 +#[derive(Debug)] +pub(super) struct DNSKeyRData { + pub flags: u16, + pub protocol: u8, + pub algorithm: u8, + pub public_key: Vec, +} + +//TODO: validate values +impl FromBytes for DNSKeyRData { + fn from_bytes(reader: &mut Reader) -> Result { + Ok(DNSKeyRData { + flags: reader.read_u16()?, + protocol: reader.read_u8()?, + algorithm: reader.read_u8()?, + public_key: reader.read(reader.unread_bytes())?, + }) + } +} + +impl DNSKeyRData { + pub fn verify(&self, sig: &Sig) -> bool { + let encoded = BASE64_STANDARD.encode(&self.public_key); + match self.algorithm { + 15 => sig.verify_ed25519(&encoded), + _ => false, + } + } +} diff --git a/src/handlers/update/mod.rs b/src/handlers/update/mod.rs index 41f9707..b90f904 100644 --- a/src/handlers/update/mod.rs +++ b/src/handlers/update/mod.rs @@ -10,6 +10,7 @@ use self::sig::Sig; use super::ResponseHandler; mod authenticate; +mod dnskey; mod sig; pub(super) struct UpdateHandler {} diff --git a/src/handlers/update/sig.rs b/src/handlers/update/sig.rs index f5f19ad..ae0ddca 100644 --- a/src/handlers/update/sig.rs +++ b/src/handlers/update/sig.rs @@ -12,10 +12,6 @@ pub(super) struct Sig { key_rdata: KeyRData, } -pub(super) enum PublicKey { - ED25519(String), -} - impl Sig { pub fn new(rr: &RR, datagram: &[u8]) -> Result { let mut request = datagram[0..datagram.len() - 11 - rr.rdlength as usize].to_vec(); @@ -33,7 +29,16 @@ impl Sig { }) } - fn verify_ed25519(&self, key: String) -> bool { + pub fn verify_ed25519(&self, key: &str) -> bool { + let blob = BASE64_STANDARD.decode(key).unwrap(); + + let pkey = ring::signature::UnparsedPublicKey::new(&ring::signature::ED25519, &blob); + + pkey.verify(&self.raw_data, &self.key_rdata.signature) + .is_ok() + } + + pub fn verify_ssh_ed25519(&self, key: &str) -> bool { let blob = BASE64_STANDARD.decode(key).unwrap(); let pkey = ring::signature::UnparsedPublicKey::new( @@ -44,10 +49,4 @@ impl Sig { pkey.verify(&self.raw_data, &self.key_rdata.signature) .is_ok() } - - pub fn verify(&self, key: PublicKey) -> bool { - match key { - PublicKey::ED25519(pkey) => self.verify_ed25519(pkey), - } - } } diff --git a/src/parser.rs b/src/parser.rs index dba7b15..6c8341e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -4,8 +4,7 @@ use crate::{ errors::ParseError, reader::Reader, structs::{ - Class, DNSKeyRData, Header, KeyRData, LabelString, Message, Opcode, Question, RRClass, - RRType, Type, RR, + Class, Header, KeyRData, LabelString, Message, Opcode, Question, RRClass, RRType, Type, RR, }, }; @@ -43,12 +42,14 @@ impl From for u16 { impl From for Type { fn from(value: u16) -> Self { + //TODO: use macro match value { x if x == RRType::A as u16 => Type::Type(RRType::A), x if x == RRType::OPT as u16 => Type::Type(RRType::OPT), x if x == RRType::SOA as u16 => Type::Type(RRType::SOA), x if x == RRType::ANY as u16 => Type::Type(RRType::SOA), x if x == RRType::KEY as u16 => Type::Type(RRType::KEY), + x if x == RRType::DNSKEY as u16 => Type::Type(RRType::DNSKEY), x => Type::Other(x), } } @@ -327,21 +328,3 @@ impl FromBytes for KeyRData { } } } - -impl FromBytes for DNSKeyRData { - fn from_bytes(reader: &mut Reader) -> Result { - if reader.unread_bytes() < 18 { - Err(ParseError { - object: String::from("DNSKeyRData"), - message: String::from("invalid rdata"), - }) - } else { - Ok(DNSKeyRData { - flags: reader.read_u16()?, - protocol: reader.read_u8()?, - algorithm: reader.read_u8()?, - public_key: reader.read(reader.unread_bytes())?, - }) - } - } -} diff --git a/src/structs.rs b/src/structs.rs index 10a7e38..58cf94b 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -9,7 +9,8 @@ pub enum Type { pub enum RRType { A = 1, SOA = 6, - KEY = 24, + KEY = 24, //TODO: change to SIG + DNSKEY = 48, OPT = 41, ANY = 255, } @@ -106,12 +107,3 @@ pub struct KeyRData { pub signer: LabelString, pub signature: Vec, } - -/// https://datatracker.ietf.org/doc/html/rfc4034#section-2 -#[derive(Debug)] -pub struct DNSKeyRData { - pub flags: u16, - pub protocol: u8, - pub algorithm: u8, - pub public_key: Vec, -} diff --git a/src/utils.rs b/src/utils.rs index 785fb20..2fa2923 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,3 @@ - pub fn vec_equal(vec1: &[T], vec2: &[T]) -> bool { if vec1.len() != vec2.len() { return false;