def test_remove_event_handler_on_callable_events():

    engine = Engine(lambda e, b: 1)

    def foo(e):
        pass

    assert not engine.has_event_handler(foo)

    engine.add_event_handler(Events.EPOCH_STARTED, foo)
    assert engine.has_event_handler(foo)
    engine.remove_event_handler(foo, Events.EPOCH_STARTED)
    assert not engine.has_event_handler(foo)

    def bar(e):
        pass

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    engine.remove_event_handler(bar, Events.EPOCH_COMPLETED)
    assert not engine.has_event_handler(bar)

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
    assert not engine.has_event_handler(bar)
Beispiel #2
0
def setup_event_handler(trainer, evaluator, train_loader, test_loader):
    log_interval = 10

    writer = SummaryWriter(log_dir=log_dir)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_loss(trainer):
        print("Epoch[{}] Loss: {:.5f}".format(trainer.state.epoch,
                                              trainer.state.output))
        writer.add_scalar("training_iteration_loss", trainer.state.output,
                          trainer.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED(every=log_interval))
    def log_training_results(trainer):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        print("Training Results - Epoch: {}  Accuracy: {:.5f} Loss: {:.5f}".
              format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))
        writer.add_scalar("training_loss", metrics["nll"], trainer.state.epoch)
        writer.add_scalar("training_accuracy", metrics["accuracy"],
                          trainer.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED(every=log_interval))
    def log_testing_results(trainer):
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        print("Validation Results - Epoch: {}  Accuracy: {:.5f} Loss: {:.5f}".
              format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))
        writer.add_scalar("testing_loss", metrics["nll"], trainer.state.epoch)
        writer.add_scalar("testing_accuracy", metrics["accuracy"],
                          trainer.state.epoch)
Beispiel #3
0
def main(width, depth, max_epochs, state_dict_path, device, data_dir, num_workers):
    """
    This function constructs and trains a model from scratch, without any knowledge transfer method applied. 

    :param int depth: factor for controlling the depth of the model.
    :param int width: factor for controlling the width of the model.
    :param int max_epochs: maximum number of epochs for training the student model.
    :param string state_dict_path: path to save the trained model.
    :param int device: device to use for training the model.
    :param string data_dir: directory to save and load the dataset.
    :param int num_workers: number of workers to use for loading the dataset.
    """

    # Define the device for training the model.
    device = torch.device(device)

    # Get data loaders for the CIFAR-10 dataset.
    train_loader, validation_loader, test_loader = get_cifar10_loaders(
        data_dir, batch_size=BATCH_SIZE, num_workers=num_workers
    )

    # Construct the model to be trained.
    model = WideResidualNetwork(depth=depth, width=width)
    model = model.to(device)

    # Define optimizer and learning rate scheduler.
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=LEARNING_RATE_DECAY_MILESTONES, gamma=LEARNING_RATE_DECAY_FACTOR
    )

    # Construct the loss function to be used for training.
    criterion = torch.nn.CrossEntropyLoss()

    # Define the ignite engines for training and evaluation.
    batch_updater = BatchUpdaterWithoutTransfer(model=model, optimizer=optimizer, criterion=criterion, device=device)
    batch_evaluator = BatchEvaluator(model=model, device=device)
    trainer = Engine(batch_updater)
    evaluator = Engine(batch_evaluator)

    # Define and attach the progress bar, loss metric, and the accuracy metrics.
    attach_pbar_and_metrics(trainer, evaluator)

    # The training engine updates the learning rate schedule at end of each epoch.
    lr_updater = LearningRateUpdater(lr_scheduler=lr_scheduler)
    trainer.on(Events.EPOCH_COMPLETED(every=1))(lr_updater)

    # The training engine logs the training and the evaluation metrics at end of each epoch.
    metric_logger = MetricLogger(evaluator=evaluator, eval_loader=validation_loader)
    trainer.on(Events.EPOCH_COMPLETED(every=1))(metric_logger)

    # Train the model
    trainer.run(train_loader, max_epochs=max_epochs)

    # Save the model to pre-defined path. We move the model to CPU which is desirable as the default device
    # for loading the model.
    model.cpu()
    state_dict_dir = "/".join(state_dict_path.split("/")[:-1])
    os.makedirs(state_dict_dir, exist_ok=True)
    torch.save(model.state_dict(), state_dict_path)
Beispiel #4
0
def test_remove_event_handler_on_callable_events():

    engine = Engine(lambda e, b: 1)

    def foo(e):
        pass

    assert not engine.has_event_handler(foo)

    engine.add_event_handler(Events.EPOCH_STARTED, foo)
    assert engine.has_event_handler(foo)
    engine.remove_event_handler(foo, Events.EPOCH_STARTED)
    assert not engine.has_event_handler(foo)

    def bar(e):
        pass

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    engine.remove_event_handler(bar, Events.EPOCH_COMPLETED)
    assert not engine.has_event_handler(foo)

    with pytest.raises(
            TypeError,
            match=r"Argument event_name should not be a filtered event"):
        engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
