def test_grads_hist_handler_wrong_setup():

    with pytest.raises(TypeError, match="Argument model should be of type torch.nn.Module"):
        GradsHistHandler(None)

    model = MagicMock(spec=torch.nn.Module)
    wrapper = GradsHistHandler(model)
    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(RuntimeError, match="Handler 'GradsHistHandler' works only with TensorboardLogger"):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
def add_tensorboard(engine_train, optimizer, model, log_dir):
    """Creates an ignite logger object and adds training elements such as weight and gradient histograms

    Args:
        engine_train (:obj:`ignite.engine`): the train engine to attach to the logger
        optimizer (:obj:`torch.optim`): the model's optimizer
        model (:obj:`torch.nn.Module`): the model being trained
        log_dir (string): path to where tensorboard data should be saved
    """
    # Create a logger
    tb_logger = TensorboardLogger(log_dir=log_dir)

    # Attach the logger to the trainer to log training loss at each iteration
    tb_logger.attach(engine_train,
                     log_handler=OutputHandler(
                         tag="training",
                         output_transform=lambda loss: {"loss": loss}),
                     event_name=Events.ITERATION_COMPLETED)

    # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
    tb_logger.attach(engine_train,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.EPOCH_COMPLETED)

    # Attach the logger to the trainer to log model's weights as a histogram after each epoch
    tb_logger.attach(engine_train,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)

    # Attach the logger to the trainer to log model's gradients as a histogram after each epoch
    tb_logger.attach(engine_train,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)

    tb_logger.close()
def test_grads_hist_frozen_layers(dummy_model_factory):
    model = dummy_model_factory(with_grads=True, with_frozen_layer=True)

    wrapper = GradsHistHandler(model)
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.writer.add_histogram.call_count == 2
    mock_logger.writer.add_histogram.assert_has_calls(
        [
            call(tag="grads/fc2/weight", values=ANY, global_step=5),
            call(tag="grads/fc2/bias", values=ANY, global_step=5),
        ],
        any_order=True,
    )

    with pytest.raises(AssertionError):
        mock_logger.writer.add_histogram.assert_has_calls(
            [
                call(tag="grads/fc1/weight", values=ANY, global_step=5),
                call(tag="grads/fc1/bias", values=ANY, global_step=5),
            ],
            any_order=True,
        )
    def _test(tag=None):
        wrapper = GradsHistHandler(model, tag=tag)
        mock_logger = MagicMock(spec=TensorboardLogger)
        mock_logger.writer = MagicMock()

        mock_engine = MagicMock()
        mock_engine.state = State()
        mock_engine.state.epoch = 5

        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

        tag_prefix = f"{tag}/" if tag else ""

        assert mock_logger.writer.add_histogram.call_count == 4
        mock_logger.writer.add_histogram.assert_has_calls(
            [
                call(tag=tag_prefix + "grads/fc1/weight",
                     values=ANY,
                     global_step=5),
                call(tag=tag_prefix + "grads/fc1/bias",
                     values=ANY,
                     global_step=5),
                call(tag=tag_prefix + "grads/fc2/weight",
                     values=ANY,
                     global_step=5),
                call(tag=tag_prefix + "grads/fc2/bias",
                     values=ANY,
                     global_step=5),
            ],
            any_order=True,
        )
def test_grads_hist_handler_whitelist(dummy_model_factory):
    model = dummy_model_factory()

    wrapper = GradsHistHandler(model, whitelist=["fc2.weight"])
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    mock_logger.writer.add_histogram.assert_called_once_with(tag="grads/fc2/weight", values=ANY, global_step=5)
    mock_logger.writer.reset_mock()

    wrapper = GradsHistHandler(model, tag="model", whitelist=["fc1"])
    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    mock_logger.writer.add_histogram.assert_has_calls(
        [
            call(tag="model/grads/fc1/weight", values=ANY, global_step=5),
            call(tag="model/grads/fc1/bias", values=ANY, global_step=5),
        ],
        any_order=True,
    )
    assert mock_logger.writer.add_histogram.call_count == 2
    mock_logger.writer.reset_mock()

    def weight_selector(n, _):
        return "bias" in n

    wrapper = GradsHistHandler(model, tag="model", whitelist=weight_selector)
    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    mock_logger.writer.add_histogram.assert_has_calls(
        [
            call(tag="model/grads/fc1/bias", values=ANY, global_step=5),
            call(tag="model/grads/fc2/bias", values=ANY, global_step=5),
        ],
        any_order=True,
    )
    assert mock_logger.writer.add_histogram.call_count == 2
Beispiel #6
0
    def add_tensorboard_logging(self, logging_dir=None):

        # Add TensorBoard logging
        if logging_dir is None:
            os.path.join(self.config.DIRS.WORKING_DIR, 'tb_logs')
        else:
            os.path.join(logging_dir, 'tb_logs')
        print('Tensorboard logging saving to:: {} ...'.format(logging_dir),
              end='')

        self.tb_logger = TensorboardLogger(log_dir=logging_dir)
        # Logging iteration loss
        self.tb_logger.attach_output_handler(
            engine=self.train_engine,
            event_name=Events.ITERATION_COMPLETED,
            tag='training',
            output_transform=lambda loss: {"batch loss": loss})
        # Logging epoch training metrics
        self.tb_logger.attach_output_handler(
            engine=self.train_evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag="training",
            metric_names=[
                "loss", "accuracy", "precision", "recall", "f1", "topKCatAcc"
            ],
            global_step_transform=global_step_from_engine(self.train_engine),
        )
        # Logging epoch validation metrics
        self.tb_logger.attach_output_handler(
            engine=self.evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag="validation",
            metric_names=[
                "loss", "accuracy", "precision", "recall", "f1", "topKCatAcc"
            ],
            global_step_transform=global_step_from_engine(self.train_engine),
        )
        # Attach the logger to the trainer to log model's weights as a histogram after each epoch
        self.tb_logger.attach(self.train_engine,
                              event_name=Events.EPOCH_COMPLETED,
                              log_handler=WeightsHistHandler(self.model))
        # Attach the logger to the trainer to log model's gradients as a histogram after each epoch
        self.tb_logger.attach(self.train_engine,
                              event_name=Events.EPOCH_COMPLETED,
                              log_handler=GradsHistHandler(self.model))
        print('Tensorboard Logging...', end='')
        print('done')
    def logging_board(model_name="densenet121"):
        from ignite.contrib.handlers.tensorboard_logger import (
            TensorboardLogger,
            OutputHandler,
            OptimizerParamsHandler,
            GradsHistHandler,
        )

        tb_logger = TensorboardLogger("board/" + model_name)
        tb_logger.attach(
            trainer,
            log_handler=OutputHandler(
                tag="training", output_transform=lambda loss: {"loss": loss}),
            event_name=Events.ITERATION_COMPLETED,
        )

        tb_logger.attach(
            val_evaluator,
            log_handler=OutputHandler(
                tag="validation",
                metric_names=["accuracy", "loss"],
                another_engine=trainer,
            ),
            event_name=Events.EPOCH_COMPLETED,
        )

        tb_logger.attach(
            trainer,
            log_handler=OptimizerParamsHandler(IGTrainer.optimizer),
            event_name=Events.ITERATION_STARTED,
        )

        tb_logger.attach(
            trainer,
            log_handler=GradsHistHandler(IGTrainer.model),
            event_name=Events.EPOCH_COMPLETED,
        )
        tb_logger.close()
