10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-10-29 21:14:27 +01:00

Add delete support

This commit is contained in:
Xander Bil 2024-03-26 23:22:12 +01:00
parent 79b040f5f9
commit 08d4ca82f5
No known key found for this signature in database
GPG key ID: EC9706B54A278598
7 changed files with 110 additions and 35 deletions

View file

View file

@ -7,5 +7,5 @@ CREATE TABLE records (
rdlength INT NOT NULL, rdlength INT NOT NULL,
rdata BLOB NOT NULL, rdata BLOB NOT NULL,
PRIMARY KEY (name,type,class) PRIMARY KEY (name,type,class,rdlength,rdata)
) )

View file

@ -2,15 +2,15 @@ use crate::{
errors::DatabaseError, errors::DatabaseError,
structs::{Class, Question, Type, RR}, structs::{Class, Question, Type, RR},
}; };
use diesel::prelude::*; use diesel::{dsl, prelude::*};
use self::schema::records; use self::schema::records::{self};
use super::lib::establish_connection; use super::lib::establish_connection;
mod schema { mod schema {
diesel::table! { diesel::table! {
records (name, _type, class) { records (name, _type, class, rdlength, rdata) {
name -> Text, name -> Text,
#[sql_name = "type"] #[sql_name = "type"]
_type -> Integer, _type -> Integer,
@ -39,8 +39,14 @@ impl Record {
name: String, name: String,
_type: i32, _type: i32,
class: i32, class: i32,
) -> Result<Record, diesel::result::Error> { ) -> Result<Vec<Record>, diesel::result::Error> {
records::table.find((name, _type, class)).get_result(db) records::table
.filter(
records::name
.eq(name)
.and(records::_type.eq(_type).and(records::class.eq(class))),
)
.get_results(db)
} }
pub fn create( pub fn create(
@ -51,6 +57,28 @@ impl Record {
.values(&new_record) .values(&new_record)
.execute(db) .execute(db)
} }
pub fn delete(
db: &mut SqliteConnection,
name: String,
_type: Option<i32>,
class: i32,
rdata: Option<Vec<u8>>,
) -> Result<usize, diesel::result::Error> {
let mut query = diesel::delete(records::table).into_boxed();
query = query.filter(records::name.eq(name).and(records::class.eq(class)));
if let Some(_type) = _type {
query = query.filter(records::_type.eq(_type));
}
if let Some(rdata) = rdata {
query = query.filter(records::rdata.eq(rdata));
}
query.execute(db)
}
} }
pub async fn insert_into_database(rr: RR) -> Result<(), DatabaseError> { pub async fn insert_into_database(rr: RR) -> Result<(), DatabaseError> {
@ -71,9 +99,9 @@ pub async fn insert_into_database(rr: RR) -> Result<(), DatabaseError> {
Ok(()) Ok(())
} }
pub async fn get_from_database(question: &Question) -> Result<RR, DatabaseError> { pub async fn get_from_database(question: &Question) -> Result<Vec<RR>, DatabaseError> {
let db_connection = &mut establish_connection(); let db_connection = &mut establish_connection();
let record = Record::get( let records = Record::get(
db_connection, db_connection,
question.qname.join("."), question.qname.join("."),
question.qtype.clone() as i32, question.qtype.clone() as i32,
@ -83,12 +111,32 @@ pub async fn get_from_database(question: &Question) -> Result<RR, DatabaseError>
message: e.to_string(), message: e.to_string(),
})?; })?;
Ok(RR { Ok(records
name: record.name.split(".").map(str::to_string).collect(), .into_iter()
_type: Type::try_from(record._type as u16).map_err(|e| DatabaseError { message: e })?, .filter_map(|record| {
class: Class::try_from(record.class as u16).map_err(|e| DatabaseError { message: e })?, Some(RR {
ttl: record.ttl, name: record.name.split(".").map(str::to_string).collect(),
rdlength: record.rdlength as u16, _type: Type::try_from(record._type as u16)
rdata: record.rdata, .map_err(|e| DatabaseError { message: e })
}) .ok()?,
class: Class::try_from(record.class as u16)
.map_err(|e| DatabaseError { message: e })
.ok()?,
ttl: record.ttl,
rdlength: record.rdlength as u16,
rdata: record.rdata,
})
})
.collect())
}
//TODO: cleanup models
pub async fn delete_from_database(
name: Vec<String>,
_type: Option<Type>,
class: Class,
rdata: Option<Vec<u8>>,
) {
let db_connection = &mut establish_connection();
let _ = Record::delete(db_connection, name.join("."), _type.map(|f| f as i32), class as i32, rdata);
} }

View file

@ -12,9 +12,11 @@ impl TryFrom<u16> for Type {
fn try_from(value: u16) -> std::result::Result<Self, String> { fn try_from(value: u16) -> std::result::Result<Self, String> {
match value { match value {
//TODO: clean this up
x if x == Type::A as u16 => Ok(Type::A), x if x == Type::A as u16 => Ok(Type::A),
x if x == Type::OPT as u16 => Ok(Type::OPT), x if x == Type::OPT as u16 => Ok(Type::OPT),
x if x == Type::SOA as u16 => Ok(Type::SOA), x if x == Type::SOA as u16 => Ok(Type::SOA),
x if x == Type::ANY as u16 => Ok(Type::ANY),
_ => Err(format!("Invalid Type value: {}", value)), _ => Err(format!("Invalid Type value: {}", value)),
} }
} }
@ -26,6 +28,8 @@ impl TryFrom<u16> for Class {
fn try_from(value: u16) -> std::result::Result<Self, String> { fn try_from(value: u16) -> std::result::Result<Self, String> {
match value { match value {
x if x == Class::IN as u16 => Ok(Class::IN), x if x == Class::IN as u16 => Ok(Class::IN),
x if x == Class::ANY as u16 => Ok(Class::ANY),
x if x == Class::NONE as u16 => Ok(Class::NONE),
_ => Err(format!("Invalid Class value: {}", value)), _ => Err(format!("Invalid Class value: {}", value)),
} }
} }
@ -90,8 +94,7 @@ impl Type {
}) })
} }
} }
Type::SOA => todo!(), _ => todo!()
Type::OPT => todo!(),
} }
} }
pub fn from_data(&self, bytes: &[u8]) -> Result<String> { pub fn from_data(&self, bytes: &[u8]) -> Result<String> {
@ -107,8 +110,7 @@ impl Type {
}) })
} }
} }
Type::SOA => todo!(), _ => todo!()
Type::OPT => todo!(),
} }
} }
} }
@ -246,7 +248,6 @@ impl FromBytes for Question {
impl FromBytes for RR { impl FromBytes for RR {
fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> { fn from_bytes(bytes: &[u8], i: &mut usize) -> Result<Self> {
let name = LabelString::from_bytes(bytes, i)?; let name = LabelString::from_bytes(bytes, i)?;
println!("{:#?}", name);
if bytes.len() - *i < size_of::<Type>() + size_of::<Class>() + 6 { if bytes.len() - *i < size_of::<Type>() + size_of::<Class>() + 6 {
Err(ParseError { Err(ParseError {
object: String::from("RR"), object: String::from("RR"),

View file

@ -4,7 +4,7 @@ use std::sync::Arc;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use crate::db::models::{get_from_database, insert_into_database}; use crate::db::models::{delete_from_database, get_from_database, insert_into_database};
use crate::errors::ParseError; use crate::errors::ParseError;
use crate::parser::FromBytes; use crate::parser::FromBytes;
use crate::structs::{Class, Header, Message, Opcode, Type, RCODE}; use crate::structs::{Class, Header, Message, Opcode, Type, RCODE};
@ -25,13 +25,13 @@ async fn handle_query(message: Message) -> Message {
response.header.arcount = 0; //TODO: fix this, handle unknown class values response.header.arcount = 0; //TODO: fix this, handle unknown class values
for question in message.question { for question in message.question {
let answer = get_from_database(&question).await; let answers = get_from_database(&question).await;
match answer { match answers {
Ok(rr) => { Ok(rrs) => {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR); response.header.flags = set_response_flags(response.header.flags, RCODE::NOERROR);
response.header.ancount = 1; response.header.ancount = 1;
response.answer = vec![rr] response.answer = rrs
} }
Err(e) => { Err(e) => {
response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN); response.header.flags = set_response_flags(response.header.flags, RCODE::NXDOMAIN);
@ -81,22 +81,34 @@ async fn handle_update(message: Message) -> Message {
if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0)) if (rr.class == Class::ANY && (rr.ttl != 0 || rr.rdlength != 0))
|| (rr.class == Class::NONE && rr.ttl != 0) || (rr.class == Class::NONE && rr.ttl != 0)
|| rr.class != zone.qclass || ![Class::NONE, Class::ANY, zone.qclass.clone()].contains(&rr.class)
{ {
response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR); response.header.flags = set_response_flags(response.header.flags, RCODE::FORMERR);
return response; return response;
} }
} }
//FIX: with nsupdate delete, I get `dns_request_getresponse: unexpected end of input`
for rr in message.authority { for rr in message.authority {
if rr.class == zone.qclass { if rr.class == zone.qclass {
insert_into_database(rr).await; let _ = insert_into_database(rr).await;
} else if rr.class == Class::ANY { } else if rr.class == Class::ANY {
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP); if rr._type == Type::ANY {
return response; if rr.name == zone.qname {
} else if rr.class == Class::ANY { response.header.flags =
response.header.flags = set_response_flags(response.header.flags, RCODE::NOTIMP); set_response_flags(response.header.flags, RCODE::NOTIMP);
return response; return response;
} else {
delete_from_database(rr.name, None, Class::IN, None).await;
}
} else {
delete_from_database(rr.name, Some(rr._type), Class::IN, None).await;
}
} else if rr.class == Class::NONE {
if rr._type == Type::SOA {
continue;
}
delete_from_database(rr.name, Some(rr._type),Class::IN, Some(rr.rdata)).await;
} }
} }

13
src/schema.rs Normal file
View file

@ -0,0 +1,13 @@
// @generated automatically by Diesel CLI.
diesel::table! {
records (name, type_, class, rdlength, rdata) {
name -> Text,
#[sql_name = "type"]
type_ -> Integer,
class -> Integer,
ttl -> Integer,
rdlength -> Integer,
rdata -> Binary,
}
}

View file

@ -1,11 +1,12 @@
use serde::Deserialize; use serde::Deserialize;
#[repr(u16)] #[repr(u16)]
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize, PartialEq)]
pub enum Type { pub enum Type {
A = 1, A = 1,
SOA = 6, SOA = 6,
OPT = 41 OPT = 41,
ANY = 255
} }
#[repr(u16)] #[repr(u16)]