mirror of
https://github.com/ZeusWPI/ZNS.git
synced 2024-11-27 22:51:12 +01:00
Add TCP support for large queries
This commit is contained in:
parent
d8f88dcac5
commit
2184d016ea
4 changed files with 45 additions and 14 deletions
|
@ -18,7 +18,6 @@ impl ResponseHandler for QueryHandler {
|
|||
connection: &mut PgConnection,
|
||||
) -> Result<Message, ZNSError> {
|
||||
let mut response = message.clone();
|
||||
response.header.arcount = 0; //TODO: fix this, handle unknown class values
|
||||
|
||||
for question in &message.question {
|
||||
let answers = get_from_database(
|
||||
|
|
|
@ -22,7 +22,6 @@ pub async fn authenticate(
|
|||
|
||||
let ssh_verified = validate_ssh(username, sig).await.is_ok_and(|b| b);
|
||||
|
||||
|
||||
if ssh_verified {
|
||||
Ok(true)
|
||||
} else {
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{error::Error, net::SocketAddr};
|
|||
|
||||
use config::Config;
|
||||
|
||||
use crate::resolver::resolver_listener_loop;
|
||||
use crate::resolver::{tcp_listener_loop, udp_listener_loop};
|
||||
|
||||
mod config;
|
||||
mod db;
|
||||
|
@ -19,6 +19,9 @@ mod utils;
|
|||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
Config::initialize();
|
||||
let resolver_add = SocketAddr::from(([127, 0, 0, 1], 8080));
|
||||
let _ = tokio::join!(resolver_listener_loop(resolver_add),);
|
||||
let _ = tokio::join!(
|
||||
udp_listener_loop(resolver_add),
|
||||
tcp_listener_loop(resolver_add)
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -2,7 +2,8 @@ use std::error::Error;
|
|||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpSocket, UdpSocket};
|
||||
|
||||
use crate::db::lib::get_connection;
|
||||
use crate::errors::ZNSError;
|
||||
|
@ -11,7 +12,7 @@ use crate::parser::{FromBytes, ToBytes};
|
|||
use crate::reader::Reader;
|
||||
use crate::structs::{Header, Message, RCODE};
|
||||
|
||||
const MAX_DATAGRAM_SIZE: usize = 4096;
|
||||
const MAX_DATAGRAM_SIZE: usize = 512;
|
||||
|
||||
fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
|
||||
eprintln!("{}", err);
|
||||
|
@ -41,9 +42,9 @@ fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
|
|||
message
|
||||
}
|
||||
|
||||
async fn get_response(bytes: &[u8]) -> Message {
|
||||
async fn get_response(bytes: &[u8]) -> Vec<u8> {
|
||||
let mut reader = Reader::new(bytes);
|
||||
match Message::from_bytes(&mut reader) {
|
||||
Message::to_bytes(match Message::from_bytes(&mut reader) {
|
||||
Ok(mut message) => match Handler::handle(&message, bytes, &mut get_connection()).await {
|
||||
Ok(mut response) => {
|
||||
response.set_response(RCODE::NOERROR);
|
||||
|
@ -56,10 +57,10 @@ async fn get_response(bytes: &[u8]) -> Message {
|
|||
}
|
||||
},
|
||||
Err(err) => handle_parse_error(bytes, err),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
|
||||
pub async fn udp_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
|
||||
let socket_shared = Arc::new(UdpSocket::bind(addr).await?);
|
||||
loop {
|
||||
let mut data = vec![0u8; MAX_DATAGRAM_SIZE];
|
||||
|
@ -67,9 +68,34 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Erro
|
|||
let socket = socket_shared.clone();
|
||||
tokio::spawn(async move {
|
||||
let response = get_response(&data[..len]).await;
|
||||
let _ = socket
|
||||
.send_to(Message::to_bytes(response).as_slice(), addr)
|
||||
.await;
|
||||
// TODO: if length is larger then 512 bytes, message should be truncated
|
||||
let _ = socket.send_to(&response, addr).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn tcp_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
|
||||
let socket = TcpSocket::new_v4()?;
|
||||
socket.bind(addr)?;
|
||||
let listener = socket.listen(1024)?;
|
||||
loop {
|
||||
let (mut stream, _) = listener.accept().await?;
|
||||
tokio::spawn(async move {
|
||||
if stream.readable().await.is_ok() {
|
||||
if let Ok(length) = stream.read_u16().await {
|
||||
let mut buf = Vec::with_capacity(length as usize);
|
||||
if stream
|
||||
.try_read_buf(&mut buf)
|
||||
.is_ok_and(|v| v == length as usize)
|
||||
{
|
||||
let response = get_response(&buf).await;
|
||||
if stream.writable().await.is_ok() {
|
||||
let _ = stream.write_u16(response.len() as u16).await;
|
||||
let _ = stream.try_write(&response);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -102,7 +128,11 @@ mod tests {
|
|||
};
|
||||
|
||||
let response = get_response(&Message::to_bytes(message)).await;
|
||||
let mut reader = Reader::new(&response);
|
||||
|
||||
assert_eq!(response.get_rcode(), Ok(RCODE::NXDOMAIN));
|
||||
assert_eq!(
|
||||
Message::from_bytes(&mut reader).unwrap().get_rcode(),
|
||||
Ok(RCODE::NXDOMAIN)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue