Example #1
0
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
          scaled_lr):
    """Train and evaluate the model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        optimizer (torch.nn.optim):
        data_loader_train (torch.utils.data.DataLoader):
        data_loader_test (torch.utils.data.DataLoader):
    """
    model.train()
    prefetching_enabled = is_data_prefetching_enabled()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq
    steps_per_epoch = len(data_loader_train)

    checkpoint_writer = make_serial_checkpoint_writer(
        embedding_indices=range(len(get_categorical_feature_sizes(FLAGS))),
        config=FLAGS.flag_values_dict())

    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    if prefetching_enabled:
        data_stream = torch.cuda.Stream()

    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()

    timer.click()

    for epoch in range(FLAGS.epochs):
        input_pipeline = iter(data_loader_train)

        if prefetching_enabled:
            input_pipeline = prefetcher(input_pipeline, data_stream)

        for step, batch in enumerate(input_pipeline):
            global_step = steps_per_epoch * epoch + step
            numerical_features, categorical_features, click = batch

            utils.lr_step(optimizer,
                          num_warmup_iter=FLAGS.warmup_steps,
                          current_step=global_step + 1,
                          base_lr=scaled_lr,
                          warmup_factor=FLAGS.warmup_factor,
                          decay_steps=FLAGS.decay_steps,
                          decay_start_step=FLAGS.decay_start_step)

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    f"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            if prefetching_enabled:
                torch.cuda.synchronize()

            output = model(numerical_features,
                           categorical_features).squeeze().float()

            loss = loss_fn(output, click.squeeze())

            # Setting grad to None is faster than zero_grad()
            for param_group in optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = None

            if FLAGS.amp:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            if step % print_freq == 0 and step > 0:
                loss_value = loss.item()

                timer.click()

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(loss=loss_value,
                                         lr=optimizer.param_groups[0]["lr"])
                else:
                    unscale_factor = FLAGS.loss_scale if FLAGS.amp else 1
                    metric_logger.update(
                        loss=loss_value / unscale_factor,
                        step_time=timer.measured / FLAGS.print_freq,
                        lr=optimizer.param_groups[0]["lr"] * unscale_factor)

                if global_step < FLAGS.benchmark_warmup_steps:
                    print(
                        f'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]'
                    )
                    continue

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

            if (global_step % test_freq == 0 and global_step > 0
                    and global_step / steps_per_epoch >= FLAGS.test_after):
                loss, auc, test_step_time = evaluate(model, loss_fn,
                                                     data_loader_test)
                print(
                    f"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}"
                )

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
                    maybe_save_checkpoint(checkpoint_writer, model,
                                          FLAGS.save_checkpoint_path)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    stop_time = time()
                    run_time_s = int(stop_time - start_time)
                    print(
                        f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        f"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

    stop_time = time()
    run_time_s = int(stop_time - start_time)

    print(
        f"Finished training in {run_time_s}s. "
        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
    )

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    if 'test_step_time' in locals():
        avg_test_throughput = FLAGS.test_batch_size / test_step_time
        results['average_test_throughput'] = avg_test_throughput

    dllogger.log(data=results, step=tuple())
def train(model, loss_fn, optimizer, data_loader_train, data_loader_test,
          scaled_lr):
    """Train and evaluate the model

    Args:
        model (dlrm):
        loss_fn (torch.nn.Module): Loss function
        optimizer (torch.nn.optim):
        data_loader_train (torch.utils.data.DataLoader):
        data_loader_test (torch.utils.data.DataLoader):
    """
    model.train()
    base_device = FLAGS.base_device
    print_freq = FLAGS.print_freq
    steps_per_epoch = len(data_loader_train)

    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=print_freq, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time',
        utils.SmoothedValue(window_size=print_freq, fmt='{avg:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()
    for epoch in range(FLAGS.epochs):

        batch_iter = iter(data_loader_train)
        for step in range(len(data_loader_train)):
            timer.click()

            global_step = steps_per_epoch * epoch + step

            numerical_features, categorical_features, click = next(batch_iter)

            categorical_features = categorical_features.to(base_device).to(
                torch.long)
            numerical_features = numerical_features.to(base_device)
            click = click.to(base_device).to(torch.float32)

            utils.lr_step(optimizer,
                          num_warmup_iter=FLAGS.warmup_steps,
                          current_step=global_step + 1,
                          base_lr=scaled_lr,
                          warmup_factor=FLAGS.warmup_factor,
                          decay_steps=FLAGS.decay_steps,
                          decay_start_step=FLAGS.decay_start_step)

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    F"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            output = model(numerical_features,
                           categorical_features).squeeze().float()

            loss = loss_fn(output, click.squeeze())

            optimizer.zero_grad()
            if FLAGS.fp16:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            loss_value = loss.item()

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if global_step < FLAGS.benchmark_warmup_steps:
                metric_logger.update(loss=loss_value,
                                     lr=optimizer.param_groups[0]["lr"])
            else:
                unscale_factor = FLAGS.loss_scale if FLAGS.fp16 else 1
                metric_logger.update(loss=loss_value / unscale_factor,
                                     step_time=timer.measured,
                                     lr=optimizer.param_groups[0]["lr"] *
                                     unscale_factor)

            if step % print_freq == 0 and step > 0:
                if global_step < FLAGS.benchmark_warmup_steps:
                    print(
                        F'Warming up, step [{global_step}/{FLAGS.benchmark_warmup_steps}]'
                    )
                    continue

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

            if (
                    global_step + 1
            ) % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                loss, auc, test_step_time = evaluate(model, loss_fn,
                                                     data_loader_test)
                print(
                    F"Epoch {epoch} step {step}. Test loss {loss:.5f}, auc {auc:.6f}"
                )

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)
                    maybe_save_checkpoint(model, FLAGS.save_checkpoint_path)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    stop_time = time()
                    run_time_s = int(stop_time - start_time)
                    print(
                        F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    if 'test_step_time' in locals():
        avg_test_throughput = FLAGS.test_batch_size / test_step_time
        results['average_test_throughput'] = avg_test_throughput

    dllogger.log(data=results, step=tuple())