def test_custom_events_asserts():
    # Dummy engine
    engine = Engine(lambda engine, batch: 0)

    class A:
        pass

    with pytest.raises(
            TypeError,
            match=r"Value at \d of event_names should be a str or EventEnum"):
        engine.register_events(None)

    with pytest.raises(
            TypeError,
            match=r"Value at \d of event_names should be a str or EventEnum"):
        engine.register_events("str", None)

    with pytest.raises(
            TypeError,
            match=r"Value at \d of event_names should be a str or EventEnum"):
        engine.register_events(1)

    with pytest.raises(
            TypeError,
            match=r"Value at \d of event_names should be a str or EventEnum"):
        engine.register_events(A())

    assert Events.EPOCH_COMPLETED != 1
    assert Events.EPOCH_COMPLETED != "abc"
    assert Events.ITERATION_COMPLETED != Events.EPOCH_COMPLETED
    assert Events.ITERATION_COMPLETED != Events.EPOCH_COMPLETED(every=2)
    # In current implementation, EPOCH_COMPLETED and EPOCH_COMPLETED with event filter are the same
    assert Events.EPOCH_COMPLETED == Events.EPOCH_COMPLETED(every=2)
    assert Events.ITERATION_COMPLETED == Events.ITERATION_COMPLETED(every=2)
Beispiel #6
0
 def init(self):
     assert 'engine' in self.frame, 'The frame does not have engine.'
     shutil.copy(self.frame.config_path, self.save_handler.dirname)
     checkpoint = self.Checkpoint(self.modules, self.frame)
     self.frame['engine'].engine.add_event_handler(
         Events.EPOCH_COMPLETED(every=self.save_interval), self,
         {'checkpoint': checkpoint})
     self.frame['engine'].engine.add_event_handler(
         Events.EPOCH_COMPLETED(every=self.save_interval),
         self._correct_checkpoint)
Beispiel #7
0
def test_state_get_event_attrib_value():
    state = State()
    state.iteration = 10
    state.epoch = 9

    e = Events.ITERATION_STARTED
    assert state.get_event_attrib_value(e) == state.iteration
    e = Events.ITERATION_COMPLETED
    assert state.get_event_attrib_value(e) == state.iteration
    e = Events.EPOCH_STARTED
    assert state.get_event_attrib_value(e) == state.epoch
    e = Events.EPOCH_COMPLETED
    assert state.get_event_attrib_value(e) == state.epoch
    e = Events.STARTED
    assert state.get_event_attrib_value(e) == state.epoch
    e = Events.COMPLETED
    assert state.get_event_attrib_value(e) == state.epoch

    e = Events.ITERATION_STARTED(every=10)
    assert state.get_event_attrib_value(e) == state.iteration
    e = Events.ITERATION_COMPLETED(every=10)
    assert state.get_event_attrib_value(e) == state.iteration
    e = Events.EPOCH_STARTED(once=5)
    assert state.get_event_attrib_value(e) == state.epoch
    e = Events.EPOCH_COMPLETED(once=5)
    assert state.get_event_attrib_value(e) == state.epoch
def create_trainer(model, optimizer, loss_fn, lr_scheduler, config):
    # Define any training logic for iteration update
    def train_step(engine, batch):
        x = batch[0].to(idist.device())
        y = batch[1].to(idist.device())

        model.train()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        return loss.item()

    # Define trainer engine
    trainer = Engine(train_step)

    if idist.get_rank() == 0:
        # Add any custom handlers
        @trainer.on(Events.EPOCH_COMPLETED(every=1))
        def save_checkpoint():
            model_path = os.path.join((config.get("output_path", "output")),
                                      "checkpoint.pt")
            torch.save(model.state_dict(), model_path)

        # Add progress bar showing batch loss value
        ProgressBar().attach(trainer,
                             output_transform=lambda x: {"batch loss": x})

    return trainer
def get_event_by_freq(freq: Union[int, Epochs, Iters]):
    if isinstance(freq, int):
        freq = Epochs(freq)
    if isinstance(freq, Epochs):
        return Events.EPOCH_COMPLETED(every=freq.n)
    elif isinstance(freq, Iters):
        return Events.ITERATION_COMPLETED(every=freq.n)
Beispiel #10
0
def configure_checkpoint_saving(trainer, evaluator, model, optimizer, args):
    to_save = {"model": model, "optimizer": optimizer}
    save_handler = DiskSaver(str(args.output_dir),
                             create_dir=False,
                             require_empty=False)

    # Configure epoch checkpoints.
    interval = 1 if args.dev_mode else min(5, args.max_epochs)
    checkpoint = Checkpoint(
        to_save,
        save_handler,
        n_saved=None,
        global_step_transform=lambda *_: trainer.state.epoch)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=interval),
                              checkpoint, evaluator)

    # Configure "best score" checkpoints.
    metric_name = "accuracy"
    best_checkpoint = Checkpoint(
        to_save,
        save_handler,
        score_name=metric_name,
        score_function=lambda engine: engine.state.metrics[metric_name],
        filename_prefix="best")
    trainer.add_event_handler(Events.EPOCH_COMPLETED, best_checkpoint,
                              evaluator)
