Exemple #1
0
def run(output_path, config):
    device = "cuda"

    local_rank = config['local_rank']
    distributed = backend is not None
    if distributed:
        torch.cuda.set_device(local_rank)
        device = "cuda"
    rank = dist.get_rank() if distributed else 0

    # Rescale batch_size and num_workers
    ngpus_per_node = torch.cuda.device_count()
    ngpus = dist.get_world_size() if distributed else 1
    batch_size = config['batch_size'] // ngpus
    num_workers = int(
        (config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)

    train_labelled_loader, test_loader = \
        get_train_test_loaders(path=config['data_path'],
                               batch_size=batch_size,
                               distributed=distributed,
                               num_workers=num_workers)

    model = get_model(config['model'])
    model = model.to(device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[
                local_rank,
            ], output_device=local_rank)

    optimizer = optim.SGD(model.parameters(),
                          lr=config['learning_rate'],
                          momentum=config['momentum'],
                          weight_decay=config['weight_decay'],
                          nesterov=True)

    criterion = nn.CrossEntropyLoss().to(device)

    le = len(train_labelled_loader)
    milestones_values = [(0, 0.0),
                         (le * config['num_warmup_epochs'],
                          config['learning_rate']),
                         (le * config['num_epochs'], 0.0)]
    lr_scheduler = PiecewiseLinear(optimizer,
                                   param_name="lr",
                                   milestones_values=milestones_values)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    def process_function(engine, labelled_batch):

        x, y = _prepare_batch(labelled_batch, device=device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            'batch loss': loss.item(),
        }

    trainer = Engine(process_function)

    if not hasattr(lr_scheduler, "step"):
        trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
    else:
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  lambda engine: lr_scheduler.step())

    metric_names = [
        'batch loss',
    ]

    def output_transform(x, name):
        return x[name]

    for n in metric_names:
        # We compute running average values on the output (batch loss) across all devices
        RunningAverage(output_transform=partial(output_transform, name=n),
                       epoch_bound=False,
                       device=device).attach(trainer, n)

    if rank == 0:
        checkpoint_handler = ModelCheckpoint(dirname=output_path,
                                             filename_prefix="checkpoint")
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000),
                                  checkpoint_handler, {
                                      'model': model,
                                      'optimizer': optimizer
                                  })

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
        if config['display_iters']:
            ProgressBar(persist=False,
                        bar_format="").attach(trainer,
                                              metric_names=metric_names)

        tb_logger = TensorboardLogger(log_dir=output_path)
        tb_logger.attach(trainer,
                         log_handler=tbOutputHandler(
                             tag="train", metric_names=metric_names),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=tbOptimizerParamsHandler(optimizer,
                                                              param_name="lr"),
                         event_name=Events.ITERATION_STARTED)

    metrics = {
        "accuracy": Accuracy(device=device if distributed else None),
        "loss": Loss(criterion, device=device if distributed else None)
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        torch.cuda.synchronize()
        train_evaluator.run(train_labelled_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(Events.EPOCH_STARTED(every=3), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        if config['display_iters']:
            ProgressBar(persist=False,
                        desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False,
                        desc="Test evaluation").attach(evaluator)

        tb_logger.attach(train_evaluator,
                         log_handler=tbOutputHandler(tag="train",
                                                     metric_names=list(
                                                         metrics.keys()),
                                                     another_engine=trainer),
                         event_name=Events.COMPLETED)

        tb_logger.attach(evaluator,
                         log_handler=tbOutputHandler(tag="test",
                                                     metric_names=list(
                                                         metrics.keys()),
                                                     another_engine=trainer),
                         event_name=Events.COMPLETED)

        # Store the best model
        def default_score_fn(engine):
            score = engine.state.metrics['accuracy']
            return score

        score_function = default_score_fn if not hasattr(
            config, "score_function") else config.score_function

        best_model_handler = ModelCheckpoint(
            dirname=output_path,
            filename_prefix="best",
            n_saved=3,
            global_step_transform=global_step_from_engine(trainer),
            score_name="val_accuracy",
            score_function=score_function)
        evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {
            'model': model,
        })

    trainer.run(train_labelled_loader, max_epochs=config['num_epochs'])

    if rank == 0:
        tb_logger.close()
Exemple #2
0
def run(output_path, config):

    device = "cuda"
    batch_size = config['batch_size']

    train_labelled_loader, train_unlabelled_loader, test_loader = \
        get_train_test_loaders(dataset_name=config['dataset'],
                               num_labelled_samples=config['num_labelled_samples'],
                               path=config['data_path'],
                               batch_size=batch_size,
                               unlabelled_batch_size=config.get('unlabelled_batch_size', None),
                               num_workers=config['num_workers'])

    model = get_model(config['model'])
    model = model.to(device)

    optimizer = optim.SGD(model.parameters(),
                          lr=config['learning_rate'],
                          momentum=config['momentum'],
                          weight_decay=config['weight_decay'],
                          nesterov=True)

    with_SWA = config['with_SWA']
    if with_SWA:
        optimizer = torchcontrib.optim.SWA(optimizer)

    criterion = nn.CrossEntropyLoss().to(device)
    if config['consistency_criterion'] == "MSE":
        consistency_criterion = nn.MSELoss()
    elif config['consistency_criterion'] == "KL":
        consistency_criterion = nn.KLDivLoss(reduction='batchmean')
    else:
        raise RuntimeError("Unknown consistency criterion {}".format(
            config['consistency_criterion']))

    consistency_criterion = consistency_criterion.to(device)

    le = len(train_labelled_loader)
    num_train_steps = le * config['num_epochs']
    mlflow.log_param("num train steps", num_train_steps)

    lr = config['learning_rate']
    eta_min = lr * config['min_lr_ratio']
    num_warmup_steps = config['num_warmup_steps']

    lr_scheduler = CosineAnnealingLR(optimizer,
                                     eta_min=eta_min,
                                     T_max=num_train_steps - num_warmup_steps)

    if num_warmup_steps > 0:
        lr_scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=lr * (1.0 + 1.0 / num_warmup_steps),
            warmup_duration=num_warmup_steps)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    def cycle(iterable):
        while True:
            for i in iterable:
                yield i

    train_unlabelled_loader_iter = cycle(train_unlabelled_loader)

    lam = config['consistency_lambda']

    tsa = TrainingSignalAnnealing(num_steps=num_train_steps,
                                  min_threshold=config['TSA_proba_min'],
                                  max_threshold=config['TSA_proba_max'])

    with_tsa = config['with_TSA']
    with_UDA = not config['no_UDA']

    def uda_process_function(engine, labelled_batch):

        x, y = _prepare_batch(labelled_batch, device=device, non_blocking=True)

        if with_UDA:
            unsup_x, unsup_aug_x = next(train_unlabelled_loader_iter)
            unsup_x = convert_tensor(unsup_x, device=device, non_blocking=True)
            unsup_aug_x = convert_tensor(unsup_aug_x,
                                         device=device,
                                         non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        supervised_loss = loss
        step = engine.state.iteration - 1
        if with_tsa and with_UDA:
            new_y_pred, new_y = tsa(y_pred, y, step=step)
            new_loss = criterion(new_y_pred, new_y)

            engine.state.tsa_log = {
                "new_y_pred": new_y_pred,
                "loss": loss.item(),
                "tsa_loss": new_loss.item()
            }
            supervised_loss = new_loss

        # Unsupervised part
        if with_UDA:
            unsup_orig_y_pred = model(unsup_x).detach()
            unsup_orig_y_probas = torch.softmax(unsup_orig_y_pred, dim=-1)

            unsup_aug_y_pred = model(unsup_aug_x)
            unsup_aug_y_probas = torch.log_softmax(unsup_aug_y_pred, dim=-1)

            consistency_loss = consistency_criterion(unsup_aug_y_probas,
                                                     unsup_orig_y_probas)

        final_loss = supervised_loss

        if with_UDA:
            final_loss += lam * consistency_loss

        optimizer.zero_grad()
        final_loss.backward()
        optimizer.step()

        return {
            'supervised batch loss': supervised_loss.item(),
            'consistency batch loss':
            consistency_loss.item() if with_UDA else 0.0,
            'final batch loss': final_loss.item(),
        }

    trainer = Engine(uda_process_function)

    if with_UDA and with_tsa:

        @trainer.on(Events.ITERATION_COMPLETED)
        def log_tsa(engine):
            step = engine.state.iteration - 1
            if step % 50 == 0:
                mlflow.log_metric("TSA threshold",
                                  tsa.thresholds[step].item(),
                                  step=step)
                mlflow.log_metric("TSA selection",
                                  engine.state.tsa_log['new_y_pred'].shape[0],
                                  step=step)
                mlflow.log_metric("Original X Loss",
                                  engine.state.tsa_log['loss'],
                                  step=step)
                mlflow.log_metric("TSA X Loss",
                                  engine.state.tsa_log['tsa_loss'],
                                  step=step)

    if not hasattr(lr_scheduler, "step"):
        trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
    else:
        trainer.add_event_handler(Events.ITERATION_STARTED,
                                  lambda engine: lr_scheduler.step())

    @trainer.on(Events.ITERATION_STARTED)
    def log_learning_rate(engine):
        step = engine.state.iteration - 1
        if step % 50 == 0:
            lr = optimizer.param_groups[0]['lr']
            mlflow.log_metric("learning rate", lr, step=step)

    if with_SWA:

        @trainer.on(Events.COMPLETED)
        def swap_swa_sgd(engine):
            optimizer.swap_swa_sgd()
            optimizer.bn_update(train_labelled_loader, model)

        @trainer.on(Events.EPOCH_COMPLETED)
        def update_swa(engine):
            if engine.state.epoch - 1 > int(num_epochs * 0.75):
                optimizer.update_swa()

    metric_names = [
        'supervised batch loss', 'consistency batch loss', 'final batch loss'
    ]

    def output_transform(x, name):
        return x[name]

    for n in metric_names:
        RunningAverage(output_transform=partial(output_transform, name=n),
                       epoch_bound=False).attach(trainer, n)

    ProgressBar(persist=True,
                bar_format="").attach(trainer,
                                      event_name=Events.EPOCH_STARTED,
                                      closing_event_name=Events.COMPLETED)

    tb_logger = TensorboardLogger(log_dir=output_path)
    tb_logger.attach(trainer,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names=[
                                                     'final batch loss',
                                                     'consistency batch loss',
                                                     'supervised batch loss'
                                                 ]),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=tbOptimizerParamsHandler(optimizer,
                                                          param_name="lr"),
                     event_name=Events.ITERATION_STARTED)

    metrics = {
        "accuracy": Accuracy(),
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine, val_interval):
        if (engine.state.epoch - 1) % val_interval == 0:
            train_evaluator.run(train_labelled_loader)
            evaluator.run(test_loader)

    trainer.add_event_handler(Events.EPOCH_STARTED,
                              run_validation,
                              val_interval=2)
    trainer.add_event_handler(Events.COMPLETED, run_validation, val_interval=1)

    tb_logger.attach(train_evaluator,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names=list(
                                                     metrics.keys()),
                                                 another_engine=trainer),
                     event_name=Events.COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=tbOutputHandler(tag="test",
                                                 metric_names=list(
                                                     metrics.keys()),
                                                 another_engine=trainer),
                     event_name=Events.COMPLETED)

    def mlflow_batch_metrics_logging(engine, tag):
        step = trainer.state.iteration
        for name, value in engine.state.metrics.items():
            mlflow.log_metric("{} {}".format(tag, name), value, step=step)

    def mlflow_val_metrics_logging(engine, tag):
        step = trainer.state.epoch
        for name in metrics.keys():
            value = engine.state.metrics[name]
            mlflow.log_metric("{} {}".format(tag, name), value, step=step)

    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              mlflow_batch_metrics_logging, "train")
    train_evaluator.add_event_handler(Events.COMPLETED,
                                      mlflow_val_metrics_logging, "train")
    evaluator.add_event_handler(Events.COMPLETED, mlflow_val_metrics_logging,
                                "test")

    trainer.run(train_labelled_loader, max_epochs=config['num_epochs'])