Beispiel #8
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    if sys.version_info > (3, ):
        from ignite.contrib.metrics.gpu_info import GpuInfo

        try:
            GpuInfo().attach(trainer)
        except RuntimeError:
            print(
                "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
                "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
                "install it : `pip install pynvml`")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    tb_logger = TensorboardLogger(log_dir=log_dir)

    tb_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
        metric_names="all",
    )

    for tag, evaluator in [("training", train_evaluator),
                           ("validation", validation_evaluator)]:
        tb_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    tb_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    tb_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    def score_function(engine):
        return engine.state.metrics["accuracy"]

    model_checkpoint = ModelCheckpoint(
        log_dir,
        n_saved=2,
        filename_prefix="best",
        score_function=score_function,
        score_name="validation_accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    tb_logger.close()
    if backend_conf.rank == 0:
        event = Events.ITERATION_COMPLETED(every=hp['log_progress_every_iters'] if hp['log_progress_every_iters'] else None)
        ProgressBar(persist=False, desc='Train evaluation').attach(train_evaluator, event_name=event)
        ProgressBar(persist=False, desc='Test evaluation').attach(valid_evaluator)

        log_handler = OutputHandler(tag='train', metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer))
        tb_logger.attach(train_evaluator, log_handler=log_handler, event_name=Events.COMPLETED)

        log_handler = OutputHandler(tag='test', metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer))
        tb_logger.attach(valid_evaluator, log_handler=log_handler, event_name=Events.COMPLETED)

        # Store the best model by validation accuracy:
        common.save_best_model_by_val_score(str(output_path), valid_evaluator, model=model, metric_name='accuracy', n_saved=3, trainer=trainer, tag='val')

        if hp['log_grads_every_iters'] is not None and hp['log_grads_every_iters'] > 0:
            tb_logger.attach(trainer, log_handler=GradsHistHandler(model, tag=model.__class__.__name__), event_name=Events.ITERATION_COMPLETED(every=hp['log_grads_every_iters']))

    if hp['crash_iteration'] is not None and hp['crash_iteration'] >= 0:
        @trainer.on(Events.ITERATION_STARTED(once=hp['crash_iteration']))
        def _(engine):
            raise Exception('STOP at iteration: {}'.format(engine.state.iteration))

    if nni_compression_pruner is not None:
        # Notify NNI compressor (pruning or quantization) for each epoch and eventually each steps/batch-iteration if need by provided Pruner/Quantizer (see NNI Compression Documentation for more details: https://nni.readthedocs.io/en/latest/Compressor/QuickStart.html#apis-for-updating-fine-tuning-status)
        @trainer.on(Events.EPOCH_STARTED)
        def _nni_compression_update_epoch(engine):
            nni_compression_pruner.update_epoch(engine.state.epoch)

        if getattr(nni_compression_pruner, 'step', None) is Callable:
            @trainer.on(Events.ITERATION_COMPLETED)
            def _nni_compression_batch_step(engine):