Beispiel #11
0
def training(local_rank, config):

    # Setup dataflow and
    train_loader, val_loader = get_dataflow(config)
    model, optimizer, criterion, lr_scheduler = initialize(config)

    # Setup model trainer and evaluator
    trainer = create_trainer(model, optimizer, criterion, lr_scheduler, config)
    evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy()}, device=idist.device())

    # Run model evaluation every 3 epochs and show results
    @trainer.on(Events.EPOCH_COMPLETED(every=3))
    def evaluate_model():
        state = evaluator.run(val_loader)
        if idist.get_rank() == 0:
            print(state.metrics)

    # Setup tensorboard experiment tracking
    if idist.get_rank() == 0:
        tb_logger = common.setup_tb_logging(
            config.get("output_path", "output"), trainer, optimizer, evaluators={"validation": evaluator},
        )

    trainer.run(train_loader, max_epochs=config.get("max_epochs", 3))

    if idist.get_rank() == 0:
        tb_logger.close()
Beispiel #12
0
 def attach(self, engine):
     if self.epoch_level:
         engine.add_event_handler(
             Events.EPOCH_COMPLETED(every=self.interval), self)
     else:
         engine.add_event_handler(
             Events.ITERATION_COMPLETED(every=self.interval), self)
Beispiel #13
0
def run(cfg, train_loader, tr_comp, saver, trainer, valid_dict):
    # TODO resume

    # trainer = Engine(...)
    # trainer.load_state_dict(state_dict)
    # trainer.run(data)
    # checkpoint
    handler = ModelCheckpoint(saver.model_dir, 'train', n_saved=3, create_dir=True)
    checkpoint_params = tr_comp.state_dict()
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              handler,
                              checkpoint_params)

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)
    # average metric to attach on trainer
    names = ["Acc", "Loss"]
    names.extend(tr_comp.loss_function_map.keys())
    for n in names:
        RunningAverage(output_transform=Run(n)).attach(trainer, n)

    @trainer.on(Events.EPOCH_COMPLETED)
    def adjust_learning_rate(engine):
        tr_comp.scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED(every=cfg.TRAIN.LOG_ITER_PERIOD))
    def log_training_loss(engine):
        message = f"Epoch[{engine.state.epoch}], " + \
                  f"Iteration[{engine.state.iteration}/{len(train_loader)}], " + \
                  f"Base Lr: {tr_comp.scheduler.get_last_lr()[0]:.2e}, "

        for loss_name in engine.state.metrics.keys():
            message += f"{loss_name}: {engine.state.metrics[loss_name]:.4f}, "

        if tr_comp.xent and tr_comp.xent.learning_weight:
            message += f"xentWeight: {tr_comp.xent.uncertainty.mean().item():.4f}, "

        logger.info(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            train_loader.batch_size / timer.value()))
        logger.info('-' * 80)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD))
    def log_validation_results(engine):
        logger.info(f"Valid - Epoch: {engine.state.epoch}")
        eval_multi_dataset(cfg, valid_dict, tr_comp)

    trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS)
Beispiel #14
0
    def run(self, epochs: int = 1):
        trainer = self.trainer
        train_loader = self.dataloader["train"]
        val_loader = self.dataloader["val"]

        @trainer.on(Events.ITERATION_COMPLETED(every=self.log_interval))
        def log_training_loss(engine):
            length = len(train_loader)
            self.logger.info(f"Epoch[{engine.state.epoch}] "
                             f"Iteration[{engine.state.iteration}/{length}] "
                             f"Loss: {engine.state.output:.2f}")
            self.writer.add_scalar("training/loss", engine.state.output,
                                   engine.state.iteration)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_training_results(engine):
            self.evaluator.run(train_loader)
            metrics = self.evaluator.state.metrics
            avg_accuracy = metrics["accuracy"]
            avg_loss = metrics["loss"]
            self.logger.info(
                f"Training Results - Epoch: {engine.state.epoch}  "
                f"Avg accuracy: {avg_accuracy:.2f}"
                f"Avg loss: {avg_loss:.2f}")
            self.writer.add_scalar("training/avg_loss", avg_loss,
                                   engine.state.epoch)
            self.writer.add_scalar("training/avg_accuracy", avg_accuracy,
                                   engine.state.epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            self.evaluator.run(val_loader)
            metrics = self.evaluator.state.metrics
            avg_accuracy = metrics["accuracy"]
            avg_loss = metrics["loss"]
            self.logger.info(
                f"Validation Results - Epoch: {engine.state.epoch} "
                f"Avg accuracy: {avg_accuracy:.2f} "
                f"Avg loss: {avg_loss:.2f}")
            self.writer.add_scalar("valdation/avg_loss", avg_loss,
                                   engine.state.epoch)
            self.writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                                   engine.state.epoch)

        objects_to_checkpoint = dict(model=self.model,
                                     optimizer=self.optimizer)
        training_checkpoint = Checkpoint(
            to_save=objects_to_checkpoint,
            save_handler=DiskSaver(self.log_dir, require_empty=False),
            n_saved=None,
            global_step_transform=lambda *_: trainer.state.epoch,
        )
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED(every=self.checkpoint_every),
            training_checkpoint,
        )
        trainer.run(train_loader, max_epochs=epochs)
