10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-11-24 06:11:10 +01:00

Convert errors using thiserror crate

This commit is contained in:
Xander Bil 2024-06-26 00:22:41 +02:00
parent 6ac9f2f36e
commit aa94dc21bb
No known key found for this signature in database
GPG key ID: EC9706B54A278598
16 changed files with 143 additions and 212 deletions

21
Cargo.lock generated
View file

@ -943,6 +943,26 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "thiserror"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.6.0" version = "1.6.0"
@ -1361,5 +1381,6 @@ dependencies = [
"int-enum", "int-enum",
"reqwest", "reqwest",
"ring", "ring",
"thiserror",
"tokio", "tokio",
] ]

View file

@ -12,3 +12,4 @@ base64 = "0.22.0"
reqwest = {version = "0.12.4", features = ["json","default"]} reqwest = {version = "0.12.4", features = ["json","default"]}
asn1 = "0.16.2" asn1 = "0.16.2"
int-enum = "1.1" int-enum = "1.1"
thiserror = "1.0"

View file

@ -1,5 +1,5 @@
use crate::{ use crate::{
errors::DatabaseError, errors::ZNSError,
structs::{Class, Type, RR}, structs::{Class, Type, RR},
}; };
use diesel::prelude::*; use diesel::prelude::*;
@ -81,7 +81,7 @@ impl Record {
} }
} }
pub async fn insert_into_database(rr: &RR) -> Result<(), DatabaseError> { pub async fn insert_into_database(rr: &RR) -> Result<(), ZNSError> {
let db_connection = &mut establish_connection(); let db_connection = &mut establish_connection();
let record = Record { let record = Record {
name: rr.name.join("."), name: rr.name.join("."),
@ -92,7 +92,7 @@ pub async fn insert_into_database(rr: &RR) -> Result<(), DatabaseError> {
rdata: rr.rdata.clone(), rdata: rr.rdata.clone(),
}; };
Record::create(db_connection, record).map_err(|e| DatabaseError { Record::create(db_connection, record).map_err(|e| ZNSError::Database {
message: e.to_string(), message: e.to_string(),
})?; })?;
@ -103,11 +103,11 @@ pub async fn get_from_database(
name: &Vec<String>, name: &Vec<String>,
_type: Type, _type: Type,
class: Class, class: Class,
) -> Result<Vec<RR>, DatabaseError> { ) -> Result<Vec<RR>, ZNSError> {
let db_connection = &mut establish_connection(); let db_connection = &mut establish_connection();
let records = let records =
Record::get(db_connection, name.join("."), _type.into(), class.into()).map_err(|e| { Record::get(db_connection, name.join("."), _type.into(), class.into()).map_err(|e| {
DatabaseError { ZNSError::Database {
message: e.to_string(), message: e.to_string(),
} }
})?; })?;

View file

@ -1,100 +1,44 @@
use core::fmt; use thiserror::Error;
use crate::structs::RCODE; use crate::structs::RCODE;
pub struct DNSError { #[derive(Error, Debug)]
pub message: String, pub enum ZNSError {
pub rcode: RCODE, #[error("Parse Error for {object:?}: {message:?}")]
Parse { object: String, message: String },
#[error("Database Error: {message:?}")]
Database { message: String },
#[error("Reader Error: {message:?}")]
Reader { message: String },
#[error("PublicKey Error: {message:?}")]
PublicKey { message: String },
#[error("Reqwest error")]
Reqwest(#[from] reqwest::Error),
#[error("DNS Query Format Error: {message:?}")]
Formerr { message: String },
#[error("Domain name does not exist")]
NXDomain { domain: String },
#[error("NotImplemented Error for {object:?}: {message:?}")]
NotImp { object: String, message: String },
#[error("Authentication Error: {message:?}")]
NotAuth { message: String },
#[error("I refuse to answer the query: {message:?}")]
Refused { message: String },
} }
impl fmt::Display for DNSError { impl ZNSError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { pub fn rcode(&self) -> RCODE {
write!(f, "Error: {}", self.message) match self {
ZNSError::Formerr { .. } | ZNSError::Parse { .. } | ZNSError::Reader { .. } => {
RCODE::FORMERR
} }
} ZNSError::Database { .. } | ZNSError::Reqwest(_) => RCODE::SERVFAIL,
#[derive(Debug)] ZNSError::NotAuth { .. } | ZNSError::PublicKey { .. } => RCODE::NOTAUTH,
pub struct ParseError { ZNSError::NXDomain { .. } => RCODE::NXDOMAIN,
pub object: String, ZNSError::NotImp { .. } => RCODE::NOTIMP,
pub message: String, ZNSError::Refused { .. } => RCODE::REFUSED,
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Parse Error for {}: {}", self.object, self.message)
}
}
#[derive(Debug)]
pub struct DatabaseError {
pub message: String,
}
impl fmt::Display for DatabaseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Database Error: {}", self.message)
}
}
#[derive(Debug)]
pub struct AuthenticationError {
pub message: String,
}
impl fmt::Display for AuthenticationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Authentication Error: {}", self.message)
}
}
#[derive(Debug)]
pub struct ReaderError {
pub message: String,
}
impl fmt::Display for ReaderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Reader Error: {}", self.message)
}
}
impl<E> From<E> for ParseError
where
E: Into<ReaderError>,
{
fn from(value: E) -> Self {
ParseError {
object: String::from("Reader"),
message: value.into().to_string(),
}
}
}
impl<E> From<E> for DNSError
where
E: Into<ParseError>,
{
fn from(value: E) -> Self {
DNSError {
message: value.into().to_string(),
rcode: RCODE::FORMERR,
}
}
}
trait Supported {}
impl Supported for reqwest::Error {}
impl Supported for DatabaseError {}
impl<E> From<E> for AuthenticationError
where
E: Supported,
E: std::fmt::Display,
{
fn from(value: E) -> Self {
AuthenticationError {
message: value.to_string(),
} }
} }
} }

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
errors::DNSError, errors::ZNSError,
structs::{Message, Opcode, RCODE}, structs::{Message, Opcode},
}; };
use self::{query::QueryHandler, update::UpdateHandler}; use self::{query::QueryHandler, update::UpdateHandler};
@ -9,21 +9,20 @@ mod query;
mod update; mod update;
pub trait ResponseHandler { pub trait ResponseHandler {
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, DNSError>; async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError>;
} }
pub struct Handler {} pub struct Handler {}
impl ResponseHandler for Handler { impl ResponseHandler for Handler {
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, DNSError> { async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError> {
match message.get_opcode() { match message.get_opcode() {
Ok(opcode) => match opcode { Ok(opcode) => match opcode {
Opcode::QUERY => QueryHandler::handle(&message, raw).await, Opcode::QUERY => QueryHandler::handle(&message, raw).await,
Opcode::UPDATE => UpdateHandler::handle(&message, raw).await, Opcode::UPDATE => UpdateHandler::handle(&message, raw).await,
}, },
Err(e) => Err(DNSError { Err(e) => Err(ZNSError::Formerr {
message: e.to_string(), message: e.to_string(),
rcode: RCODE::FORMERR,
}), }),
} }
} }

View file

@ -1,15 +1,11 @@
use crate::{ use crate::{db::models::get_from_database, errors::ZNSError, structs::Message};
db::models::get_from_database,
errors::DNSError,
structs::{Message, RCODE},
};
use super::ResponseHandler; use super::ResponseHandler;
pub struct QueryHandler {} pub struct QueryHandler {}
impl ResponseHandler for QueryHandler { impl ResponseHandler for QueryHandler {
async fn handle(message: &Message, _raw: &[u8]) -> Result<Message, DNSError> { async fn handle(message: &Message, _raw: &[u8]) -> Result<Message, ZNSError> {
let mut response = message.clone(); let mut response = message.clone();
response.header.arcount = 0; //TODO: fix this, handle unknown class values response.header.arcount = 0; //TODO: fix this, handle unknown class values
@ -23,12 +19,16 @@ impl ResponseHandler for QueryHandler {
match answers { match answers {
Ok(rrs) => { Ok(rrs) => {
if rrs.len() == 0 {
return Err(ZNSError::NXDomain {
domain: question.qname.join("."),
});
}
response.header.ancount = rrs.len() as u16; response.header.ancount = rrs.len() as u16;
response.answer.extend(rrs) response.answer.extend(rrs)
} }
Err(e) => { Err(e) => {
return Err(DNSError { return Err(ZNSError::Database {
rcode: RCODE::NXDOMAIN,
message: e.to_string(), message: e.to_string(),
}) })
} }

View file

@ -1,15 +1,15 @@
use crate::{ use crate::{
config::Config, config::Config,
db::models::get_from_database, db::models::get_from_database,
errors::{AuthenticationError, DatabaseError}, errors::ZNSError,
parser::FromBytes, parser::FromBytes,
reader::Reader, reader::Reader,
structs::{Class, RRClass, RRType, Type}, structs::{Class, RRClass, RRType, Type},
}; };
use super::{dnskey::DNSKeyRData, pubkeys::PublicKeyError, sig::Sig}; use super::{dnskey::DNSKeyRData, sig::Sig};
pub async fn authenticate(sig: &Sig, zone: &Vec<String>) -> Result<bool, AuthenticationError> { pub async fn authenticate(sig: &Sig, zone: &Vec<String>) -> Result<bool, ZNSError> {
if zone.len() >= 4 { if zone.len() >= 4 {
let username = &zone[zone.len() - 4]; // Should match: username.users.zeus.gent let username = &zone[zone.len() - 4]; // Should match: username.users.zeus.gent
@ -21,7 +21,7 @@ pub async fn authenticate(sig: &Sig, zone: &Vec<String>) -> Result<bool, Authent
Ok(validate_dnskey(zone, sig).await?) Ok(validate_dnskey(zone, sig).await?)
} }
} else { } else {
Err(AuthenticationError { Err(ZNSError::NotAuth {
message: String::from("Invalid zone"), message: String::from("Invalid zone"),
}) })
} }
@ -40,7 +40,7 @@ async fn validate_ssh(username: &String, sig: &Sig) -> Result<bool, reqwest::Err
.any(|key| sig.verify_ssh(&key).is_ok_and(|b| b))) .any(|key| sig.verify_ssh(&key).is_ok_and(|b| b)))
} }
async fn validate_dnskey(zone: &Vec<String>, sig: &Sig) -> Result<bool, DatabaseError> { async fn validate_dnskey(zone: &Vec<String>, sig: &Sig) -> Result<bool, ZNSError> {
Ok( Ok(
get_from_database(zone, Type::Type(RRType::DNSKEY), Class::Class(RRClass::IN)) get_from_database(zone, Type::Type(RRType::DNSKEY), Class::Class(RRClass::IN))
.await? .await?
@ -52,11 +52,3 @@ async fn validate_dnskey(zone: &Vec<String>, sig: &Sig) -> Result<bool, Database
}), }),
) )
} }
impl From<PublicKeyError> for AuthenticationError {
fn from(value: PublicKeyError) -> Self {
AuthenticationError {
message: value.to_string(),
}
}
}

View file

@ -1,4 +1,4 @@
use crate::{errors::ParseError, parser::FromBytes, reader::Reader}; use crate::{errors::ZNSError, parser::FromBytes, reader::Reader};
use super::sig::Algorithm; use super::sig::Algorithm;
@ -13,7 +13,7 @@ pub struct DNSKeyRData {
//TODO: validate values //TODO: validate values
impl FromBytes for DNSKeyRData { impl FromBytes for DNSKeyRData {
fn from_bytes(reader: &mut Reader) -> Result<Self, ParseError> { fn from_bytes(reader: &mut Reader) -> Result<Self, ZNSError> {
Ok(DNSKeyRData { Ok(DNSKeyRData {
flags: reader.read_u16()?, flags: reader.read_u16()?,
protocol: reader.read_u8()?, protocol: reader.read_u8()?,

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
db::models::{delete_from_database, insert_into_database}, db::models::{delete_from_database, insert_into_database},
errors::DNSError, errors::ZNSError,
structs::{Class, Message, RRClass, RRType, Type, RCODE}, structs::{Class, Message, RRClass, RRType, Type},
utils::vec_equal, utils::vec_equal,
}; };
@ -17,15 +17,14 @@ mod sig;
pub struct UpdateHandler {} pub struct UpdateHandler {}
impl ResponseHandler for UpdateHandler { impl ResponseHandler for UpdateHandler {
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, crate::errors::DNSError> { async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError> {
let response = message.clone(); let response = message.clone();
// Zone section (question) processing // Zone section (question) processing
if (message.header.qdcount != 1) if (message.header.qdcount != 1)
|| !matches!(message.question[0].qtype, Type::Type(RRType::SOA)) || !matches!(message.question[0].qtype, Type::Type(RRType::SOA))
{ {
return Err(DNSError { return Err(ZNSError::Formerr {
message: "Qdcount not one".to_string(), message: "Qdcount not one".to_string(),
rcode: RCODE::FORMERR,
}); });
} }
@ -33,9 +32,8 @@ impl ResponseHandler for UpdateHandler {
let zone = &message.question[0]; let zone = &message.question[0];
let zlen = zone.qname.len(); let zlen = zone.qname.len();
if !(zlen >= 2 && zone.qname[zlen - 1] == "gent" && zone.qname[zlen - 2] == "zeus") { if !(zlen >= 2 && zone.qname[zlen - 1] == "gent" && zone.qname[zlen - 2] == "zeus") {
return Err(DNSError { return Err(ZNSError::Formerr {
message: "Invalid zone".to_string(), message: "Invalid zone".to_string(),
rcode: RCODE::NOTAUTH,
}); });
} }
@ -50,15 +48,13 @@ impl ResponseHandler for UpdateHandler {
.await .await
.is_ok_and(|x| x) .is_ok_and(|x| x)
{ {
return Err(DNSError { return Err(ZNSError::NotAuth {
message: "Unable to verify authentication".to_string(), message: "Unable to verify authentication".to_string(),
rcode: RCODE::NOTAUTH,
}); });
} }
} else { } else {
return Err(DNSError { return Err(ZNSError::NotAuth {
message: "No KEY record at the end of request found".to_string(), message: "No KEY record at the end of request found".to_string(),
rcode: RCODE::NOTAUTH,
}); });
} }
@ -68,9 +64,8 @@ impl ResponseHandler for UpdateHandler {
// Check if rr has same zone // Check if rr has same zone
if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) { if rlen < zlen || !(vec_equal(&zone.qname, &rr.name[rlen - zlen..])) {
return Err(DNSError { return Err(ZNSError::Refused {
message: "RR has different zone from Question".to_string(), message: "RR has different zone from Question".to_string(),
rcode: RCODE::NOTZONE,
}); });
} }
@ -84,9 +79,8 @@ impl ResponseHandler for UpdateHandler {
.contains(&rr.class) .contains(&rr.class)
{ {
true => { true => {
return Err(DNSError { return Err(ZNSError::Formerr {
message: "RR has invalid rr,ttl or class".to_string(), message: "RR has invalid rr,ttl or class".to_string(),
rcode: RCODE::FORMERR,
}); });
} }
false => (), false => (),
@ -99,9 +93,9 @@ impl ResponseHandler for UpdateHandler {
} else if rr.class == Class::Class(RRClass::ANY) { } else if rr.class == Class::Class(RRClass::ANY) {
if rr._type == Type::Type(RRType::ANY) { if rr._type == Type::Type(RRType::ANY) {
if rr.name == zone.qname { if rr.name == zone.qname {
return Err(DNSError { return Err(ZNSError::NotImp {
message: "Not yet implemented".to_string(), object: String::from("Update Handler"),
rcode: RCODE::NOTIMP, message: "rr.name == zone.qname".to_string(),
}); });
} else { } else {
delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await; delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await;

View file

@ -1,15 +1,15 @@
use ring::signature; use ring::signature;
use crate::{handlers::update::sig::Algorithm, reader::Reader}; use crate::{errors::ZNSError, handlers::update::sig::Algorithm, reader::Reader};
use super::{PublicKey, PublicKeyError, SSH_ED25519}; use super::{PublicKey, SSH_ED25519};
pub struct Ed25519PublicKey { pub struct Ed25519PublicKey {
data: Vec<u8>, data: Vec<u8>,
} }
impl PublicKey for Ed25519PublicKey { impl PublicKey for Ed25519PublicKey {
fn from_openssh(key: &[u8]) -> Result<Self, PublicKeyError> fn from_openssh(key: &[u8]) -> Result<Self, ZNSError>
where where
Self: Sized, Self: Sized,
{ {
@ -21,7 +21,7 @@ impl PublicKey for Ed25519PublicKey {
}) })
} }
fn from_dnskey(key: &[u8]) -> Result<Self, PublicKeyError> fn from_dnskey(key: &[u8]) -> Result<Self, ZNSError>
where where
Self: Sized, Self: Sized,
{ {
@ -33,7 +33,7 @@ impl PublicKey for Ed25519PublicKey {
data: &[u8], data: &[u8],
signature: &[u8], signature: &[u8],
_algorithm: &Algorithm, _algorithm: &Algorithm,
) -> Result<bool, PublicKeyError> { ) -> Result<bool, ZNSError> {
let pkey = ring::signature::UnparsedPublicKey::new(&signature::ED25519, &self.data); let pkey = ring::signature::UnparsedPublicKey::new(&signature::ED25519, &self.data);
Ok(pkey.verify(data, signature).is_ok()) Ok(pkey.verify(data, signature).is_ok())

View file

@ -1,42 +1,23 @@
mod ed25519; mod ed25519;
mod rsa; mod rsa;
use core::fmt;
use std::str::from_utf8; use std::str::from_utf8;
use crate::{errors::ReaderError, reader::Reader}; use crate::errors::ZNSError;
use crate::reader::Reader;
pub use self::ed25519::Ed25519PublicKey; pub use self::ed25519::Ed25519PublicKey;
pub use self::rsa::RsaPublicKey; pub use self::rsa::RsaPublicKey;
use super::sig::Algorithm; use super::sig::Algorithm;
#[derive(Debug)]
pub struct PublicKeyError {
pub message: String,
}
impl fmt::Display for PublicKeyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Public Key Error: {}", self.message)
}
}
impl From<ReaderError> for PublicKeyError {
fn from(value: ReaderError) -> Self {
PublicKeyError {
message: value.to_string(),
}
}
}
pub const SSH_ED25519: &str = "ssh-ed25519"; pub const SSH_ED25519: &str = "ssh-ed25519";
pub const SSH_RSA: &str = "ssh-rsa"; pub const SSH_RSA: &str = "ssh-rsa";
pub trait PublicKey { pub trait PublicKey {
fn verify_ssh_type(reader: &mut Reader, key_type: &str) -> Result<(), PublicKeyError> { fn verify_ssh_type(reader: &mut Reader, key_type: &str) -> Result<(), ZNSError> {
let type_size = reader.read_i32()?; let type_size = reader.read_i32()?;
let read = reader.read(type_size as usize)?; let read = reader.read(type_size as usize)?;
let algo_type = from_utf8(&read).map_err(|e| PublicKeyError { let algo_type = from_utf8(&read).map_err(|e| ZNSError::PublicKey {
message: format!( message: format!(
"Could not convert type name bytes to string: {}", "Could not convert type name bytes to string: {}",
e.to_string() e.to_string()
@ -46,17 +27,17 @@ pub trait PublicKey {
if algo_type == key_type { if algo_type == key_type {
Ok(()) Ok(())
} else { } else {
Err(PublicKeyError { Err(ZNSError::PublicKey {
message: String::from("ssh key type does not match identifier"), message: String::from("ssh key type does not match identifier"),
}) })
} }
} }
fn from_openssh(key: &[u8]) -> Result<Self, PublicKeyError> fn from_openssh(key: &[u8]) -> Result<Self, ZNSError>
where where
Self: Sized; Self: Sized;
fn from_dnskey(key: &[u8]) -> Result<Self, PublicKeyError> fn from_dnskey(key: &[u8]) -> Result<Self, ZNSError>
where where
Self: Sized; Self: Sized;
@ -65,5 +46,5 @@ pub trait PublicKey {
data: &[u8], data: &[u8],
signature: &[u8], signature: &[u8],
algorithm: &Algorithm, algorithm: &Algorithm,
) -> Result<bool, PublicKeyError>; ) -> Result<bool, ZNSError>;
} }

View file

@ -1,8 +1,8 @@
use ring::signature; use ring::signature;
use crate::{handlers::update::sig::Algorithm, reader::Reader}; use crate::{errors::ZNSError, handlers::update::sig::Algorithm, reader::Reader};
use super::{PublicKey, PublicKeyError, SSH_RSA}; use super::{PublicKey, SSH_RSA};
pub struct RsaPublicKey { pub struct RsaPublicKey {
e: Vec<u8>, e: Vec<u8>,
@ -16,7 +16,7 @@ struct RsaAsn1<'a> {
} }
impl PublicKey for RsaPublicKey { impl PublicKey for RsaPublicKey {
fn from_openssh(key: &[u8]) -> Result<Self, PublicKeyError> fn from_openssh(key: &[u8]) -> Result<Self, ZNSError>
where where
Self: Sized, Self: Sized,
{ {
@ -34,19 +34,19 @@ impl PublicKey for RsaPublicKey {
data: &[u8], data: &[u8],
signature: &[u8], signature: &[u8],
algorithm: &Algorithm, algorithm: &Algorithm,
) -> Result<bool, PublicKeyError> { ) -> Result<bool, ZNSError> {
let result = asn1::write_single(&RsaAsn1 { let result = asn1::write_single(&RsaAsn1 {
n: asn1::BigInt::new(&self.n), n: asn1::BigInt::new(&self.n),
e: asn1::BigInt::new(&self.e), e: asn1::BigInt::new(&self.e),
}) })
.map_err(|e| PublicKeyError { .map_err(|e| ZNSError::PublicKey {
message: format!("Verify Error: {}", e), message: format!("Verify Error: {}", e),
})?; })?;
let signature_type = match algorithm { let signature_type = match algorithm {
Algorithm::RSASHA512 => Ok(&signature::RSA_PKCS1_2048_8192_SHA512), Algorithm::RSASHA512 => Ok(&signature::RSA_PKCS1_2048_8192_SHA512),
Algorithm::RSASHA256 => Ok(&signature::RSA_PKCS1_2048_8192_SHA256), Algorithm::RSASHA256 => Ok(&signature::RSA_PKCS1_2048_8192_SHA256),
_ => Err(PublicKeyError { _ => Err(ZNSError::PublicKey {
message: format!("RsaPublicKey: invalid verify algorithm",), message: format!("RsaPublicKey: invalid verify algorithm",),
}), }),
}?; }?;
@ -56,7 +56,7 @@ impl PublicKey for RsaPublicKey {
Ok(pkey.verify(data, signature).is_ok()) Ok(pkey.verify(data, signature).is_ok())
} }
fn from_dnskey(key: &[u8]) -> Result<Self, PublicKeyError> fn from_dnskey(key: &[u8]) -> Result<Self, ZNSError>
where where
Self: Sized, Self: Sized,
{ {

View file

@ -2,7 +2,7 @@ use base64::prelude::*;
use int_enum::IntEnum; use int_enum::IntEnum;
use crate::{ use crate::{
errors::ParseError, errors::ZNSError,
parser::FromBytes, parser::FromBytes,
reader::Reader, reader::Reader,
structs::{LabelString, RR}, structs::{LabelString, RR},
@ -10,7 +10,7 @@ use crate::{
use super::{ use super::{
dnskey::DNSKeyRData, dnskey::DNSKeyRData,
pubkeys::{Ed25519PublicKey, PublicKey, PublicKeyError, RsaPublicKey, SSH_ED25519, SSH_RSA}, pubkeys::{Ed25519PublicKey, PublicKey, RsaPublicKey, SSH_ED25519, SSH_RSA},
}; };
pub struct Sig { pub struct Sig {
@ -41,9 +41,8 @@ pub enum Algorithm {
} }
impl Algorithm { impl Algorithm {
pub fn from(value: u8) -> Result<Self, ParseError> { pub fn from(value: u8) -> Result<Self, ZNSError> {
Algorithm::try_from(value).map_err(|a| ParseError { Algorithm::try_from(value).map_err(|a| ZNSError::NotImp {
// TODO: Should respond with error code refused or notimpl
object: String::from("Algorithm"), object: String::from("Algorithm"),
message: format!("Usupported algorithm: {}", a), message: format!("Usupported algorithm: {}", a),
}) })
@ -51,9 +50,9 @@ impl Algorithm {
} }
impl FromBytes for SigRData { impl FromBytes for SigRData {
fn from_bytes(reader: &mut Reader) -> Result<Self, ParseError> { fn from_bytes(reader: &mut Reader) -> Result<Self, ZNSError> {
if reader.unread_bytes() < 18 { if reader.unread_bytes() < 18 {
Err(ParseError { Err(ZNSError::Parse {
object: String::from("KeyRData"), object: String::from("KeyRData"),
message: String::from("invalid rdata"), message: String::from("invalid rdata"),
}) })
@ -74,7 +73,7 @@ impl FromBytes for SigRData {
} }
impl Sig { impl Sig {
pub fn new(rr: &RR, datagram: &[u8]) -> Result<Self, ParseError> { pub fn new(rr: &RR, datagram: &[u8]) -> Result<Self, ZNSError> {
let mut request = datagram[0..datagram.len() - 11 - rr.rdlength as usize].to_vec(); let mut request = datagram[0..datagram.len() - 11 - rr.rdlength as usize].to_vec();
request[11] -= 1; // Decrease arcount request[11] -= 1; // Decrease arcount
@ -90,7 +89,7 @@ impl Sig {
}) })
} }
fn verify(&self, key: impl PublicKey) -> Result<bool, PublicKeyError> { fn verify(&self, key: impl PublicKey) -> Result<bool, ZNSError> {
key.verify( key.verify(
&self.raw_data, &self.raw_data,
&self.key_rdata.signature, &self.key_rdata.signature,
@ -98,7 +97,7 @@ impl Sig {
) )
} }
pub fn verify_ssh(&self, key: &str) -> Result<bool, PublicKeyError> { pub fn verify_ssh(&self, key: &str) -> Result<bool, ZNSError> {
let key_split: Vec<&str> = key.split_ascii_whitespace().collect(); let key_split: Vec<&str> = key.split_ascii_whitespace().collect();
let bin = BASE64_STANDARD.decode(key_split[1]).unwrap(); let bin = BASE64_STANDARD.decode(key_split[1]).unwrap();
@ -111,7 +110,7 @@ impl Sig {
} }
} }
pub fn verify_dnskey(&self, key: DNSKeyRData) -> Result<bool, PublicKeyError> { pub fn verify_dnskey(&self, key: DNSKeyRData) -> Result<bool, ZNSError> {
if self.key_rdata.algo != key.algorithm { if self.key_rdata.algo != key.algorithm {
Ok(false) Ok(false)
} else { } else {

View file

@ -1,12 +1,12 @@
use std::{mem::size_of, vec}; use std::mem::size_of;
use crate::{ use crate::{
errors::ParseError, errors::ZNSError,
reader::Reader, reader::Reader,
structs::{Class, Header, LabelString, Message, Opcode, Question, RRClass, RRType, Type, RR}, structs::{Class, Header, LabelString, Message, Opcode, Question, RRClass, RRType, Type, RR},
}; };
type Result<T> = std::result::Result<T, ParseError>; type Result<T> = std::result::Result<T, ZNSError>;
impl From<Type> for u16 { impl From<Type> for u16 {
fn from(value: Type) -> Self { fn from(value: Type) -> Self {
@ -83,7 +83,7 @@ pub trait ToBytes {
impl FromBytes for Header { impl FromBytes for Header {
fn from_bytes(reader: &mut Reader) -> Result<Self> { fn from_bytes(reader: &mut Reader) -> Result<Self> {
if reader.unread_bytes() < size_of::<Header>() { if reader.unread_bytes() < size_of::<Header>() {
Err(ParseError { Err(ZNSError::Parse {
object: String::from("Header"), object: String::from("Header"),
message: String::from("Size of Header does not match"), message: String::from("Size of Header does not match"),
}) })
@ -124,7 +124,7 @@ impl FromBytes for LabelString {
while code != 0 && (code & 0b11000000 == 0) && reader.unread_bytes() > code as usize { while code != 0 && (code & 0b11000000 == 0) && reader.unread_bytes() > code as usize {
out.push( out.push(
String::from_utf8(reader.read(code as usize)?.to_vec()).map_err(|e| { String::from_utf8(reader.read(code as usize)?.to_vec()).map_err(|e| {
ParseError { ZNSError::Parse {
object: String::from("Label"), object: String::from("Label"),
message: e.to_string(), message: e.to_string(),
} }
@ -159,7 +159,7 @@ impl FromBytes for Question {
fn from_bytes(reader: &mut Reader) -> Result<Self> { fn from_bytes(reader: &mut Reader) -> Result<Self> {
// 16 for length octet + zero length octet // 16 for length octet + zero length octet
if reader.unread_bytes() < 2 + size_of::<Class>() + size_of::<Type>() { if reader.unread_bytes() < 2 + size_of::<Class>() + size_of::<Type>() {
Err(ParseError { Err(ZNSError::Parse {
object: String::from("Question"), object: String::from("Question"),
message: String::from("len of bytes smaller then minimum size"), message: String::from("len of bytes smaller then minimum size"),
}) })
@ -167,7 +167,7 @@ impl FromBytes for Question {
let qname = LabelString::from_bytes(reader)?; let qname = LabelString::from_bytes(reader)?;
if reader.unread_bytes() < 4 { if reader.unread_bytes() < 4 {
Err(ParseError { Err(ZNSError::Parse {
object: String::from("Question"), object: String::from("Question"),
message: String::from("len of rest bytes smaller then minimum size"), message: String::from("len of rest bytes smaller then minimum size"),
}) })
@ -201,7 +201,7 @@ impl FromBytes for RR {
fn from_bytes(reader: &mut Reader) -> Result<Self> { fn from_bytes(reader: &mut Reader) -> Result<Self> {
let name = LabelString::from_bytes(reader)?; let name = LabelString::from_bytes(reader)?;
if reader.unread_bytes() < size_of::<Type>() + size_of::<Class>() + 6 { if reader.unread_bytes() < size_of::<Type>() + size_of::<Class>() + 6 {
Err(ParseError { Err(ZNSError::Parse {
object: String::from("RR"), object: String::from("RR"),
message: String::from("len of rest of bytes smaller then minimum size"), message: String::from("len of rest of bytes smaller then minimum size"),
}) })
@ -211,7 +211,7 @@ impl FromBytes for RR {
let ttl = reader.read_i32()?; let ttl = reader.read_i32()?;
let rdlength = reader.read_u16()?; let rdlength = reader.read_u16()?;
if reader.unread_bytes() < rdlength as usize { if reader.unread_bytes() < rdlength as usize {
Err(ParseError { Err(ZNSError::Parse {
object: String::from("RR"), object: String::from("RR"),
message: String::from("len of rest of bytes not equal to rdlength"), message: String::from("len of rest of bytes not equal to rdlength"),
}) })

View file

@ -1,13 +1,13 @@
use std::array::TryFromSliceError; use std::array::TryFromSliceError;
use crate::errors::ReaderError; use crate::errors::ZNSError;
pub struct Reader<'a> { pub struct Reader<'a> {
buffer: &'a [u8], buffer: &'a [u8],
position: usize, position: usize,
} }
type Result<T> = std::result::Result<T, ReaderError>; type Result<T> = std::result::Result<T, ZNSError>;
impl<'a> Reader<'a> { impl<'a> Reader<'a> {
pub fn new(buffer: &[u8]) -> Reader { pub fn new(buffer: &[u8]) -> Reader {
@ -23,7 +23,7 @@ impl<'a> Reader<'a> {
pub fn read(&mut self, size: usize) -> Result<Vec<u8>> { pub fn read(&mut self, size: usize) -> Result<Vec<u8>> {
if size > self.unread_bytes() { if size > self.unread_bytes() {
Err(ReaderError { Err(ZNSError::Reader {
message: String::from("cannot read enough bytes"), message: String::from("cannot read enough bytes"),
}) })
} else { } else {
@ -41,7 +41,7 @@ impl<'a> Reader<'a> {
let result = u16::from_be_bytes( let result = u16::from_be_bytes(
self.buffer[self.position..self.position + 2] self.buffer[self.position..self.position + 2]
.try_into() .try_into()
.map_err(|e: TryFromSliceError| ReaderError { .map_err(|e: TryFromSliceError| ZNSError::Reader {
message: e.to_string(), message: e.to_string(),
})?, })?,
); );
@ -53,7 +53,7 @@ impl<'a> Reader<'a> {
let result = i32::from_be_bytes( let result = i32::from_be_bytes(
self.buffer[self.position..self.position + 4] self.buffer[self.position..self.position + 4]
.try_into() .try_into()
.map_err(|e: TryFromSliceError| ReaderError { .map_err(|e: TryFromSliceError| ZNSError::Reader {
message: e.to_string(), message: e.to_string(),
})?, })?,
); );
@ -65,7 +65,7 @@ impl<'a> Reader<'a> {
let result = u32::from_be_bytes( let result = u32::from_be_bytes(
self.buffer[self.position..self.position + 4] self.buffer[self.position..self.position + 4]
.try_into() .try_into()
.map_err(|e: TryFromSliceError| ReaderError { .map_err(|e: TryFromSliceError| ZNSError::Reader {
message: e.to_string(), message: e.to_string(),
})?, })?,
); );
@ -75,7 +75,7 @@ impl<'a> Reader<'a> {
pub fn seek(&self, position: usize) -> Result<Self> { pub fn seek(&self, position: usize) -> Result<Self> {
if position >= self.position { if position >= self.position {
Err(ReaderError { Err(ZNSError::Reader {
message: String::from("Seeking into the future is not allowed!!"), message: String::from("Seeking into the future is not allowed!!"),
}) })
} else { } else {

View file

@ -4,7 +4,7 @@ use std::sync::Arc;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use crate::errors::ParseError; use crate::errors::ZNSError;
use crate::handlers::{Handler, ResponseHandler}; use crate::handlers::{Handler, ResponseHandler};
use crate::parser::{FromBytes, ToBytes}; use crate::parser::{FromBytes, ToBytes};
use crate::reader::Reader; use crate::reader::Reader;
@ -12,7 +12,7 @@ use crate::structs::{Header, Message, RCODE};
const MAX_DATAGRAM_SIZE: usize = 4096; const MAX_DATAGRAM_SIZE: usize = 4096;
fn handle_parse_error(bytes: &[u8], err: ParseError) -> Message { fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
eprintln!("{}", err); eprintln!("{}", err);
let mut reader = Reader::new(bytes); let mut reader = Reader::new(bytes);
let mut header = Header::from_bytes(&mut reader).unwrap_or(Header { let mut header = Header::from_bytes(&mut reader).unwrap_or(Header {
@ -50,7 +50,7 @@ async fn get_response(bytes: &[u8]) -> Message {
} }
Err(e) => { Err(e) => {
eprintln!("{}", e.to_string()); eprintln!("{}", e.to_string());
message.set_response(e.rcode); message.set_response(e.rcode());
message message
} }
}, },