10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-11-21 21:41:10 +01:00

wrote some tests (at last)

This commit is contained in:
Xander Bil 2024-07-07 21:01:26 +02:00
parent aa94dc21bb
commit 5d59f1bd97
No known key found for this signature in database
GPG key ID: EC9706B54A278598
12 changed files with 479 additions and 92 deletions

View file

@ -2,8 +2,19 @@ use diesel::prelude::*;
use crate::config::Config;
pub fn establish_connection() -> PgConnection {
pub fn get_connection() -> PgConnection {
let database_url = Config::get().db_uri.clone();
PgConnection::establish(&database_url)
.unwrap_or_else(|_| panic!("Error connecting to {}", Config::get().db_uri))
}
#[cfg(test)]
pub mod tests {
use super::*;
pub fn get_test_connection() -> PgConnection {
let mut connection = get_connection();
assert!(connection.begin_test_transaction().is_ok());
connection
}
}

View file

@ -6,8 +6,6 @@ use diesel::prelude::*;
use self::schema::records::{self};
use super::lib::establish_connection;
mod schema {
diesel::table! {
records (name, _type, class, rdlength, rdata) {
@ -81,8 +79,7 @@ impl Record {
}
}
pub async fn insert_into_database(rr: &RR) -> Result<(), ZNSError> {
let db_connection = &mut establish_connection();
pub fn insert_into_database(rr: &RR, connection: &mut PgConnection) -> Result<(), ZNSError> {
let record = Record {
name: rr.name.join("."),
_type: rr._type.clone().into(),
@ -92,21 +89,21 @@ pub async fn insert_into_database(rr: &RR) -> Result<(), ZNSError> {
rdata: rr.rdata.clone(),
};
Record::create(db_connection, record).map_err(|e| ZNSError::Database {
Record::create(connection, record).map_err(|e| ZNSError::Database {
message: e.to_string(),
})?;
Ok(())
}
pub async fn get_from_database(
pub fn get_from_database(
name: &Vec<String>,
_type: Type,
class: Class,
connection: &mut PgConnection,
) -> Result<Vec<RR>, ZNSError> {
let db_connection = &mut establish_connection();
let records =
Record::get(db_connection, name.join("."), _type.into(), class.into()).map_err(|e| {
Record::get(connection, name.join("."), _type.into(), class.into()).map_err(|e| {
ZNSError::Database {
message: e.to_string(),
}
@ -128,18 +125,58 @@ pub async fn get_from_database(
}
//TODO: cleanup models
pub async fn delete_from_database(
pub fn delete_from_database(
name: &Vec<String>,
_type: Option<Type>,
class: Class,
rdata: Option<Vec<u8>>,
connection: &mut PgConnection,
) {
let db_connection = &mut establish_connection();
let _ = Record::delete(
db_connection,
connection,
name.join("."),
_type.map(|f| f.into()),
class.into(),
rdata,
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{db::lib::tests::get_test_connection, parser::tests::get_rr};
#[test]
fn test() {
let mut connection = get_test_connection();
let rr = get_rr();
let f = |connection: &mut PgConnection| {
get_from_database(&rr.name, rr._type.clone(), rr.class.clone(), connection)
};
assert!(f(&mut connection).unwrap().is_empty());
assert!(insert_into_database(&rr, &mut connection).is_ok());
let result = f(&mut connection);
assert!(result.is_ok());
assert_eq!(result.as_ref().unwrap().len(), 1);
assert_eq!(result.unwrap()[0], rr);
delete_from_database(
&rr.name,
Some(rr._type.clone()),
rr.class.clone(),
Some(rr.rdata.clone()),
&mut connection,
);
assert!(f(&mut connection).unwrap().is_empty());
assert!(insert_into_database(&rr, &mut connection).is_ok());
assert!(insert_into_database(&rr, &mut connection).is_err());
}
}

View file

@ -1,3 +1,5 @@
use diesel::PgConnection;
use crate::{
errors::ZNSError,
structs::{Message, Opcode},
@ -9,17 +11,26 @@ mod query;
mod update;
pub trait ResponseHandler {
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError>;
async fn handle(
message: &Message,
raw: &[u8],
connection: &mut PgConnection,
) -> Result<Message, ZNSError>;
}
pub struct Handler {}
impl ResponseHandler for Handler {
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError> {
async fn handle(
message: &Message,
raw: &[u8],
connection: &mut PgConnection,
) -> Result<Message, ZNSError> {
match message.get_opcode() {
//TODO: implement this in Opcode
Ok(opcode) => match opcode {
Opcode::QUERY => QueryHandler::handle(&message, raw).await,
Opcode::UPDATE => UpdateHandler::handle(&message, raw).await,
Opcode::QUERY => QueryHandler::handle(&message, raw, connection).await,
Opcode::UPDATE => UpdateHandler::handle(&message, raw, connection).await,
},
Err(e) => Err(ZNSError::Formerr {
message: e.to_string(),

View file

@ -1,3 +1,5 @@
use diesel::PgConnection;
use crate::{db::models::get_from_database, errors::ZNSError, structs::Message};
use super::ResponseHandler;
@ -5,7 +7,11 @@ use super::ResponseHandler;
pub struct QueryHandler {}
impl ResponseHandler for QueryHandler {
async fn handle(message: &Message, _raw: &[u8]) -> Result<Message, ZNSError> {
async fn handle(
message: &Message,
_raw: &[u8],
connection: &mut PgConnection,
) -> Result<Message, ZNSError> {
let mut response = message.clone();
response.header.arcount = 0; //TODO: fix this, handle unknown class values
@ -14,8 +20,8 @@ impl ResponseHandler for QueryHandler {
&question.qname,
question.qtype.clone(),
question.qclass.clone(),
)
.await;
connection,
);
match answers {
Ok(rrs) => {
@ -24,7 +30,7 @@ impl ResponseHandler for QueryHandler {
domain: question.qname.join("."),
});
}
response.header.ancount = rrs.len() as u16;
response.header.ancount += rrs.len() as u16;
response.answer.extend(rrs)
}
Err(e) => {
@ -38,3 +44,38 @@ impl ResponseHandler for QueryHandler {
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::tests::get_message;
use crate::structs::*;
use crate::{
db::{lib::tests::get_test_connection, models::insert_into_database},
parser::{tests::get_rr, ToBytes},
};
#[tokio::test]
async fn test_handle_query() {
let mut connection = get_test_connection();
let rr = get_rr();
let mut message = get_message();
message.header.ancount = 0;
message.answer = vec![];
assert!(insert_into_database(&rr, &mut connection).is_ok());
let result = QueryHandler::handle(
&message,
&Message::to_bytes(message.clone()),
&mut connection,
)
.await
.unwrap();
assert_eq!(result.header.ancount, 2);
assert_eq!(result.answer.len(), 2);
assert_eq!(result.answer[0], rr);
assert_eq!(result.answer[1], rr);
}
}

View file

@ -1,3 +1,5 @@
use diesel::PgConnection;
use crate::{
config::Config,
db::models::get_from_database,
@ -9,7 +11,11 @@ use crate::{
use super::{dnskey::DNSKeyRData, sig::Sig};
pub async fn authenticate(sig: &Sig, zone: &Vec<String>) -> Result<bool, ZNSError> {
pub async fn authenticate(
sig: &Sig,
zone: &Vec<String>,
connection: &mut PgConnection,
) -> Result<bool, ZNSError> {
if zone.len() >= 4 {
let username = &zone[zone.len() - 4]; // Should match: username.users.zeus.gent
@ -18,7 +24,7 @@ pub async fn authenticate(sig: &Sig, zone: &Vec<String>) -> Result<bool, ZNSErro
if ssh_verified {
Ok(true)
} else {
Ok(validate_dnskey(zone, sig).await?)
Ok(validate_dnskey(zone, sig, connection).await?)
}
} else {
Err(ZNSError::NotAuth {
@ -40,15 +46,21 @@ async fn validate_ssh(username: &String, sig: &Sig) -> Result<bool, reqwest::Err
.any(|key| sig.verify_ssh(&key).is_ok_and(|b| b)))
}
async fn validate_dnskey(zone: &Vec<String>, sig: &Sig) -> Result<bool, ZNSError> {
Ok(
get_from_database(zone, Type::Type(RRType::DNSKEY), Class::Class(RRClass::IN))
.await?
async fn validate_dnskey(
zone: &Vec<String>,
sig: &Sig,
connection: &mut PgConnection,
) -> Result<bool, ZNSError> {
Ok(get_from_database(
zone,
Type::Type(RRType::DNSKEY),
Class::Class(RRClass::IN),
connection,
)?
.iter()
.any(|rr| {
let mut reader = Reader::new(&rr.rdata);
DNSKeyRData::from_bytes(&mut reader)
.is_ok_and(|dnskey| sig.verify_dnskey(dnskey).is_ok_and(|b| b))
}),
)
}))
}

View file

@ -1,3 +1,5 @@
use diesel::PgConnection;
use crate::{
db::models::{delete_from_database, insert_into_database},
errors::ZNSError,
@ -17,7 +19,11 @@ mod sig;
pub struct UpdateHandler {}
impl ResponseHandler for UpdateHandler {
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError> {
async fn handle(
message: &Message,
raw: &[u8],
connection: &mut PgConnection,
) -> Result<Message, ZNSError> {
let response = message.clone();
// Zone section (question) processing
if (message.header.qdcount != 1)
@ -44,7 +50,7 @@ impl ResponseHandler for UpdateHandler {
if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) {
let sig = Sig::new(last.unwrap(), raw)?;
if !authenticate::authenticate(&sig, &zone.qname)
if !authenticate::authenticate(&sig, &zone.qname, connection)
.await
.is_ok_and(|x| x)
{
@ -89,7 +95,7 @@ impl ResponseHandler for UpdateHandler {
for rr in &message.authority {
if rr.class == zone.qclass {
let _ = insert_into_database(&rr).await;
let _ = insert_into_database(&rr, connection);
} else if rr.class == Class::Class(RRClass::ANY) {
if rr._type == Type::Type(RRType::ANY) {
if rr.name == zone.qname {
@ -98,7 +104,13 @@ impl ResponseHandler for UpdateHandler {
message: "rr.name == zone.qname".to_string(),
});
} else {
delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await;
delete_from_database(
&rr.name,
None,
Class::Class(RRClass::IN),
None,
connection,
)
}
} else {
delete_from_database(
@ -106,8 +118,8 @@ impl ResponseHandler for UpdateHandler {
Some(rr._type.clone()),
Class::Class(RRClass::IN),
None,
connection,
)
.await;
}
} else if rr.class == Class::Class(RRClass::NONE) {
if rr._type == Type::Type(RRType::SOA) {
@ -118,8 +130,8 @@ impl ResponseHandler for UpdateHandler {
Some(rr._type.clone()),
Class::Class(RRClass::IN),
Some(rr.rdata.clone()),
connection,
)
.await;
}
}

View file

@ -29,6 +29,18 @@ impl PublicKey for RsaPublicKey {
Ok(RsaPublicKey { e, n })
}
fn from_dnskey(key: &[u8]) -> Result<Self, ZNSError>
where
Self: Sized,
{
let mut reader = Reader::new(key);
let e_len = reader.read_u8()?;
let e = reader.read(e_len as usize)?;
let mut n = reader.read(reader.unread_bytes())?;
n.insert(0, 0);
Ok(RsaPublicKey { e, n })
}
fn verify(
&self,
data: &[u8],
@ -47,7 +59,7 @@ impl PublicKey for RsaPublicKey {
Algorithm::RSASHA512 => Ok(&signature::RSA_PKCS1_2048_8192_SHA512),
Algorithm::RSASHA256 => Ok(&signature::RSA_PKCS1_2048_8192_SHA256),
_ => Err(ZNSError::PublicKey {
message: format!("RsaPublicKey: invalid verify algorithm",),
message: String::from("RsaPublicKey: invalid verify algorithm"),
}),
}?;
@ -55,16 +67,4 @@ impl PublicKey for RsaPublicKey {
Ok(pkey.verify(data, signature).is_ok())
}
fn from_dnskey(key: &[u8]) -> Result<Self, ZNSError>
where
Self: Sized,
{
let mut reader = Reader::new(key);
let e_len = reader.read_u8()?;
let e = reader.read(e_len as usize)?;
let mut n = reader.read(reader.unread_bytes())?;
n.insert(0, 0);
Ok(RsaPublicKey { e, n })
}
}

View file

@ -9,4 +9,43 @@ impl Message {
pub fn get_opcode(&self) -> Result<Opcode, String> {
Opcode::try_from((self.header.flags & 0b0111100000000000) >> 11)
}
#[allow(dead_code)] // Used with tests
pub fn get_rcode(&self) -> Result<RCODE, u16> {
RCODE::try_from(self.header.flags & (!0 >> 12))
}
}
#[cfg(test)]
mod tests {
use crate::structs::Header;
use super::*;
#[test]
fn test() {
let mut message = Message {
header: Header {
id: 1,
flags: 288,
qdcount: 0,
ancount: 0,
nscount: 0,
arcount: 0,
},
question: vec![],
answer: vec![],
authority: vec![],
additional: vec![],
};
assert_eq!(message.get_opcode().unwrap() as u8, Opcode::QUERY as u8);
message.set_response(RCODE::NOTIMP);
assert!((message.header.flags & (1 << 15)) > 0);
assert_eq!(message.get_rcode().unwrap(), RCODE::NOTIMP);
}
}

View file

@ -295,3 +295,151 @@ impl ToBytes for Message {
result
}
}
#[cfg(test)]
pub mod tests {
use super::*;
pub fn get_rr() -> RR {
RR {
name: vec![String::from("example"), String::from("org")],
_type: Type::Type(RRType::A),
class: Class::Class(RRClass::IN),
ttl: 10,
rdlength: 4,
rdata: vec![1, 2, 3, 4],
}
}
pub fn get_message() -> Message {
Message {
header: Header {
id: 1,
flags: 288,
qdcount: 2,
ancount: 1,
nscount: 1,
arcount: 1,
},
question: vec![
Question {
qname: vec![String::from("example"), String::from("org")],
qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN),
},
Question {
qname: vec![String::from("example"), String::from("org")],
qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN),
},
],
answer: vec![get_rr()],
authority: vec![get_rr()],
additional: vec![get_rr()],
}
}
#[test]
fn test_parse_header() {
let header = Header {
id: 1,
flags: 288,
qdcount: 1,
ancount: 0,
nscount: 0,
arcount: 0,
};
let bytes = Header::to_bytes(header.clone());
let parsed = Header::from_bytes(&mut Reader::new(&bytes));
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap(), header);
}
#[test]
fn test_parse_question() {
let question = Question {
qname: vec![String::from("example"), String::from("org")],
qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN),
};
let bytes = Question::to_bytes(question.clone());
let parsed = Question::from_bytes(&mut Reader::new(&bytes));
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap(), question);
}
#[test]
fn test_parse_rr() {
let rr = get_rr();
let bytes = RR::to_bytes(rr.clone());
let parsed = RR::from_bytes(&mut Reader::new(&bytes));
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap(), rr);
}
#[test]
fn test_labelstring() {
let labelstring = vec![String::from("example"), String::from("org")];
let bytes = LabelString::to_bytes(labelstring.clone());
let parsed = LabelString::from_bytes(&mut Reader::new(&bytes));
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap(), labelstring);
}
#[test]
fn test_labelstring_ptr() {
let labelstring = vec![String::from("example"), String::from("org")];
let mut bytes = LabelString::to_bytes(labelstring.clone());
bytes.insert(0, 0);
bytes.insert(0, 0);
let to_read = bytes.len();
bytes.push(0b11000000);
bytes.push(0b00000010);
let mut reader = Reader::new(&bytes);
let _ = reader.read(to_read);
let parsed = LabelString::from_bytes(&mut reader);
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap(), labelstring);
}
#[test]
fn test_labelstring_invalid_ptr() {
let labelstring = vec![String::from("example"), String::from("org")];
let mut bytes = LabelString::to_bytes(labelstring.clone());
bytes.insert(0, 0);
bytes.insert(0, 0);
let to_read = bytes.len();
bytes.push(0b11000000);
// Not allowed to point to itself or in the future
bytes.push(to_read as u8);
let mut reader = Reader::new(&bytes);
let _ = reader.read(to_read);
let parsed = LabelString::from_bytes(&mut reader);
assert!(parsed.is_err());
}
#[test]
fn test_parse_message() {
let message = get_message();
let bytes = Message::to_bytes(message.clone());
let parsed = Message::from_bytes(&mut Reader::new(&bytes));
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap(), message);
}
}

View file

@ -1,5 +1,3 @@
use std::array::TryFromSliceError;
use crate::errors::ZNSError;
pub struct Reader<'a> {
@ -33,43 +31,37 @@ impl<'a> Reader<'a> {
}
pub fn read_u8(&mut self) -> Result<u8> {
if self.unread_bytes() == 0 {
Err(ZNSError::Reader {
message: String::from("cannot read u8"),
})
} else {
self.position += 1;
Ok(self.buffer[self.position - 1])
}
}
pub fn read_u16(&mut self) -> Result<u16> {
let result = u16::from_be_bytes(
self.buffer[self.position..self.position + 2]
.try_into()
.map_err(|e: TryFromSliceError| ZNSError::Reader {
message: e.to_string(),
})?,
);
self.position += 2;
let result =
u16::from_be_bytes(self.read(2)?.try_into().map_err(|_| ZNSError::Reader {
message: String::from("invalid read_u16"),
})?);
Ok(result)
}
pub fn read_i32(&mut self) -> Result<i32> {
let result = i32::from_be_bytes(
self.buffer[self.position..self.position + 4]
.try_into()
.map_err(|e: TryFromSliceError| ZNSError::Reader {
message: e.to_string(),
})?,
);
self.position += 4;
let result =
i32::from_be_bytes(self.read(4)?.try_into().map_err(|_| ZNSError::Reader {
message: String::from("invalid read_u32"),
})?);
Ok(result)
}
pub fn read_u32(&mut self) -> Result<u32> {
let result = u32::from_be_bytes(
self.buffer[self.position..self.position + 4]
.try_into()
.map_err(|e: TryFromSliceError| ZNSError::Reader {
message: e.to_string(),
})?,
);
self.position += 4;
let result =
u32::from_be_bytes(self.read(4)?.try_into().map_err(|_| ZNSError::Reader {
message: String::from("invalid read_u32"),
})?);
Ok(result)
}
@ -83,3 +75,58 @@ impl<'a> Reader<'a> {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test() {
let fake_bytes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut reader = Reader::new(&fake_bytes);
assert_eq!(reader.unread_bytes(), 11);
let u16 = reader.read_u16();
assert!(u16.is_ok());
assert_eq!(u16.unwrap(), 1);
assert_eq!(reader.unread_bytes(), 9);
let u8 = reader.read_u8();
assert!(u8.is_ok());
assert_eq!(u8.unwrap(), 2);
assert_eq!(reader.unread_bytes(), 8);
let u32 = reader.read_u32();
assert!(u32.is_ok());
assert_eq!(
u32.unwrap(),
u32::from_be_bytes(fake_bytes[3..7].try_into().unwrap())
);
assert_eq!(reader.unread_bytes(), 4);
let read = reader.read(3);
assert!(read.is_ok());
assert_eq!(read.unwrap(), fake_bytes[7..10]);
assert_eq!(reader.unread_bytes(), 1);
let too_much = reader.read(2);
assert!(too_much.is_err());
assert_eq!(reader.unread_bytes(), 1);
assert!(reader.read_u8().is_ok());
assert!(reader.read_u8().is_err());
assert!(reader.read_u16().is_err());
assert!(reader.read_u32().is_err());
assert!(reader.read_i32().is_err());
let new_reader = reader.seek(1);
assert!(new_reader.is_ok());
assert_eq!(new_reader.unwrap().unread_bytes(), 10);
let new_reader = reader.seek(100);
assert!(new_reader.is_err());
}
}

View file

@ -4,6 +4,7 @@ use std::sync::Arc;
use tokio::net::UdpSocket;
use crate::db::lib::get_connection;
use crate::errors::ZNSError;
use crate::handlers::{Handler, ResponseHandler};
use crate::parser::{FromBytes, ToBytes};
@ -43,12 +44,13 @@ fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
async fn get_response(bytes: &[u8]) -> Message {
let mut reader = Reader::new(bytes);
match Message::from_bytes(&mut reader) {
Ok(mut message) => match Handler::handle(&message, bytes).await {
Ok(mut message) => match Handler::handle(&message, bytes, &mut get_connection()).await {
Ok(mut response) => {
response.set_response(RCODE::NOERROR);
response
}
Err(e) => {
println!("{:#?}", message);
eprintln!("{}", e.to_string());
message.set_response(e.rcode());
message
@ -72,3 +74,36 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Erro
});
}
}
#[cfg(test)]
mod tests {
use crate::structs::{Class, Question, RRClass, RRType, Type};
use super::*;
#[tokio::test]
async fn test_get_response() {
let message = Message {
header: Header {
id: 1,
flags: 288,
qdcount: 1,
ancount: 0,
nscount: 0,
arcount: 0,
},
question: vec![Question {
qname: vec![String::from("example"), String::from("org")],
qtype: Type::Type(RRType::A),
qclass: Class::Class(RRClass::IN),
}],
answer: vec![],
authority: vec![],
additional: vec![],
};
let response = get_response(&Message::to_bytes(message)).await;
assert_eq!(response.get_rcode(), Ok(RCODE::NXDOMAIN));
}
}

View file

@ -33,6 +33,7 @@ pub enum RRClass {
#[repr(u16)]
#[allow(dead_code)]
#[derive(Debug, IntEnum, PartialEq)]
pub enum RCODE {
NOERROR = 0,
FORMERR = 1,
@ -52,14 +53,14 @@ pub enum Opcode {
UPDATE = 5,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct Question {
pub qname: LabelString,
pub qtype: Type, // NOTE: should be QTYPE, right now not really needed
pub qclass: Class, //NOTE: should be QCLASS, right now not really needed
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct Header {
pub id: u16,
pub flags: u16, // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | ; 1 | 4 | 1 | 1 | 1 | 1 | 3 | 4
@ -69,7 +70,7 @@ pub struct Header {
pub arcount: u16,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct Message {
pub header: Header,
pub question: Vec<Question>,
@ -78,7 +79,7 @@ pub struct Message {
pub additional: Vec<RR>,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct RR {
pub name: LabelString,
pub _type: Type,
@ -88,11 +89,4 @@ pub struct RR {
pub rdata: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct OptRR {
pub code: u16,
pub length: u16,
pub rdata: Vec<u8>,
}
pub type LabelString = Vec<String>;