Beispiel #15
0
    def add_logging(self):

        # Add validation logging
        self.train_engine.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                            self.evaluate_model)

        # Add step length update at the end of each epoch
        self.train_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                            lambda _: self.scheduler.step())
Beispiel #16
0
 def attach(self, engine: Engine) -> None:
     """
     Args:
         engine: Ignite Engine, it can be a trainer, validator or evaluator.
     """
     if self.epoch_level:
         engine.add_event_handler(
             Events.EPOCH_COMPLETED(every=self.interval), self)
     else:
         engine.add_event_handler(
             Events.ITERATION_COMPLETED(every=self.interval), self)
Beispiel #17
0
    def attach(self, engine) -> None:

        event_filter = lambda engine, event: True if (
            event >= (self.start) and
            (event - self.start) % self.interval == 0) else False
        if self.epoch_level:
            engine.add_event_handler(
                Events.EPOCH_COMPLETED(event_filter=event_filter), self)
        else:
            engine.add_event_handler(
                Events.ITERATION_COMPLETED(event_filter=event_filter), self)
Beispiel #18
0
def test_pbar_wrong_events_order():

    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(engine,
                    event_name=Events.COMPLETED,
                    closing_event_name=Events.COMPLETED)

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(
            engine,
            event_name=Events.COMPLETED,
            closing_event_name=Events.EPOCH_COMPLETED,
        )

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(
            engine,
            event_name=Events.COMPLETED,
            closing_event_name=Events.ITERATION_COMPLETED,
        )

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(
            engine,
            event_name=Events.EPOCH_COMPLETED,
            closing_event_name=Events.EPOCH_COMPLETED,
        )

    with pytest.raises(ValueError,
                       match="should be called before closing event"):
        pbar.attach(
            engine,
            event_name=Events.ITERATION_COMPLETED,
            closing_event_name=Events.ITERATION_STARTED,
        )

    with pytest.raises(ValueError,
                       match="Closing event should not use any event filter"):
        pbar.attach(
            engine,
            event_name=Events.ITERATION_STARTED,
            closing_event_name=Events.EPOCH_COMPLETED(every=10),
        )
    def train_epochs(self, max_epochs):
        self.trainer = Engine(self.train_one_step)
        self.evaluator = Engine(self.evaluate_one_step)
        self.metrics = {'Loss': Loss(self.criterion), 'Acc': Accuracy()}
        for name, metric in self.metrics.items():
            metric.attach(self.evaluator, name)

        with SummaryWriter(
                log_dir="/tmp/tensorboard/Transform" +
                str(type(self))[17:len(str(type(self))) - 2]) as writer:

            @self.trainer.on(Events.EPOCH_COMPLETED(every=1))  # Cada 1 epocas
            def log_results(engine):
                # Evaluo el conjunto de entrenamiento
                self.eval()
                self.evaluator.run(self.train_loader)
                writer.add_scalar("train/loss",
                                  self.evaluator.state.metrics['Loss'],
                                  engine.state.epoch)
                writer.add_scalar("train/accy",
                                  self.evaluator.state.metrics['Acc'],
                                  engine.state.epoch)

                # Evaluo el conjunto de validación
                self.evaluator.run(self.valid_loader)
                writer.add_scalar("valid/loss",
                                  self.evaluator.state.metrics['Loss'],
                                  engine.state.epoch)
                writer.add_scalar("valid/accy",
                                  self.evaluator.state.metrics['Acc'],
                                  engine.state.epoch)
                self.train()

            # Guardo el mejor modelo en validación
            best_model_handler = ModelCheckpoint(
                dirname='.',
                require_empty=False,
                filename_prefix="best",
                n_saved=1,
                score_function=lambda engine: -engine.state.metrics['Loss'],
                score_name="val_loss")
            # Lo siguiente se ejecuta cada ves que termine el loop de validación
            self.evaluator.add_event_handler(
                Events.COMPLETED, best_model_handler, {
                    f'Transform{str(type(self))[17:len(str(type(self)))-2]}':
                    model
                })

        self.trainer.run(self.train_loader, max_epochs=max_epochs)
