mirror of
https://github.com/ZeusWPI/ZNS.git
synced 2024-11-24 14:11:10 +01:00
wrote some tests (at last)
This commit is contained in:
parent
aa94dc21bb
commit
5d59f1bd97
12 changed files with 479 additions and 92 deletions
|
@ -2,8 +2,19 @@ use diesel::prelude::*;
|
|||
|
||||
use crate::config::Config;
|
||||
|
||||
pub fn establish_connection() -> PgConnection {
|
||||
pub fn get_connection() -> PgConnection {
|
||||
let database_url = Config::get().db_uri.clone();
|
||||
PgConnection::establish(&database_url)
|
||||
.unwrap_or_else(|_| panic!("Error connecting to {}", Config::get().db_uri))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
|
||||
pub fn get_test_connection() -> PgConnection {
|
||||
let mut connection = get_connection();
|
||||
assert!(connection.begin_test_transaction().is_ok());
|
||||
connection
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,8 +6,6 @@ use diesel::prelude::*;
|
|||
|
||||
use self::schema::records::{self};
|
||||
|
||||
use super::lib::establish_connection;
|
||||
|
||||
mod schema {
|
||||
diesel::table! {
|
||||
records (name, _type, class, rdlength, rdata) {
|
||||
|
@ -81,8 +79,7 @@ impl Record {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn insert_into_database(rr: &RR) -> Result<(), ZNSError> {
|
||||
let db_connection = &mut establish_connection();
|
||||
pub fn insert_into_database(rr: &RR, connection: &mut PgConnection) -> Result<(), ZNSError> {
|
||||
let record = Record {
|
||||
name: rr.name.join("."),
|
||||
_type: rr._type.clone().into(),
|
||||
|
@ -92,21 +89,21 @@ pub async fn insert_into_database(rr: &RR) -> Result<(), ZNSError> {
|
|||
rdata: rr.rdata.clone(),
|
||||
};
|
||||
|
||||
Record::create(db_connection, record).map_err(|e| ZNSError::Database {
|
||||
Record::create(connection, record).map_err(|e| ZNSError::Database {
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_from_database(
|
||||
pub fn get_from_database(
|
||||
name: &Vec<String>,
|
||||
_type: Type,
|
||||
class: Class,
|
||||
connection: &mut PgConnection,
|
||||
) -> Result<Vec<RR>, ZNSError> {
|
||||
let db_connection = &mut establish_connection();
|
||||
let records =
|
||||
Record::get(db_connection, name.join("."), _type.into(), class.into()).map_err(|e| {
|
||||
Record::get(connection, name.join("."), _type.into(), class.into()).map_err(|e| {
|
||||
ZNSError::Database {
|
||||
message: e.to_string(),
|
||||
}
|
||||
|
@ -128,18 +125,58 @@ pub async fn get_from_database(
|
|||
}
|
||||
|
||||
//TODO: cleanup models
|
||||
pub async fn delete_from_database(
|
||||
pub fn delete_from_database(
|
||||
name: &Vec<String>,
|
||||
_type: Option<Type>,
|
||||
class: Class,
|
||||
rdata: Option<Vec<u8>>,
|
||||
connection: &mut PgConnection,
|
||||
) {
|
||||
let db_connection = &mut establish_connection();
|
||||
let _ = Record::delete(
|
||||
db_connection,
|
||||
connection,
|
||||
name.join("."),
|
||||
_type.map(|f| f.into()),
|
||||
class.into(),
|
||||
rdata,
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::{db::lib::tests::get_test_connection, parser::tests::get_rr};
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
let mut connection = get_test_connection();
|
||||
|
||||
let rr = get_rr();
|
||||
|
||||
let f = |connection: &mut PgConnection| {
|
||||
get_from_database(&rr.name, rr._type.clone(), rr.class.clone(), connection)
|
||||
};
|
||||
|
||||
assert!(f(&mut connection).unwrap().is_empty());
|
||||
|
||||
assert!(insert_into_database(&rr, &mut connection).is_ok());
|
||||
|
||||
let result = f(&mut connection);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.as_ref().unwrap().len(), 1);
|
||||
assert_eq!(result.unwrap()[0], rr);
|
||||
|
||||
delete_from_database(
|
||||
&rr.name,
|
||||
Some(rr._type.clone()),
|
||||
rr.class.clone(),
|
||||
Some(rr.rdata.clone()),
|
||||
&mut connection,
|
||||
);
|
||||
|
||||
assert!(f(&mut connection).unwrap().is_empty());
|
||||
|
||||
assert!(insert_into_database(&rr, &mut connection).is_ok());
|
||||
assert!(insert_into_database(&rr, &mut connection).is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use diesel::PgConnection;
|
||||
|
||||
use crate::{
|
||||
errors::ZNSError,
|
||||
structs::{Message, Opcode},
|
||||
|
@ -9,17 +11,26 @@ mod query;
|
|||
mod update;
|
||||
|
||||
pub trait ResponseHandler {
|
||||
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError>;
|
||||
async fn handle(
|
||||
message: &Message,
|
||||
raw: &[u8],
|
||||
connection: &mut PgConnection,
|
||||
) -> Result<Message, ZNSError>;
|
||||
}
|
||||
|
||||
pub struct Handler {}
|
||||
|
||||
impl ResponseHandler for Handler {
|
||||
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError> {
|
||||
async fn handle(
|
||||
message: &Message,
|
||||
raw: &[u8],
|
||||
connection: &mut PgConnection,
|
||||
) -> Result<Message, ZNSError> {
|
||||
match message.get_opcode() {
|
||||
//TODO: implement this in Opcode
|
||||
Ok(opcode) => match opcode {
|
||||
Opcode::QUERY => QueryHandler::handle(&message, raw).await,
|
||||
Opcode::UPDATE => UpdateHandler::handle(&message, raw).await,
|
||||
Opcode::QUERY => QueryHandler::handle(&message, raw, connection).await,
|
||||
Opcode::UPDATE => UpdateHandler::handle(&message, raw, connection).await,
|
||||
},
|
||||
Err(e) => Err(ZNSError::Formerr {
|
||||
message: e.to_string(),
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use diesel::PgConnection;
|
||||
|
||||
use crate::{db::models::get_from_database, errors::ZNSError, structs::Message};
|
||||
|
||||
use super::ResponseHandler;
|
||||
|
@ -5,7 +7,11 @@ use super::ResponseHandler;
|
|||
pub struct QueryHandler {}
|
||||
|
||||
impl ResponseHandler for QueryHandler {
|
||||
async fn handle(message: &Message, _raw: &[u8]) -> Result<Message, ZNSError> {
|
||||
async fn handle(
|
||||
message: &Message,
|
||||
_raw: &[u8],
|
||||
connection: &mut PgConnection,
|
||||
) -> Result<Message, ZNSError> {
|
||||
let mut response = message.clone();
|
||||
response.header.arcount = 0; //TODO: fix this, handle unknown class values
|
||||
|
||||
|
@ -14,8 +20,8 @@ impl ResponseHandler for QueryHandler {
|
|||
&question.qname,
|
||||
question.qtype.clone(),
|
||||
question.qclass.clone(),
|
||||
)
|
||||
.await;
|
||||
connection,
|
||||
);
|
||||
|
||||
match answers {
|
||||
Ok(rrs) => {
|
||||
|
@ -24,7 +30,7 @@ impl ResponseHandler for QueryHandler {
|
|||
domain: question.qname.join("."),
|
||||
});
|
||||
}
|
||||
response.header.ancount = rrs.len() as u16;
|
||||
response.header.ancount += rrs.len() as u16;
|
||||
response.answer.extend(rrs)
|
||||
}
|
||||
Err(e) => {
|
||||
|
@ -38,3 +44,38 @@ impl ResponseHandler for QueryHandler {
|
|||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::parser::tests::get_message;
|
||||
use crate::structs::*;
|
||||
|
||||
use crate::{
|
||||
db::{lib::tests::get_test_connection, models::insert_into_database},
|
||||
parser::{tests::get_rr, ToBytes},
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_query() {
|
||||
let mut connection = get_test_connection();
|
||||
let rr = get_rr();
|
||||
let mut message = get_message();
|
||||
message.header.ancount = 0;
|
||||
message.answer = vec![];
|
||||
|
||||
assert!(insert_into_database(&rr, &mut connection).is_ok());
|
||||
|
||||
let result = QueryHandler::handle(
|
||||
&message,
|
||||
&Message::to_bytes(message.clone()),
|
||||
&mut connection,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.header.ancount, 2);
|
||||
assert_eq!(result.answer.len(), 2);
|
||||
assert_eq!(result.answer[0], rr);
|
||||
assert_eq!(result.answer[1], rr);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use diesel::PgConnection;
|
||||
|
||||
use crate::{
|
||||
config::Config,
|
||||
db::models::get_from_database,
|
||||
|
@ -9,7 +11,11 @@ use crate::{
|
|||
|
||||
use super::{dnskey::DNSKeyRData, sig::Sig};
|
||||
|
||||
pub async fn authenticate(sig: &Sig, zone: &Vec<String>) -> Result<bool, ZNSError> {
|
||||
pub async fn authenticate(
|
||||
sig: &Sig,
|
||||
zone: &Vec<String>,
|
||||
connection: &mut PgConnection,
|
||||
) -> Result<bool, ZNSError> {
|
||||
if zone.len() >= 4 {
|
||||
let username = &zone[zone.len() - 4]; // Should match: username.users.zeus.gent
|
||||
|
||||
|
@ -18,7 +24,7 @@ pub async fn authenticate(sig: &Sig, zone: &Vec<String>) -> Result<bool, ZNSErro
|
|||
if ssh_verified {
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(validate_dnskey(zone, sig).await?)
|
||||
Ok(validate_dnskey(zone, sig, connection).await?)
|
||||
}
|
||||
} else {
|
||||
Err(ZNSError::NotAuth {
|
||||
|
@ -40,15 +46,21 @@ async fn validate_ssh(username: &String, sig: &Sig) -> Result<bool, reqwest::Err
|
|||
.any(|key| sig.verify_ssh(&key).is_ok_and(|b| b)))
|
||||
}
|
||||
|
||||
async fn validate_dnskey(zone: &Vec<String>, sig: &Sig) -> Result<bool, ZNSError> {
|
||||
Ok(
|
||||
get_from_database(zone, Type::Type(RRType::DNSKEY), Class::Class(RRClass::IN))
|
||||
.await?
|
||||
async fn validate_dnskey(
|
||||
zone: &Vec<String>,
|
||||
sig: &Sig,
|
||||
connection: &mut PgConnection,
|
||||
) -> Result<bool, ZNSError> {
|
||||
Ok(get_from_database(
|
||||
zone,
|
||||
Type::Type(RRType::DNSKEY),
|
||||
Class::Class(RRClass::IN),
|
||||
connection,
|
||||
)?
|
||||
.iter()
|
||||
.any(|rr| {
|
||||
let mut reader = Reader::new(&rr.rdata);
|
||||
DNSKeyRData::from_bytes(&mut reader)
|
||||
.is_ok_and(|dnskey| sig.verify_dnskey(dnskey).is_ok_and(|b| b))
|
||||
}),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use diesel::PgConnection;
|
||||
|
||||
use crate::{
|
||||
db::models::{delete_from_database, insert_into_database},
|
||||
errors::ZNSError,
|
||||
|
@ -17,7 +19,11 @@ mod sig;
|
|||
pub struct UpdateHandler {}
|
||||
|
||||
impl ResponseHandler for UpdateHandler {
|
||||
async fn handle(message: &Message, raw: &[u8]) -> Result<Message, ZNSError> {
|
||||
async fn handle(
|
||||
message: &Message,
|
||||
raw: &[u8],
|
||||
connection: &mut PgConnection,
|
||||
) -> Result<Message, ZNSError> {
|
||||
let response = message.clone();
|
||||
// Zone section (question) processing
|
||||
if (message.header.qdcount != 1)
|
||||
|
@ -44,7 +50,7 @@ impl ResponseHandler for UpdateHandler {
|
|||
if last.is_some() && last.unwrap()._type == Type::Type(RRType::KEY) {
|
||||
let sig = Sig::new(last.unwrap(), raw)?;
|
||||
|
||||
if !authenticate::authenticate(&sig, &zone.qname)
|
||||
if !authenticate::authenticate(&sig, &zone.qname, connection)
|
||||
.await
|
||||
.is_ok_and(|x| x)
|
||||
{
|
||||
|
@ -89,7 +95,7 @@ impl ResponseHandler for UpdateHandler {
|
|||
|
||||
for rr in &message.authority {
|
||||
if rr.class == zone.qclass {
|
||||
let _ = insert_into_database(&rr).await;
|
||||
let _ = insert_into_database(&rr, connection);
|
||||
} else if rr.class == Class::Class(RRClass::ANY) {
|
||||
if rr._type == Type::Type(RRType::ANY) {
|
||||
if rr.name == zone.qname {
|
||||
|
@ -98,7 +104,13 @@ impl ResponseHandler for UpdateHandler {
|
|||
message: "rr.name == zone.qname".to_string(),
|
||||
});
|
||||
} else {
|
||||
delete_from_database(&rr.name, None, Class::Class(RRClass::IN), None).await;
|
||||
delete_from_database(
|
||||
&rr.name,
|
||||
None,
|
||||
Class::Class(RRClass::IN),
|
||||
None,
|
||||
connection,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
delete_from_database(
|
||||
|
@ -106,8 +118,8 @@ impl ResponseHandler for UpdateHandler {
|
|||
Some(rr._type.clone()),
|
||||
Class::Class(RRClass::IN),
|
||||
None,
|
||||
connection,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
} else if rr.class == Class::Class(RRClass::NONE) {
|
||||
if rr._type == Type::Type(RRType::SOA) {
|
||||
|
@ -118,8 +130,8 @@ impl ResponseHandler for UpdateHandler {
|
|||
Some(rr._type.clone()),
|
||||
Class::Class(RRClass::IN),
|
||||
Some(rr.rdata.clone()),
|
||||
connection,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,18 @@ impl PublicKey for RsaPublicKey {
|
|||
Ok(RsaPublicKey { e, n })
|
||||
}
|
||||
|
||||
fn from_dnskey(key: &[u8]) -> Result<Self, ZNSError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let mut reader = Reader::new(key);
|
||||
let e_len = reader.read_u8()?;
|
||||
let e = reader.read(e_len as usize)?;
|
||||
let mut n = reader.read(reader.unread_bytes())?;
|
||||
n.insert(0, 0);
|
||||
Ok(RsaPublicKey { e, n })
|
||||
}
|
||||
|
||||
fn verify(
|
||||
&self,
|
||||
data: &[u8],
|
||||
|
@ -47,7 +59,7 @@ impl PublicKey for RsaPublicKey {
|
|||
Algorithm::RSASHA512 => Ok(&signature::RSA_PKCS1_2048_8192_SHA512),
|
||||
Algorithm::RSASHA256 => Ok(&signature::RSA_PKCS1_2048_8192_SHA256),
|
||||
_ => Err(ZNSError::PublicKey {
|
||||
message: format!("RsaPublicKey: invalid verify algorithm",),
|
||||
message: String::from("RsaPublicKey: invalid verify algorithm"),
|
||||
}),
|
||||
}?;
|
||||
|
||||
|
@ -55,16 +67,4 @@ impl PublicKey for RsaPublicKey {
|
|||
|
||||
Ok(pkey.verify(data, signature).is_ok())
|
||||
}
|
||||
|
||||
fn from_dnskey(key: &[u8]) -> Result<Self, ZNSError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let mut reader = Reader::new(key);
|
||||
let e_len = reader.read_u8()?;
|
||||
let e = reader.read(e_len as usize)?;
|
||||
let mut n = reader.read(reader.unread_bytes())?;
|
||||
n.insert(0, 0);
|
||||
Ok(RsaPublicKey { e, n })
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,4 +9,43 @@ impl Message {
|
|||
pub fn get_opcode(&self) -> Result<Opcode, String> {
|
||||
Opcode::try_from((self.header.flags & 0b0111100000000000) >> 11)
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // Used with tests
|
||||
pub fn get_rcode(&self) -> Result<RCODE, u16> {
|
||||
RCODE::try_from(self.header.flags & (!0 >> 12))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::structs::Header;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
let mut message = Message {
|
||||
header: Header {
|
||||
id: 1,
|
||||
flags: 288,
|
||||
qdcount: 0,
|
||||
ancount: 0,
|
||||
nscount: 0,
|
||||
arcount: 0,
|
||||
},
|
||||
question: vec![],
|
||||
answer: vec![],
|
||||
authority: vec![],
|
||||
additional: vec![],
|
||||
};
|
||||
|
||||
assert_eq!(message.get_opcode().unwrap() as u8, Opcode::QUERY as u8);
|
||||
|
||||
message.set_response(RCODE::NOTIMP);
|
||||
|
||||
assert!((message.header.flags & (1 << 15)) > 0);
|
||||
|
||||
assert_eq!(message.get_rcode().unwrap(), RCODE::NOTIMP);
|
||||
}
|
||||
}
|
||||
|
|
148
src/parser.rs
148
src/parser.rs
|
@ -295,3 +295,151 @@ impl ToBytes for Message {
|
|||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
|
||||
pub fn get_rr() -> RR {
|
||||
RR {
|
||||
name: vec![String::from("example"), String::from("org")],
|
||||
_type: Type::Type(RRType::A),
|
||||
class: Class::Class(RRClass::IN),
|
||||
ttl: 10,
|
||||
rdlength: 4,
|
||||
rdata: vec![1, 2, 3, 4],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_message() -> Message {
|
||||
Message {
|
||||
header: Header {
|
||||
id: 1,
|
||||
flags: 288,
|
||||
qdcount: 2,
|
||||
ancount: 1,
|
||||
nscount: 1,
|
||||
arcount: 1,
|
||||
},
|
||||
question: vec![
|
||||
Question {
|
||||
qname: vec![String::from("example"), String::from("org")],
|
||||
qtype: Type::Type(RRType::A),
|
||||
qclass: Class::Class(RRClass::IN),
|
||||
},
|
||||
Question {
|
||||
qname: vec![String::from("example"), String::from("org")],
|
||||
qtype: Type::Type(RRType::A),
|
||||
qclass: Class::Class(RRClass::IN),
|
||||
},
|
||||
],
|
||||
answer: vec![get_rr()],
|
||||
authority: vec![get_rr()],
|
||||
additional: vec![get_rr()],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_header() {
|
||||
let header = Header {
|
||||
id: 1,
|
||||
flags: 288,
|
||||
qdcount: 1,
|
||||
ancount: 0,
|
||||
nscount: 0,
|
||||
arcount: 0,
|
||||
};
|
||||
|
||||
let bytes = Header::to_bytes(header.clone());
|
||||
let parsed = Header::from_bytes(&mut Reader::new(&bytes));
|
||||
assert!(parsed.is_ok());
|
||||
assert_eq!(parsed.unwrap(), header);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_question() {
|
||||
let question = Question {
|
||||
qname: vec![String::from("example"), String::from("org")],
|
||||
qtype: Type::Type(RRType::A),
|
||||
qclass: Class::Class(RRClass::IN),
|
||||
};
|
||||
|
||||
let bytes = Question::to_bytes(question.clone());
|
||||
let parsed = Question::from_bytes(&mut Reader::new(&bytes));
|
||||
assert!(parsed.is_ok());
|
||||
assert_eq!(parsed.unwrap(), question);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_rr() {
|
||||
let rr = get_rr();
|
||||
|
||||
let bytes = RR::to_bytes(rr.clone());
|
||||
let parsed = RR::from_bytes(&mut Reader::new(&bytes));
|
||||
assert!(parsed.is_ok());
|
||||
assert_eq!(parsed.unwrap(), rr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_labelstring() {
|
||||
let labelstring = vec![String::from("example"), String::from("org")];
|
||||
|
||||
let bytes = LabelString::to_bytes(labelstring.clone());
|
||||
let parsed = LabelString::from_bytes(&mut Reader::new(&bytes));
|
||||
assert!(parsed.is_ok());
|
||||
assert_eq!(parsed.unwrap(), labelstring);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_labelstring_ptr() {
|
||||
let labelstring = vec![String::from("example"), String::from("org")];
|
||||
|
||||
let mut bytes = LabelString::to_bytes(labelstring.clone());
|
||||
|
||||
bytes.insert(0, 0);
|
||||
bytes.insert(0, 0);
|
||||
|
||||
let to_read = bytes.len();
|
||||
|
||||
bytes.push(0b11000000);
|
||||
bytes.push(0b00000010);
|
||||
|
||||
let mut reader = Reader::new(&bytes);
|
||||
let _ = reader.read(to_read);
|
||||
|
||||
let parsed = LabelString::from_bytes(&mut reader);
|
||||
assert!(parsed.is_ok());
|
||||
assert_eq!(parsed.unwrap(), labelstring);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_labelstring_invalid_ptr() {
|
||||
let labelstring = vec![String::from("example"), String::from("org")];
|
||||
|
||||
let mut bytes = LabelString::to_bytes(labelstring.clone());
|
||||
|
||||
bytes.insert(0, 0);
|
||||
bytes.insert(0, 0);
|
||||
|
||||
let to_read = bytes.len();
|
||||
|
||||
bytes.push(0b11000000);
|
||||
// Not allowed to point to itself or in the future
|
||||
bytes.push(to_read as u8);
|
||||
|
||||
let mut reader = Reader::new(&bytes);
|
||||
let _ = reader.read(to_read);
|
||||
|
||||
let parsed = LabelString::from_bytes(&mut reader);
|
||||
assert!(parsed.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_message() {
|
||||
let message = get_message();
|
||||
let bytes = Message::to_bytes(message.clone());
|
||||
let parsed = Message::from_bytes(&mut Reader::new(&bytes));
|
||||
assert!(parsed.is_ok());
|
||||
assert_eq!(parsed.unwrap(), message);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use std::array::TryFromSliceError;
|
||||
|
||||
use crate::errors::ZNSError;
|
||||
|
||||
pub struct Reader<'a> {
|
||||
|
@ -33,43 +31,37 @@ impl<'a> Reader<'a> {
|
|||
}
|
||||
|
||||
pub fn read_u8(&mut self) -> Result<u8> {
|
||||
if self.unread_bytes() == 0 {
|
||||
Err(ZNSError::Reader {
|
||||
message: String::from("cannot read u8"),
|
||||
})
|
||||
} else {
|
||||
self.position += 1;
|
||||
Ok(self.buffer[self.position - 1])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_u16(&mut self) -> Result<u16> {
|
||||
let result = u16::from_be_bytes(
|
||||
self.buffer[self.position..self.position + 2]
|
||||
.try_into()
|
||||
.map_err(|e: TryFromSliceError| ZNSError::Reader {
|
||||
message: e.to_string(),
|
||||
})?,
|
||||
);
|
||||
self.position += 2;
|
||||
let result =
|
||||
u16::from_be_bytes(self.read(2)?.try_into().map_err(|_| ZNSError::Reader {
|
||||
message: String::from("invalid read_u16"),
|
||||
})?);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn read_i32(&mut self) -> Result<i32> {
|
||||
let result = i32::from_be_bytes(
|
||||
self.buffer[self.position..self.position + 4]
|
||||
.try_into()
|
||||
.map_err(|e: TryFromSliceError| ZNSError::Reader {
|
||||
message: e.to_string(),
|
||||
})?,
|
||||
);
|
||||
self.position += 4;
|
||||
let result =
|
||||
i32::from_be_bytes(self.read(4)?.try_into().map_err(|_| ZNSError::Reader {
|
||||
message: String::from("invalid read_u32"),
|
||||
})?);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn read_u32(&mut self) -> Result<u32> {
|
||||
let result = u32::from_be_bytes(
|
||||
self.buffer[self.position..self.position + 4]
|
||||
.try_into()
|
||||
.map_err(|e: TryFromSliceError| ZNSError::Reader {
|
||||
message: e.to_string(),
|
||||
})?,
|
||||
);
|
||||
self.position += 4;
|
||||
let result =
|
||||
u32::from_be_bytes(self.read(4)?.try_into().map_err(|_| ZNSError::Reader {
|
||||
message: String::from("invalid read_u32"),
|
||||
})?);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
|
@ -83,3 +75,58 @@ impl<'a> Reader<'a> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
let fake_bytes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||
let mut reader = Reader::new(&fake_bytes);
|
||||
|
||||
assert_eq!(reader.unread_bytes(), 11);
|
||||
|
||||
let u16 = reader.read_u16();
|
||||
assert!(u16.is_ok());
|
||||
assert_eq!(u16.unwrap(), 1);
|
||||
|
||||
assert_eq!(reader.unread_bytes(), 9);
|
||||
|
||||
let u8 = reader.read_u8();
|
||||
assert!(u8.is_ok());
|
||||
assert_eq!(u8.unwrap(), 2);
|
||||
assert_eq!(reader.unread_bytes(), 8);
|
||||
|
||||
let u32 = reader.read_u32();
|
||||
assert!(u32.is_ok());
|
||||
assert_eq!(
|
||||
u32.unwrap(),
|
||||
u32::from_be_bytes(fake_bytes[3..7].try_into().unwrap())
|
||||
);
|
||||
assert_eq!(reader.unread_bytes(), 4);
|
||||
|
||||
let read = reader.read(3);
|
||||
assert!(read.is_ok());
|
||||
assert_eq!(read.unwrap(), fake_bytes[7..10]);
|
||||
assert_eq!(reader.unread_bytes(), 1);
|
||||
|
||||
let too_much = reader.read(2);
|
||||
assert!(too_much.is_err());
|
||||
assert_eq!(reader.unread_bytes(), 1);
|
||||
|
||||
assert!(reader.read_u8().is_ok());
|
||||
|
||||
assert!(reader.read_u8().is_err());
|
||||
assert!(reader.read_u16().is_err());
|
||||
assert!(reader.read_u32().is_err());
|
||||
assert!(reader.read_i32().is_err());
|
||||
|
||||
let new_reader = reader.seek(1);
|
||||
assert!(new_reader.is_ok());
|
||||
assert_eq!(new_reader.unwrap().unread_bytes(), 10);
|
||||
|
||||
let new_reader = reader.seek(100);
|
||||
assert!(new_reader.is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ use std::sync::Arc;
|
|||
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
use crate::db::lib::get_connection;
|
||||
use crate::errors::ZNSError;
|
||||
use crate::handlers::{Handler, ResponseHandler};
|
||||
use crate::parser::{FromBytes, ToBytes};
|
||||
|
@ -43,12 +44,13 @@ fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
|
|||
async fn get_response(bytes: &[u8]) -> Message {
|
||||
let mut reader = Reader::new(bytes);
|
||||
match Message::from_bytes(&mut reader) {
|
||||
Ok(mut message) => match Handler::handle(&message, bytes).await {
|
||||
Ok(mut message) => match Handler::handle(&message, bytes, &mut get_connection()).await {
|
||||
Ok(mut response) => {
|
||||
response.set_response(RCODE::NOERROR);
|
||||
response
|
||||
}
|
||||
Err(e) => {
|
||||
println!("{:#?}", message);
|
||||
eprintln!("{}", e.to_string());
|
||||
message.set_response(e.rcode());
|
||||
message
|
||||
|
@ -72,3 +74,36 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Erro
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::structs::{Class, Question, RRClass, RRType, Type};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_response() {
|
||||
let message = Message {
|
||||
header: Header {
|
||||
id: 1,
|
||||
flags: 288,
|
||||
qdcount: 1,
|
||||
ancount: 0,
|
||||
nscount: 0,
|
||||
arcount: 0,
|
||||
},
|
||||
question: vec![Question {
|
||||
qname: vec![String::from("example"), String::from("org")],
|
||||
qtype: Type::Type(RRType::A),
|
||||
qclass: Class::Class(RRClass::IN),
|
||||
}],
|
||||
answer: vec![],
|
||||
authority: vec![],
|
||||
additional: vec![],
|
||||
};
|
||||
|
||||
let response = get_response(&Message::to_bytes(message)).await;
|
||||
|
||||
assert_eq!(response.get_rcode(), Ok(RCODE::NXDOMAIN));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ pub enum RRClass {
|
|||
|
||||
#[repr(u16)]
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, IntEnum, PartialEq)]
|
||||
pub enum RCODE {
|
||||
NOERROR = 0,
|
||||
FORMERR = 1,
|
||||
|
@ -52,14 +53,14 @@ pub enum Opcode {
|
|||
UPDATE = 5,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Question {
|
||||
pub qname: LabelString,
|
||||
pub qtype: Type, // NOTE: should be QTYPE, right now not really needed
|
||||
pub qclass: Class, //NOTE: should be QCLASS, right now not really needed
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Header {
|
||||
pub id: u16,
|
||||
pub flags: u16, // |QR| Opcode |AA|TC|RD|RA| Z | RCODE | ; 1 | 4 | 1 | 1 | 1 | 1 | 3 | 4
|
||||
|
@ -69,7 +70,7 @@ pub struct Header {
|
|||
pub arcount: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Message {
|
||||
pub header: Header,
|
||||
pub question: Vec<Question>,
|
||||
|
@ -78,7 +79,7 @@ pub struct Message {
|
|||
pub additional: Vec<RR>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct RR {
|
||||
pub name: LabelString,
|
||||
pub _type: Type,
|
||||
|
@ -88,11 +89,4 @@ pub struct RR {
|
|||
pub rdata: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OptRR {
|
||||
pub code: u16,
|
||||
pub length: u16,
|
||||
pub rdata: Vec<u8>,
|
||||
}
|
||||
|
||||
pub type LabelString = Vec<String>;
|
||||
|
|
Loading…
Reference in a new issue