tokio/io/
blocking.rs

1use crate::io::sys;
2use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
3
4use std::cmp;
5use std::future::Future;
6use std::io;
7use std::io::prelude::*;
8use std::mem::MaybeUninit;
9use std::pin::Pin;
10use std::task::{ready, Context, Poll};
11
12/// `T` should not implement _both_ Read and Write.
13#[derive(Debug)]
14pub(crate) struct Blocking<T> {
15    inner: Option<T>,
16    state: State<T>,
17    /// `true` if the lower IO layer needs flushing.
18    need_flush: bool,
19}
20
21#[derive(Debug)]
22pub(crate) struct Buf {
23    buf: Vec<u8>,
24    pos: usize,
25}
26
27pub(crate) const DEFAULT_MAX_BUF_SIZE: usize = 2 * 1024 * 1024;
28
29#[derive(Debug)]
30enum State<T> {
31    Idle(Option<Buf>),
32    Busy(sys::Blocking<(io::Result<usize>, Buf, T)>),
33}
34
35cfg_io_blocking! {
36    impl<T> Blocking<T> {
37        /// # Safety
38        ///
39        /// The `Read` implementation of `inner` must never read from the buffer
40        /// it is borrowing and must correctly report the length of the data
41        /// written into the buffer.
42        #[cfg_attr(feature = "fs", allow(dead_code))]
43        pub(crate) unsafe fn new(inner: T) -> Blocking<T> {
44            Blocking {
45                inner: Some(inner),
46                state: State::Idle(Some(Buf::with_capacity(0))),
47                need_flush: false,
48            }
49        }
50    }
51}
52
53impl<T> AsyncRead for Blocking<T>
54where
55    T: Read + Unpin + Send + 'static,
56{
57    fn poll_read(
58        mut self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        dst: &mut ReadBuf<'_>,
61    ) -> Poll<io::Result<()>> {
62        loop {
63            match self.state {
64                State::Idle(ref mut buf_cell) => {
65                    let mut buf = buf_cell.take().unwrap();
66
67                    if !buf.is_empty() {
68                        buf.copy_to(dst);
69                        *buf_cell = Some(buf);
70                        return Poll::Ready(Ok(()));
71                    }
72
73                    let mut inner = self.inner.take().unwrap();
74
75                    let max_buf_size = cmp::min(dst.remaining(), DEFAULT_MAX_BUF_SIZE);
76                    self.state = State::Busy(sys::run(move || {
77                        // SAFETY: the requirements are satisfied by `Blocking::new`.
78                        let res = unsafe { buf.read_from(&mut inner, max_buf_size) };
79                        (res, buf, inner)
80                    }));
81                }
82                State::Busy(ref mut rx) => {
83                    let (res, mut buf, inner) = ready!(Pin::new(rx).poll(cx))?;
84                    self.inner = Some(inner);
85
86                    match res {
87                        Ok(_) => {
88                            buf.copy_to(dst);
89                            self.state = State::Idle(Some(buf));
90                            return Poll::Ready(Ok(()));
91                        }
92                        Err(e) => {
93                            assert!(buf.is_empty());
94
95                            self.state = State::Idle(Some(buf));
96                            return Poll::Ready(Err(e));
97                        }
98                    }
99                }
100            }
101        }
102    }
103}
104
105impl<T> AsyncWrite for Blocking<T>
106where
107    T: Write + Unpin + Send + 'static,
108{
109    fn poll_write(
110        mut self: Pin<&mut Self>,
111        cx: &mut Context<'_>,
112        src: &[u8],
113    ) -> Poll<io::Result<usize>> {
114        loop {
115            match self.state {
116                State::Idle(ref mut buf_cell) => {
117                    let mut buf = buf_cell.take().unwrap();
118
119                    assert!(buf.is_empty());
120
121                    let n = buf.copy_from(src, DEFAULT_MAX_BUF_SIZE);
122                    let mut inner = self.inner.take().unwrap();
123
124                    self.state = State::Busy(sys::run(move || {
125                        let n = buf.len();
126                        let res = buf.write_to(&mut inner).map(|()| n);
127
128                        (res, buf, inner)
129                    }));
130                    self.need_flush = true;
131
132                    return Poll::Ready(Ok(n));
133                }
134                State::Busy(ref mut rx) => {
135                    let (res, buf, inner) = ready!(Pin::new(rx).poll(cx))?;
136                    self.state = State::Idle(Some(buf));
137                    self.inner = Some(inner);
138
139                    // If error, return
140                    res?;
141                }
142            }
143        }
144    }
145
146    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
147        loop {
148            let need_flush = self.need_flush;
149            match self.state {
150                // The buffer is not used here
151                State::Idle(ref mut buf_cell) => {
152                    if need_flush {
153                        let buf = buf_cell.take().unwrap();
154                        let mut inner = self.inner.take().unwrap();
155
156                        self.state = State::Busy(sys::run(move || {
157                            let res = inner.flush().map(|()| 0);
158                            (res, buf, inner)
159                        }));
160
161                        self.need_flush = false;
162                    } else {
163                        return Poll::Ready(Ok(()));
164                    }
165                }
166                State::Busy(ref mut rx) => {
167                    let (res, buf, inner) = ready!(Pin::new(rx).poll(cx))?;
168                    self.state = State::Idle(Some(buf));
169                    self.inner = Some(inner);
170
171                    // If error, return
172                    res?;
173                }
174            }
175        }
176    }
177
178    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
179        Poll::Ready(Ok(()))
180    }
181}
182
183/// Repeats operations that are interrupted.
184macro_rules! uninterruptibly {
185    ($e:expr) => {{
186        loop {
187            match $e {
188                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
189                res => break res,
190            }
191        }
192    }};
193}
194
195impl Buf {
196    pub(crate) fn with_capacity(n: usize) -> Buf {
197        Buf {
198            buf: Vec::with_capacity(n),
199            pos: 0,
200        }
201    }
202
203    pub(crate) fn is_empty(&self) -> bool {
204        self.len() == 0
205    }
206
207    pub(crate) fn len(&self) -> usize {
208        self.buf.len() - self.pos
209    }
210
211    pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize {
212        let n = cmp::min(self.len(), dst.remaining());
213        dst.put_slice(&self.bytes()[..n]);
214        self.pos += n;
215
216        if self.pos == self.buf.len() {
217            self.buf.truncate(0);
218            self.pos = 0;
219        }
220
221        n
222    }
223
224    pub(crate) fn copy_from(&mut self, src: &[u8], max_buf_size: usize) -> usize {
225        assert!(self.is_empty());
226
227        let n = cmp::min(src.len(), max_buf_size);
228
229        self.buf.extend_from_slice(&src[..n]);
230        n
231    }
232
233    pub(crate) fn bytes(&self) -> &[u8] {
234        &self.buf[self.pos..]
235    }
236
237    /// # Safety
238    ///
239    /// `rd` must not read from the buffer `read` is borrowing and must correctly
240    /// report the length of the data written into the buffer.
241    pub(crate) unsafe fn read_from<T: Read>(
242        &mut self,
243        rd: &mut T,
244        max_buf_size: usize,
245    ) -> io::Result<usize> {
246        assert!(self.is_empty());
247        self.buf.reserve(max_buf_size);
248
249        let buf = &mut self.buf.spare_capacity_mut()[..max_buf_size];
250        // SAFETY: The memory may be uninitialized, but `rd.read` will only write to the buffer.
251        let buf = unsafe { &mut *(buf as *mut [MaybeUninit<u8>] as *mut [u8]) };
252        let res = uninterruptibly!(rd.read(buf));
253
254        if let Ok(n) = res {
255            // SAFETY: the caller promises that `rd.read` initializes
256            // a section of `buf` and correctly reports that length.
257            // The `self.is_empty()` assertion verifies that `n`
258            // equals the length of the `buf` capacity that was written
259            // to (and that `buf` isn't being shrunk).
260            unsafe { self.buf.set_len(n) }
261        } else {
262            self.buf.clear();
263        }
264
265        assert_eq!(self.pos, 0);
266
267        res
268    }
269
270    pub(crate) fn write_to<T: Write>(&mut self, wr: &mut T) -> io::Result<()> {
271        assert_eq!(self.pos, 0);
272
273        // `write_all` already ignores interrupts
274        let res = wr.write_all(&self.buf);
275        self.buf.clear();
276        res
277    }
278}
279
280cfg_fs! {
281    impl Buf {
282        pub(crate) fn discard_read(&mut self) -> i64 {
283            let ret = -(self.bytes().len() as i64);
284            self.pos = 0;
285            self.buf.truncate(0);
286            ret
287        }
288
289        pub(crate) fn copy_from_bufs(&mut self, bufs: &[io::IoSlice<'_>], max_buf_size: usize) -> usize {
290            assert!(self.is_empty());
291
292            let mut rem = max_buf_size;
293            for buf in bufs {
294                if rem == 0 {
295                    break
296                }
297
298                let len = buf.len().min(rem);
299                self.buf.extend_from_slice(&buf[..len]);
300                rem -= len;
301            }
302
303            max_buf_size - rem
304        }
305    }
306}