Beispiel #20
0
def run(subj_ind: int,
        result_name: str,
        dataset_path: str,
        deep4_path: str,
        result_path: str,
        config: dict = default_config,
        model_builder: ProgressiveModelBuilder = default_model_builder):
    result_path_subj = os.path.join(result_path, result_name, str(subj_ind))
    os.makedirs(result_path_subj, exist_ok=True)

    joblib.dump(config,
                os.path.join(result_path_subj, 'config.dict'),
                compress=False)
    joblib.dump(model_builder,
                os.path.join(result_path_subj, 'model_builder.jblb'),
                compress=True)

    # create discriminator and generator modules
    discriminator = model_builder.build_discriminator()
    generator = model_builder.build_generator()

    # initiate weights
    generator.apply(weight_filler)
    discriminator.apply(weight_filler)

    # trainer engine
    trainer = GanSoftplusTrainer(10, discriminator, generator,
                                 config['r1_gamma'], config['r2_gamma'])

    # handles potential progression after each epoch
    progression_handler = ProgressionHandler(
        discriminator,
        generator,
        config['n_stages'],
        config['use_fade'],
        config['n_epochs_fade'],
        freeze_stages=config['freeze_stages'])
    progression_handler.set_progression(0, 1.)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                              progression_handler.advance_alpha)

    generator.train()
    discriminator.train()

    train(subj_ind, dataset_path, deep4_path, result_path_subj,
          progression_handler, trainer, config['n_batch'], config['lr_d'],
          config['lr_g'], config['betas'], config['n_epochs_per_stage'],
          config['n_epochs_metrics'], config['plot_every_epoch'],
          config['orig_fs'])
Beispiel #21
0
    def train(self,
              epochs: int,
              train_loader,
              test_loader=None,
              trainsize=None,
              valsize=None):
        self.model.train()
        train_engine = Engine(lambda e, b: self.train_step(b))

        @train_engine.on(Events.EPOCH_COMPLETED(every=self.track_loss_freq))
        def eval_test(engine):
            if self.track_loss:
                self.tb_log(train_loader,
                            engine.state.epoch,
                            is_train=True,
                            eval_length=valsize)
                if test_loader is not None:
                    self.tb_log(test_loader,
                                engine.state.epoch,
                                is_train=False,
                                eval_length=valsize)

        @train_engine.on(Events.EPOCH_COMPLETED)
        def save_state(engine):
            torch.save(self.model.state_dict(), self.snail_path)
            torch.save(self.opt.state_dict(), self.snail_opt_path)

        @train_engine.on(
            Events.ITERATION_COMPLETED(every=self.track_params_freq))
        def tb_log_histogram_params(engine):
            if self.track_layers:
                for name, params in self.model.named_parameters():
                    self.logger.add_histogram(name.replace('.', '/'), params,
                                              engine.state.iteration)
                    if params.grad is not None:
                        self.logger.add_histogram(
                            name.replace('.', '/') + '/grad', params.grad,
                            engine.state.iteration)

        if self.trainpbar:
            RunningAverage(output_transform=lambda x: x).attach(
                train_engine, 'loss')
            p = ProgressBar()
            p.attach(train_engine, ['loss'])
        train_engine.run(train_loader,
                         max_epochs=epochs,
                         epoch_length=trainsize)
def setup_checkpoints(trainer, obj_to_save, epoch_length, conf):
    # type: (Engine, Dict[str, Any], int, DictConfig) -> None
    cp = conf.checkpoints
    save_path = cp.get('save_dir', os.getcwd())
    logging.info("Saving checkpoints to {}".format(save_path))
    max_cp = max(int(cp.get('max_checkpoints', 1)), 1)
    save = DiskSaver(save_path, create_dir=True, require_empty=True)
    make_checkpoint = Checkpoint(obj_to_save, save, n_saved=max_cp)
    cp_iter = cp.interval_iteration
    cp_epoch = cp.interval_epoch
    if cp_iter > 0:
        save_event = Events.ITERATION_COMPLETED(every=cp_iter)
        trainer.add_event_handler(save_event, make_checkpoint)
    if cp_epoch > 0:
        if cp_iter < 1 or epoch_length % cp_iter:
            save_event = Events.EPOCH_COMPLETED(every=cp_epoch)
            trainer.add_event_handler(save_event, make_checkpoint)
def setup_evaluation(
    trainer: Engine,
    evaluators: Dict[str, Engine],
    data_loaders: Dict[str, DataLoader],
    logger: Logger,
) -> None:
    # We define two evaluators as they wont have exactly similar roles:
    # - `evaluator` will save the best model based on validation score
    def _evaluation(engine: Engine) -> None:
        epoch = trainer.state.epoch
        for split in ["train", "val", "test"]:
            state = evaluators[split].run(data_loaders[split])
            log_metrics(logger, epoch, state.times["COMPLETED"], split, state.metrics)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=config.validate_every) | Events.COMPLETED,
        _evaluation,
    )
    return
