pulldown_cmark_escape/
lib.rs

1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19// THE SOFTWARE.
20
21//! Utility functions for HTML escaping. Only useful when building your own
22//! HTML renderer.
23
24use std::fmt::{self, Arguments};
25use std::io::{self, Write};
26use std::str::from_utf8;
27
28#[rustfmt::skip]
29static HREF_SAFE: [u8; 128] = [
30    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
31    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
32    0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
33    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
34    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
35    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,
36    0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
37    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0,
38];
39
40static HEX_CHARS: &[u8] = b"0123456789ABCDEF";
41static AMP_ESCAPE: &str = "&";
42static SINGLE_QUOTE_ESCAPE: &str = "'";
43
44/// This wrapper exists because we can't have both a blanket implementation
45/// for all types implementing `Write` and types of the for `&mut W` where
46/// `W: StrWrite`. Since we need the latter a lot, we choose to wrap
47/// `Write` types.
48#[derive(Debug)]
49pub struct IoWriter<W>(pub W);
50
51/// Trait that allows writing string slices. This is basically an extension
52/// of `std::io::Write` in order to include `String`.
53pub trait StrWrite {
54    type Error;
55
56    fn write_str(&mut self, s: &str) -> Result<(), Self::Error>;
57    fn write_fmt(&mut self, args: Arguments) -> Result<(), Self::Error>;
58}
59
60impl<W> StrWrite for IoWriter<W>
61where
62    W: Write,
63{
64    type Error = io::Error;
65
66    #[inline]
67    fn write_str(&mut self, s: &str) -> io::Result<()> {
68        self.0.write_all(s.as_bytes())
69    }
70
71    #[inline]
72    fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
73        self.0.write_fmt(args)
74    }
75}
76
77/// This wrapper exists because we can't have both a blanket implementation
78/// for all types implementing `io::Write` and types of the form `&mut W` where
79/// `W: StrWrite`. Since we need the latter a lot, we choose to wrap
80/// `Write` types.
81#[derive(Debug)]
82pub struct FmtWriter<W>(pub W);
83
84impl<W> StrWrite for FmtWriter<W>
85where
86    W: fmt::Write,
87{
88    type Error = fmt::Error;
89
90    #[inline]
91    fn write_str(&mut self, s: &str) -> fmt::Result {
92        self.0.write_str(s)
93    }
94
95    #[inline]
96    fn write_fmt(&mut self, args: Arguments) -> fmt::Result {
97        self.0.write_fmt(args)
98    }
99}
100
101impl StrWrite for String {
102    type Error = fmt::Error;
103
104    #[inline]
105    fn write_str(&mut self, s: &str) -> fmt::Result {
106        self.push_str(s);
107        Ok(())
108    }
109
110    #[inline]
111    fn write_fmt(&mut self, args: Arguments) -> fmt::Result {
112        fmt::Write::write_fmt(self, args)
113    }
114}
115
116impl<W> StrWrite for &'_ mut W
117where
118    W: StrWrite,
119{
120    type Error = W::Error;
121
122    #[inline]
123    fn write_str(&mut self, s: &str) -> Result<(), Self::Error> {
124        (**self).write_str(s)
125    }
126
127    #[inline]
128    fn write_fmt(&mut self, args: Arguments) -> Result<(), Self::Error> {
129        (**self).write_fmt(args)
130    }
131}
132
133/// Writes an href to the buffer, escaping href unsafe bytes.
134pub fn escape_href<W>(mut w: W, s: &str) -> Result<(), W::Error>
135where
136    W: StrWrite,
137{
138    let bytes = s.as_bytes();
139    let mut mark = 0;
140    for i in 0..bytes.len() {
141        let c = bytes[i];
142        if c >= 0x80 || HREF_SAFE[c as usize] == 0 {
143            // character needing escape
144
145            // write partial substring up to mark
146            if mark < i {
147                w.write_str(&s[mark..i])?;
148            }
149            match c {
150                b'&' => {
151                    w.write_str(AMP_ESCAPE)?;
152                }
153                b'\'' => {
154                    w.write_str(SINGLE_QUOTE_ESCAPE)?;
155                }
156                _ => {
157                    let mut buf = [0u8; 3];
158                    buf[0] = b'%';
159                    buf[1] = HEX_CHARS[((c as usize) >> 4) & 0xF];
160                    buf[2] = HEX_CHARS[(c as usize) & 0xF];
161                    let escaped = from_utf8(&buf).unwrap();
162                    w.write_str(escaped)?;
163                }
164            }
165            mark = i + 1; // all escaped characters are ASCII
166        }
167    }
168    w.write_str(&s[mark..])
169}
170
171const fn create_html_escape_table(body: bool) -> [u8; 256] {
172    let mut table = [0; 256];
173    table[b'&' as usize] = 1;
174    table[b'<' as usize] = 2;
175    table[b'>' as usize] = 3;
176    if !body {
177        table[b'"' as usize] = 4;
178        table[b'\'' as usize] = 5;
179    }
180    table
181}
182
183static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(false);
184static HTML_BODY_TEXT_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(true);
185
186static HTML_ESCAPES: [&str; 6] = ["", "&amp;", "&lt;", "&gt;", "&quot;", "&#39;"];
187
188/// Writes the given string to the Write sink, replacing special HTML bytes
189/// (<, >, &, ", ') by escape sequences.
190///
191/// Use this function to write output to quoted HTML attributes.
192/// Since this function doesn't escape spaces, unquoted attributes
193/// cannot be used. For example:
194///
195/// ```rust
196/// let mut value = String::new();
197/// pulldown_cmark_escape::escape_html(&mut value, "two words")
198///     .expect("writing to a string is infallible");
199/// // This is okay.
200/// let ok = format!("<a title='{value}'>test</a>");
201/// // This is not okay.
202/// //let not_ok = format!("<a title={value}>test</a>");
203/// ````
204pub fn escape_html<W: StrWrite>(w: W, s: &str) -> Result<(), W::Error> {
205    #[cfg(all(target_arch = "x86_64", feature = "simd"))]
206    {
207        simd::escape_html(w, s, &HTML_ESCAPE_TABLE)
208    }
209    #[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
210    {
211        escape_html_scalar(w, s, &HTML_ESCAPE_TABLE)
212    }
213}
214
215/// For use in HTML body text, writes the given string to the Write sink,
216/// replacing special HTML bytes (<, >, &) by escape sequences.
217///
218/// <div class="warning">
219///
220/// This function should be used for escaping text nodes, not attributes.
221/// In the below example, the word "foo" is an attribute, and the word
222/// "bar" is an text node. The word "bar" could be escaped by this function,
223/// but the word "foo" must be escaped using [`escape_html`].
224///
225/// ```html
226/// <span class="foo">bar</span>
227/// ```
228///
229/// If you aren't sure what the difference is, use [`escape_html`].
230/// It should always be correct, but will produce larger output.
231///
232/// </div>
233pub fn escape_html_body_text<W: StrWrite>(w: W, s: &str) -> Result<(), W::Error> {
234    #[cfg(all(target_arch = "x86_64", feature = "simd"))]
235    {
236        simd::escape_html(w, s, &HTML_BODY_TEXT_ESCAPE_TABLE)
237    }
238    #[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
239    {
240        escape_html_scalar(w, s, &HTML_BODY_TEXT_ESCAPE_TABLE)
241    }
242}
243
244fn escape_html_scalar<W: StrWrite>(
245    mut w: W,
246    s: &str,
247    table: &'static [u8; 256],
248) -> Result<(), W::Error> {
249    let bytes = s.as_bytes();
250    let mut mark = 0;
251    let mut i = 0;
252    while i < s.len() {
253        match bytes[i..].iter().position(|&c| table[c as usize] != 0) {
254            Some(pos) => {
255                i += pos;
256            }
257            None => break,
258        }
259        let c = bytes[i];
260        let escape = table[c as usize];
261        let escape_seq = HTML_ESCAPES[escape as usize];
262        w.write_str(&s[mark..i])?;
263        w.write_str(escape_seq)?;
264        i += 1;
265        mark = i; // all escaped characters are ASCII
266    }
267    w.write_str(&s[mark..])
268}
269
270#[cfg(all(target_arch = "x86_64", feature = "simd"))]
271mod simd {
272    use super::StrWrite;
273    use std::arch::x86_64::*;
274    use std::mem::size_of;
275
276    const VECTOR_SIZE: usize = size_of::<__m128i>();
277
278    pub(super) fn escape_html<W: StrWrite>(
279        mut w: W,
280        s: &str,
281        table: &'static [u8; 256],
282    ) -> Result<(), W::Error> {
283        // The SIMD accelerated code uses the PSHUFB instruction, which is part
284        // of the SSSE3 instruction set. Further, we can only use this code if
285        // the buffer is at least one VECTOR_SIZE in length to prevent reading
286        // out of bounds. If either of these conditions is not met, we fall back
287        // to scalar code.
288        if is_x86_feature_detected!("ssse3") && s.len() >= VECTOR_SIZE {
289            let bytes = s.as_bytes();
290            let mut mark = 0;
291
292            unsafe {
293                foreach_special_simd(bytes, 0, |i| {
294                    let escape_ix = *bytes.get_unchecked(i) as usize;
295                    let entry = table[escape_ix] as usize;
296                    w.write_str(s.get_unchecked(mark..i))?;
297                    mark = i + 1; // all escaped characters are ASCII
298                    if entry == 0 {
299                        w.write_str(s.get_unchecked(i..mark))
300                    } else {
301                        let replacement = super::HTML_ESCAPES[entry];
302                        w.write_str(replacement)
303                    }
304                })?;
305                w.write_str(s.get_unchecked(mark..))
306            }
307        } else {
308            super::escape_html_scalar(w, s, table)
309        }
310    }
311
312    /// Creates the lookup table for use in `compute_mask`.
313    const fn create_lookup() -> [u8; 16] {
314        let mut table = [0; 16];
315        table[(b'<' & 0x0f) as usize] = b'<';
316        table[(b'>' & 0x0f) as usize] = b'>';
317        table[(b'&' & 0x0f) as usize] = b'&';
318        table[(b'"' & 0x0f) as usize] = b'"';
319        table[(b'\'' & 0x0f) as usize] = b'\'';
320        table[0] = 0b0111_1111;
321        table
322    }
323
324    #[target_feature(enable = "ssse3")]
325    /// Computes a byte mask at given offset in the byte buffer. Its first 16 (least significant)
326    /// bits correspond to whether there is an HTML special byte (&, <, ", >) at the 16 bytes
327    /// `bytes[offset..]`. For example, the mask `(1 << 3)` states that there is an HTML byte
328    /// at `offset + 3`. It is only safe to call this function when
329    /// `bytes.len() >= offset + VECTOR_SIZE`.
330    unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 {
331        debug_assert!(bytes.len() >= offset + VECTOR_SIZE);
332
333        let table = create_lookup();
334        let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i);
335        let raw_ptr = bytes.as_ptr().add(offset) as *const __m128i;
336
337        // Load the vector from memory.
338        let vector = _mm_loadu_si128(raw_ptr);
339        // We take the least significant 4 bits of every byte and use them as indices
340        // to map into the lookup vector.
341        // Note that shuffle maps bytes with their most significant bit set to lookup[0].
342        // Bytes that share their lower nibble with an HTML special byte get mapped to that
343        // corresponding special byte. Note that all HTML special bytes have distinct lower
344        // nibbles. Other bytes either get mapped to 0 or 127.
345        let expected = _mm_shuffle_epi8(lookup, vector);
346        // We compare the original vector to the mapped output. Bytes that shared a lower
347        // nibble with an HTML special byte match *only* if they are that special byte. Bytes
348        // that have either a 0 lower nibble or their most significant bit set were mapped to
349        // 127 and will hence never match. All other bytes have non-zero lower nibbles but
350        // were mapped to 0 and will therefore also not match.
351        let matches = _mm_cmpeq_epi8(expected, vector);
352
353        // Translate matches to a bitmask, where every 1 corresponds to a HTML special character
354        // and a 0 is a non-HTML byte.
355        _mm_movemask_epi8(matches)
356    }
357
358    /// Calls the given function with the index of every byte in the given byteslice
359    /// that is either ", &, <, or > and for no other byte.
360    /// Make sure to only call this when `bytes.len() >= 16`, undefined behaviour may
361    /// occur otherwise.
362    #[target_feature(enable = "ssse3")]
363    unsafe fn foreach_special_simd<E, F>(
364        bytes: &[u8],
365        mut offset: usize,
366        mut callback: F,
367    ) -> Result<(), E>
368    where
369        F: FnMut(usize) -> Result<(), E>,
370    {
371        // The strategy here is to walk the byte buffer in chunks of VECTOR_SIZE (16)
372        // bytes at a time starting at the given offset. For each chunk, we compute a
373        // a bitmask indicating whether the corresponding byte is a HTML special byte.
374        // We then iterate over all the 1 bits in this mask and call the callback function
375        // with the corresponding index in the buffer.
376        // When the number of HTML special bytes in the buffer is relatively low, this
377        // allows us to quickly go through the buffer without a lookup and for every
378        // single byte.
379
380        debug_assert!(bytes.len() >= VECTOR_SIZE);
381        let upperbound = bytes.len() - VECTOR_SIZE;
382        while offset < upperbound {
383            let mut mask = compute_mask(bytes, offset);
384            while mask != 0 {
385                let ix = mask.trailing_zeros();
386                callback(offset + ix as usize)?;
387                mask ^= mask & -mask;
388            }
389            offset += VECTOR_SIZE;
390        }
391
392        // Final iteration. We align the read with the end of the slice and
393        // shift off the bytes at start we have already scanned.
394        let mut mask = compute_mask(bytes, upperbound);
395        mask >>= offset - upperbound;
396        while mask != 0 {
397            let ix = mask.trailing_zeros();
398            callback(offset + ix as usize)?;
399            mask ^= mask & -mask;
400        }
401        Ok(())
402    }
403
404    #[cfg(test)]
405    mod html_scan_tests {
406        #[test]
407        fn multichunk() {
408            let mut vec = Vec::new();
409            unsafe {
410                super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&".as_bytes(), 0, |ix| {
411                    #[allow(clippy::unit_arg)]
412                    Ok::<_, std::fmt::Error>(vec.push(ix))
413                })
414                .unwrap();
415            }
416            assert_eq!(vec, vec![0, 9, 14, 15, 19]);
417        }
418
419        // only match these bytes, and when we match them, match them VECTOR_SIZE times
420        #[test]
421        fn only_right_bytes_matched() {
422            for b in 0..255u8 {
423                let right_byte = b == b'&' || b == b'<' || b == b'>' || b == b'"' || b == b'\'';
424                let vek = vec![b; super::VECTOR_SIZE];
425                let mut match_count = 0;
426                unsafe {
427                    super::foreach_special_simd(&vek, 0, |_| {
428                        match_count += 1;
429                        Ok::<_, std::fmt::Error>(())
430                    })
431                    .unwrap();
432                }
433                assert!((match_count > 0) == (match_count == super::VECTOR_SIZE));
434                assert_eq!(
435                    (match_count == super::VECTOR_SIZE),
436                    right_byte,
437                    "match_count: {}, byte: {:?}",
438                    match_count,
439                    b as char
440                );
441            }
442        }
443    }
444}
445
446#[cfg(test)]
447mod test {
448    pub use super::{escape_href, escape_html, escape_html_body_text};
449
450    #[test]
451    fn check_href_escape() {
452        let mut s = String::new();
453        escape_href(&mut s, "&^_").unwrap();
454        assert_eq!(s.as_str(), "&amp;^_");
455    }
456
457    #[test]
458    fn check_attr_escape() {
459        let mut s = String::new();
460        escape_html(&mut s, r##"&^"'_"##).unwrap();
461        assert_eq!(s.as_str(), "&amp;^&quot;&#39;_");
462    }
463
464    #[test]
465    fn check_body_escape() {
466        let mut s = String::new();
467        escape_html_body_text(&mut s, r##"&^"'_"##).unwrap();
468        assert_eq!(s.as_str(), r##"&amp;^"'_"##);
469    }
470}