def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() import resource print(f" CPU Usage After: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}") if step % FLAGS.log_steps == 0: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure( _train_update, args=(device, step, loss, tracker, epoch, writer) )
def train_loop_fn(loader, epoch): if FLAGS.fine_grained_metrics: epoch_start_time = time.time() step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], [] else: tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if FLAGS.fine_grained_metrics: step_start_time = time.time() optimizer.zero_grad() if FLAGS.fine_grained_metrics: fwd_start_time = time.time() with autocast(): output = model(data) loss = loss_fn(output, target) if FLAGS.fine_grained_metrics: fwd_end_time = time.time() fwd_latency = fwd_end_time - fwd_start_time bwd_start_time = time.time() scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() if lr_scheduler: lr_scheduler.step() if FLAGS.fine_grained_metrics: bwd_end_time = time.time() bwd_latency = bwd_end_time - bwd_start_time step_latency = bwd_end_time - step_start_time step_latency_tracker.append(step_latency) bwd_latency_tracker.append(bwd_latency) fwd_latency_tracker.append(fwd_latency) else: tracker.add(FLAGS.batch_size) if step % FLAGS.log_steps == 0: if FLAGS.fine_grained_metrics: print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) else: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) if FLAGS.fine_grained_metrics: epoch_end_time = time.time() epoch_latency = epoch_end_time - epoch_start_time print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))
def loop_with_amp(model, input, positions, target, causal_mask, optimizer, xla_enabled, autocast, scaler): with autocast(): loss = model(input, positions, target, batch_mask=causal_mask) if xla_enabled: scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() else: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() return loss
def loop_with_amp(model, input_ids, attention_mask, labels, optim, xla_enabled, autocast, scaler): with autocast(): outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs[0] if xla_enabled: scaler.scale(loss).backward() gradients = xm._fetch_gradients(optim) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optim) scaler.update() xm.mark_step() else: scaler.scale(loss).backward() scaler.step(optim) scaler.update() return loss, optim
def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None): # helper to cast args def cast(val, to_type): if isinstance(val, torch.Tensor): return val.to(to_type) if val.is_floating_point() else val elif isinstance(val, collections.abc.Iterable): return type(val)(cast(v, to_type) for v in val) else: return val if add_kwargs is None: add_kwargs = {} self.assertFalse(torch.is_autocast_enabled()) with autocast(): self.assertTrue(torch.is_autocast_enabled()) out_type = out_type if out_type is not None else run_as_type output = output_method = None # Try module.* variant, if requested: if module is not None and hasattr(module, op): output = getattr(module, op)(*args, **add_kwargs) if isinstance(output, torch.Tensor): self.assertTrue( out_type == output.dtype, "autocast for torch.{} produced {}, should produce {}". format(op, output.dtype, out_type)) # Try Tensor.* variant: if hasattr(torch.Tensor, op): output_method = getattr(args[0], op)(*args[1:], **add_kwargs) if isinstance(output_method, torch.Tensor): self.assertTrue( out_type == output_method.dtype, "autocast for torch.{} produced {}, should produce torch.{}" .format(op, output_method.dtype, out_type)) self.assertTrue((output is not None) or ( output_method is not None ), "{} not found as an attribute on either Tensor or the requested module {}" .format(op, module)) # Accounts for ops that return Tensors, iterables, and other non-Tensors. # For example, lstm_cell returns a tuple and equal returns bool. def compare(first, second): if isinstance(first, torch.Tensor): return torch.equal(first, second) elif isinstance(first, collections.abc.Iterable): return all(compare(f, s) for f, s in zip(first, second)) else: return first == second # If both torch.* and Tensor.* variants were found, check outputs are identical if (output is not None) and (output_method is not None): self.assertTrue(type(output) == type(output_method)) comparison = compare(output, output_method) self.assertTrue( comparison, "torch.{0} result did not match Tensor.{0} result".format( op)) # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method with autocast(enabled=False): self.assertFalse(torch.is_autocast_enabled()) if module is not None and hasattr(module, op): control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs) else: control = getattr(args[0].to(run_as_type), op)(*cast(args[1:], run_as_type), **add_kwargs) self.assertTrue(type(output_to_compare) == type(control)) comparison = compare(output_to_compare, control) self.assertTrue( comparison, "torch.{} result did not match control".format(op)) self.assertTrue(torch.is_autocast_enabled()) self.assertFalse(torch.is_autocast_enabled())