def mark_step(): torch_xla._XLAC._xla_step_marker(torch_xla._XLAC._xla_get_default_device(), [], wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False)) # Only emit metrics from the first local device index, to avoid emitting the # same values from different threads. if is_master_ordinal(): ms.save_metrics()
def mark_step(): if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False): print('torch_xla.core.xla_model::mark_step', file=sys.stderr, flush=True) torch_xla._XLAC._xla_step_marker( torch_xla._XLAC._xla_get_default_device(), [], wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False)) # Only emit metrics from the first local device index, to avoid emitting the # same values from different threads. if is_master_ordinal(): ms.save_metrics() _run_step_closures() _TLS.all_reduce_token = None