From 2184d016ea29eb50cc3f08b9a9702fd046e8f5ea Mon Sep 17 00:00:00 2001 From: Xander Bil Date: Mon, 15 Jul 2024 16:56:53 +0200 Subject: [PATCH] Add TCP support for large queries --- src/handlers/query.rs | 1 - src/handlers/update/authenticate.rs | 1 - src/main.rs | 7 ++-- src/resolver.rs | 50 +++++++++++++++++++++++------ 4 files changed, 45 insertions(+), 14 deletions(-) diff --git a/src/handlers/query.rs b/src/handlers/query.rs index 593f41d..ebe6be7 100644 --- a/src/handlers/query.rs +++ b/src/handlers/query.rs @@ -18,7 +18,6 @@ impl ResponseHandler for QueryHandler { connection: &mut PgConnection, ) -> Result { let mut response = message.clone(); - response.header.arcount = 0; //TODO: fix this, handle unknown class values for question in &message.question { let answers = get_from_database( diff --git a/src/handlers/update/authenticate.rs b/src/handlers/update/authenticate.rs index bc4ae7d..0e17a0b 100644 --- a/src/handlers/update/authenticate.rs +++ b/src/handlers/update/authenticate.rs @@ -22,7 +22,6 @@ pub async fn authenticate( let ssh_verified = validate_ssh(username, sig).await.is_ok_and(|b| b); - if ssh_verified { Ok(true) } else { diff --git a/src/main.rs b/src/main.rs index f4a036f..2ed0936 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ use std::{error::Error, net::SocketAddr}; use config::Config; -use crate::resolver::resolver_listener_loop; +use crate::resolver::{tcp_listener_loop, udp_listener_loop}; mod config; mod db; @@ -19,6 +19,9 @@ mod utils; async fn main() -> Result<(), Box> { Config::initialize(); let resolver_add = SocketAddr::from(([127, 0, 0, 1], 8080)); - let _ = tokio::join!(resolver_listener_loop(resolver_add),); + let _ = tokio::join!( + udp_listener_loop(resolver_add), + tcp_listener_loop(resolver_add) + ); Ok(()) } diff --git a/src/resolver.rs b/src/resolver.rs index e31f299..3846ff6 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -2,7 +2,8 @@ use std::error::Error; use std::net::SocketAddr; use std::sync::Arc; -use tokio::net::UdpSocket; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpSocket, UdpSocket}; use crate::db::lib::get_connection; use crate::errors::ZNSError; @@ -11,7 +12,7 @@ use crate::parser::{FromBytes, ToBytes}; use crate::reader::Reader; use crate::structs::{Header, Message, RCODE}; -const MAX_DATAGRAM_SIZE: usize = 4096; +const MAX_DATAGRAM_SIZE: usize = 512; fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message { eprintln!("{}", err); @@ -41,9 +42,9 @@ fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message { message } -async fn get_response(bytes: &[u8]) -> Message { +async fn get_response(bytes: &[u8]) -> Vec { let mut reader = Reader::new(bytes); - match Message::from_bytes(&mut reader) { + Message::to_bytes(match Message::from_bytes(&mut reader) { Ok(mut message) => match Handler::handle(&message, bytes, &mut get_connection()).await { Ok(mut response) => { response.set_response(RCODE::NOERROR); @@ -56,10 +57,10 @@ async fn get_response(bytes: &[u8]) -> Message { } }, Err(err) => handle_parse_error(bytes, err), - } + }) } -pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box> { +pub async fn udp_listener_loop(addr: SocketAddr) -> Result<(), Box> { let socket_shared = Arc::new(UdpSocket::bind(addr).await?); loop { let mut data = vec![0u8; MAX_DATAGRAM_SIZE]; @@ -67,9 +68,34 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box Result<(), Box> { + let socket = TcpSocket::new_v4()?; + socket.bind(addr)?; + let listener = socket.listen(1024)?; + loop { + let (mut stream, _) = listener.accept().await?; + tokio::spawn(async move { + if stream.readable().await.is_ok() { + if let Ok(length) = stream.read_u16().await { + let mut buf = Vec::with_capacity(length as usize); + if stream + .try_read_buf(&mut buf) + .is_ok_and(|v| v == length as usize) + { + let response = get_response(&buf).await; + if stream.writable().await.is_ok() { + let _ = stream.write_u16(response.len() as u16).await; + let _ = stream.try_write(&response); + } + } + } + } }); } } @@ -102,7 +128,11 @@ mod tests { }; let response = get_response(&Message::to_bytes(message)).await; + let mut reader = Reader::new(&response); - assert_eq!(response.get_rcode(), Ok(RCODE::NXDOMAIN)); + assert_eq!( + Message::from_bytes(&mut reader).unwrap().get_rcode(), + Ok(RCODE::NXDOMAIN) + ); } }