def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)

            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()

            import resource
            print(f" CPU Usage After: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}")

            if step % FLAGS.log_steps == 0:
                # _train_update(device, step, loss, tracker, epoch, writer)
                xm.add_step_closure(
                    _train_update, args=(device, step, loss, tracker, epoch, writer)
                )
예제 #2
0
    def train_loop_fn(loader, epoch):
        if FLAGS.fine_grained_metrics:
            epoch_start_time = time.time()
            step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], []
        else:
            tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if FLAGS.fine_grained_metrics:
                step_start_time = time.time()
            optimizer.zero_grad()
            if FLAGS.fine_grained_metrics:
                fwd_start_time = time.time()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)
            if FLAGS.fine_grained_metrics:
                fwd_end_time = time.time()
                fwd_latency = fwd_end_time - fwd_start_time

                bwd_start_time = time.time()
            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            if lr_scheduler:
                lr_scheduler.step()
            if FLAGS.fine_grained_metrics:
                bwd_end_time = time.time()
                bwd_latency = bwd_end_time - bwd_start_time

                step_latency = bwd_end_time - step_start_time
                step_latency_tracker.append(step_latency)
                bwd_latency_tracker.append(bwd_latency)
                fwd_latency_tracker.append(fwd_latency)
            else:
                tracker.add(FLAGS.batch_size)
            if step % FLAGS.log_steps == 0:
                if FLAGS.fine_grained_metrics:
                    print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                                epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
                else:
                    # _train_update(device, step, loss, tracker, epoch, writer)
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              epoch, writer))
        if FLAGS.fine_grained_metrics:
            epoch_end_time = time.time()
            epoch_latency = epoch_end_time - epoch_start_time
            print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                            epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
예제 #3
0
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for step, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         with autocast():
             output = model(data)
             loss = loss_fn(output, target)
         scaler.scale(loss).backward()
         gradients = xm._fetch_gradients(optimizer)
         xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
         scaler.step(optimizer)
         scaler.update()
         tracker.add(flags.batch_size)
         if step % flags.log_steps == 0:
             xm.add_step_closure(_train_update,
                                 args=(device, step, loss, tracker, writer))
예제 #4
0
파일: train.py 프로젝트: codeislife99/xla
def loop_with_amp(model, input, positions, target, causal_mask, optimizer,
                  xla_enabled, autocast, scaler):
    with autocast():
        loss = model(input, positions, target, batch_mask=causal_mask)
    if xla_enabled:
        scaler.scale(loss).backward()
        gradients = xm._fetch_gradients(optimizer)
        xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
        scaler.step(optimizer)
        scaler.update()
        xm.mark_step()
    else:
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    return loss
예제 #5
0
def loop_with_amp(model, input_ids, attention_mask, labels, optim, xla_enabled, autocast, scaler):
    with autocast():
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
    
    if xla_enabled:
        scaler.scale(loss).backward()
        gradients = xm._fetch_gradients(optim)
        xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
        scaler.step(optim)
        scaler.update()
        xm.mark_step()
    else:
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()

    return loss, optim
예제 #6
0
    def _run_autocast_outofplace(self,
                                 op,
                                 args,
                                 run_as_type,
                                 out_type=None,
                                 module=torch,
                                 add_kwargs=None):
        # helper to cast args
        def cast(val, to_type):
            if isinstance(val, torch.Tensor):
                return val.to(to_type) if val.is_floating_point() else val
            elif isinstance(val, collections.abc.Iterable):
                return type(val)(cast(v, to_type) for v in val)
            else:
                return val

        if add_kwargs is None:
            add_kwargs = {}

        self.assertFalse(torch.is_autocast_enabled())
        with autocast():
            self.assertTrue(torch.is_autocast_enabled())

            out_type = out_type if out_type is not None else run_as_type
            output = output_method = None

            # Try module.* variant, if requested:
            if module is not None and hasattr(module, op):
                output = getattr(module, op)(*args, **add_kwargs)
                if isinstance(output, torch.Tensor):
                    self.assertTrue(
                        out_type == output.dtype,
                        "autocast for torch.{} produced {}, should produce {}".
                        format(op, output.dtype, out_type))

            # Try Tensor.* variant:
            if hasattr(torch.Tensor, op):
                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
                if isinstance(output_method, torch.Tensor):
                    self.assertTrue(
                        out_type == output_method.dtype,
                        "autocast for torch.{} produced {}, should produce torch.{}"
                        .format(op, output_method.dtype, out_type))

            self.assertTrue((output is not None) or (
                output_method is not None
            ), "{} not found as an attribute on either Tensor or the requested module {}"
                            .format(op, module))

            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
            # For example, lstm_cell returns a tuple and equal returns bool.
            def compare(first, second):
                if isinstance(first, torch.Tensor):
                    return torch.equal(first, second)
                elif isinstance(first, collections.abc.Iterable):
                    return all(compare(f, s) for f, s in zip(first, second))
                else:
                    return first == second

            # If both torch.* and Tensor.* variants were found, check outputs are identical
            if (output is not None) and (output_method is not None):
                self.assertTrue(type(output) == type(output_method))
                comparison = compare(output, output_method)
                self.assertTrue(
                    comparison,
                    "torch.{0} result did not match Tensor.{0} result".format(
                        op))

            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
            # as the C++-side autocasting, and should be bitwise accurate.
            output_to_compare = output if output is not None else output_method
            with autocast(enabled=False):
                self.assertFalse(torch.is_autocast_enabled())

                if module is not None and hasattr(module, op):
                    control = getattr(module, op)(*cast(args, run_as_type),
                                                  **add_kwargs)
                else:
                    control = getattr(args[0].to(run_as_type),
                                      op)(*cast(args[1:], run_as_type),
                                          **add_kwargs)
                self.assertTrue(type(output_to_compare) == type(control))
                comparison = compare(output_to_compare, control)
                self.assertTrue(
                    comparison,
                    "torch.{} result did not match control".format(op))
            self.assertTrue(torch.is_autocast_enabled())
        self.assertFalse(torch.is_autocast_enabled())