Beispiel #1
0
 def train_loop_fn(model, loader, device, context):
     loss_fn = nn.CrossEntropyLoss()
     optimizer = context.getattr_or(
         'optimizer', lambda: optim.SGD(model.parameters(),
                                        lr=FLAGS.lr,
                                        momentum=FLAGS.momentum,
                                        weight_decay=5e-4))
     lr_scheduler = context.getattr_or(
         'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
             optimizer,
             scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
             scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
             scheduler_divide_every_n_epochs=getattr(
                 FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
             num_steps_per_epoch=num_training_steps_per_epoch,
             summary_writer=writer if xm.is_master_ordinal() else None))
     tracker = xm.RateTracker()
     model.train()
     for x, (data, target) in loader:
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(FLAGS.batch_size)
         if x % FLAGS.log_steps == 0:
             test_utils.print_training_update(device, x, loss.item(),
                                              tracker.rate(),
                                              tracker.global_rate())
         if lr_scheduler:
             lr_scheduler.step()
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     total_samples = 0
     correct = 0
     top5_accuracys = 0
     losses = 0
     for x, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(FLAGS.batch_size)
         pred = output.max(1, keepdim=True)[1]
         correct += pred.eq(target.view_as(pred)).sum().item()
         losses += loss.item()
         total_samples += data.size()[0]
         top5_accuracys += topk_accuracy(output, target, topk=5).item()
         if lr_scheduler:
             lr_scheduler.step()
         if x % FLAGS.log_steps == 0:
             test_utils.print_training_update(device, x, loss.item(),
                                              tracker.rate(),
                                              tracker.global_rate())
     return (
         losses / (x + 1),
         (100.0 * correct / total_samples),
         (top5_accuracys / (x + 1)),
     )
Beispiel #3
0
  def train_loop_fn(loader):
    tracker = xm.RateTracker()

    model.train()
    for x, (data, target) in loader:
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if x % FLAGS.log_steps == 0:
        test_utils.print_training_update(device, x, loss.item(), tracker.rate(),
                                         tracker.global_rate())
Beispiel #4
0
    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.NLLLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(
                model.parameters(), lr=lr, momentum=FLAGS.momentum))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())