def train(epochs=500,
          batch_size=32,
          bptt_len=70,
          lr=0.00025,
          log_steps=200,
          clip_grad=0.25,
          log_dir="experiments"):
    ###################################################################
    # Dataset
    ###################################################################
    wt = wikitext103(batch_size=batch_size, bptt_len=bptt_len)
    # wt = wikitext2(batch_size=batch_size, bptt_len=bptt_len)

    ###################################################################
    # Configs
    ###################################################################
    embedding_config = DropEmbedding.Hyperparams(len(wt.text_field.vocab) + 3,
                                                 ninp=512)
    encoder_config = TransformerEncoder.Hyperparams(
        att_num_units=[512, 512, 512, 512, 512, 512], max_ext=384)

    ###################################################################
    # Models
    ###################################################################
    base_embedding = DropEmbedding(embedding_config)
    embedding = TransformerEmbedding(embedding=base_embedding,
                                     max_length=bptt_len,
                                     embedding_size=embedding_config.ninp,
                                     use_positional_embedding=False)
    encoder = TransformerEncoder(encoder_config)
    model = TransformerLanguageModel(embedding, encoder)
    model.init_weight()

    ###################################################################
    # Loss
    ###################################################################
    criterion = lm_criterion(in_features=encoder_config.att_num_units[-1],
                             vocab_size=len(wt.text_field.vocab))

    ###################################################################
    # Parameters + Train ops
    ###################################################################
    parameters = (list(model.parameters()) + list(criterion.parameters()))
    tot_params = 0
    for p in parameters:
        tot_params += reduce(lambda x, y: x * y, p.size())
    print("Total Parameters: ", tot_params)
    opt = optim.Adam(parameters, lr=lr)
    model.to(DEVICE)
    criterion.to(DEVICE)

    ###################################################################
    # Train + Evaluation
    ###################################################################
    def train_step(engine, batch):
        model.train()
        opt.zero_grad()

        text = batch.text.to(DEVICE).t().contiguous()
        target = batch.target.to(DEVICE).t().contiguous()

        out, out_past = model(text, engine.state.train_past)
        engine.state.train_past = out_past
        raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1))
        loss = raw_loss[1]

        loss.backward()
        nn.utils.clip_grad_norm_(parameters, clip_grad)
        opt.step()

        return {"train_loss": loss.item(), "train_ppl": loss.exp().item()}

    def eval_step(engine, batch):
        model.eval()

        if not hasattr(engine.state, "eval_past"):
            engine.state.eval_past = None

        with torch.no_grad():
            text = batch.text.to(DEVICE).t().contiguous()
            target = batch.target.to(DEVICE).t().contiguous()

            out, out_past = model(text, engine.state.eval_past)
            engine.state.eval_past = out_past
            raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1))
            loss = raw_loss[1]

            return {"val_loss": loss.item()}

    train_engine = Engine(train_step)
    eval_engine = Engine(eval_step)

    def reset_state(engine):
        engine.state.train_past = None

    def run_eval(_):
        print("start running eval")
        eval_engine.run(wt.valid_iter)
        metrics = eval_engine.state.metrics
        print("Validation loss: ", metrics["val_loss"], ", ppl: ",
              np.exp(metrics["val_loss"]))

    train_engine.add_event_handler(Events.EPOCH_STARTED, reset_state)
    train_engine.add_event_handler(Events.EPOCH_COMPLETED, run_eval)

    ###################################################################
    # LR Scheduler
    ###################################################################
    cosine_scheduler = CosineAnnealingScheduler(opt.param_groups[0],
                                                "lr",
                                                0.0,
                                                2.5e-4,
                                                cycle_size=len(wt.train_iter))
    warmup_scheduler = create_lr_scheduler_with_warmup(cosine_scheduler, 0.0,
                                                       2.5e-4, 200)
    train_engine.add_event_handler(Events.ITERATION_STARTED, warmup_scheduler)

    ###################################################################
    # Metrics
    ###################################################################
    RunningAverage(output_transform=lambda x: x["train_ppl"]).attach(
        train_engine, "train_ppl")
    RunningAverage(output_transform=lambda x: x["train_loss"]).attach(
        train_engine, "train_loss")
    RunningAverage(output_transform=lambda x: x["val_loss"]).attach(
        eval_engine, "val_loss")
    progress_bar = ProgressBar(persist=True)
    progress_bar.attach(train_engine, ["train_ppl", "train_loss"])
    progress_bar_val = ProgressBar(persist=True)
    progress_bar_val.attach(eval_engine, ["val_loss"])

    ###################################################################
    # Tensorboard
    ###################################################################
    tb_logger = TensorboardLogger(log_dir=log_dir)

    def stepn_logger(num_steps, handler):
        def logger_runner(engine, log_handler, event_name):
            if engine.state.iteration % num_steps == 0:
                handler(engine, log_handler, event_name)

        return logger_runner

    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(
                         log_steps,
                         OutputHandler(tag="training",
                                       output_transform=lambda loss: loss)),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(eval_engine,
                     log_handler=OutputHandler(
                         tag="validation",
                         output_transform=lambda loss: loss,
                         another_engine=train_engine),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(log_steps,
                                              OptimizerParamsHandler(opt)),
                     event_name=Events.ITERATION_STARTED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(log_steps,
                                              WeightsScalarHandler(model)),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(log_steps,
                                              GradsScalarHandler(model)),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(500, WeightsHistHandler(model)),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(train_engine,
                     log_handler=stepn_logger(500, GradsHistHandler(model)),
                     event_name=Events.ITERATION_COMPLETED)

    try:
        train_engine.run(wt.train_iter, max_epochs=epochs)
    except Exception:
        pass
    finally:
        tb_logger.close()
Beispiel #11
0
    event_name=Events.EPOCH_COMPLETED,
    tag="validation",
    metric_names=["loss", "accuracy", "precision", "recall", "f1", "topKCatAcc"],
    global_step_transform=global_step_from_engine(trainer),
)
# Attach the logger to the trainer to log model's weights as a histogram after each epoch
tb_logger.attach(
    trainer,
    event_name=Events.EPOCH_COMPLETED,
    log_handler=WeightsHistHandler(model)
)
# Attach the logger to the trainer to log model's gradients as a histogram after each epoch
tb_logger.attach(
    trainer,
    event_name=Events.EPOCH_COMPLETED,
    log_handler=GradsHistHandler(model)
)
print('Tensorboard Logging...', end='')
print('done')

## SETUP CALLBACKS
print('[INFO] Creating callback functions for training loop...', end='')
# Early Stopping - stops training if the validation loss does not decrease after 5 epochs
handler = EarlyStopping(patience=early_stopping_patience, score_function=score_function_loss, trainer=trainer)
evaluator.add_event_handler(Events.COMPLETED, handler)
print('Early Stopping ({} epochs)...'.format(early_stopping_patience), end='')

# Checkpoint the model
# iteration checkpointer
checkpointer = ModelCheckpoint(
    dirname=working_dir, 
Beispiel #12
0
def run(output_path, config):

    distributed = dist.is_available() and dist.is_initialized()
    rank = dist.get_rank() if distributed else 0

    manual_seed(config["seed"] + rank)

    # Setup dataflow, model, optimizer, criterion
    train_loader, test_loader = utils.get_dataflow(config, distributed)
    model, optimizer = utils.get_model_optimizer(config, distributed)
    criterion = nn.CrossEntropyLoss().to(utils.device)

    le = len(train_loader)
    milestones_values = [
        (0, 0.0),
        (le * config["num_warmup_epochs"], config["learning_rate"]),
        (le * config["num_epochs"], 0.0),
    ]
    lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    def train_step(engine, batch):

        x = convert_tensor(batch[0], device=utils.device, non_blocking=True)
        y = convert_tensor(batch[1], device=utils.device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            "batch loss": loss.item(),
        }

    if config["deterministic"] and rank == 0:
        print("Setup deterministic trainer")
    trainer = Engine(train_step) if not config["deterministic"] else DeterministicEngine(train_step)
    train_sampler = train_loader.sampler if distributed else None
    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
    metric_names = [
        "batch loss",
    ]
    common.setup_common_training_handlers(
        trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        output_path=output_path,
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbar_on_iters=config["display_iters"],
        log_every_iters=10,
    )

    if rank == 0:
        # Setup Tensorboard logger - wrapper on SummaryWriter
        tb_logger = TensorboardLogger(log_dir=output_path)
        # Attach logger to the trainer and log trainer's metrics (stored in trainer.state.metrics) every iteration
        tb_logger.attach(
            trainer,
            log_handler=OutputHandler(tag="train", metric_names=metric_names),
            event_name=Events.ITERATION_COMPLETED,
        )
        # log optimizer's parameters: "lr" every iteration
        tb_logger.attach(
            trainer, log_handler=OptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED
        )

    # Let's now setup evaluator engine to perform model's validation and compute metrics
    metrics = {
        "accuracy": Accuracy(device=utils.device if distributed else None),
        "loss": Loss(criterion, device=utils.device if distributed else None),
    }

    # We define two evaluators as they wont have exactly similar roles:
    # - `evaluator` will save the best model based on validation score
    evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True)
    train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True)

    def run_validation(engine):
        train_evaluator.run(train_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(Events.EPOCH_STARTED(every=config["validate_every"]), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        # Setup progress bar on evaluation engines
        if config["display_iters"]:
            ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False, desc="Test evaluation").attach(evaluator)

        # Let's log metrics of `train_evaluator` stored in `train_evaluator.state.metrics` when validation run is done
        tb_logger.attach(
            train_evaluator,
            log_handler=OutputHandler(
                tag="train", metric_names="all", global_step_transform=global_step_from_engine(trainer)
            ),
            event_name=Events.COMPLETED,
        )

        # Let's log metrics of `evaluator` stored in `evaluator.state.metrics` when validation run is done
        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="test", metric_names="all", global_step_transform=global_step_from_engine(trainer)
            ),
            event_name=Events.COMPLETED,
        )

        # Store 3 best models by validation accuracy:
        common.save_best_model_by_val_score(
            output_path, evaluator, model=model, metric_name="accuracy", n_saved=3, trainer=trainer, tag="test"
        )

        # Optionally log model gradients
        if config["log_model_grads_every"] is not None:
            tb_logger.attach(
                trainer,
                log_handler=GradsHistHandler(model, tag=model.__class__.__name__),
                event_name=Events.ITERATION_COMPLETED(every=config["log_model_grads_every"]),
            )

    # In order to check training resuming we can emulate a crash
    if config["crash_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"]))
        def _(engine):
            raise Exception("STOP at iteration: {}".format(engine.state.iteration))

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix())
        print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    if rank == 0:
        tb_logger.close()
