mirror of
https://github.com/ZeusWPI/ZNS.git
synced 2024-10-30 05:24:26 +01:00
wrote some tests (at last)
This commit is contained in:
parent
aa94dc21bb
commit
5d59f1bd97
12 changed files with 479 additions and 92 deletions
|
@ -2,8 +2,19 @@ use diesel::prelude::*;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
|
||||||
pub fn establish_connection() -> PgConnection {
|
pub fn get_connection() -> PgConnection {
|
||||||
let database_url = Config::get().db_uri.clone();
|
let database_url = Config::get().db_uri.clone();
|
||||||
PgConnection::establish(&database_url)
|
PgConnection::establish(&database_url)
|
||||||
.unwrap_or_else(|_| panic!("Error connecting to {}", Config::get().db_uri))
|
.unwrap_or_else(|_| panic!("Error connecting to {}", Config::get().db_uri))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub fn get_test_connection() -> PgConnection {
|
||||||
|
let mut connection = get_connection();
|
||||||
|
assert!(connection.begin_test_transaction().is_ok());
|
||||||
|
connection
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -6,8 +6,6 @@ use diesel::prelude::*;
|
||||||
|
|
||||||
use self::schema::records::{self};
|
use self::schema::records::{self};
|
||||||
|
|
||||||
use super::lib::establish_connection;
|
|
||||||
|
|
||||||
mod schema {
|
mod schema {
|
||||||
diesel::table! {
|
diesel::table! {
|
||||||
records (name, _type, class, rdlength, rdata) {
|
records (name, _type, class, rdlength, rdata) {
|
||||||
|
@ -81,8 +79,7 @@ impl Record {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn insert_into_database(rr: &RR) -> Result<(), ZNSError> {
|
pub fn insert_into_database(rr: &RR, connection: &mut PgConnection) -> Result<(), ZNSError> {
|
||||||
let db_connection = &mut establish_connection();
|
|
||||||
let record = Record {
|
let record = Record {
|
||||||
name: rr.name.join("."),
|
name: rr.name.join("."),
|
||||||
_type: rr._type.clone().into(),
|
_type: rr._type.clone().into(),
|
||||||
|
@ -92,21 +89,21 @@ pub async fn insert_into_database(rr: &RR) -> Result<(), ZNSError> {
|
||||||
rdata: rr.rdata.clone(),
|
rdata: rr.rdata.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Record::create(db_connection, record).map_err(|e| ZNSError::Database {
|
Record::create(connection, record).map_err(|e| ZNSError::Database {
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_from_database(
|
pub fn get_from_database(
|
||||||
name: &Vec<String>,
|
name: &Vec<String>,
|
||||||
_type: Type,
|
_type: Type,
|
||||||
class: Class,
|
class: Class,
|
||||||
|
connection: &mut PgConnection,
|
||||||
) -> Result<Vec<RR>, ZNSError> {
|
) -> Result<Vec<RR>, ZNSError> {
|
||||||
let db_connection = &mut establish_connection();
|
|
||||||
let records =
|
let records =
|
||||||
Record::get(db_connection, name.join("."), _type.into(), class.into()).map_err(|e| {
|
Record::get(connection, name.join("."), _type.into(), class.into()).map_err(|e| {
|
||||||
ZNSError::Database {
|
ZNSError::Database {
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
}
|
}
|
||||||
|
@ -128,18 +125,58 @@ pub async fn get_from_database(
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: cleanup models
|
//TODO: cleanup models
|
||||||
pub async fn delete_from_database(
|
pub fn delete_from_database(
|
||||||
name: &Vec<String>,
|
name: &Vec<String>,
|
||||||
_type: Option<Type>,
|
_type: Option<Type>,
|
||||||
class: Class,
|
class: Class,
|
||||||
rdata: Option<Vec<u8>>,
|
rdata: Option<Vec<u8>>,
|
||||||
|
connection: &mut PgConnection,
|
||||||
) {
|
) {
|
||||||
let db_connection = &mut establish_connection();
|
|
||||||
let _ = Record::delete(
|
let _ = Record::delete(
|
||||||
db_connection,
|
connection,
|
||||||
name.join("."),
|
name.join("."),
|
||||||
_type.map(|f| f.into()),
|
_type.map(|f| f.into()),
|
||||||
class.into(),
|
class.into(),
|
||||||
rdata,
|
rdata,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
use crate::{db::lib::tests::get_test_connection, parser::tests::get_rr};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test() {
|
||||||
|
let mut connection = get_test_connection();
|
||||||
|
|
||||||
|
let rr = get_rr();
|
||||||
|
|
||||||
|
let f = |connection: &mut PgConnection| {
|
||||||
|
get_from_database(&rr.name, rr._type.clone(), rr.class.clone(), connection)
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(f(&mut connection).unwrap().is_empty());
|
||||||
|
|
||||||
|
assert!(insert_into_database(&rr, &mut connection).is_ok());
|
||||||
|
|
||||||
|
let result = f(&mut connection);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.as_ref().unwrap().len(), 1);
|
||||||
|
assert_eq!(result.unwrap()[0], rr);
|
||||||
|
|
||||||
|
delete_from_database(
|
||||||
|
&rr.name,
|
||||||
|
Some(rr._type.clone()),
|
||||||
|
rr.class.clone(),
|
||||||
|
Some(rr.rdata.clone()),
|
||||||
|
&mut connection,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(f(&mut connection).unwrap().is_empty());
|
||||||
|
|
||||||
|
assert!(insert_into_database(&rr, &mut connection).is_ok());
|
||||||
|
assert!(insert_into_database(&rr, &mut connection).is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use diesel::PgConnection;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
errors::ZNSError,
|
errors::ZNSError,
|
||||||
structs::{Message, Opcode},
|
structs::{Message, Opcode},
|
||||||
|
@ -9,17 +11,26 @@ mod query;
|
||||||
mod update;
|
mod update;
|
||||||
|
|
||||||
pub trait ResponseHandler {
|
pub trait ResponseHandler {
|
||||||
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError>;
|
async fn handle(
|
||||||
|
message: &Message,
|
||||||
|
raw: &[u8],
|
||||||
|
connection: &mut PgConnection,
|
||||||
|
) -> Result<Message, ZNSError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Handler {}
|
pub struct Handler {}
|
||||||
|
|
||||||
impl ResponseHandler for Handler {
|
impl ResponseHandler for Handler {
|
||||||
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError> {
|
async fn handle(
|
||||||
|
message: &Message,
|
||||||
|
raw: &[u8],
|
||||||
|
connection: &mut PgConnection,
|
||||||
|
) -> Result<Message, ZNSError> {
|
||||||
match message.get_opcode() {
|
match message.get_opcode() {
|
||||||
|
//TODO: implement this in Opcode
|
||||||
Ok(opcode) => match opcode {
|
Ok(opcode) => match opcode {
|
||||||
Opcode::QUERY => QueryHandler::handle(&message, raw).await,
|
Opcode::QUERY => QueryHandler::handle(&message, raw, connection).await,
|
||||||
Opcode::UPDATE => UpdateHandler::handle(&message, raw).await,
|
Opcode::UPDATE => UpdateHandler::handle(&message, raw, connection).await,
|
||||||
},
|
},
|
||||||
Err(e) => Err(ZNSError::Formerr {
|
Err(e) => Err(ZNSError::Formerr {
|
||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use diesel::PgConnection;
|
||||||
|
|
||||||
use crate::{db::models::get_from_database, errors::ZNSError, structs::Message};
|
use crate::{db::models::get_from_database, errors::ZNSError, structs::Message};
|
||||||
|
|
||||||
use super::ResponseHandler;
|
use super::ResponseHandler;
|
||||||
|
@ -5,7 +7,11 @@ use super::ResponseHandler;
|
||||||
pub struct QueryHandler {}
|
pub struct QueryHandler {}
|
||||||
|
|
||||||
impl ResponseHandler for QueryHandler {
|
impl ResponseHandler for QueryHandler {
|
||||||
async fn handle(message: &Message, _raw: &[u8]) -> Result<Message, ZNSError> {
|
async fn handle(
|
||||||
|
message: &Message,
|
||||||
|
_raw: &[u8],
|
||||||
|
connection: &mut PgConnection,
|
||||||
|
) -> Result<Message, ZNSError> {
|
||||||
let mut response = message.clone();
|
let mut response = message.clone();
|
||||||
response.header.arcount = 0; //TODO: fix this, handle unknown class values
|
response.header.arcount = 0; //TODO: fix this, handle unknown class values
|
||||||
|
|
||||||
|
@ -14,8 +20,8 @@ impl ResponseHandler for QueryHandler {
|
||||||
&question.qname,
|
&question.qname,
|
||||||
question.qtype.clone(),
|
question.qtype.clone(),
|
||||||
question.qclass.clone(),
|
question.qclass.clone(),
|
||||||
)
|
connection,
|
||||||
.await;
|
);
|
||||||
|
|
||||||
match answers {
|
match answers {
|
||||||
Ok(rrs) => {
|
Ok(rrs) => {
|
||||||
|
@ -24,7 +30,7 @@ impl ResponseHandler for QueryHandler {
|
||||||
domain: question.qname.join("."),
|
domain: question.qname.join("."),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
response.header.ancount = rrs.len() as u16;
|
response.header.ancount += rrs.len() as u16;
|
||||||
response.answer.extend(rrs)
|
response.answer.extend(rrs)
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -38,3 +44,38 @@ impl ResponseHandler for QueryHandler {
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::parser::tests::get_message;
|
||||||
|
use crate::structs::*;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
db::{lib::tests::get_test_connection, models::insert_into_database},
|
||||||
|
parser::{tests::get_rr, ToBytes},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_handle_query() {
|
||||||
|
let mut connection = get_test_connection();
|
||||||
|
let rr = get_rr();
|
||||||
|
let mut message = get_message();
|
||||||
|
message.header.ancount = 0;
|
||||||
|
message.answer = vec![];
|
||||||
|
|
||||||
|
assert!(insert_into_database(&rr, &mut connection).is_ok());
|
||||||
|
|
||||||
|
let result = QueryHandler::handle(
|
||||||
|
&message,
|
||||||
|
&Message::to_bytes(message.clone()),
|
||||||
|
&mut connection,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result.header.ancount, 2);
|
||||||
|
assert_eq!(result.answer.len(), 2);
|
||||||
|
assert_eq!(result.answer[0], rr);
|
||||||
|
assert_eq!(result.answer[1], rr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use diesel::PgConnection;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::Config,
|
config::Config,
|
||||||
db::models::get_from_database,
|
db::models::get_from_database,
|
||||||
|
@ -9,7 +11,11 @@ use crate::{
|
||||||
|
|
||||||
use super::{dnskey::DNSKeyRData, sig::Sig};
|
use super::{dnskey::DNSKeyRData, sig::Sig};
|
||||||
|
|
||||||
pub async fn authenticate(sig: &Sig, zone: &Vec<String>) -> Result<bool, ZNSError> {
|
pub async fn authenticate(
|
||||||
|
sig: &Sig,
|
||||||
|
zone: &Vec<String>,
|
||||||
|
connection: &mut PgConnection,
|
||||||
|
) -> Result<bool, ZNSError> {
|
||||||
if zone.len() >= 4 {
|
if zone.len() >= 4 {
|
||||||
let username = &zone[zone.len() - 4]; // Should match: username.users.zeus.gent
|
let username = &zone[zone.len() - 4]; // Should match: username.users.zeus.gent
|
||||||
|
|
||||||
|
@ -18,7 +24,7 @@ pub async fn authenticate(sig: &Sig, zone: &Vec<String>) -> Result<bool, ZNSErro
|
||||||
if ssh_verified {
|
if ssh_verified {
|
||||||
Ok(true)
|
Ok(true)
|
||||||
} else {
|
} else {
|
||||||
Ok(validate_dnskey(zone, sig).await?)
|
Ok(validate_dnskey(zone, sig, connection).await?)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Err(ZNSError::NotAuth {
|
Err(ZNSError::NotAuth {
|
||||||
|
@ -40,15 +46,21 @@ async fn validate_ssh(username: &String, sig: &Sig) -> Result<bool, reqwest::Err
|
||||||
.any(|key| sig.verify_ssh(&key).is_ok_and(|b| b)))
|
.any(|key| sig.verify_ssh(&key).is_ok_and(|b| b)))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn validate_dnskey(zone: &Vec<String>, sig: &Sig) -> Result<bool, ZNSError> {
|
async fn validate_dnskey(
|
||||||
Ok(
|
zone: &Vec<String>,
|
||||||
get_from_database(zone, Type::Type(RRType::DNSKEY), Class::Class(RRClass::IN))
|
sig: &Sig,
|
||||||
.await?
|
connection: &mut PgConnection,
|
||||||
|
) -> Result<bool, ZNSError> {
|
||||||
|
Ok(get_from_database(
|
||||||
|
zone,
|
||||||
|
Type::Type(RRType::DNSKEY),
|
||||||
|
Class::Class(RRClass::IN),
|
||||||
|
connection,
|
||||||
|
)?
|
||||||
.iter()
|
.iter()
|
||||||
.any(|rr| {
|
.any(|rr| {
|
||||||
let mut reader = Reader::new(&rr.rdata);
|
let mut reader = Reader::new(&rr.rdata);
|
||||||
DNSKeyRData::from_bytes(&mut reader)
|
DNSKeyRData::from_bytes(&mut reader)
|
||||||
.is_ok_and(|dnskey| sig.verify_dnskey(dnskey).is_ok_and(|b| b))
|
.is_ok_and(|dnskey| sig.verify_dnskey(dnskey).is_ok_and(|b| b))
|
||||||
}),
|
}))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use diesel::PgConnection;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
db::models::{delete_from_database, insert_into_database},
|
db::models::{delete_from_database, insert_into_database},
|
||||||
errors::ZNSError,
|
errors::ZNSError,
|
||||||
|
@ -17,7 +19,11 @@ mod sig;
|
||||||
pub struct UpdateHandler {}
|
pub struct UpdateHandler {}
|
||||||
|
|
||||||
impl ResponseHandler for UpdateHandler {
|
impl ResponseHandler for UpdateHandler {
|
||||||
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError> {
|
async fn handle(
|
||||||
|
message: &Message,
|
||||||
|
raw: &[u8],
|
||||||
|
connection: &mut PgConnection,
|
||||||
|
) -> Result<Message, ZNSError> {
|
||||||
let response = message.clone();
|
let response = message.clone();
|
||||||
// Zone section (question) processing
|
// Zone section (question) processing
|
||||||
if (message.header.qdcount != 1)
|
if (message.header.qdcount != 1)
|
||||||
|
@ -44,7 +50,7 @@ impl ResponseHandler for UpdateHandler {
|
||||||
if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) {
|
if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) {
|
||||||
let sig = Sig::new(last.unwrap(), raw)?;
|
let sig = Sig::new(last.unwrap(), raw)?;
|
||||||
|
|
||||||
if !authenticate::authenticate(&sig, &zone.qname)
|
if !authenticate::authenticate(&sig, &zone.qname, connection)
|
||||||
.await
|
.await
|
||||||
.is_ok_and(|x| x)
|
.is_ok_and(|x| x)
|
||||||
{
|
{
|
||||||
|
@ -89,7 +95,7 @@ impl ResponseHandler for UpdateHandler {
|
||||||
|
|
||||||
for rr in &message.authority {
|
for rr in &message.authority {
|
||||||
if rr.class == zone.qclass {
|
if rr.class == zone.qclass {
|
||||||
let _ = insert_into_database(&rr).await;
|
let _ = insert_into_database(&rr, connection);
|
||||||
} else if rr.class == Class::Class(RRClass::ANY) {
|
} else if rr.class == Class::Class(RRClass::ANY) {
|
||||||
if rr._type == Type::Type(RRType::ANY) {
|
if rr._type == Type::Type(RRType::ANY) {
|
||||||
if rr.name == zone.qname {
|
if rr.name == zone.qname {
|
||||||
|
@ -98,7 +104,13 @@ impl ResponseHandler for UpdateHandler {
|
||||||
message: "rr.name == zone.qname".to_string(),
|
message: "rr.name == zone.qname".to_string(),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await;
|
delete_from_database(
|
||||||
|
&rr.name,
|
||||||
|
None,
|
||||||
|
Class::Class(RRClass::IN),
|
||||||
|
None,
|
||||||
|
connection,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
delete_from_database(
|
delete_from_database(
|
||||||
|
@ -106,8 +118,8 @@ impl ResponseHandler for UpdateHandler {
|
||||||
Some(rr._type.clone()),
|
Some(rr._type.clone()),
|
||||||
Class::Class(RRClass::IN),
|
Class::Class(RRClass::IN),
|
||||||
None,
|
None,
|
||||||
|
connection,
|
||||||
)
|
)
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
} else if rr.class == Class::Class(RRClass::NONE) {
|
} else if rr.class == Class::Class(RRClass::NONE) {
|
||||||
if rr._type == Type::Type(RRType::SOA) {
|
if rr._type == Type::Type(RRType::SOA) {
|
||||||
|
@ -118,8 +130,8 @@ impl ResponseHandler for UpdateHandler {
|
||||||
Some(rr._type.clone()),
|
Some(rr._type.clone()),
|
||||||
Class::Class(RRClass::IN),
|
Class::Class(RRClass::IN),
|
||||||
Some(rr.rdata.clone()),
|
Some(rr.rdata.clone()),
|
||||||
|
connection,
|
||||||
)
|
)
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,18 @@ impl PublicKey for RsaPublicKey {
|
||||||
Ok(RsaPublicKey { e, n })
|
Ok(RsaPublicKey { e, n })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn from_dnskey(key: &[u8]) -> Result<Self, ZNSError>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
let mut reader = Reader::new(key);
|
||||||
|
let e_len = reader.read_u8()?;
|
||||||
|
let e = reader.read(e_len as usize)?;
|
||||||
|
let mut n = reader.read(reader.unread_bytes())?;
|
||||||
|
n.insert(0, 0);
|
||||||
|
Ok(RsaPublicKey { e, n })
|
||||||
|
}
|
||||||
|
|
||||||
fn verify(
|
fn verify(
|
||||||
&self,
|
&self,
|
||||||
data: &[u8],
|
data: &[u8],
|
||||||
|
@ -47,7 +59,7 @@ impl PublicKey for RsaPublicKey {
|
||||||
Algorithm::RSASHA512 => Ok(&signature::RSA_PKCS1_2048_8192_SHA512),
|
Algorithm::RSASHA512 => Ok(&signature::RSA_PKCS1_2048_8192_SHA512),
|
||||||
Algorithm::RSASHA256 => Ok(&signature::RSA_PKCS1_2048_8192_SHA256),
|
Algorithm::RSASHA256 => Ok(&signature::RSA_PKCS1_2048_8192_SHA256),
|
||||||
_ => Err(ZNSError::PublicKey {
|
_ => Err(ZNSError::PublicKey {
|
||||||
message: format!("RsaPublicKey: invalid verify algorithm",),
|
message: String::from("RsaPublicKey: invalid verify algorithm"),
|
||||||
}),
|
}),
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
|
@ -55,16 +67,4 @@ impl PublicKey for RsaPublicKey {
|
||||||
|
|
||||||
Ok(pkey.verify(data, signature).is_ok())
|
Ok(pkey.verify(data, signature).is_ok())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_dnskey(key: &[u8]) -> Result<Self, ZNSError>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
let mut reader = Reader::new(key);
|
|
||||||
let e_len = reader.read_u8()?;
|
|
||||||
let e = reader.read(e_len as usize)?;
|
|
||||||
let mut n = reader.read(reader.unread_bytes())?;
|
|
||||||
n.insert(0, 0);
|
|
||||||
Ok(RsaPublicKey { e, n })
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,4 +9,43 @@ impl Message {
|
||||||
pub fn get_opcode(&self) -> Result<Opcode, String> {
|
pub fn get_opcode(&self) -> Result<Opcode, String> {
|
||||||
Opcode::try_from((self.header.flags & 0b0111100000000000) >> 11)
|
Opcode::try_from((self.header.flags & 0b0111100000000000) >> 11)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)] // Used with tests
|
||||||
|
pub fn get_rcode(&self) -> Result<RCODE, u16> {
|
||||||
|
RCODE::try_from(self.header.flags & (!0 >> 12))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
|
||||||
|
use crate::structs::Header;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test() {
|
||||||
|
let mut message = Message {
|
||||||
|
header: Header {
|
||||||
|
id: 1,
|
||||||
|
flags: 288,
|
||||||
|
qdcount: 0,
|
||||||
|
ancount: 0,
|
||||||
|
nscount: 0,
|
||||||
|
arcount: 0,
|
||||||
|
},
|
||||||
|
question: vec![],
|
||||||
|
answer: vec![],
|
||||||
|
authority: vec![],
|
||||||
|
additional: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(message.get_opcode().unwrap() as u8, Opcode::QUERY as u8);
|
||||||
|
|
||||||
|
message.set_response(RCODE::NOTIMP);
|
||||||
|
|
||||||
|
assert!((message.header.flags & (1 << 15)) > 0);
|
||||||
|
|
||||||
|
assert_eq!(message.get_rcode().unwrap(), RCODE::NOTIMP);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
148
src/parser.rs
148
src/parser.rs
|
@ -295,3 +295,151 @@ impl ToBytes for Message {
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub fn get_rr() -> RR {
|
||||||
|
RR {
|
||||||
|
name: vec![String::from("example"), String::from("org")],
|
||||||
|
_type: Type::Type(RRType::A),
|
||||||
|
class: Class::Class(RRClass::IN),
|
||||||
|
ttl: 10,
|
||||||
|
rdlength: 4,
|
||||||
|
rdata: vec![1, 2, 3, 4],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_message() -> Message {
|
||||||
|
Message {
|
||||||
|
header: Header {
|
||||||
|
id: 1,
|
||||||
|
flags: 288,
|
||||||
|
qdcount: 2,
|
||||||
|
ancount: 1,
|
||||||
|
nscount: 1,
|
||||||
|
arcount: 1,
|
||||||
|
},
|
||||||
|
question: vec![
|
||||||
|
Question {
|
||||||
|
qname: vec![String::from("example"), String::from("org")],
|
||||||
|
qtype: Type::Type(RRType::A),
|
||||||
|
qclass: Class::Class(RRClass::IN),
|
||||||
|
},
|
||||||
|
Question {
|
||||||
|
qname: vec![String::from("example"), String::from("org")],
|
||||||
|
qtype: Type::Type(RRType::A),
|
||||||
|
qclass: Class::Class(RRClass::IN),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
answer: vec![get_rr()],
|
||||||
|
authority: vec![get_rr()],
|
||||||
|
additional: vec![get_rr()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_header() {
|
||||||
|
let header = Header {
|
||||||
|
id: 1,
|
||||||
|
flags: 288,
|
||||||
|
qdcount: 1,
|
||||||
|
ancount: 0,
|
||||||
|
nscount: 0,
|
||||||
|
arcount: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let bytes = Header::to_bytes(header.clone());
|
||||||
|
let parsed = Header::from_bytes(&mut Reader::new(&bytes));
|
||||||
|
assert!(parsed.is_ok());
|
||||||
|
assert_eq!(parsed.unwrap(), header);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_question() {
|
||||||
|
let question = Question {
|
||||||
|
qname: vec![String::from("example"), String::from("org")],
|
||||||
|
qtype: Type::Type(RRType::A),
|
||||||
|
qclass: Class::Class(RRClass::IN),
|
||||||
|
};
|
||||||
|
|
||||||
|
let bytes = Question::to_bytes(question.clone());
|
||||||
|
let parsed = Question::from_bytes(&mut Reader::new(&bytes));
|
||||||
|
assert!(parsed.is_ok());
|
||||||
|
assert_eq!(parsed.unwrap(), question);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_rr() {
|
||||||
|
let rr = get_rr();
|
||||||
|
|
||||||
|
let bytes = RR::to_bytes(rr.clone());
|
||||||
|
let parsed = RR::from_bytes(&mut Reader::new(&bytes));
|
||||||
|
assert!(parsed.is_ok());
|
||||||
|
assert_eq!(parsed.unwrap(), rr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_labelstring() {
|
||||||
|
let labelstring = vec![String::from("example"), String::from("org")];
|
||||||
|
|
||||||
|
let bytes = LabelString::to_bytes(labelstring.clone());
|
||||||
|
let parsed = LabelString::from_bytes(&mut Reader::new(&bytes));
|
||||||
|
assert!(parsed.is_ok());
|
||||||
|
assert_eq!(parsed.unwrap(), labelstring);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_labelstring_ptr() {
|
||||||
|
let labelstring = vec![String::from("example"), String::from("org")];
|
||||||
|
|
||||||
|
let mut bytes = LabelString::to_bytes(labelstring.clone());
|
||||||
|
|
||||||
|
bytes.insert(0, 0);
|
||||||
|
bytes.insert(0, 0);
|
||||||
|
|
||||||
|
let to_read = bytes.len();
|
||||||
|
|
||||||
|
bytes.push(0b11000000);
|
||||||
|
bytes.push(0b00000010);
|
||||||
|
|
||||||
|
let mut reader = Reader::new(&bytes);
|
||||||
|
let _ = reader.read(to_read);
|
||||||
|
|
||||||
|
let parsed = LabelString::from_bytes(&mut reader);
|
||||||
|
assert!(parsed.is_ok());
|
||||||
|
assert_eq!(parsed.unwrap(), labelstring);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_labelstring_invalid_ptr() {
|
||||||
|
let labelstring = vec![String::from("example"), String::from("org")];
|
||||||
|
|
||||||
|
let mut bytes = LabelString::to_bytes(labelstring.clone());
|
||||||
|
|
||||||
|
bytes.insert(0, 0);
|
||||||
|
bytes.insert(0, 0);
|
||||||
|
|
||||||
|
let to_read = bytes.len();
|
||||||
|
|
||||||
|
bytes.push(0b11000000);
|
||||||
|
// Not allowed to point to itself or in the future
|
||||||
|
bytes.push(to_read as u8);
|
||||||
|
|
||||||
|
let mut reader = Reader::new(&bytes);
|
||||||
|
let _ = reader.read(to_read);
|
||||||
|
|
||||||
|
let parsed = LabelString::from_bytes(&mut reader);
|
||||||
|
assert!(parsed.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_message() {
|
||||||
|
let message = get_message();
|
||||||
|
let bytes = Message::to_bytes(message.clone());
|
||||||
|
let parsed = Message::from_bytes(&mut Reader::new(&bytes));
|
||||||
|
assert!(parsed.is_ok());
|
||||||
|
assert_eq!(parsed.unwrap(), message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
use std::array::TryFromSliceError;
|
|
||||||
|
|
||||||
use crate::errors::ZNSError;
|
use crate::errors::ZNSError;
|
||||||
|
|
||||||
pub struct Reader<'a> {
|
pub struct Reader<'a> {
|
||||||
|
@ -33,43 +31,37 @@ impl<'a> Reader<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_u8(&mut self) -> Result<u8> {
|
pub fn read_u8(&mut self) -> Result<u8> {
|
||||||
|
if self.unread_bytes() == 0 {
|
||||||
|
Err(ZNSError::Reader {
|
||||||
|
message: String::from("cannot read u8"),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
self.position += 1;
|
self.position += 1;
|
||||||
Ok(self.buffer[self.position - 1])
|
Ok(self.buffer[self.position - 1])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn read_u16(&mut self) -> Result<u16> {
|
pub fn read_u16(&mut self) -> Result<u16> {
|
||||||
let result = u16::from_be_bytes(
|
let result =
|
||||||
self.buffer[self.position..self.position + 2]
|
u16::from_be_bytes(self.read(2)?.try_into().map_err(|_| ZNSError::Reader {
|
||||||
.try_into()
|
message: String::from("invalid read_u16"),
|
||||||
.map_err(|e: TryFromSliceError| ZNSError::Reader {
|
})?);
|
||||||
message: e.to_string(),
|
|
||||||
})?,
|
|
||||||
);
|
|
||||||
self.position += 2;
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_i32(&mut self) -> Result<i32> {
|
pub fn read_i32(&mut self) -> Result<i32> {
|
||||||
let result = i32::from_be_bytes(
|
let result =
|
||||||
self.buffer[self.position..self.position + 4]
|
i32::from_be_bytes(self.read(4)?.try_into().map_err(|_| ZNSError::Reader {
|
||||||
.try_into()
|
message: String::from("invalid read_u32"),
|
||||||
.map_err(|e: TryFromSliceError| ZNSError::Reader {
|
})?);
|
||||||
message: e.to_string(),
|
|
||||||
})?,
|
|
||||||
);
|
|
||||||
self.position += 4;
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_u32(&mut self) -> Result<u32> {
|
pub fn read_u32(&mut self) -> Result<u32> {
|
||||||
let result = u32::from_be_bytes(
|
let result =
|
||||||
self.buffer[self.position..self.position + 4]
|
u32::from_be_bytes(self.read(4)?.try_into().map_err(|_| ZNSError::Reader {
|
||||||
.try_into()
|
message: String::from("invalid read_u32"),
|
||||||
.map_err(|e: TryFromSliceError| ZNSError::Reader {
|
})?);
|
||||||
message: e.to_string(),
|
|
||||||
})?,
|
|
||||||
);
|
|
||||||
self.position += 4;
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,3 +75,58 @@ impl<'a> Reader<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test() {
|
||||||
|
let fake_bytes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||||
|
let mut reader = Reader::new(&fake_bytes);
|
||||||
|
|
||||||
|
assert_eq!(reader.unread_bytes(), 11);
|
||||||
|
|
||||||
|
let u16 = reader.read_u16();
|
||||||
|
assert!(u16.is_ok());
|
||||||
|
assert_eq!(u16.unwrap(), 1);
|
||||||
|
|
||||||
|
assert_eq!(reader.unread_bytes(), 9);
|
||||||
|
|
||||||
|
let u8 = reader.read_u8();
|
||||||
|
assert!(u8.is_ok());
|
||||||
|
assert_eq!(u8.unwrap(), 2);
|
||||||
|
assert_eq!(reader.unread_bytes(), 8);
|
||||||
|
|
||||||
|
let u32 = reader.read_u32();
|
||||||
|
assert!(u32.is_ok());
|
||||||
|
assert_eq!(
|
||||||
|
u32.unwrap(),
|
||||||
|
u32::from_be_bytes(fake_bytes[3..7].try_into().unwrap())
|
||||||
|
);
|
||||||
|
assert_eq!(reader.unread_bytes(), 4);
|
||||||
|
|
||||||
|
let read = reader.read(3);
|
||||||
|
assert!(read.is_ok());
|
||||||
|
assert_eq!(read.unwrap(), fake_bytes[7..10]);
|
||||||
|
assert_eq!(reader.unread_bytes(), 1);
|
||||||
|
|
||||||
|
let too_much = reader.read(2);
|
||||||
|
assert!(too_much.is_err());
|
||||||
|
assert_eq!(reader.unread_bytes(), 1);
|
||||||
|
|
||||||
|
assert!(reader.read_u8().is_ok());
|
||||||
|
|
||||||
|
assert!(reader.read_u8().is_err());
|
||||||
|
assert!(reader.read_u16().is_err());
|
||||||
|
assert!(reader.read_u32().is_err());
|
||||||
|
assert!(reader.read_i32().is_err());
|
||||||
|
|
||||||
|
let new_reader = reader.seek(1);
|
||||||
|
assert!(new_reader.is_ok());
|
||||||
|
assert_eq!(new_reader.unwrap().unread_bytes(), 10);
|
||||||
|
|
||||||
|
let new_reader = reader.seek(100);
|
||||||
|
assert!(new_reader.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
|
|
||||||
|
use crate::db::lib::get_connection;
|
||||||
use crate::errors::ZNSError;
|
use crate::errors::ZNSError;
|
||||||
use crate::handlers::{Handler, ResponseHandler};
|
use crate::handlers::{Handler, ResponseHandler};
|
||||||
use crate::parser::{FromBytes, ToBytes};
|
use crate::parser::{FromBytes, ToBytes};
|
||||||
|
@ -43,12 +44,13 @@ fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
|
||||||
async fn get_response(bytes: &[u8]) -> Message {
|
async fn get_response(bytes: &[u8]) -> Message {
|
||||||
let mut reader = Reader::new(bytes);
|
let mut reader = Reader::new(bytes);
|
||||||
match Message::from_bytes(&mut reader) {
|
match Message::from_bytes(&mut reader) {
|
||||||
Ok(mut message) => match Handler::handle(&message, bytes).await {
|
Ok(mut message) => match Handler::handle(&message, bytes, &mut get_connection()).await {
|
||||||
Ok(mut response) => {
|
Ok(mut response) => {
|
||||||
response.set_response(RCODE::NOERROR);
|
response.set_response(RCODE::NOERROR);
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
println!("{:#?}", message);
|
||||||
eprintln!("{}", e.to_string());
|
eprintln!("{}", e.to_string());
|
||||||
message.set_response(e.rcode());
|
message.set_response(e.rcode());
|
||||||
message
|
message
|
||||||
|
@ -72,3 +74,36 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Erro
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::structs::{Class, Question, RRClass, RRType, Type};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_response() {
|
||||||
|
let message = Message {
|
||||||
|
header: Header {
|
||||||
|
id: 1,
|
||||||
|
flags: 288,
|
||||||
|
qdcount: 1,
|
||||||
|
ancount: 0,
|
||||||
|
nscount: 0,
|
||||||
|
arcount: 0,
|
||||||
|
},
|
||||||
|
question: vec![Question {
|
||||||
|
qname: vec![String::from("example"), String::from("org")],
|
||||||
|
qtype: Type::Type(RRType::A),
|
||||||
|
qclass: Class::Class(RRClass::IN),
|
||||||
|
}],
|
||||||
|
answer: vec![],
|
||||||
|
authority: vec![],
|
||||||
|
additional: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = get_response(&Message::to_bytes(message)).await;
|
||||||
|
|
||||||
|
assert_eq!(response.get_rcode(), Ok(RCODE::NXDOMAIN));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -33,6 +33,7 @@ pub enum RRClass {
|
||||||
|
|
||||||
#[repr(u16)]
|
#[repr(u16)]
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug, IntEnum, PartialEq)]
|
||||||
pub enum RCODE {
|
pub enum RCODE {
|
||||||
NOERROR = 0,
|
NOERROR = 0,
|
||||||
FORMERR = 1,
|
FORMERR = 1,
|
||||||
|
@ -52,14 +53,14 @@ pub enum Opcode {
|
||||||
UPDATE = 5,
|
UPDATE = 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct Question {
|
pub struct Question {
|
||||||
pub qname: LabelString,
|
pub qname: LabelString,
|
||||||
pub qtype: Type, // NOTE: should be QTYPE, right now not really needed
|
pub qtype: Type, // NOTE: should be QTYPE, right now not really needed
|
||||||
pub qclass: Class, //NOTE: should be QCLASS, right now not really needed
|
pub qclass: Class, //NOTE: should be QCLASS, right now not really needed
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct Header {
|
pub struct Header {
|
||||||
pub id: u16,
|
pub id: u16,
|
||||||
pub flags: u16, // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | ; 1 | 4 | 1 | 1 | 1 | 1 | 3 | 4
|
pub flags: u16, // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | ; 1 | 4 | 1 | 1 | 1 | 1 | 3 | 4
|
||||||
|
@ -69,7 +70,7 @@ pub struct Header {
|
||||||
pub arcount: u16,
|
pub arcount: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub header: Header,
|
pub header: Header,
|
||||||
pub question: Vec<Question>,
|
pub question: Vec<Question>,
|
||||||
|
@ -78,7 +79,7 @@ pub struct Message {
|
||||||
pub additional: Vec<RR>,
|
pub additional: Vec<RR>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct RR {
|
pub struct RR {
|
||||||
pub name: LabelString,
|
pub name: LabelString,
|
||||||
pub _type: Type,
|
pub _type: Type,
|
||||||
|
@ -88,11 +89,4 @@ pub struct RR {
|
||||||
pub rdata: Vec<u8>,
|
pub rdata: Vec<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct OptRR {
|
|
||||||
pub code: u16,
|
|
||||||
pub length: u16,
|
|
||||||
pub rdata: Vec<u8>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type LabelString = Vec<String>;
|
pub type LabelString = Vec<String>;
|
||||||
|
|
Loading…
Reference in a new issue