mirror of
https://github.com/ZeusWPI/ZNS.git
synced 2024-11-23 22:11:10 +01:00
Add TCP support for large queries
This commit is contained in:
parent
d8f88dcac5
commit
2184d016ea
4 changed files with 45 additions and 14 deletions
|
@ -18,7 +18,6 @@ impl ResponseHandler for QueryHandler {
|
||||||
connection: &mut PgConnection,
|
connection: &mut PgConnection,
|
||||||
) -> Result<Message, ZNSError> {
|
) -> Result<Message, ZNSError> {
|
||||||
let mut response = message.clone();
|
let mut response = message.clone();
|
||||||
response.header.arcount = 0; //TODO: fix this, handle unknown class values
|
|
||||||
|
|
||||||
for question in &message.question {
|
for question in &message.question {
|
||||||
let answers = get_from_database(
|
let answers = get_from_database(
|
||||||
|
|
|
@ -22,7 +22,6 @@ pub async fn authenticate(
|
||||||
|
|
||||||
let ssh_verified = validate_ssh(username, sig).await.is_ok_and(|b| b);
|
let ssh_verified = validate_ssh(username, sig).await.is_ok_and(|b| b);
|
||||||
|
|
||||||
|
|
||||||
if ssh_verified {
|
if ssh_verified {
|
||||||
Ok(true)
|
Ok(true)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::{error::Error, net::SocketAddr};
|
||||||
|
|
||||||
use config::Config;
|
use config::Config;
|
||||||
|
|
||||||
use crate::resolver::resolver_listener_loop;
|
use crate::resolver::{tcp_listener_loop, udp_listener_loop};
|
||||||
|
|
||||||
mod config;
|
mod config;
|
||||||
mod db;
|
mod db;
|
||||||
|
@ -19,6 +19,9 @@ mod utils;
|
||||||
async fn main() -> Result<(), Box<dyn Error>> {
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
Config::initialize();
|
Config::initialize();
|
||||||
let resolver_add = SocketAddr::from(([127, 0, 0, 1], 8080));
|
let resolver_add = SocketAddr::from(([127, 0, 0, 1], 8080));
|
||||||
let _ = tokio::join!(resolver_listener_loop(resolver_add),);
|
let _ = tokio::join!(
|
||||||
|
udp_listener_loop(resolver_add),
|
||||||
|
tcp_listener_loop(resolver_add)
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,8 @@ use std::error::Error;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use tokio::net::UdpSocket;
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::{TcpSocket, UdpSocket};
|
||||||
|
|
||||||
use crate::db::lib::get_connection;
|
use crate::db::lib::get_connection;
|
||||||
use crate::errors::ZNSError;
|
use crate::errors::ZNSError;
|
||||||
|
@ -11,7 +12,7 @@ use crate::parser::{FromBytes, ToBytes};
|
||||||
use crate::reader::Reader;
|
use crate::reader::Reader;
|
||||||
use crate::structs::{Header, Message, RCODE};
|
use crate::structs::{Header, Message, RCODE};
|
||||||
|
|
||||||
const MAX_DATAGRAM_SIZE: usize = 4096;
|
const MAX_DATAGRAM_SIZE: usize = 512;
|
||||||
|
|
||||||
fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
|
fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
|
||||||
eprintln!("{}", err);
|
eprintln!("{}", err);
|
||||||
|
@ -41,9 +42,9 @@ fn handle_parse_error(bytes: &[u8], err: ZNSError) -> Message {
|
||||||
message
|
message
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_response(bytes: &[u8]) -> Message {
|
async fn get_response(bytes: &[u8]) -> Vec<u8> {
|
||||||
let mut reader = Reader::new(bytes);
|
let mut reader = Reader::new(bytes);
|
||||||
match Message::from_bytes(&mut reader) {
|
Message::to_bytes(match Message::from_bytes(&mut reader) {
|
||||||
Ok(mut message) => match Handler::handle(&message, bytes, &mut get_connection()).await {
|
Ok(mut message) => match Handler::handle(&message, bytes, &mut get_connection()).await {
|
||||||
Ok(mut response) => {
|
Ok(mut response) => {
|
||||||
response.set_response(RCODE::NOERROR);
|
response.set_response(RCODE::NOERROR);
|
||||||
|
@ -56,10 +57,10 @@ async fn get_response(bytes: &[u8]) -> Message {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Err(err) => handle_parse_error(bytes, err),
|
Err(err) => handle_parse_error(bytes, err),
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
|
pub async fn udp_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
|
||||||
let socket_shared = Arc::new(UdpSocket::bind(addr).await?);
|
let socket_shared = Arc::new(UdpSocket::bind(addr).await?);
|
||||||
loop {
|
loop {
|
||||||
let mut data = vec![0u8; MAX_DATAGRAM_SIZE];
|
let mut data = vec![0u8; MAX_DATAGRAM_SIZE];
|
||||||
|
@ -67,9 +68,34 @@ pub async fn resolver_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Erro
|
||||||
let socket = socket_shared.clone();
|
let socket = socket_shared.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let response = get_response(&data[..len]).await;
|
let response = get_response(&data[..len]).await;
|
||||||
let _ = socket
|
// TODO: if length is larger then 512 bytes, message should be truncated
|
||||||
.send_to(Message::to_bytes(response).as_slice(), addr)
|
let _ = socket.send_to(&response, addr).await;
|
||||||
.await;
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn tcp_listener_loop(addr: SocketAddr) -> Result<(), Box<dyn Error>> {
|
||||||
|
let socket = TcpSocket::new_v4()?;
|
||||||
|
socket.bind(addr)?;
|
||||||
|
let listener = socket.listen(1024)?;
|
||||||
|
loop {
|
||||||
|
let (mut stream, _) = listener.accept().await?;
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if stream.readable().await.is_ok() {
|
||||||
|
if let Ok(length) = stream.read_u16().await {
|
||||||
|
let mut buf = Vec::with_capacity(length as usize);
|
||||||
|
if stream
|
||||||
|
.try_read_buf(&mut buf)
|
||||||
|
.is_ok_and(|v| v == length as usize)
|
||||||
|
{
|
||||||
|
let response = get_response(&buf).await;
|
||||||
|
if stream.writable().await.is_ok() {
|
||||||
|
let _ = stream.write_u16(response.len() as u16).await;
|
||||||
|
let _ = stream.try_write(&response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -102,7 +128,11 @@ mod tests {
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = get_response(&Message::to_bytes(message)).await;
|
let response = get_response(&Message::to_bytes(message)).await;
|
||||||
|
let mut reader = Reader::new(&response);
|
||||||
|
|
||||||
assert_eq!(response.get_rcode(), Ok(RCODE::NXDOMAIN));
|
assert_eq!(
|
||||||
|
Message::from_bytes(&mut reader).unwrap().get_rcode(),
|
||||||
|
Ok(RCODE::NXDOMAIN)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue