speed_benchmark.py 716 B

12345678910111213141516171819202122232425262728293031
  1. import argparse
  2. import timeit
  3. from typing import Callable
  4. from utils import inference
  5. parser = argparse.ArgumentParser(description="Speed benchmark")
  6. parser.add_argument(
  7. "--repeat",
  8. type=int,
  9. default=3,
  10. help="Times an experiment will be run.",
  11. )
  12. args = parser.parse_args()
  13. def measure_speed(func: Callable[[], None]):
  14. # as written in https://docs.python.org/3/library/timeit.html#timeit.Timer.repeat,
  15. # min should be taken rather than the average
  16. runtimes = timeit.repeat(
  17. func,
  18. repeat=args.repeat,
  19. number=10,
  20. )
  21. print(runtimes)
  22. print("Min execution time: %.3fs" % (min(runtimes) / 10.0))
  23. if __name__ == "__main__":
  24. measure_speed(inference)