10
0
Fork 0
mirror of https://github.com/ZeusWPI/ZNS.git synced 2024-10-29 21:14:27 +01:00

Add TCP support for large queries

This commit is contained in:
Xander Bil 2024-07-15 16:56:53 +02:00
parent d8f88dcac5
commit 2184d016ea
No known key found for this signature in database
GPG key ID: EC9706B54A278598
4 changed files with 45 additions and 14 deletions

View file

@ -18,7 +18,6 @@ impl ResponseHandler for QueryHandler {
connection: &mut PgConnection,
) -> Result<Message, ZNSError> {
let mut response = message.clone();
response.header.arcount = 0; //TODO: fix this, handle unknown class values
for question in &message.question {
let answers = get_from_database(

View file

@ -22,7 +22,6 @@ pub async fn authenticate(
let ssh_verified = validate_ssh(username, sig).await.is_ok_and(|b| b);
if ssh_verified {
Ok(true)
} else {

View file

@ -2,7 +2,7 @@ use std::{error::Error, net::SocketAddr};
use config::Config;
use crate::resolver::resolver_listener_loop;
use crate::resolver::{tcp_listener_loop, udp_listener_loop};
mod config;
mod db;
@ -19,6 +19,9 @@ mod utils;
async fn main() -> Result<(), Box<dyn Error>> {
Config::initialize();
let resolver_add = SocketAddr::from(([127, 0, 0, 1], 8080));
let _ = tokio::join!(resolver_listener_loop(resolver_add),);
let _ = tokio::join!(
udp_listener_loop(resolver_add),
tcp_listener_loop(resolver_add)
);
Ok(())
}

View file

@ -2,7 +2,8 @@ use std::error::Error;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpSocket, UdpSocket};
use crate::db::lib::get_connection;
use crate::errors::ZNSError;
@ -11,7 +12,7 @@ use crate::parser::{FromBytes, ToBytes};
use crate::reader::Reader;
use crate::structs::{Header, Message, RCODE};
const MAX_DATAGRAM_SIZE: usize = 4096;
const MAX_DATAGRAM_SIZE: usize = 512;
fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
eprintln!("{}", err);
@ -41,9 +42,9 @@ fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
message
}
async fn get_response(bytes: &[u8]) -> Message {
async fn get_response(bytes: &[u8]) -> Vec<u8> {
let mut reader = Reader::new(bytes);
match Message::from_bytes(&mut reader) {
Message::to_bytes(match Message::from_bytes(&mut reader) {
Ok(mut message) => match Handler::handle(&message, bytes, &mut get_connection()).await {
Ok(mut response) => {
response.set_response(RCODE::NOERROR);
@ -56,10 +57,10 @@ async fn get_response(bytes: &[u8]) -> Message {
}
},
Err(err) => handle_parse_error(bytes, err),
}
})
}
pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
pub async fn udp_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
let socket_shared = Arc::new(UdpSocket::bind(addr).await?);
loop {
let mut data = vec![0u8; MAX_DATAGRAM_SIZE];
@ -67,9 +68,34 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Erro
let socket = socket_shared.clone();
tokio::spawn(async move {
let response = get_response(&data[..len]).await;
let _ = socket
.send_to(Message::to_bytes(response).as_slice(), addr)
.await;
// TODO: if length is larger then 512 bytes, message should be truncated
let _ = socket.send_to(&response, addr).await;
});
}
}
pub async fn tcp_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
let socket = TcpSocket::new_v4()?;
socket.bind(addr)?;
let listener = socket.listen(1024)?;
loop {
let (mut stream, _) = listener.accept().await?;
tokio::spawn(async move {
if stream.readable().await.is_ok() {
if let Ok(length) = stream.read_u16().await {
let mut buf = Vec::with_capacity(length as usize);
if stream
.try_read_buf(&mut buf)
.is_ok_and(|v| v == length as usize)
{
let response = get_response(&buf).await;
if stream.writable().await.is_ok() {
let _ = stream.write_u16(response.len() as u16).await;
let _ = stream.try_write(&response);
}
}
}
}
});
}
}
@ -102,7 +128,11 @@ mod tests {
};
let response = get_response(&Message::to_bytes(message)).await;
let mut reader = Reader::new(&response);
assert_eq!(response.get_rcode(), Ok(RCODE::NXDOMAIN));
assert_eq!(
Message::from_bytes(&mut reader).unwrap().get_rcode(),
Ok(RCODE::NXDOMAIN)
);
}
}