torch.profiler¶
Overview¶
PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference. Profiler’s context manager API can be used to better understand what model operators are the most expensive, examine their input shapes and stack traces, study device kernel activity and visualize the execution trace.
Note
An earlier version of the API in torch.autograd module is considered legacy and will be deprecated.
API Reference¶
- class torch.profiler._KinetoProfile(*, activities=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None)[source]¶
Low-level profiler wrap the autograd profile
- Parameters
activities (iterable) – list of activity groups (CPU, CUDA) to use in profiling, supported values:
torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA.record_shapes (bool) – save information about operator’s input shapes.
profile_memory (bool) – track tensor memory allocation/deallocation (see
export_memory_timelinefor more details).with_stack (bool) – record source information (file and line number) for the ops.
with_flops (bool) – use formula to estimate the FLOPS of specific operators (matrix multiplication and 2D convolution).
with_modules (bool) – record module hierarchy (including function names) corresponding to the callstack of the op. e.g. If module A’s forward call’s module B’s forward which contains an aten::add op, then aten::add’s module hierarchy is A.B Note that this support exist, at the moment, only for TorchScript models and not eager mode models.
experimental_config (_ExperimentalConfig) – A set of experimental options used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
Note
This API is experimental and subject to change in the future.
Enabling shape and stack tracing results in additional overhead. When record_shapes=True is specified, profiler will temporarily hold references to the tensors; that may further prevent certain optimizations that depend on the reference count and introduce extra tensor copies.
- add_metadata(key, value)[source]¶
Adds a user defined metadata with a string key and a string value into the trace file
- add_metadata_json(key, value)[source]¶
Adds a user defined metadata with a string key and a valid json value into the trace file
- events()[source]¶
Returns the list of unaggregated profiler events, to be used in the trace callback or after the profiling is finished
- export_memory_timeline(path, device=None)[source]¶
Export memory event information from the profiler collected tree for a given device, and export a timeline plot. There are 3 exportable files using
export_memory_timeline, each controlled by thepath’s suffix.For an HTML compatible plot, use the suffix
.html, and a memory timeline plot will be embedded as a PNG file in the HTML file.For plot points consisting of
[times, [sizes by category]], wheretimesare timestamps andsizesare memory usage for each category. The memory timeline plot will be saved a JSON (.json) or gzipped JSON (.json.gz) depending on the suffix.For raw memory points, use the suffix
.raw.json.gz. Each raw memory event will consist of(timestamp, action, numbytes, category), whereactionis one of[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY], andcategoryis one of the enums fromtorch.profiler._memory_profiler.Category.
Output: Memory timeline written as gzipped JSON, JSON, or HTML.
- class torch.profiler.profile(*, activities=None, schedule=None, on_trace_ready=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None, use_cuda=None)[source]¶
Profiler context manager.
- Parameters
activities (iterable) – list of activity groups (CPU, CUDA) to use in profiling, supported values:
torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA.schedule (Callable) – callable that takes step (int) as a single parameter and returns
ProfilerActionvalue that specifies the profiler action to perform at each step.on_trace_ready (Callable) – callable that is called at each step when
schedulereturnsProfilerAction.RECORD_AND_SAVEduring the profiling.record_shapes (bool) – save information about operator’s input shapes.
profile_memory (bool) – track tensor memory allocation/deallocation.
with_stack (bool) – record source information (file and line number) for the ops.
with_flops (bool) – use formula to estimate the FLOPs (floating point operations) of specific operators (matrix multiplication and 2D convolution).
with_modules (bool) – record module hierarchy (including function names) corresponding to the callstack of the op. e.g. If module A’s forward call’s module B’s forward which contains an aten::add op, then aten::add’s module hierarchy is A.B Note that this support exist, at the moment, only for TorchScript models and not eager mode models.
experimental_config (_ExperimentalConfig) – A set of experimental options used for Kineto library features. Note, backward compatibility is not guaranteed.
use_cuda (bool) –
Deprecated since version 1.8.1: use
activitiesinstead.
Note
Use
schedule()to generate the callable schedule. Non-default schedules are useful when profiling long training jobs and allow the user to obtain multiple traces at the different iterations of the training process. The default schedule simply records all the events continuously for the duration of the context manager.Note
Use
tensorboard_trace_handler()to generate result files for TensorBoard:on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)After profiling, result files can be found in the specified directory. Use the command:
tensorboard --logdir dir_nameto see the results in TensorBoard. For more information, see PyTorch Profiler TensorBoard Plugin
Note
Enabling shape and stack tracing results in additional overhead. When record_shapes=True is specified, profiler will temporarily hold references to the tensors; that may further prevent certain optimizations that depend on the reference count and introduce extra tensor copies.
Examples:
with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: code_to_profile() print(p.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1))
Using the profiler’s
schedule,on_trace_readyandstepfunctions:# Non-default profiler schedule allows user to turn profiler on and off # on different iterations of the training loop; # trace_handler is called every time a new trace becomes available def trace_handler(prof): print(prof.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1)) # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json") with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], # In this example with wait=1, warmup=1, active=2, repeat=1, # profiler will skip the first step/iteration, # start warming up on the second, record # the third and the forth iterations, # after which the trace will become available # and on_trace_ready (when set) is called; # the cycle repeats starting with the next step schedule=torch.profiler.schedule( wait=1, warmup=1, active=2, repeat=1), on_trace_ready=trace_handler # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log') # used when outputting for tensorboard ) as p: for iter in range(N): code_iteration_to_profile(iter) # send a signal to the profiler that the next iteration has started p.step()
- class torch.profiler.ProfilerAction(value)[source]¶
Profiler actions that can be taken at the specified intervals
- torch.profiler.schedule(*, wait, warmup, active, repeat=0, skip_first=0)[source]¶
Returns a callable that can be used as profiler
scheduleargument. The profiler will skip the firstskip_firststeps, then wait forwaitsteps, then do the warmup for the nextwarmupsteps, then do the active recording for the nextactivesteps and then repeat the cycle starting withwaitsteps. The optional number of cycles is specified with therepeatparameter, the zero value means that the cycles will continue until the profiling is finished.- Return type
- torch.profiler.tensorboard_trace_handler(dir_name, worker_name=None, use_gzip=False)[source]¶
Outputs tracing files to directory of
dir_name, then that directory can be directly delivered to tensorboard as logdir.worker_nameshould be unique for each worker in distributed scenario, it will be set to ‘[hostname]_[pid]’ by default.
Intel Instrumentation and Tracing Technology APIs¶
- torch.profiler.itt.mark(msg)[source]¶
Describe an instantaneous event that occurred at some point.
- Parameters
msg (str) – ASCII message to associate with the event.