From 335670c3c2a06a875c12693063a138075e58157c Mon Sep 17 00:00:00 2001 From: Xander Bil Date: Tue, 12 Mar 2024 21:07:34 +0100 Subject: [PATCH] Parsing rdata + get record --- Cargo.lock | 16 ++++++++++++++++ Cargo.toml | 1 + src/api.rs | 43 +++++++++++++++++++++++-------------------- src/parser.rs | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a72f489..8b6b12b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,6 +106,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -305,6 +314,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -631,6 +646,7 @@ version = "0.1.0" dependencies = [ "diesel", "dotenvy", + "form_urlencoded", "http-body-util", "hyper", "hyper-util", diff --git a/Cargo.toml b/Cargo.toml index 65d1f23..dd80bca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] diesel = { version = "2.1.4", features = ["sqlite"] } dotenvy = "0.15" +form_urlencoded = "1.2" tokio = {version = "1.36.0", features = ["macros","rt-multi-thread"], default-features = false} hyper = {version = "1.2.0", features = ["server", "http1"], default-features = false} hyper-util = { version = "0.1", features = ["server","http1", "tokio"], default-features = false} diff --git a/src/api.rs b/src/api.rs index f974638..8206653 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,10 +1,11 @@ +use std::collections::HashMap; use std::net::SocketAddr; use http_body_util::{BodyExt, Full}; use hyper::body::{Buf, Bytes}; use hyper::server::conn::http1; use hyper::service::service_fn; -use hyper::{header, Method, Request, Response, StatusCode}; +use hyper::{Method, Request, Response, StatusCode}; use hyper_util::rt::TokioIo; use serde::Deserialize; use tokio::net::TcpListener; @@ -28,18 +29,22 @@ struct Record { data: String, } -async fn api_post_response(req: Request) -> Result> { +async fn create_record(req: Request) -> Result> { let whole_body = req.collect().await?.aggregate(); match serde_json::from_reader::<_, Record>(whole_body.reader()) { Ok(record) => { + let rdata = record + ._type + .to_bytes(&record.data) + .map_err(|e| e.to_string())?; match insert_into_database(RR { name: record.name, _type: record._type, class: Class::IN, ttl: record.ttl, - rdlength: record.data.as_bytes().len() as u16, - rdata: record.data.as_bytes().to_vec(), + rdlength: rdata.len() as u16, + rdata, }) .await { @@ -55,30 +60,28 @@ async fn api_post_response(req: Request) -> Result Ok(Response::builder() - .status(StatusCode::FORBIDDEN) + .status(StatusCode::UNPROCESSABLE_ENTITY) .body(full(e.to_string()))?), } } -async fn api_get_response() -> Result> { - let data = vec!["foo", "bar"]; - let res = match serde_json::to_string(&data) { - Ok(json) => Response::builder() - .header(header::CONTENT_TYPE, "application/json") - .body(full(json)) - .unwrap(), - Err(_) => Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(full(INTERNAL_SERVER_ERROR)) - .unwrap(), - }; - Ok(res) +async fn get_record(req: Request) -> Result> { + if let Some(q) = req.uri().query() { + let params = form_urlencoded::parse(q.as_bytes()).into_owned().collect::>(); + if let Some(domain) = params.get("domain_name") { + return Ok(Response::builder().status(StatusCode::OK).body(full(domain.to_owned()))?) + } + } + + Ok(Response::builder() + .status(StatusCode::UNPROCESSABLE_ENTITY) + .body(full("Missing domain_name query parameter"))?) } async fn routes(req: Request) -> Result> { match (req.method(), req.uri().path()) { - (&Method::POST, "/add") => api_post_response(req).await, - (&Method::GET, "/json_api") => api_get_response().await, + (&Method::POST, "/add") => create_record(req).await, + (&Method::GET, "/get") => get_record(req).await, _ => Ok(Response::builder() .status(StatusCode::NOT_FOUND) .body(full(NOTFOUND)) diff --git a/src/parser.rs b/src/parser.rs index 9c92d4a..e23b030 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -38,6 +38,42 @@ pub trait FromBytes { Self: Sized; } +impl Type { + pub fn to_bytes(&self, text: &String) -> Result> { + match self { + Type::A => { + let arr: Vec = text + .split(".") + .filter_map(|s| s.parse::().ok()) + .collect(); + if arr.len() == 4 { + Ok(arr) + } else { + Err(ParseError { + object: String::from("Type::A"), + message: String::from("Invalid IPv4 address"), + }) + } + } + } + } + pub fn from_bytes(&self, bytes: &[u8]) -> Result { + match self { + Type::A => { + if bytes.len() == 4 { + let arr: Vec = bytes.iter().map(|b| b.to_string()).collect(); + Ok(arr.join(".")) + } else { + Err(ParseError { + object: String::from("Type::A"), + message: String::from("Invalid Ipv4 address bytes"), + }) + } + } + } + } +} + impl FromBytes for Header { fn from_bytes(bytes: &[u8]) -> Result { if bytes.len() != size_of::
() {