1use std::fmt;
2use std::hint;
3use std::io::Write;
4use std::path::PathBuf;
5use std::sync::Arc;
6use std::time::Instant;
7
8use crate::alloc::Vec;
9use crate::cli::{AssetKind, CommandBase, Config, ExitCode, Io, SharedFlags};
10use crate::modules::capture_io::CaptureIo;
11use crate::modules::test::Bencher;
12use crate::runtime::{Function, Unit, Value};
13use crate::support::Result;
14use crate::{Context, Hash, ItemBuf, Sources, Vm};
15
16use super::{Color, Stream};
17
18mod cli {
19 use std::path::PathBuf;
20 use std::vec::Vec;
21
22 use clap::Parser;
23
24 #[derive(Parser, Debug)]
25 #[command(rename_all = "kebab-case")]
26 pub(crate) struct Flags {
27 #[arg(long, default_value = "5.0")]
29 pub(super) warmup: f32,
30 #[arg(long, default_value = "10.0")]
32 pub(super) iter: f32,
33 pub(super) bench_path: Vec<PathBuf>,
35 }
36}
37
38pub(super) use cli::Flags;
39
40impl CommandBase for Flags {
41 #[inline]
42 fn is_workspace(&self, kind: AssetKind) -> bool {
43 matches!(kind, AssetKind::Bench)
44 }
45
46 #[inline]
47 fn describe(&self) -> &str {
48 "Benchmarking"
49 }
50
51 #[inline]
52 fn propagate(&mut self, c: &mut Config, _: &mut SharedFlags) {
53 c.test = true;
54 }
55
56 #[inline]
57 fn paths(&self) -> &[PathBuf] {
58 &self.bench_path
59 }
60}
61
62pub(super) async fn run(
64 io: &mut Io<'_>,
65 args: &Flags,
66 context: &Context,
67 capture_io: Option<&CaptureIo>,
68 unit: Arc<Unit>,
69 sources: &Sources,
70 fns: &[(Hash, ItemBuf)],
71) -> Result<ExitCode> {
72 let runtime = Arc::new(context.runtime()?);
73 let mut vm = Vm::new(runtime, unit);
74
75 if fns.is_empty() {
76 return Ok(ExitCode::Success);
77 }
78
79 io.section("Benching", Stream::Stdout, Color::Highlight)?
80 .append(format_args!(" Found {} benches", fns.len()))?
81 .close()?;
82
83 let mut any_error = false;
84
85 for (hash, item) in fns {
86 let mut bencher = Bencher::default();
87
88 if let Err(error) = vm.call(*hash, (&mut bencher,)) {
89 writeln!(io.stdout, "{}: Error in benchmark", item)?;
90 error.emit(io.stdout, sources)?;
91 any_error = true;
92
93 if let Some(capture_io) = capture_io {
94 if !capture_io.is_empty() {
95 writeln!(io.stdout, "-- output --")?;
96 capture_io.drain_into(&mut *io.stdout)?;
97 writeln!(io.stdout, "-- end output --")?;
98 }
99 }
100
101 continue;
102 }
103
104 let fns = bencher.into_functions();
105
106 let multiple = fns.len() > 1;
107
108 for (i, f) in fns.iter().enumerate() {
109 let out;
110
111 let item: &dyn fmt::Display = if multiple {
112 out = DisplayHash(item, i);
113 &out
114 } else {
115 &item
116 };
117
118 if let Err(e) = bench_fn(io, item, args, f) {
119 writeln!(io.stdout, "{}: Error in bench iteration: {}", item, e)?;
120
121 if let Some(capture_io) = capture_io {
122 if !capture_io.is_empty() {
123 writeln!(io.stdout, "-- output --")?;
124 capture_io.drain_into(&mut *io.stdout)?;
125 writeln!(io.stdout, "-- end output --")?;
126 }
127 }
128
129 any_error = true;
130 }
131 }
132 }
133
134 if any_error {
135 Ok(ExitCode::Failure)
136 } else {
137 Ok(ExitCode::Success)
138 }
139}
140
141struct DisplayHash<A, B>(A, B);
142
143impl<A, B> fmt::Display for DisplayHash<A, B>
144where
145 A: fmt::Display,
146 B: fmt::Display,
147{
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 let Self(a, b) = self;
150 write!(f, "{a}#{b}")
151 }
152}
153
154fn bench_fn(io: &mut Io<'_>, item: &dyn fmt::Display, args: &Flags, f: &Function) -> Result<()> {
155 let mut section = io.section("Warming up", Stream::Stdout, Color::Progress)?;
156 section.append(format_args!(" {item} for {:.2}s:", args.warmup))?;
157 section.flush()?;
158
159 let start = Instant::now();
160 let mut warmup = 0;
161
162 let elapsed = loop {
163 let value = f.call::<Value>(()).into_result()?;
164 drop(hint::black_box(value));
165 warmup += 1;
166
167 let elapsed = start.elapsed().as_secs_f32();
168
169 if elapsed >= args.warmup {
170 break elapsed;
171 }
172 };
173
174 section
175 .append(format_args!(" {warmup} iters in {elapsed:.2}s"))?
176 .close()?;
177
178 let iterations = (((args.iter * warmup as f32) / args.warmup).round() as usize).max(1);
179 let step = (iterations / 10).max(1);
180 let mut collected = Vec::try_with_capacity(iterations)?;
181
182 let mut section = io.section("Running", Stream::Stdout, Color::Progress)?;
183 section.append(format_args!(
184 " {item} {} iterations for {:.2}s: ",
185 iterations, args.iter
186 ))?;
187
188 let mut added = 0;
189
190 for n in 0..=iterations {
191 if n % step == 0 {
192 section.append(".")?;
193 section.flush()?;
194 added += 1;
195 }
196
197 let start = Instant::now();
198 let value = f.call::<Value>(()).into_result()?;
199 let duration = Instant::now().duration_since(start);
200 collected.try_push(duration.as_nanos() as i128)?;
201 drop(hint::black_box(value));
202 }
203
204 for _ in added..10 {
205 section.append(".")?;
206 section.flush()?;
207 }
208
209 section.close()?;
210
211 collected.sort_unstable();
212
213 let len = collected.len() as f64;
214 let average = collected.iter().copied().sum::<i128>() as f64 / len;
215
216 let variance = collected
217 .iter()
218 .copied()
219 .map(|n| (n as f64 - average).powf(2.0))
220 .sum::<f64>()
221 / len;
222
223 let stddev = variance.sqrt();
224
225 let format = Format {
226 average: average as u128,
227 stddev: stddev as u128,
228 iterations,
229 };
230
231 let mut section = io.section("Result", Stream::Stdout, Color::Highlight)?;
232 section.append(format_args!(" {item}: {format}"))?.close()?;
233 Ok(())
234}
235
236struct Format {
237 average: u128,
238 stddev: u128,
239 iterations: usize,
240}
241
242impl fmt::Display for Format {
243 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244 write!(
245 f,
246 "mean={:.2}, stddev={:.2}, iterations={}",
247 Time(self.average),
248 Time(self.stddev),
249 self.iterations
250 )
251 }
252}
253
254struct Time(u128);
255
256impl fmt::Display for Time {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 if self.0 >= 1_000_000_000 {
259 write!(f, "{:.3}s", self.0 as f64 / 1_000_000_000.0)
260 } else if self.0 >= 1_000_000 {
261 write!(f, "{:.3}ms", self.0 as f64 / 1_000_000.0)
262 } else if self.0 >= 1_000 {
263 write!(f, "{:.3}µs", self.0 as f64 / 1_000.0)
264 } else {
265 write!(f, "{}ns", self.0)
266 }
267 }
268}