Beispiel #24
0
 def attach(self, engine: Engine):
     if self._name is None:
         self.logger = engine.logger
     if self._final_checkpoint is not None:
         engine.add_event_handler(Events.COMPLETED, self.completed)
         engine.add_event_handler(Events.EXCEPTION_RAISED,
                                  self.exception_raised)
     if self._key_metric_checkpoint is not None:
         engine.add_event_handler(Events.EPOCH_COMPLETED,
                                  self.metrics_completed)
     if self._interval_checkpoint is not None:
         if self.epoch_level:
             engine.add_event_handler(
                 Events.EPOCH_COMPLETED(every=self.save_interval),
                 self.interval_completed)
         else:
             engine.add_event_handler(
                 Events.ITERATION_COMPLETED(every=self.save_interval),
                 self.interval_completed)
Beispiel #25
0
# Define mean dice metric and Evaluator.
validation_every_n_epochs = 1

val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)}
evaluator = create_supervised_evaluator(net,
                                        val_metrics,
                                        device,
                                        True,
                                        output_transform=lambda x, y, y_pred:
                                        (y_pred[0], y))

val_stats_handler = StatsHandler()
val_stats_handler.attach(evaluator)

# Add early stopping handler to evaluator.
early_stopper = EarlyStopping(
    patience=4,
    score_function=stopping_fn_from_metric('Mean Dice'),
    trainer=trainer)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                            handler=early_stopper)


@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
    evaluator.run(val_loader)


state = trainer.run(loader, train_epochs)
    def fit(self, dataset, fold=0, train_split='train', valid_split='val'):
        """Fit the predictor model.

    Args:
      - dataset: temporal, static, label, time, treatment information
      - fold: Cross validation fold
      - train_split: training set splitting parameter
      - valid_split: validation set splitting parameter

    Returns:
      - self.predictor_model: trained predictor model
    """
        train_x, train_y = self._data_preprocess(dataset, fold, train_split)
        valid_x, valid_y = self._data_preprocess(dataset, fold, valid_split)

        train_dataset = torch.utils.data.dataset.TensorDataset(
            self._make_tensor(train_x), self._make_tensor(train_y))
        valid_dataset = torch.utils.data.dataset.TensorDataset(
            self._make_tensor(valid_x), self._make_tensor(valid_y))

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=self.batch_size,
                                                   shuffle=True)
        val_loader = torch.utils.data.DataLoader(valid_dataset,
                                                 batch_size=self.batch_size,
                                                 shuffle=True)

        if self.predictor_model is None:
            self.predictor_model = TransformerModule(
                self.task, dataset.problem, train_x.shape[-1], self.h_dim,
                train_y.shape[-1], self.n_head, self.n_layer).to(self.device)
            self.optimizer = torch.optim.Adam(
                self.predictor_model.parameters(), lr=self.learning_rate)

        self.predictor_model.train()

        # classification vs regression
        # static vs dynamic
        trainer = create_supervised_trainer(self.predictor_model,
                                            self.optimizer,
                                            self.predictor_model.loss_fn)
        evaluator = create_supervised_evaluator(
            self.predictor_model,
            metrics={'loss': Loss(self.predictor_model.loss_fn)})
        # model check point
        checkpoint_handler = ModelCheckpoint(self.model_path,
                                             self.model_id,
                                             n_saved=1,
                                             create_dir=True,
                                             require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  checkpoint_handler,
                                  {'model': self.predictor_model})

        # early stopping
        def score_function(engine):
            val_loss = engine.state.metrics['loss']
            return -val_loss

        early_stopping_handler = EarlyStopping(patience=10,
                                               score_function=score_function,
                                               trainer=trainer)
        evaluator.add_event_handler(Events.COMPLETED, early_stopping_handler)

        # evaluation loss
        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(trainer):
            evaluator.run(val_loader)
            metrics = evaluator.state.metrics
            print("Validation Results - Epoch[{}] Avg loss: {:.2f}".format(
                trainer.state.epoch, metrics['loss']))

        trainer.run(train_loader, max_epochs=self.epoch)

        return self.predictor_model