Beispiel #13
0
def run(output_path, config):
    device = "cuda"

    local_rank = config["local_rank"]
    distributed = backend is not None
    if distributed:
        torch.cuda.set_device(local_rank)
        device = "cuda"
    rank = dist.get_rank() if distributed else 0

    torch.manual_seed(config["seed"] + rank)

    # Rescale batch_size and num_workers
    ngpus_per_node = torch.cuda.device_count()
    ngpus = dist.get_world_size() if distributed else 1
    batch_size = config["batch_size"] // ngpus
    num_workers = int(
        (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node)

    train_loader, test_loader = get_train_test_loaders(
        path=config["data_path"],
        batch_size=batch_size,
        distributed=distributed,
        num_workers=num_workers,
    )

    model = get_model(config["model"])
    model = model.to(device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[
                local_rank,
            ],
            output_device=local_rank,
        )

    optimizer = optim.SGD(
        model.parameters(),
        lr=config["learning_rate"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
        nesterov=True,
    )

    criterion = nn.CrossEntropyLoss().to(device)

    le = len(train_loader)
    milestones_values = [
        (0, 0.0),
        (le * config["num_warmup_epochs"], config["learning_rate"]),
        (le * config["num_epochs"], 0.0),
    ]
    lr_scheduler = PiecewiseLinear(optimizer,
                                   param_name="lr",
                                   milestones_values=milestones_values)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
        )

    def process_function(engine, batch):

        x, y = _prepare_batch(batch, device=device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(process_function)
    train_sampler = train_loader.sampler if distributed else None
    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = [
        "batch loss",
    ]
    common.setup_common_training_handlers(
        trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        output_path=output_path,
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbar_on_iters=config["display_iters"],
        log_every_iters=10,
    )

    if rank == 0:
        tb_logger = TensorboardLogger(log_dir=output_path)
        tb_logger.attach(
            trainer,
            log_handler=OutputHandler(tag="train", metric_names=metric_names),
            event_name=Events.ITERATION_COMPLETED,
        )
        tb_logger.attach(
            trainer,
            log_handler=OptimizerParamsHandler(optimizer, param_name="lr"),
            event_name=Events.ITERATION_STARTED,
        )

    metrics = {
        "accuracy": Accuracy(device=device if distributed else None),
        "loss": Loss(criterion, device=device if distributed else None),
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        torch.cuda.synchronize()
        train_evaluator.run(train_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(
        Events.EPOCH_STARTED(every=config["validate_every"]), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        if config["display_iters"]:
            ProgressBar(persist=False,
                        desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False,
                        desc="Test evaluation").attach(evaluator)

        tb_logger.attach(
            train_evaluator,
            log_handler=OutputHandler(
                tag="train",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer),
            ),
            event_name=Events.COMPLETED,
        )

        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="test",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer),
            ),
            event_name=Events.COMPLETED,
        )

        # Store the best model by validation accuracy:
        common.save_best_model_by_val_score(
            output_path,
            evaluator,
            model=model,
            metric_name="accuracy",
            n_saved=3,
            trainer=trainer,
            tag="test",
        )

        if config["log_model_grads_every"] is not None:
            tb_logger.attach(
                trainer,
                log_handler=GradsHistHandler(model,
                                             tag=model.__class__.__name__),
                event_name=Events.ITERATION_COMPLETED(
                    every=config["log_model_grads_every"]),
            )

    if config["crash_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"]))
        def _(engine):
            raise Exception("STOP at iteration: {}".format(
                engine.state.iteration))

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    if rank == 0:
        tb_logger.close()
