1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use std::{
collections::HashMap,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::channel::mpsc;
use futures::{SinkExt, Stream, StreamExt};
use log::debug;
use parking_lot::RwLock;
use crate::{
channel::{LdapChannel, LdapMessageReceiver, LdapMessageSender},
error::Error,
oid,
rasn_ldap::{LdapMessage, ProtocolOp},
TlsOptions,
};
const CHANNEL_SIZE: usize = 1024;
type RequestMap = Arc<RwLock<HashMap<u32, LdapMessageSender>>>;
#[derive(Clone)]
pub(crate) struct LdapConnection {
requests: RequestMap,
channel_sender: LdapMessageSender,
}
impl LdapConnection {
pub(crate) async fn connect<A>(address: A, port: u16, tls_options: TlsOptions) -> Result<Self, Error>
where
A: AsRef<str>,
{
let (channel_sender, mut channel_receiver) =
LdapChannel::for_client(address, port).connect(tls_options).await?;
let connection = Self {
requests: RequestMap::default(),
channel_sender,
};
let requests = connection.requests.clone();
tokio::spawn(async move {
while let Some(msg) = channel_receiver.next().await {
match msg.protocol_op {
ProtocolOp::ExtendedResp(resp)
if resp.response_name.as_deref() == Some(oid::NOTICE_OF_DISCONNECTION_OID) =>
{
debug!("Notice of disconnection received, exiting");
break;
}
_ => {
let sender = requests.read().get(&msg.message_id).cloned();
if let Some(mut sender) = sender {
let _ = sender.send(msg).await;
}
}
}
}
debug!("Connection terminated");
requests.write().clear();
});
Ok(connection)
}
pub(crate) async fn send_recv_stream(&mut self, msg: LdapMessage) -> Result<MessageStream, Error> {
let id = msg.message_id;
self.channel_sender.send(msg).await?;
let (tx, rx) = mpsc::channel(CHANNEL_SIZE);
self.requests.write().insert(id, tx);
Ok(MessageStream {
id,
requests: self.requests.clone(),
receiver: rx,
})
}
pub(crate) async fn send(&mut self, msg: LdapMessage) -> Result<(), Error> {
Ok(self.channel_sender.send(msg).await?)
}
pub(crate) async fn send_recv(&mut self, msg: LdapMessage) -> Result<LdapMessage, Error> {
Ok(self
.send_recv_stream(msg)
.await?
.next()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::ConnectionReset, "Connection closed"))?)
}
}
pub(crate) struct MessageStream {
id: u32,
requests: RequestMap,
receiver: LdapMessageReceiver,
}
impl Stream for MessageStream {
type Item = LdapMessage;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.receiver).poll_next(cx)
}
}
impl Drop for MessageStream {
fn drop(&mut self) {
self.requests.write().remove(&self.id);
}
}