1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
use concread::cowcell::asynch::CowCellReadTxn;
use packed_struct::prelude::*;
use std::io::Error;
use std::net::SocketAddr;
use std::str::{from_utf8, FromStr};
use std::time::Duration;
use tokio::io::{self, AsyncReadExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio::time::timeout;

use crate::config::ConfigFile;
use crate::datastore::Command;
use crate::enums::{Agent, AgentState, PacketType, Rcode, RecordClass, RecordType};
use crate::reply::{reply_any, reply_builder, reply_nxdomain, Reply};
use crate::resourcerecord::{DNSCharString, InternalResourceRecord};
use crate::zones::ZoneRecord;
use crate::{Header, OpCode, Question, HEADER_BYTES, REPLY_TIMEOUT_MS, UDP_BUFFER_SIZE};

lazy_static! {
    static ref LOCALHOST: std::net::IpAddr = std::net::IpAddr::from_str("127.0.0.1").expect("Failed to parse localhost IP address");
    // static ref VERSION_STRINGS: Vec<String> =
        // vec![String::from("version"), String::from("version.bind"),];
}

/// this handles a shutdown CHAOS request
async fn check_for_shutdown(r: &Reply, allowed_shutdown: bool) -> Result<Reply, Option<Reply>> {
    // when you get a CHAOS from localhost with "shutdown" break dat loop
    if let Some(q) = &r.question {
        if q.qclass == RecordClass::Chaos {
            let qname = from_utf8(&q.qname)
                .map_err(|e| {
                    log::error!(
                        "Failed to parse qname from {:?}, this shouldn't be able to happen! {e:?}",
                        q.qname
                    );
                })
                .unwrap();
            // Just don't do this on UDP, because we can't really tell who it's coming from.
            if qname == "shutdown" {
                // when we get a request, we update the response to say if we're going to do it or not
                match allowed_shutdown {
                    true => {
                        log::info!("Got CHAOS shutdown, shutting down");
                        let mut chaos_reply = r.clone();
                        chaos_reply.answers.push(CHAOS_OK.clone());
                        return Ok(chaos_reply);
                    }
                    false => {
                        // get lost!  🤣
                        log::warn!("Got CHAOS shutdown, ignoring!");
                        let mut chaos_reply = r.clone();
                        chaos_reply.answers.push(CHAOS_NO.clone());
                        chaos_reply.header.rcode = Rcode::Refused;
                        return Err(Some(chaos_reply));
                    }
                };
            }
        }
    };
    Err(None)
}

/// this handles a version CHAOS request
// async fn check_for_version(r: &Reply, addr: &SocketAddr, config: &ConfigFile) -> Result<(), ()> {
//     // when you get a CHAOS from localhost with "VERSION" or "VERSION.BIND" we might respond
//     if let Some(q) = &r.question {
//         if q.qclass == RecordClass::Chaos {
//             if let Ok(qname) = from_utf8(&q.qname) {
//                 if VERSION_STRINGS.contains(&qname.to_ascii_lowercase()) & (config.ip_allow_lists.shutdown.contains(&addr.ip())) {
//                     info!("Got CHAOS VERSION from {:?}, responding.", addr.ip());
//                     return Ok(());
//                 } else {
//                     warn!("Got CHAOS VERSION from {:?}, ignoring!", addr.ip());
//                 }
//             } else {
//                 error!("Failed to parse qname from {:?}, this shouldn't be able to happen!", q.qname);
//             }
//         }
//     };
//     Err(())
// }

pub async fn udp_server(
    config: CowCellReadTxn<ConfigFile>,
    datastore_sender: mpsc::Sender<crate::datastore::Command>,
    _agent_tx: broadcast::Sender<AgentState>,
) -> io::Result<()> {
    let udp_sock = match UdpSocket::bind(
        config
            .dns_listener_address()
            .expect("Failed to get DNS listener address on startup!"),
    )
    .await
    {
        Ok(value) => {
            log::info!("Started UDP listener on {}:{}", config.address, config.port);
            value
        }
        Err(error) => {
            log::error!("Failed to start UDP listener: {:?}", error);
            return Ok(());
        }
    };

    // TODO: this needs to be bigger to handle edns0-negotiated queries
    let mut udp_buffer = [0; UDP_BUFFER_SIZE];

    loop {
        let (len, addr) = match udp_sock.recv_from(&mut udp_buffer).await {
            Ok(value) => value,
            Err(error) => {
                log::error!("Error accepting connection via UDP: {:?}", error);
                continue;
            }
        };

        log::debug!("{:?} bytes received from {:?}", len, addr);

        let udp_result = match timeout(
            Duration::from_millis(REPLY_TIMEOUT_MS),
            parse_query(
                datastore_sender.clone(),
                len,
                &udp_buffer,
                config.capture_packets,
            ),
        )
        .await
        {
            Ok(reply) => reply,
            Err(_) => {
                log::error!("Did not receive response from parse_query within 10 ms");
                continue;
            }
        };

        match udp_result {
            Ok(mut r) => {
                log::debug!("Result: {:?}", r);

                let reply_bytes: Vec<u8> = match r.as_bytes().await {
                    Ok(value) => {
                        // Check if it's too long and set truncate flag if so, it's safe to unwrap since we've already gone
                        if value.len() > UDP_BUFFER_SIZE {
                            let mut response_bytes = value.to_vec();
                            response_bytes.truncate(UDP_BUFFER_SIZE);
                            r = r.check_set_truncated().await;
                            let r = r.as_bytes_udp().await;
                            r.unwrap_or(value)
                        } else {
                            value
                        }
                    }
                    Err(error) => {
                        log::error!("Failed to parse reply {:?} into bytes: {:?}", r, error);
                        continue;
                    }
                };

                log::trace!("reply_bytes: {:?}", reply_bytes);
                let len = match udp_sock.send_to(&reply_bytes as &[u8], addr).await {
                    Ok(value) => value,
                    Err(err) => {
                        log::error!("Failed to send data back to {:?}: {:?}", addr, err);
                        return Ok(());
                    }
                };
                // let len = sock.send_to(r.answer.as_bytes(), addr).await?;
                log::trace!("{:?} bytes sent", len);
            }
            Err(error) => log::error!("Error: {}", error),
        }
    }
}

pub async fn tcp_conn_handler(
    stream: &mut TcpStream,
    addr: SocketAddr,
    datastore_sender: mpsc::Sender<Command>,
    agent_tx: broadcast::Sender<AgentState>,
    capture_packets: bool,
    allowed_shutdown: bool,
) -> io::Result<()> {
    let (mut reader, writer) = stream.split();
    let msg_length: usize = reader.read_u16().await?.into();
    log::debug!("msg_length={msg_length}");
    let mut buf: Vec<u8> = vec![];

    while buf.len() < msg_length {
        let len = match reader.read_buf(&mut buf).await {
            Ok(size) => size,
            Err(error) => {
                log::error!("Failed to read from TCP Stream: {:?}", error);
                return Ok(());
            }
        };
        if len > 0 {
            log::debug!("Read {:?} bytes from TCP stream", len);
        }
    }

    crate::utils::hexdump(buf.clone());
    // the first two bytes of a tcp query is the message length
    // ref <https://www.rfc-editor.org/rfc/rfc7766#section-8>

    // check the message is long enough
    if buf.len() < msg_length {
        log::warn!(
            "Message length too short {}, wanted {}",
            buf.len(),
            msg_length + 2
        );
    } else {
        log::info!("TCP Message length ftw!");
    }

    // skip the TCP length header because rad
    let buf = &buf[0..msg_length];
    let result = match timeout(
        Duration::from_millis(REPLY_TIMEOUT_MS),
        parse_query(datastore_sender.clone(), msg_length, buf, capture_packets),
    )
    .await
    {
        Ok(reply) => reply,
        Err(_) => {
            log::error!("Did not receive response from parse_query within {REPLY_TIMEOUT_MS} ms");
            return Ok(());
        }
    };

    match result {
        Ok(r) => {
            log::debug!("TCP Result: {r:?}");

            // when you get a CHAOS from the allow-list with "shutdown" it's quitting time
            let r = match check_for_shutdown(&r, allowed_shutdown).await {
                // no change here
                Err(reply) => match reply {
                    None => r,
                    Some(response) => response,
                },
                Ok(reply) => {
                    if let Err(error) = agent_tx.send(AgentState::Stopped {
                        agent: Agent::TCPServer,
                    }) {
                        eprintln!("Failed to send UDPServer shutdown message: {error:?}");
                    };
                    if let Err(error) = datastore_sender.send(Command::Shutdown).await {
                        eprintln!("Failed to send shutdown command to datastore.. {error:?}");
                    };
                    reply
                }
            };

            let reply_bytes: Vec<u8> = match r.as_bytes().await {
                Ok(value) => value,
                Err(error) => {
                    log::error!("Failed to parse reply {:?} into bytes: {:?}", r, error);
                    return Ok(());
                }
            };

            log::trace!("reply_bytes: {:?}", reply_bytes);

            let reply_bytes = &reply_bytes as &[u8];
            // send the outgoing message length
            let response_length: u16 = reply_bytes.len() as u16;
            let len = match writer.try_write(&response_length.to_be_bytes()) {
                Ok(value) => value,
                Err(err) => {
                    log::error!("Failed to send data back to {:?}: {:?}", addr, err);
                    return Ok(());
                }
            };
            log::trace!("{:?} bytes sent", len);

            // send the data
            let len = match writer.try_write(reply_bytes) {
                Ok(value) => value,
                Err(err) => {
                    log::error!("Failed to send data back to {:?}: {:?}", addr, err);
                    return Ok(());
                }
            };
            log::trace!("{:?} bytes sent", len);
        }
        Err(error) => log::error!("Error: {}", error),
    }
    Ok(())
}

/// main handler for the TCP side of things
///
/// Ref <https://www.rfc-editor.org/rfc/rfc7766>
pub async fn tcp_server(
    config: CowCellReadTxn<ConfigFile>,
    tx: mpsc::Sender<crate::datastore::Command>,
    agent_tx: broadcast::Sender<AgentState>,
    // mut agent_rx: broadcast::Receiver<AgentState>,
) -> io::Result<()> {
    let mut agent_rx = agent_tx.subscribe();
    let tcpserver = match TcpListener::bind(
        config
            .dns_listener_address()
            .expect("Failed to get DNS listener address on startup!"),
    )
    .await
    {
        Ok(value) => {
            log::info!(
                "Started TCP listener on {}",
                config
                    .dns_listener_address()
                    .expect("Failed to get DNS listener address on startup!")
            );
            value
        }
        Err(error) => {
            log::error!("Failed to start TCP Server: {:?}", error);
            return Ok(());
        }
    };

    let tcp_client_timeout = config.tcp_client_timeout;
    let shutdown_ip_address_list = config.ip_allow_lists.shutdown.to_vec();
    let capture_packets = config.capture_packets;
    loop {
        let (mut stream, addr) = match tcpserver.accept().await {
            Ok(value) => value,
            Err(error) => panic!("Couldn't get data from TcpStream: {:?}", error),
        };

        let allowed_shutdown = shutdown_ip_address_list.contains(&addr.ip());
        log::debug!("TCP connection from {:?}", addr);
        let loop_tx = tx.clone();
        let loop_agent_tx = agent_tx.clone();
        tokio::spawn(async move {
            if timeout(
                Duration::from_secs(tcp_client_timeout),
                tcp_conn_handler(
                    &mut stream,
                    addr,
                    loop_tx,
                    loop_agent_tx,
                    capture_packets,
                    allowed_shutdown,
                ),
            )
            .await
            .is_err()
            {
                log::warn!(
                    "TCP Connection from {addr:?} terminated after {} seconds.",
                    tcp_client_timeout
                );
            }
        })
        .await?;

        if let Ok(agent_state) = agent_rx.try_recv() {
            log::info!("Got agent state: {:?}", agent_state);
        };
    }
}

/// Parses the rest of the packets once we have stripped the header off.
pub async fn parse_query(
    datastore: tokio::sync::mpsc::Sender<crate::datastore::Command>,
    len: usize,
    buf: &[u8],
    capture_packets: bool,
) -> Result<Reply, String> {
    if capture_packets {
        crate::packet_dumper::dump_bytes(
            buf[0..len].into(),
            crate::packet_dumper::DumpType::ClientRequest,
        )
        .await;
    }
    // we only want the first 12 bytes for the header
    let mut split_header: [u8; HEADER_BYTES] = [0; HEADER_BYTES];
    split_header.copy_from_slice(&buf[0..HEADER_BYTES]);
    // unpack the header for great justice
    let header = match crate::Header::unpack(&split_header) {
        Ok(value) => value,
        Err(error) => {
            // can't return a servfail if we can't unpack the header, they're probably doing something bad.
            return Err(format!("Failed to parse header: {:?}", error));
        }
    };
    log::trace!("Buffer length: {}", len);
    log::trace!("Parsed header: {:?}", header);
    get_result(header, len, buf, datastore).await
}

lazy_static! {
    static ref CHAOS_OK: InternalResourceRecord = InternalResourceRecord::TXT {
        txtdata: DNSCharString::from("OK"),
        ttl: 0,
        class: RecordClass::Chaos,
    };
    static ref CHAOS_NO: InternalResourceRecord = InternalResourceRecord::TXT {
        txtdata: DNSCharString::from("NO"),
        ttl: 0,
        class: RecordClass::Chaos,
    };
}

/// The generic handler for the packets once they've been pulled out of their protocol handlers. TCP has a slightly different stream format to UDP, y'know?
async fn get_result(
    header: Header,
    len: usize,
    buf: &[u8],
    datastore: mpsc::Sender<crate::datastore::Command>,
) -> Result<Reply, String> {
    log::trace!("called get_result(header={header}, len={len})");

    // if we get something other than a query, yeah nah.
    if header.opcode != OpCode::Query {
        return Err(format!("Invalid OPCODE, got {:?}", header.opcode));
    };

    let question = match Question::from_packets(&buf[HEADER_BYTES..len]) {
        Ok(value) => {
            log::trace!("Parsed question: {:?}", value);
            value
        }
        Err(error) => {
            log::debug!("Failed to parse question: {} id={}", error, header.id);
            return reply_builder(header.id, Rcode::ServFail);
        }
    };

    // yeet them when we get a request we can't handle
    if !question.qtype.supported() {
        log::debug!(
            "Unsupported request: {} {:?}, returning NotImplemented",
            from_utf8(&question.qname).unwrap_or("<unable to parse>"),
            question.qtype,
        );
        return reply_builder(header.id, Rcode::NotImplemented);
    }

    // Check for CHAOS commands
    #[allow(clippy::collapsible_if)]
    if question.qclass == RecordClass::Chaos {
        if &question.normalized_name()? == "shutdown" {
            log::debug!("Got CHAOS shutdown!");
            return Ok(Reply {
                header,
                question: Some(question),
                answers: vec![],
                authorities: vec![],
                additional: vec![],
            });
        }
    }

    if let RecordType::ANY {} = question.qtype {
        // TODO this should check to see if we have a zone record, but that requires walking down the qname record recursively, which is its own thing. We just YOLO a HINFO back for any request now.
        return reply_any(header.id, question);
    };

    // build the request to the datastore to make the query
    let (tx_oneshot, rx_oneshot) = oneshot::channel();
    let ds_req: Command = Command::GetRecord {
        name: question.qname.clone(),
        rrtype: question.qtype,
        rclass: question.qclass,
        resp: tx_oneshot,
    };

    // here we talk to the datastore to pull the result
    match datastore.send(ds_req).await {
        Ok(_) => log::trace!("Sent a request to the datastore!"),
        // TODO: handle errors sending to the DS properly
        Err(error) => log::error!("Error sending to datastore: {:?}", error),
    };

    let record: ZoneRecord = match rx_oneshot.await {
        Ok(value) => match value {
            Some(zr) => {
                log::debug!("DS Response: {}", zr);
                zr
            }
            None => {
                log::debug!("No response from datastore");
                return reply_nxdomain(header.id);
            }
        },
        Err(error) => {
            log::error!("Failed to get response from datastore: {:?}", error);
            return reply_builder(header.id, Rcode::ServFail);
        }
    };

    // this is our reply - static until that bit's done
    Ok(Reply {
        header: Header {
            id: header.id,
            qr: PacketType::Answer,
            opcode: header.opcode,
            authoritative: true,
            truncated: false, // TODO: work out if it's truncated (ie, UDP)
            recursion_desired: header.recursion_desired,
            recursion_available: header.recursion_available, // TODO: work this out
            z: false,
            ad: true, // TODO: decide how the ad flag should be set -  "authentic data" - This requests the server to return whether all of the answer and
            // authority sections have all been validated as secure according to the security policy of the server. AD=1 indicates that all
            // records have been validated as secure and the answer is not from a OPT-OUT range. AD=0 indicate that some part of the answer
            // was insecure or not validated. This bit is set by default.
            cd: false, // TODO: figure this out -  CD (checking disabled) bit in the query. This requests the server to not perform DNSSEC validation of responses.
            rcode: Rcode::NoError,
            qdcount: 1,
            ancount: record.typerecords.len() as u16,
            nscount: 0,
            arcount: 0,
        },
        question: Some(question),
        answers: record.typerecords,
        authorities: vec![], // TODO: we're authoritative, we should respond with our records!
        additional: vec![],
    })
}

#[derive(Debug)]
pub struct Servers {
    pub datastore: Option<JoinHandle<Result<(), String>>>,
    pub udpserver: Option<JoinHandle<Result<(), Error>>>,
    pub tcpserver: Option<JoinHandle<Result<(), Error>>>,
    pub apiserver: Option<JoinHandle<Result<(), Error>>>,
    pub agent_tx: broadcast::Sender<AgentState>,
}

impl Default for Servers {
    fn default() -> Self {
        let (agent_tx, _) = broadcast::channel(10000);
        Self {
            datastore: None,
            udpserver: None,
            tcpserver: None,
            apiserver: None,
            agent_tx,
        }
    }
}

impl Servers {
    pub fn build(agent_tx: broadcast::Sender<AgentState>) -> Self {
        Self {
            agent_tx,
            ..Default::default()
        }
    }
    pub fn with_apiserver(self, apiserver: Option<JoinHandle<Result<(), Error>>>) -> Self {
        Self { apiserver, ..self }
    }
    pub fn with_datastore(self, datastore: JoinHandle<Result<(), String>>) -> Self {
        Self {
            datastore: Some(datastore),
            ..self
        }
    }
    pub fn with_tcpserver(self, tcpserver: JoinHandle<Result<(), Error>>) -> Self {
        Self {
            tcpserver: Some(tcpserver),
            ..self
        }
    }
    pub fn with_udpserver(self, udpserver: JoinHandle<Result<(), Error>>) -> Self {
        Self {
            udpserver: Some(udpserver),
            ..self
        }
    }

    fn send_shutdown(&self, agent: Agent) {
        log::info!("{agent:?} shut down");
        if let Err(error) = self.agent_tx.send(AgentState::Stopped { agent }) {
            eprintln!("Failed to send agent shutdown message: {error:?}");
        };
    }

    pub fn all_finished(&self) -> bool {
        let mut results = vec![];
        if let Some(server) = &self.apiserver {
            if server.is_finished() {
                println!("Sending API Shutdown");
                self.send_shutdown(Agent::API);
            }
            results.push(server.is_finished())
        }
        if let Some(server) = &self.datastore {
            if server.is_finished() {
                println!("Sending Datastore Shutdown");
                self.send_shutdown(Agent::Datastore);
            }
            results.push(server.is_finished())
        }
        if let Some(server) = &self.tcpserver {
            if server.is_finished() {
                println!("Sending TCP Server Shutdown");
                self.send_shutdown(Agent::TCPServer);
            }
            results.push(server.is_finished())
        }
        if let Some(server) = &self.udpserver {
            if server.is_finished() {
                println!("Sending UDP Server Shutdown");
                self.send_shutdown(Agent::UDPServer);
            }
            results.push(server.is_finished())
        }
        results.iter().any(|&r| r)
    }
}