Exemple #3
0
def run(output_path, config):

    device = "cuda"
    batch_size = config['batch_size']

    train_loader, test_loader = get_train_test_loaders(
        dataset_name=config['dataset'],
        path=config['data_path'],
        batch_size=batch_size,
        num_workers=config['num_workers'])

    model = get_model(config['model'])
    model = model.to(device)

    optim_fn = optim.SGD
    if config['with_layca']:
        optim_fn = LaycaSGD

    optimizer = optim_fn(model.parameters(),
                         lr=0.0,
                         momentum=config['momentum'],
                         weight_decay=config['weight_decay'],
                         nesterov=True)
    criterion = nn.CrossEntropyLoss()

    le = len(train_loader)
    milestones_values = [(le * m, v)
                         for m, v in config['lr_milestones_values']]
    scheduler = PiecewiseLinear(optimizer,
                                "lr",
                                milestones_values=milestones_values)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    def process_function(engine, batch):

        x, y = _prepare_batch(batch, device=device, non_blocking=True)

        model.train()
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

    trainer = Engine(process_function)

    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    RunningAverage(output_transform=lambda x: x,
                   epoch_bound=False).attach(trainer, 'batchloss')

    ProgressBar(persist=True,
                bar_format="").attach(trainer,
                                      event_name=Events.EPOCH_STARTED,
                                      closing_event_name=Events.COMPLETED)

    tb_logger = TensorboardLogger(log_dir=output_path)
    tb_logger.attach(trainer,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names='all'),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=tbOptimizerParamsHandler(optimizer,
                                                          param_name="lr"),
                     event_name=Events.ITERATION_STARTED)

    tb_logger.attach(trainer,
                     log_handler=LayerRotationStatsHandler(model),
                     event_name=Events.EPOCH_STARTED)

    metrics = {
        "accuracy": Accuracy(),
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine, val_interval):
        if (engine.state.epoch - 1) % val_interval == 0:
            train_evaluator.run(train_loader)
            evaluator.run(test_loader)

    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              run_validation,
                              val_interval=2)
    trainer.add_event_handler(Events.COMPLETED, run_validation, val_interval=1)

    tb_logger.attach(train_evaluator,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names='all',
                                                 another_engine=trainer),
                     event_name=Events.COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=tbOutputHandler(tag="test",
                                                 metric_names='all',
                                                 another_engine=trainer),
                     event_name=Events.COMPLETED)

    def mlflow_batch_metrics_logging(engine, tag):
        step = trainer.state.iteration
        for name, value in engine.state.metrics.items():
            mlflow.log_metric("{} {}".format(tag, name), value, step=step)

    def mlflow_val_metrics_logging(engine, tag):
        step = trainer.state.epoch
        for name in metrics.keys():
            value = engine.state.metrics[name]
            mlflow.log_metric("{} {}".format(tag, name), value, step=step)

    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              mlflow_batch_metrics_logging, "train")
    train_evaluator.add_event_handler(Events.COMPLETED,
                                      mlflow_val_metrics_logging, "train")
    evaluator.add_event_handler(Events.COMPLETED, mlflow_val_metrics_logging,
                                "test")

    trainer.run(train_loader, max_epochs=config['num_epochs'])
    tb_logger.close()