def train(experiment_id, ds_train, ds_val, model, optimizer, hyperparams,
          num_workers, device, debug):

    train_loader = torch.utils.data.DataLoader(ds_train,
                                               batch_size=hyperparams['bs'],
                                               shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(ds_val,
                                             batch_size=hyperparams['bs'],
                                             shuffle=True,
                                             num_workers=num_workers)

    criterion = nn.CrossEntropyLoss().to(device)

    metrics = {
        'loss': Loss(criterion),
        'accuracy': Accuracy(),
    }

    trainer = create_supervised_trainer(model, optimizer, criterion, device)

    if hyperparams['pretrained']:

        @trainer.on(Events.EPOCH_STARTED)
        def turn_on_layers(engine):
            epoch = engine.state.epoch
            if epoch == 1:
                print()
                temp = next(model.named_children())[1]
                for name, child in temp.named_children():
                    if (name == 'mlp') or (name == 'classifier'):
                        print(name + ' is unfrozen')
                        for param in child.parameters():
                            param.requires_grad = True
                    else:
                        for param in child.parameters():
                            param.requires_grad = False

            if epoch == 3:
                print()
                print('Turn on all the layers')
                for name, child in model.named_children():
                    for param in child.parameters():
                        param.requires_grad = True

    pbar = ProgressBar(bar_format='')
    pbar.attach(trainer, output_transform=lambda x: {'loss': x})

    val_evaluator = create_supervised_evaluator(model, metrics, device)

    if hyperparams['early_stopping']:

        def score_function(engine):
            return engine.state.metrics['accuracy']

        handler = EarlyStopping(patience=hyperparams['patience'],
                                score_function=score_function,
                                trainer=trainer)
        val_evaluator.add_event_handler(Events.COMPLETED, handler)

    @trainer.on(Events.STARTED)
    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_and_display_val_metrics(engine):
        epoch = engine.state.epoch
        metrics = val_evaluator.run(val_loader).metrics

        if (epoch == 0) or (metrics['accuracy'] > engine.state.best_acc):
            engine.state.best_acc = metrics['accuracy']
            print('New best accuracy! Accuracy: ' +
                  str(engine.state.best_acc) + '\nModel saved!')
            if not os.path.exists('models/'):
                os.makedirs('models/')
            path = 'models/best_model_' + experiment_id + '.pth'
            torch.save(model.state_dict(), path)

        print('Validation Results - Epoch: {} \
              Average Loss: {:.4f} | Accuracy: {:.4f}'.format(
            engine.state.epoch, metrics['loss'], metrics['accuracy']))

    if hyperparams['scheduler']:
        lr_scheduler = CosineAnnealingLR(optimizer,
                                         hyperparams['nb_epochs'],
                                         eta_min=hyperparams['lr'] / 100,
                                         last_epoch=-1)

        @trainer.on(Events.EPOCH_COMPLETED)
        def update_lr_scheduler(engine):
            lr_scheduler.step()

    tb_logger = TensorboardLogger('board/' + experiment_id)

    def output_transform(loss):
        return {'loss': loss}

    log_handler = OutputHandler(tag='training',
                                output_transform=output_transform)
    tb_logger.attach(trainer,
                     log_handler,
                     event_name=Events.ITERATION_COMPLETED)
    log_handler = OutputHandler(tag='validation',
                                metric_names=['accuracy', 'loss'],
                                another_engine=trainer)
    tb_logger.attach(val_evaluator, log_handler, event_name=Events.STARTED)
    tb_logger.attach(val_evaluator,
                     log_handler,
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.ITERATION_STARTED)
    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.close()

    trainer.run(train_loader, max_epochs=hyperparams['nb_epochs'])