Beispiel #27
0
    def _train(save_iter=None, save_epoch=None, sd=None):
        w_norms = []
        grad_norms = []
        data = []
        chkpt = []

        manual_seed(12)
        arch = [
            nn.Conv2d(3, 10, 3),
            nn.ReLU(),
            nn.Conv2d(10, 10, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(10, 5),
            nn.ReLU(),
            nn.Linear(5, 2),
        ]
        if with_dropout:
            arch.insert(2, nn.Dropout2d())
            arch.insert(-2, nn.Dropout())

        model = nn.Sequential(*arch).to(device)
        opt = SGD(model.parameters(), lr=0.001)

        def proc_fn(e, b):
            from ignite.engine.deterministic import _get_rng_states, _repr_rng_state

            s = _repr_rng_state(_get_rng_states())
            model.train()
            opt.zero_grad()
            y = model(b.to(device))
            y.sum().backward()
            opt.step()
            if debug:
                print(trainer.state.iteration, trainer.state.epoch,
                      "proc_fn - b.shape", b.shape,
                      torch.norm(y).item(), s)

        trainer = DeterministicEngine(proc_fn)

        if save_iter is not None:
            ev = Events.ITERATION_COMPLETED(once=save_iter)
        elif save_epoch is not None:
            ev = Events.EPOCH_COMPLETED(once=save_epoch)
            save_iter = save_epoch * (data_size // batch_size)

        @trainer.on(ev)
        def save_chkpt(_):
            if debug:
                print(trainer.state.iteration, "save_chkpt")
            fp = dirname / "test.pt"
            from ignite.engine.deterministic import _repr_rng_state

            tsd = trainer.state_dict()
            if debug:
                print("->", _repr_rng_state(tsd["rng_states"]))
            torch.save([model.state_dict(), opt.state_dict(), tsd], fp)
            chkpt.append(fp)

        def log_event_filter(_, event):
            if (event // save_iter == 1) and 1 <= (event % save_iter) <= 5:
                return True
            return False

        @trainer.on(Events.ITERATION_COMPLETED(event_filter=log_event_filter))
        def write_data_grads_weights(e):
            x = e.state.batch
            i = e.state.iteration
            data.append([i, x.mean().item(), x.std().item()])

            total = [0.0, 0.0]
            out1 = []
            out2 = []
            for p in model.parameters():
                n1 = torch.norm(p).item()
                n2 = torch.norm(p.grad).item()
                out1.append(n1)
                out2.append(n2)
                total[0] += n1
                total[1] += n2
            w_norms.append([i, total[0]] + out1)
            grad_norms.append([i, total[1]] + out2)

        if sd is not None:
            sd = torch.load(sd)
            model.load_state_dict(sd[0])
            opt.load_state_dict(sd[1])
            from ignite.engine.deterministic import _repr_rng_state

            if debug:
                print("-->", _repr_rng_state(sd[2]["rng_states"]))
            trainer.load_state_dict(sd[2])

        manual_seed(32)
        trainer.run(random_train_data_loader(size=data_size), max_epochs=5)
        return {
            "sd": chkpt,
            "data": data,
            "grads": grad_norms,
            "weights": w_norms
        }
Beispiel #28
0
def create_trainer(
    train_step,
    output_names,
    model,
    ema_model,
    optimizer,
    lr_scheduler,
    supervised_train_loader,
    test_loader,
    cfg,
    logger,
    cta=None,
    unsup_train_loader=None,
    cta_probe_loader=None,
):

    trainer = Engine(train_step)
    trainer.logger = logger

    output_path = os.getcwd()

    to_save = {
        "model": model,
        "ema_model": ema_model,
        "optimizer": optimizer,
        "trainer": trainer,
        "lr_scheduler": lr_scheduler,
    }
    if cta is not None:
        to_save["cta"] = cta

    common.setup_common_training_handlers(
        trainer,
        train_sampler=supervised_train_loader.sampler,
        to_save=to_save,
        save_every_iters=cfg.solver.checkpoint_every,
        output_path=output_path,
        output_names=output_names,
        lr_scheduler=lr_scheduler,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    ProgressBar(persist=False).attach(
        trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED
    )

    unsupervised_train_loader_iter = None
    if unsup_train_loader is not None:
        unsupervised_train_loader_iter = cycle(unsup_train_loader)

    cta_probe_loader_iter = None
    if cta_probe_loader is not None:
        cta_probe_loader_iter = cycle(cta_probe_loader)

    # Setup handler to prepare data batches
    @trainer.on(Events.ITERATION_STARTED)
    def prepare_batch(e):
        sup_batch = e.state.batch
        e.state.batch = {
            "sup_batch": sup_batch,
        }
        if unsupervised_train_loader_iter is not None:
            unsup_batch = next(unsupervised_train_loader_iter)
            e.state.batch["unsup_batch"] = unsup_batch

        if cta_probe_loader_iter is not None:
            cta_probe_batch = next(cta_probe_loader_iter)
            cta_probe_batch["policy"] = [
                deserialize(p) for p in cta_probe_batch["policy"]
            ]
            e.state.batch["cta_probe_batch"] = cta_probe_batch

    # Setup handler to update EMA model
    @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay)
    def update_ema_model(ema_decay):
        # EMA on parametes
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)

    # Setup handlers for debugging
    if cfg.debug:

        @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100))
        @idist.one_rank_only()
        def log_weights_norms():
            wn = []
            ema_wn = []
            for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                wn.append(torch.mean(param.data))
                ema_wn.append(torch.mean(ema_param.data))

            msg = "\n\nWeights norms"
            msg += "\n- Raw model: {}".format(
                to_list_str(torch.tensor(wn[:10] + wn[-10:]))
            )
            msg += "\n- EMA model: {}\n".format(
                to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:]))
            )
            logger.info(msg)

            rmn = []
            rvar = []
            ema_rmn = []
            ema_rvar = []
            for m1, m2 in zip(model.modules(), ema_model.modules()):
                if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d):
                    rmn.append(torch.mean(m1.running_mean))
                    rvar.append(torch.mean(m1.running_var))
                    ema_rmn.append(torch.mean(m2.running_mean))
                    ema_rvar.append(torch.mean(m2.running_var))

            msg = "\n\nBN buffers"
            msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10])))
            msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10])))
            msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10])))
            msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10])))
            logger.info(msg)

        # TODO: Need to inspect a bug
        # if idist.get_rank() == 0:
        #     from ignite.contrib.handlers import ProgressBar
        #
        #     profiler = BasicTimeProfiler()
        #     profiler.attach(trainer)
        #
        #     @trainer.on(Events.ITERATION_COMPLETED(every=200))
        #     def log_profiling(_):
        #         results = profiler.get_results()
        #         profiler.print_results(results)

    # Setup validation engine
    metrics = {
        "accuracy": Accuracy(),
    }

    if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU):
        metrics.update({
            "precision": Precision(average=False),
            "recall": Recall(average=False),
        })

    eval_kwargs = dict(
        metrics=metrics,
        prepare_batch=sup_prepare_batch,
        device=idist.device(),
        non_blocking=True,
    )

    evaluator = create_supervised_evaluator(model, **eval_kwargs)
    ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs)

    def log_results(epoch, max_epochs, metrics, ema_metrics):
        msg1 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()]
        )
        msg2 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()]
        )
        logger.info(
            "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2)
        )
        if cta is not None:
            logger.info("\n" + stats(cta))

    @trainer.on(
        Events.EPOCH_COMPLETED(every=cfg.solver.validate_every)
        | Events.STARTED
        | Events.COMPLETED
    )
    def run_evaluation():
        evaluator.run(test_loader)
        ema_evaluator.run(test_loader)
        log_results(
            trainer.state.epoch,
            trainer.state.max_epochs,
            evaluator.state.metrics,
            ema_evaluator.state.metrics,
        )

    # setup TB logging
    if idist.get_rank() == 0:
        tb_logger = common.setup_tb_logging(
            output_path,
            trainer,
            optimizers=optimizer,
            evaluators={"validation": evaluator, "ema validation": ema_evaluator},
            log_every_iters=15,
        )
        if cfg.online_exp_tracking.wandb:
            from ignite.contrib.handlers import WandBLogger

            wb_dir = Path("/tmp/output-fixmatch-wandb")
            if not wb_dir.exists():
                wb_dir.mkdir()

            _ = WandBLogger(
                project="fixmatch-pytorch",
                name=cfg.name,
                config=cfg,
                sync_tensorboard=True,
                dir=wb_dir.as_posix(),
                reinit=True,
            )

    resume_from = cfg.solver.resume_from
    if resume_from is not None:
        resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*"))
        if len(resume_from) > 0:
            # get latest
            checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime)
            assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
                checkpoint_fp.as_posix()
            )
            logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
            checkpoint = torch.load(checkpoint_fp.as_posix())
            Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    @trainer.on(Events.COMPLETED)
    def release_all_resources():
        nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter

        if idist.get_rank() == 0:
            tb_logger.close()

        if unsupervised_train_loader_iter is not None:
            unsupervised_train_loader_iter = None

        if cta_probe_loader_iter is not None:
            cta_probe_loader_iter = None

    return trainer
Beispiel #29
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz",
    ]
    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0])
    train_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[:10], labels[:10])]
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[-10:], labels[-10:])]

    # define transforms for image
    train_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["label"])

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    net = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    )
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):

        return _prepare_batch((batch["img"], batch["label"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "net": net,
                                  "opt": opt
                              })

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {
        metric_name: Accuracy(),
        "AUC": ROCAUC(to_onehot_y=True, add_softmax=True)
    }
    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
Beispiel #30
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    if sys.version_info > (3, ):
        from ignite.contrib.metrics.gpu_info import GpuInfo

        try:
            GpuInfo().attach(trainer)
        except RuntimeError:
            print(
                "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
                "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
                "install it : `pip install pynvml`")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    tb_logger = TensorboardLogger(log_dir=log_dir)

    tb_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
        metric_names="all",
    )

    for tag, evaluator in [("training", train_evaluator),
                           ("validation", validation_evaluator)]:
        tb_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    tb_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    tb_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    def score_function(engine):
        return engine.state.metrics["accuracy"]

    model_checkpoint = ModelCheckpoint(
        log_dir,
        n_saved=2,
        filename_prefix="best",
        score_function=score_function,
        score_name="validation_accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    tb_logger.close()