Ejemplo n.º 1
0
 def train_loop_fn(model, loader, device, context):
     loss_fn = nn.CrossEntropyLoss()
     optimizer = context.getattr_or(
         'optimizer', lambda: optim.SGD(model.parameters(),
                                        lr=FLAGS.lr,
                                        momentum=FLAGS.momentum,
                                        weight_decay=1e-4))
     lr_scheduler = context.getattr_or(
         'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
             optimizer,
             scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
             scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
             scheduler_divide_every_n_epochs=getattr(
                 FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
             num_steps_per_epoch=num_training_steps_per_epoch,
             summary_writer=writer if xm.is_master_ordinal() else None))
     tracker = xm.RateTracker()
     model.train()
     for x, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(FLAGS.batch_size)
         if x % FLAGS.log_steps == 0:
             test_utils.print_training_update(device, x, loss.item(),
                                              tracker.rate(),
                                              tracker.global_rate())
         if lr_scheduler:
             lr_scheduler.step()
Ejemplo n.º 2
0
def _train_update(device, x, loss, tracker, writer):
    test_utils.print_training_update(device,
                                     x,
                                     loss.item(),
                                     tracker.rate(),
                                     tracker.global_rate(),
                                     summary_writer=writer)
Ejemplo n.º 3
0
def _train_update(device, step, loss, tracker, epoch, writer):
    test_utils.print_training_update(
        device,
        step,
        loss,
        tracker.rate(),
        tracker.global_rate(),
        epoch,
        summary_writer=writer,
    )
Ejemplo n.º 4
0
def _train_update(device, step, loss, tracker, epoch, writer):
    st = time.time()
    loss.item()
    dt = time.time() - st
    test_utils.print_training_update(
        device,
        step,
        loss.item(),
        tracker.rate(),
        tracker.global_rate(),
        epoch,
        summary_writer=writer,
    )
    print(f'Getting loss took {dt} seconds')
Ejemplo n.º 5
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())
Ejemplo n.º 6
0
    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.NLLLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(
                model.parameters(), lr=lr, momentum=FLAGS.momentum))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())
Ejemplo n.º 7
0
def _train_update(device, step, loss, tracker, epoch):
    test_utils.print_training_update(device, step, loss.item(), tracker.rate(),
                                     tracker.global_rate(), epoch)