10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-11-22 05:41:11 +01:00

Parsing rdata + get record

This commit is contained in:
Xander Bil 2024-03-12 21:07:34 +01:00
parent 89df20d582
commit 335670c3c2
No known key found for this signature in database
GPG key ID: EC9706B54A278598
4 changed files with 76 additions and 20 deletions

16
Cargo.lock generated
View file

@ -106,6 +106,15 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "form_urlencoded"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
dependencies = [
"percent-encoding",
]
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.30" version = "0.3.30"
@ -305,6 +314,12 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "percent-encoding"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.13" version = "0.2.13"
@ -631,6 +646,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"diesel", "diesel",
"dotenvy", "dotenvy",
"form_urlencoded",
"http-body-util", "http-body-util",
"hyper", "hyper",
"hyper-util", "hyper-util",

View file

@ -7,6 +7,7 @@ edition = "2021"
[dependencies] [dependencies]
diesel = { version = "2.1.4", features = ["sqlite"] } diesel = { version = "2.1.4", features = ["sqlite"] }
dotenvy = "0.15" dotenvy = "0.15"
form_urlencoded = "1.2"
tokio = {version = "1.36.0", features = ["macros","rt-multi-thread"], default-features = false} tokio = {version = "1.36.0", features = ["macros","rt-multi-thread"], default-features = false}
hyper = {version = "1.2.0", features = ["server", "http1"], default-features = false} hyper = {version = "1.2.0", features = ["server", "http1"], default-features = false}
hyper-util = { version = "0.1", features = ["server","http1", "tokio"], default-features = false} hyper-util = { version = "0.1", features = ["server","http1", "tokio"], default-features = false}

View file

@ -1,10 +1,11 @@
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use http_body_util::{BodyExt, Full}; use http_body_util::{BodyExt, Full};
use hyper::body::{Buf, Bytes}; use hyper::body::{Buf, Bytes};
use hyper::server::conn::http1; use hyper::server::conn::http1;
use hyper::service::service_fn; use hyper::service::service_fn;
use hyper::{header, Method, Request, Response, StatusCode}; use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use serde::Deserialize; use serde::Deserialize;
use tokio::net::TcpListener; use tokio::net::TcpListener;
@ -28,18 +29,22 @@ struct Record {
data: String, data: String,
} }
async fn api_post_response(req: Request<hyper::body::Incoming>) -> Result<Response<BoxBody>> { async fn create_record(req: Request<hyper::body::Incoming>) -> Result<Response<BoxBody>> {
let whole_body = req.collect().await?.aggregate(); let whole_body = req.collect().await?.aggregate();
match serde_json::from_reader::<_, Record>(whole_body.reader()) { match serde_json::from_reader::<_, Record>(whole_body.reader()) {
Ok(record) => { Ok(record) => {
let rdata = record
._type
.to_bytes(&record.data)
.map_err(|e| e.to_string())?;
match insert_into_database(RR { match insert_into_database(RR {
name: record.name, name: record.name,
_type: record._type, _type: record._type,
class: Class::IN, class: Class::IN,
ttl: record.ttl, ttl: record.ttl,
rdlength: record.data.as_bytes().len() as u16, rdlength: rdata.len() as u16,
rdata: record.data.as_bytes().to_vec(), rdata,
}) })
.await .await
{ {
@ -55,30 +60,28 @@ async fn api_post_response(req: Request<hyper::body::Incoming>) -> Result<Respon
} }
} }
Err(e) => Ok(Response::builder() Err(e) => Ok(Response::builder()
.status(StatusCode::FORBIDDEN) .status(StatusCode::UNPROCESSABLE_ENTITY)
.body(full(e.to_string()))?), .body(full(e.to_string()))?),
} }
} }
async fn api_get_response() -> Result<Response<BoxBody>> { async fn get_record(req: Request<hyper::body::Incoming>) -> Result<Response<BoxBody>> {
let data = vec!["foo", "bar"]; if let Some(q) = req.uri().query() {
let res = match serde_json::to_string(&data) { let params = form_urlencoded::parse(q.as_bytes()).into_owned().collect::<HashMap<String,String>>();
Ok(json) => Response::builder() if let Some(domain) = params.get("domain_name") {
.header(header::CONTENT_TYPE, "application/json") return Ok(Response::builder().status(StatusCode::OK).body(full(domain.to_owned()))?)
.body(full(json)) }
.unwrap(), }
Err(_) => Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR) Ok(Response::builder()
.body(full(INTERNAL_SERVER_ERROR)) .status(StatusCode::UNPROCESSABLE_ENTITY)
.unwrap(), .body(full("Missing domain_name query parameter"))?)
};
Ok(res)
} }
async fn routes(req: Request<hyper::body::Incoming>) -> Result<Response<BoxBody>> { async fn routes(req: Request<hyper::body::Incoming>) -> Result<Response<BoxBody>> {
match (req.method(), req.uri().path()) { match (req.method(), req.uri().path()) {
(&Method::POST, "/add") => api_post_response(req).await, (&Method::POST, "/add") => create_record(req).await,
(&Method::GET, "/json_api") => api_get_response().await, (&Method::GET, "/get") => get_record(req).await,
_ => Ok(Response::builder() _ => Ok(Response::builder()
.status(StatusCode::NOT_FOUND) .status(StatusCode::NOT_FOUND)
.body(full(NOTFOUND)) .body(full(NOTFOUND))

View file

@ -38,6 +38,42 @@ pub trait FromBytes {
Self: Sized; Self: Sized;
} }
impl Type {
pub fn to_bytes(&self, text: &String) -> Result<Vec<u8>> {
match self {
Type::A => {
let arr: Vec<u8> = text
.split(".")
.filter_map(|s| s.parse::<u8>().ok())
.collect();
if arr.len() == 4 {
Ok(arr)
} else {
Err(ParseError {
object: String::from("Type::A"),
message: String::from("Invalid IPv4 address"),
})
}
}
}
}
pub fn from_bytes(&self, bytes: &[u8]) -> Result<String> {
match self {
Type::A => {
if bytes.len() == 4 {
let arr: Vec<String> = bytes.iter().map(|b| b.to_string()).collect();
Ok(arr.join("."))
} else {
Err(ParseError {
object: String::from("Type::A"),
message: String::from("Invalid Ipv4 address bytes"),
})
}
}
}
}
}
impl FromBytes for Header { impl FromBytes for Header {
fn from_bytes(bytes: &[u8]) -> Result<Self> { fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != size_of::<Header>() { if bytes.len() != size_of::<Header>() {