1use std::fmt::{self, Arguments};
25use std::io::{self, Write};
26use std::str::from_utf8;
27
28#[rustfmt::skip]
29static HREF_SAFE: [u8; 128] = [
30 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
31 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
32 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
33 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
34 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
35 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,
36 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
37 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0,
38];
39
40static HEX_CHARS: &[u8] = b"0123456789ABCDEF";
41static AMP_ESCAPE: &str = "&";
42static SINGLE_QUOTE_ESCAPE: &str = "'";
43
44#[derive(Debug)]
49pub struct IoWriter<W>(pub W);
50
51pub trait StrWrite {
54 type Error;
55
56 fn write_str(&mut self, s: &str) -> Result<(), Self::Error>;
57 fn write_fmt(&mut self, args: Arguments) -> Result<(), Self::Error>;
58}
59
60impl<W> StrWrite for IoWriter<W>
61where
62 W: Write,
63{
64 type Error = io::Error;
65
66 #[inline]
67 fn write_str(&mut self, s: &str) -> io::Result<()> {
68 self.0.write_all(s.as_bytes())
69 }
70
71 #[inline]
72 fn write_fmt(&mut self, args: Arguments) -> io::Result<()> {
73 self.0.write_fmt(args)
74 }
75}
76
77#[derive(Debug)]
82pub struct FmtWriter<W>(pub W);
83
84impl<W> StrWrite for FmtWriter<W>
85where
86 W: fmt::Write,
87{
88 type Error = fmt::Error;
89
90 #[inline]
91 fn write_str(&mut self, s: &str) -> fmt::Result {
92 self.0.write_str(s)
93 }
94
95 #[inline]
96 fn write_fmt(&mut self, args: Arguments) -> fmt::Result {
97 self.0.write_fmt(args)
98 }
99}
100
101impl StrWrite for String {
102 type Error = fmt::Error;
103
104 #[inline]
105 fn write_str(&mut self, s: &str) -> fmt::Result {
106 self.push_str(s);
107 Ok(())
108 }
109
110 #[inline]
111 fn write_fmt(&mut self, args: Arguments) -> fmt::Result {
112 fmt::Write::write_fmt(self, args)
113 }
114}
115
116impl<W> StrWrite for &'_ mut W
117where
118 W: StrWrite,
119{
120 type Error = W::Error;
121
122 #[inline]
123 fn write_str(&mut self, s: &str) -> Result<(), Self::Error> {
124 (**self).write_str(s)
125 }
126
127 #[inline]
128 fn write_fmt(&mut self, args: Arguments) -> Result<(), Self::Error> {
129 (**self).write_fmt(args)
130 }
131}
132
133pub fn escape_href<W>(mut w: W, s: &str) -> Result<(), W::Error>
135where
136 W: StrWrite,
137{
138 let bytes = s.as_bytes();
139 let mut mark = 0;
140 for i in 0..bytes.len() {
141 let c = bytes[i];
142 if c >= 0x80 || HREF_SAFE[c as usize] == 0 {
143 if mark < i {
147 w.write_str(&s[mark..i])?;
148 }
149 match c {
150 b'&' => {
151 w.write_str(AMP_ESCAPE)?;
152 }
153 b'\'' => {
154 w.write_str(SINGLE_QUOTE_ESCAPE)?;
155 }
156 _ => {
157 let mut buf = [0u8; 3];
158 buf[0] = b'%';
159 buf[1] = HEX_CHARS[((c as usize) >> 4) & 0xF];
160 buf[2] = HEX_CHARS[(c as usize) & 0xF];
161 let escaped = from_utf8(&buf).unwrap();
162 w.write_str(escaped)?;
163 }
164 }
165 mark = i + 1; }
167 }
168 w.write_str(&s[mark..])
169}
170
171const fn create_html_escape_table(body: bool) -> [u8; 256] {
172 let mut table = [0; 256];
173 table[b'&' as usize] = 1;
174 table[b'<' as usize] = 2;
175 table[b'>' as usize] = 3;
176 if !body {
177 table[b'"' as usize] = 4;
178 table[b'\'' as usize] = 5;
179 }
180 table
181}
182
183static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(false);
184static HTML_BODY_TEXT_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(true);
185
186static HTML_ESCAPES: [&str; 6] = ["", "&", "<", ">", """, "'"];
187
188pub fn escape_html<W: StrWrite>(w: W, s: &str) -> Result<(), W::Error> {
205 #[cfg(all(target_arch = "x86_64", feature = "simd"))]
206 {
207 simd::escape_html(w, s, &HTML_ESCAPE_TABLE)
208 }
209 #[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
210 {
211 escape_html_scalar(w, s, &HTML_ESCAPE_TABLE)
212 }
213}
214
215pub fn escape_html_body_text<W: StrWrite>(w: W, s: &str) -> Result<(), W::Error> {
234 #[cfg(all(target_arch = "x86_64", feature = "simd"))]
235 {
236 simd::escape_html(w, s, &HTML_BODY_TEXT_ESCAPE_TABLE)
237 }
238 #[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
239 {
240 escape_html_scalar(w, s, &HTML_BODY_TEXT_ESCAPE_TABLE)
241 }
242}
243
244fn escape_html_scalar<W: StrWrite>(
245 mut w: W,
246 s: &str,
247 table: &'static [u8; 256],
248) -> Result<(), W::Error> {
249 let bytes = s.as_bytes();
250 let mut mark = 0;
251 let mut i = 0;
252 while i < s.len() {
253 match bytes[i..].iter().position(|&c| table[c as usize] != 0) {
254 Some(pos) => {
255 i += pos;
256 }
257 None => break,
258 }
259 let c = bytes[i];
260 let escape = table[c as usize];
261 let escape_seq = HTML_ESCAPES[escape as usize];
262 w.write_str(&s[mark..i])?;
263 w.write_str(escape_seq)?;
264 i += 1;
265 mark = i; }
267 w.write_str(&s[mark..])
268}
269
270#[cfg(all(target_arch = "x86_64", feature = "simd"))]
271mod simd {
272 use super::StrWrite;
273 use std::arch::x86_64::*;
274 use std::mem::size_of;
275
276 const VECTOR_SIZE: usize = size_of::<__m128i>();
277
278 pub(super) fn escape_html<W: StrWrite>(
279 mut w: W,
280 s: &str,
281 table: &'static [u8; 256],
282 ) -> Result<(), W::Error> {
283 if is_x86_feature_detected!("ssse3") && s.len() >= VECTOR_SIZE {
289 let bytes = s.as_bytes();
290 let mut mark = 0;
291
292 unsafe {
293 foreach_special_simd(bytes, 0, |i| {
294 let escape_ix = *bytes.get_unchecked(i) as usize;
295 let entry = table[escape_ix] as usize;
296 w.write_str(s.get_unchecked(mark..i))?;
297 mark = i + 1; if entry == 0 {
299 w.write_str(s.get_unchecked(i..mark))
300 } else {
301 let replacement = super::HTML_ESCAPES[entry];
302 w.write_str(replacement)
303 }
304 })?;
305 w.write_str(s.get_unchecked(mark..))
306 }
307 } else {
308 super::escape_html_scalar(w, s, table)
309 }
310 }
311
312 const fn create_lookup() -> [u8; 16] {
314 let mut table = [0; 16];
315 table[(b'<' & 0x0f) as usize] = b'<';
316 table[(b'>' & 0x0f) as usize] = b'>';
317 table[(b'&' & 0x0f) as usize] = b'&';
318 table[(b'"' & 0x0f) as usize] = b'"';
319 table[(b'\'' & 0x0f) as usize] = b'\'';
320 table[0] = 0b0111_1111;
321 table
322 }
323
324 #[target_feature(enable = "ssse3")]
325 unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 {
331 debug_assert!(bytes.len() >= offset + VECTOR_SIZE);
332
333 let table = create_lookup();
334 let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i);
335 let raw_ptr = bytes.as_ptr().add(offset) as *const __m128i;
336
337 let vector = _mm_loadu_si128(raw_ptr);
339 let expected = _mm_shuffle_epi8(lookup, vector);
346 let matches = _mm_cmpeq_epi8(expected, vector);
352
353 _mm_movemask_epi8(matches)
356 }
357
358 #[target_feature(enable = "ssse3")]
363 unsafe fn foreach_special_simd<E, F>(
364 bytes: &[u8],
365 mut offset: usize,
366 mut callback: F,
367 ) -> Result<(), E>
368 where
369 F: FnMut(usize) -> Result<(), E>,
370 {
371 debug_assert!(bytes.len() >= VECTOR_SIZE);
381 let upperbound = bytes.len() - VECTOR_SIZE;
382 while offset < upperbound {
383 let mut mask = compute_mask(bytes, offset);
384 while mask != 0 {
385 let ix = mask.trailing_zeros();
386 callback(offset + ix as usize)?;
387 mask ^= mask & -mask;
388 }
389 offset += VECTOR_SIZE;
390 }
391
392 let mut mask = compute_mask(bytes, upperbound);
395 mask >>= offset - upperbound;
396 while mask != 0 {
397 let ix = mask.trailing_zeros();
398 callback(offset + ix as usize)?;
399 mask ^= mask & -mask;
400 }
401 Ok(())
402 }
403
404 #[cfg(test)]
405 mod html_scan_tests {
406 #[test]
407 fn multichunk() {
408 let mut vec = Vec::new();
409 unsafe {
410 super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&".as_bytes(), 0, |ix| {
411 #[allow(clippy::unit_arg)]
412 Ok::<_, std::fmt::Error>(vec.push(ix))
413 })
414 .unwrap();
415 }
416 assert_eq!(vec, vec![0, 9, 14, 15, 19]);
417 }
418
419 #[test]
421 fn only_right_bytes_matched() {
422 for b in 0..255u8 {
423 let right_byte = b == b'&' || b == b'<' || b == b'>' || b == b'"' || b == b'\'';
424 let vek = vec![b; super::VECTOR_SIZE];
425 let mut match_count = 0;
426 unsafe {
427 super::foreach_special_simd(&vek, 0, |_| {
428 match_count += 1;
429 Ok::<_, std::fmt::Error>(())
430 })
431 .unwrap();
432 }
433 assert!((match_count > 0) == (match_count == super::VECTOR_SIZE));
434 assert_eq!(
435 (match_count == super::VECTOR_SIZE),
436 right_byte,
437 "match_count: {}, byte: {:?}",
438 match_count,
439 b as char
440 );
441 }
442 }
443 }
444}
445
446#[cfg(test)]
447mod test {
448 pub use super::{escape_href, escape_html, escape_html_body_text};
449
450 #[test]
451 fn check_href_escape() {
452 let mut s = String::new();
453 escape_href(&mut s, "&^_").unwrap();
454 assert_eq!(s.as_str(), "&^_");
455 }
456
457 #[test]
458 fn check_attr_escape() {
459 let mut s = String::new();
460 escape_html(&mut s, r##"&^"'_"##).unwrap();
461 assert_eq!(s.as_str(), "&^"'_");
462 }
463
464 #[test]
465 fn check_body_escape() {
466 let mut s = String::new();
467 escape_html_body_text(&mut s, r##"&^"'_"##).unwrap();
468 assert_eq!(s.as_str(), r##"&^"'_"##);
469 }
470}