Exemple #4
0
def run(output_path, config):
    device = "cuda"

    local_rank = config["local_rank"]
    distributed = backend is not None
    if distributed:
        torch.cuda.set_device(local_rank)
        device = "cuda"
    rank = dist.get_rank() if distributed else 0

    torch.manual_seed(config["seed"] + rank)

    # Rescale batch_size and num_workers
    ngpus_per_node = torch.cuda.device_count()
    ngpus = dist.get_world_size() if distributed else 1
    batch_size = config["batch_size"] // ngpus
    num_workers = int(
        (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node)

    train_loader, test_loader = get_train_test_loaders(
        path=config["data_path"],
        batch_size=batch_size,
        distributed=distributed,
        num_workers=num_workers,
    )

    model = get_model(config["model"])
    model = model.to(device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[
                local_rank,
            ],
            output_device=local_rank,
        )

    optimizer = optim.SGD(
        model.parameters(),
        lr=config["learning_rate"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
        nesterov=True,
    )

    criterion = nn.CrossEntropyLoss().to(device)

    le = len(train_loader)
    milestones_values = [
        (0, 0.0),
        (le * config["num_warmup_epochs"], config["learning_rate"]),
        (le * config["num_epochs"], 0.0),
    ]
    lr_scheduler = PiecewiseLinear(optimizer,
                                   param_name="lr",
                                   milestones_values=milestones_values)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
        )

    def process_function(engine, batch):

        x, y = _prepare_batch(batch, device=device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(process_function)
    train_sampler = train_loader.sampler if distributed else None
    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = [
        "batch loss",
    ]
    common.setup_common_training_handlers(
        trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        output_path=output_path,
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbar_on_iters=config["display_iters"],
        log_every_iters=10,
    )

    if rank == 0:
        tb_logger = TensorboardLogger(log_dir=output_path)
        tb_logger.attach(
            trainer,
            log_handler=OutputHandler(tag="train", metric_names=metric_names),
            event_name=Events.ITERATION_COMPLETED,
        )
        tb_logger.attach(
            trainer,
            log_handler=OptimizerParamsHandler(optimizer, param_name="lr"),
            event_name=Events.ITERATION_STARTED,
        )

    metrics = {
        "accuracy": Accuracy(device=device if distributed else None),
        "loss": Loss(criterion, device=device if distributed else None),
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        torch.cuda.synchronize()
        train_evaluator.run(train_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(
        Events.EPOCH_STARTED(every=config["validate_every"]), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        if config["display_iters"]:
            ProgressBar(persist=False,
                        desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False,
                        desc="Test evaluation").attach(evaluator)

        tb_logger.attach(
            train_evaluator,
            log_handler=OutputHandler(
                tag="train",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer),
            ),
            event_name=Events.COMPLETED,
        )

        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="test",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer),
            ),
            event_name=Events.COMPLETED,
        )

        # Store the best model by validation accuracy:
        common.save_best_model_by_val_score(
            output_path,
            evaluator,
            model=model,
            metric_name="accuracy",
            n_saved=3,
            trainer=trainer,
            tag="test",
        )

        if config["log_model_grads_every"] is not None:
            tb_logger.attach(
                trainer,
                log_handler=GradsHistHandler(model,
                                             tag=model.__class__.__name__),
                event_name=Events.ITERATION_COMPLETED(
                    every=config["log_model_grads_every"]),
            )

    if config["crash_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"]))
        def _(engine):
            raise Exception("STOP at iteration: {}".format(
                engine.state.iteration))

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    if rank == 0:
        tb_logger.close()