Beispiel #1
0
def test_returns_state():
    engine = Engine(MagicMock(return_value=1))
    state = engine.run([])

    assert isinstance(state, State)
    def _test(save_history):
        tensor = torch.ones([1], requires_grad=True)
        optimizer = torch.optim.SGD([tensor], lr=0.001)

        max_epochs = 25
        lr_max_value = 0.4
        num_iterations_per_epoch = 128
        num_iterations = max_epochs * num_iterations_per_epoch
        warmup_duration = 5 * num_iterations_per_epoch
        cooldown_duration = 5 * num_iterations_per_epoch

        scheduler_1 = LinearCyclicalScheduler(
            optimizer,
            "lr",
            start_value=lr_max_value,
            end_value=lr_max_value * 0.9,
            cycle_size=(num_iterations - warmup_duration - cooldown_duration) * 2,
        )

        scheduler_2 = LinearCyclicalScheduler(
            optimizer, "lr", start_value=lr_max_value, end_value=0.0, cycle_size=cooldown_duration * 2
        )

        lr_scheduler = ConcatScheduler(
            schedulers=[scheduler_1, scheduler_2],
            durations=[num_iterations - warmup_duration - cooldown_duration],
            save_history=False,
        )
        lr_values = [None] * num_iterations
        scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=lr_max_value,
            warmup_duration=warmup_duration,
            save_history=save_history,
            output_simulated_values=lr_values,
        )
        state_dict = scheduler.state_dict()

        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]["lr"])

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        data = [0] * num_iterations_per_epoch

        for _ in range(2):
            lrs = []
            trainer.run(data, max_epochs=max_epochs)

            assert lrs == pytest.approx([v for i, v in lr_values])

            if save_history:
                param_history = trainer.state.param_history["lr"]
                assert lrs == pytest.approx([v[0] for v in param_history])

                trainer.state.param_history = None

            scheduler.load_state_dict(state_dict)
Beispiel #3
0
class Trainer:
    _STEPS_PER_LOSS_WRITE = 10
    _STEPS_PER_GRAD_WRITE = 10
    _STEPS_PER_LR_WRITE = 10

    def __init__(
            self,

            module,
            device,

            train_metrics,
            train_loader,
            opt,
            lr_scheduler,
            max_epochs,
            max_grad_norm,

            test_metrics,
            test_loader,
            epochs_per_test,

            early_stopping,
            valid_loss,
            valid_loader,
            max_bad_valid_epochs,

            visualizer,

            writer,
            should_checkpoint_latest,
            should_checkpoint_best_valid
    ):
        self._module = module

        self._device = device

        self._train_metrics = train_metrics
        self._train_loader = train_loader
        self._opt = opt
        self._lr_scheduler = lr_scheduler
        self._max_epochs = max_epochs
        self._max_grad_norm = max_grad_norm

        self._test_metrics = test_metrics
        self._test_loader = test_loader
        self._epochs_per_test = epochs_per_test

        self._valid_loss = valid_loss
        self._valid_loader = valid_loader
        self._max_bad_valid_epochs = max_bad_valid_epochs
        self._best_valid_loss = float("inf")
        self._num_bad_valid_epochs = 0

        self._visualizer = visualizer

        self._writer = writer
        self._should_checkpoint_best_valid = should_checkpoint_best_valid

        ### Training

        self._trainer = Engine(self._train_batch)

        AverageMetric().attach(self._trainer)
        ProgressBar(persist=True).attach(self._trainer, ["loss"])

        self._trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
        self._trainer.add_event_handler(Events.ITERATION_COMPLETED, self._log_training_info)

        ### Validation

        if early_stopping:
            self._validator = Engine(self._validate_batch)

            AverageMetric().attach(self._validator)
            ProgressBar(persist=False, desc="Validating").attach(self._validator)

            self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._validate)

        ### Testing

        self._tester = Engine(self._test_batch)

        AverageMetric().attach(self._tester)
        ProgressBar(persist=False, desc="Testing").attach(self._tester)

        self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._test_and_log)

        ### Checkpointing

        if should_checkpoint_latest:
            self._trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self._save_checkpoint("latest"))

        try:
            self._load_checkpoint("latest")
        except FileNotFoundError:
            print("Did not find `latest' checkpoint.", file=sys.stderr)

            try:
                self._load_checkpoint("best_valid")
            except FileNotFoundError:
                print("Did not find `best_valid' checkpoint.", file=sys.stderr)

    def train(self):
        self._trainer.run(data=self._train_loader, max_epochs=self._max_epochs)

    def _train_batch(self, engine, batch):
        self._module.train()

        x, _ = batch # TODO: Potentially pass y also for genericity
        x = x.to(self._device)

        self._opt.zero_grad()

        train_metrics = self._train_metrics(self._module, x)
        loss = train_metrics["loss"]
        loss.backward()

        if self._max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(self._module.parameters(), self._max_grad_norm)

        self._opt.step()

        self._lr_scheduler.step()

        return {"metrics": train_metrics}

    def test(self):
        self._module.eval()
        return self._tester.run(data=self._test_loader).metrics

    @torch.no_grad()
    def _test_and_log(self, engine):
        epoch = engine.state.epoch
        if (epoch - 1) % self._epochs_per_test == 0: # Test after first epoch
            for k, v in self.test().items():
                self._writer.write_scalar(f"test/{k}", v, global_step=engine.state.epoch)

                if not torch.isfinite(v):
                    self._save_checkpoint(tag="nan_during_test")

            self._visualizer.visualize(self._module, epoch)

    def _test_batch(self, engine, batch):
        x, _ = batch
        x = x.to(self._device)
        return {"metrics": self._test_metrics(self._module, x)}

    @torch.no_grad()
    def _validate(self, engine):
        self._module.eval()

        state = self._validator.run(data=self._valid_loader)
        valid_loss = state.metrics["loss"]

        if valid_loss < self._best_valid_loss:
            print(f"Best validation loss {valid_loss} after epoch {engine.state.epoch}")
            self._num_bad_valid_epochs = 0
            self._best_valid_loss = valid_loss

            if self._should_checkpoint_best_valid:
                self._save_checkpoint(tag="best_valid")

        else:
            if not torch.isfinite(valid_loss):
                self._save_checkpoint(tag="nan_during_validation")

            self._num_bad_valid_epochs += 1

            # We do this manually (i.e. don't use Ignite's early stopping) to permit
            # saving/resuming more easily
            if self._num_bad_valid_epochs > self._max_bad_valid_epochs:
                print(
                    f"No validation improvement after {self._num_bad_valid_epochs} epochs. Terminating."
                )
                self._trainer.terminate()

    def _validate_batch(self, engine, batch):
        x, _ = batch
        x = x.to(self._device)
        return {"metrics": {"loss": self._valid_loss(self._module, x)}}

    def _log_training_info(self, engine):
        i = engine.state.iteration

        if i % self._STEPS_PER_LOSS_WRITE == 0:
            for k, v in engine.state.output["metrics"].items():
                self._writer.write_scalar("train/" + k, v, global_step=i)

        # TODO: Inefficient to recompute this if we are doing gradient clipping
        if i % self._STEPS_PER_GRAD_WRITE == 0:
            self._writer.write_scalar("train/grad-norm", self._get_grad_norm(), global_step=i)

        # TODO: We should do this _before_ calling self._lr_scheduler.step(), since
        # we will not correspond to the learning rate used at iteration i otherwise
        if i % self._STEPS_PER_LR_WRITE == 0:
            self._writer.write_scalar("train/lr", self._get_lr(), global_step=i)

    def _get_grad_norm(self):
        norm = 0
        for param in self._module.parameters():
            if param.grad is not None:
                norm += param.grad.norm().item()**2
        return np.sqrt(norm)

    def _get_lr(self):
        param_group, = self._opt.param_groups
        return param_group["lr"]

    def _save_checkpoint(self, tag):
        # We do this manually (i.e. don't use Ignite's checkpointing) because
        # Ignite only allows saving objects, not scalars (e.g. the current epoch) 
        checkpoint = {
            "epoch": self._trainer.state.epoch,
            "iteration": self._trainer.state.iteration,
            "module_state_dict": self._module.state_dict(),
            "opt_state_dict": self._opt.state_dict(),
            "best_valid_loss": self._best_valid_loss,
            "num_bad_valid_epochs": self._num_bad_valid_epochs,
            "lr_scheduler_state_dict": self._lr_scheduler.state_dict()
        }

        self._writer.write_checkpoint(tag, checkpoint)

    def _load_checkpoint(self, tag):
        checkpoint = self._writer.load_checkpoint(tag, device=self._device)

        @self._trainer.on(Events.STARTED)
        def resume_trainer_state(engine):
            engine.state.epoch = checkpoint["epoch"]
            engine.state.iteration = checkpoint["iteration"]

        self._module.load_state_dict(checkpoint["module_state_dict"])
        self._opt.load_state_dict(checkpoint["opt_state_dict"])
        self._best_valid_loss = checkpoint["best_valid_loss"]
        self._num_bad_valid_epochs = checkpoint["num_bad_valid_epochs"]
        try:
            self._lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
        except KeyError:
            print("No lr scheduler in saved checkpoint")

        print(f"Loaded checkpoint `{tag}' after epoch {checkpoint['epoch']}", file=sys.stderr)
    def _test(duration_vals_as_np_int):
        scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=1.0, end_value=0.0, cycle_size=10)
        scheduler_2 = CosineAnnealingScheduler(optimizer, "lr", start_value=0.0, end_value=1.0, cycle_size=10)

        durations = [10]
        if duration_vals_as_np_int:
            durations = [np.int64(t) for t in durations]

        concat_scheduler = ConcatScheduler(
            schedulers=[scheduler_1, scheduler_2], durations=durations, save_history=True
        )
        state_dict = concat_scheduler.state_dict()

        data = [0] * 10
        max_epochs = 2
        simulated_values = ConcatScheduler.simulate_values(
            num_events=len(data) * max_epochs, schedulers=[scheduler_1, scheduler_2], durations=durations
        )

        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]["lr"])

        trainer = Engine(lambda engine, batch: None)
        trainer.add_event_handler(Events.ITERATION_STARTED, concat_scheduler)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

        for _ in range(2):
            lrs = []
            trainer.run(data, max_epochs=max_epochs)

            assert lrs == list(
                map(
                    pytest.approx,
                    [
                        # Cycle 1 of the LinearCyclicalScheduler
                        1.0,
                        0.8,
                        0.6,
                        0.4,
                        0.2,
                        0.0,
                        0.2,
                        0.4,
                        0.6,
                        0.8,
                        # Cycle 1 of the CosineAnnealingScheduler
                        0.0,
                        0.02447174185242318,
                        0.09549150281252627,
                        0.20610737385376332,
                        0.3454915028125263,
                        0.5,
                        0.6545084971874737,
                        0.7938926261462365,
                        0.9045084971874737,
                        0.9755282581475768,
                    ],
                )
            )

            state_lrs = trainer.state.param_history["lr"]
            assert len(state_lrs) == len(lrs)
            # Unpack singleton lists
            assert [group[0] for group in state_lrs] == lrs
            assert lrs == pytest.approx([v for i, v in simulated_values])
            concat_scheduler.load_state_dict(state_dict)

            trainer.state.param_history = None
def test_concat_scheduler_3_schedulers():
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0)

    scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=1.0, end_value=0.5, cycle_size=20)
    scheduler_2 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.5, end_value=0.45, cycle_size=10)
    scheduler_3 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.5, end_value=0.0, cycle_size=20)
    durations = [10, 5]

    concat_scheduler = ConcatScheduler(
        schedulers=[scheduler_1, scheduler_2, scheduler_3], durations=durations, save_history=True
    )
    state_dict = concat_scheduler.state_dict()

    data = [0] * 10
    max_epochs = 2
    simulated_values = ConcatScheduler.simulate_values(
        num_events=len(data) * max_epochs, schedulers=[scheduler_1, scheduler_2, scheduler_3], durations=durations
    )

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]["lr"])

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, concat_scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    for _ in range(2):
        lrs = []
        trainer.run(data, max_epochs=max_epochs)

        assert lrs == list(
            map(
                pytest.approx,
                [
                    # Cycle 1 of the first LinearCyclicalScheduler
                    1.0,
                    0.95,
                    0.9,
                    0.85,
                    0.8,
                    0.75,
                    0.7,
                    0.65,
                    0.6,
                    0.55,
                    # Cycle 1 of the second LinearCyclicalScheduler
                    0.5,
                    0.49,
                    0.48,
                    0.47,
                    0.46,
                    # Cycle 1 of the third LinearCyclicalScheduler
                    0.5,
                    0.45,
                    0.4,
                    0.35,
                    0.3,
                ],
            )
        )

        state_lrs = trainer.state.param_history["lr"]
        assert len(state_lrs) == len(lrs)
        # Unpack singleton lists
        assert [group[0] for group in state_lrs] == lrs

        assert lrs == pytest.approx([v for i, v in simulated_values])
        concat_scheduler.load_state_dict(state_dict)

        trainer.state.param_history = None
                                                           gamma=params.gamma,
                                                           steps_count=N_STEPS)
    buffer = dqn_extra.PrioReplayBuffer(exp_source, params.replay_size,
                                        PRIO_REPLAY_ALPHA)
    optimizer = optim.Adam(net.parameters(), lr=params.learning_rate)

    def process_batch(engine, batch_data):
        batch, batch_indices, batch_weights = batch_data
        optimizer.zero_grad()
        loss_v, sample_prios = calc_loss_prio(batch,
                                              batch_weights,
                                              net,
                                              tgt_net.target_model,
                                              gamma=params.gamma**N_STEPS,
                                              device=device)
        loss_v.backward()
        optimizer.step()
        buffer.update_priorities(batch_indices, sample_prios)
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()
        return {
            "loss": loss_v.item(),
            "beta": buffer.update_beta(engine.state.iteration),
        }

    engine = Engine(process_batch)
    common.setup_ignite(engine, params, exp_source, NAME)
    engine.run(
        common.batch_generator(buffer, params.replay_initial,
                               params.batch_size))
Beispiel #7
0
    @trainer.on(Events.COMPLETED)
    def plot_font_results(engine):
        evaluator.run(valid_loader)
        real_font, fake_font, latent_vectors = evaluator.state.output
        print(real_font.shape)
        print(fake_font)
        plt.figure(figsize=(6, 100))
        for i, (real, fake) in enumerate(zip(real_font, fake_font)):
            plt.subplot(107, 2, 2 * i + 1)
            plt.imshow(real.cpu().detach().numpy())
            plt.subplot(107, 2, 2 * i + 2)
            plt.imshow(fake.cpu().detach().numpy())
        # plt.savefig('real_fake_fonts_{}_for_category_5layers.png'.format(engine.state.epoch))
        plt.close()

    @trainer.on(Events.COMPLETED)
    def plot_latent_vectors(engine):
        evaluator.run(valid_loader)
        _, _, latent_vectors = evaluator.state.output
        print(latent_vectors.shape)
        plt.figure()
        latent_vectors = latent_vectors.cpu().detach().numpy()
        for i in range(len(latent_vectors)):
            plt.plot(latent_vectors[i, 0], latent_vectors[i, 1], marker='o')
        # plt.plot(latent_vectors[:, 0], latent_vectors[:, 1], marker='.')
        # plt.savefig('latent_vectors_for_category_layers.png')
        plt.close()

    trainer.run(train_loader, max_epochs=epochs)
            trainer.tb.writer.add_image("fake", fake_img,
                                        trainer.state.iteration)
            real_img = vutils.make_grid(batch_v.data[:64], normalize=True)
            trainer.tb.writer.add_image("real", real_img,
                                        trainer.state.iteration)
            trainer.tb.writer.flush()
        return dis_loss.item(), gen_loss.item()

    engine = Engine(process_batch)
    tb = tb_logger.TensorboardLogger(log_dir=None)
    engine.tb = tb
    RunningAverage(output_transform=lambda out: out[0]).attach(
        engine, "avg_loss_gen")
    RunningAverage(output_transform=lambda out: out[1]).attach(
        engine, "avg_loss_dis")

    handler = tb_logger.OutputHandler(
        tag="train", metric_names=['avg_loss_gen', 'avg_loss_dis'])
    tb.attach(engine,
              log_handler=handler,
              event_name=Events.ITERATION_COMPLETED)

    @engine.on(Events.ITERATION_COMPLETED)
    def log_losses(trainer):
        if trainer.state.iteration % REPORT_EVERY_ITER == 0:
            log.info("%d: gen_loss=%f, dis_loss=%f", trainer.state.iteration,
                     trainer.state.metrics['avg_loss_gen'],
                     trainer.state.metrics['avg_loss_dis'])

    engine.run(data=iterate_batches(envs))
Beispiel #9
0
class DataflowBenchmark:
    def __init__(self, num_iters=100, prepare_batch=None):

        from ignite.handlers import Timer

        device = idist.device()

        def upload_to_gpu(engine, batch):
            if prepare_batch is not None:
                x, y = prepare_batch(batch, device=device, non_blocking=False)

        self.num_iters = num_iters
        self.benchmark_dataflow = Engine(upload_to_gpu)

        @self.benchmark_dataflow.on(Events.ITERATION_COMPLETED(once=num_iters))
        def stop_benchmark_dataflow(engine):
            engine.terminate()

        if idist.get_rank() == 0:

            @self.benchmark_dataflow.on(
                Events.ITERATION_COMPLETED(every=num_iters // 100))
            def show_progress_benchmark_dataflow(engine):
                print(".", end=" ")

        self.timer = Timer(average=False)
        self.timer.attach(
            self.benchmark_dataflow,
            start=Events.EPOCH_STARTED,
            resume=Events.ITERATION_STARTED,
            pause=Events.ITERATION_COMPLETED,
            step=Events.ITERATION_COMPLETED,
        )

    def attach(self, trainer, train_loader):

        from torch.utils.data import DataLoader

        @trainer.on(Events.STARTED)
        def run_benchmark(_):
            if idist.get_rank() == 0:
                print("-" * 50)
                print(" - Dataflow benchmark")

            self.benchmark_dataflow.run(train_loader)
            t = self.timer.value()

            if idist.get_rank() == 0:
                print(" ")
                print(
                    f" Total time ({self.num_iters} iterations) : {t:.5f} seconds"
                )
                print(
                    f" time per iteration         : {t / self.num_iters} seconds"
                )

                if isinstance(train_loader, DataLoader):
                    num_images = train_loader.batch_size * self.num_iters
                    print(f" number of images / s       : {num_images / t}")

                print("-" * 50)
Beispiel #10
0
def run(output_path, config):
    device = "cuda"

    local_rank = config["local_rank"]
    distributed = backend is not None
    if distributed:
        torch.cuda.set_device(local_rank)
        device = "cuda"
    rank = dist.get_rank() if distributed else 0

    torch.manual_seed(config["seed"] + rank)

    # Rescale batch_size and num_workers
    ngpus_per_node = torch.cuda.device_count()
    ngpus = dist.get_world_size() if distributed else 1
    batch_size = config["batch_size"] // ngpus
    num_workers = int(
        (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node)

    train_loader, test_loader = get_train_test_loaders(
        path=config["data_path"],
        batch_size=batch_size,
        distributed=distributed,
        num_workers=num_workers,
    )

    model = get_model(config["model"])
    model = model.to(device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[
                local_rank,
            ],
            output_device=local_rank,
        )

    optimizer = optim.SGD(
        model.parameters(),
        lr=config["learning_rate"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
        nesterov=True,
    )

    criterion = nn.CrossEntropyLoss().to(device)

    le = len(train_loader)
    milestones_values = [
        (0, 0.0),
        (le * config["num_warmup_epochs"], config["learning_rate"]),
        (le * config["num_epochs"], 0.0),
    ]
    lr_scheduler = PiecewiseLinear(optimizer,
                                   param_name="lr",
                                   milestones_values=milestones_values)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
        )

    def process_function(engine, batch):

        x, y = _prepare_batch(batch, device=device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(process_function)
    train_sampler = train_loader.sampler if distributed else None
    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = [
        "batch loss",
    ]
    common.setup_common_training_handlers(
        trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        output_path=output_path,
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbar_on_iters=config["display_iters"],
        log_every_iters=10,
    )

    if rank == 0:
        tb_logger = TensorboardLogger(log_dir=output_path)
        tb_logger.attach(
            trainer,
            log_handler=OutputHandler(tag="train", metric_names=metric_names),
            event_name=Events.ITERATION_COMPLETED,
        )
        tb_logger.attach(
            trainer,
            log_handler=OptimizerParamsHandler(optimizer, param_name="lr"),
            event_name=Events.ITERATION_STARTED,
        )

    metrics = {
        "accuracy": Accuracy(device=device if distributed else None),
        "loss": Loss(criterion, device=device if distributed else None),
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        torch.cuda.synchronize()
        train_evaluator.run(train_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(
        Events.EPOCH_STARTED(every=config["validate_every"]), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        if config["display_iters"]:
            ProgressBar(persist=False,
                        desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False,
                        desc="Test evaluation").attach(evaluator)

        tb_logger.attach(
            train_evaluator,
            log_handler=OutputHandler(
                tag="train",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer),
            ),
            event_name=Events.COMPLETED,
        )

        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="test",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer),
            ),
            event_name=Events.COMPLETED,
        )

        # Store the best model by validation accuracy:
        common.save_best_model_by_val_score(
            output_path,
            evaluator,
            model=model,
            metric_name="accuracy",
            n_saved=3,
            trainer=trainer,
            tag="test",
        )

        if config["log_model_grads_every"] is not None:
            tb_logger.attach(
                trainer,
                log_handler=GradsHistHandler(model,
                                             tag=model.__class__.__name__),
                event_name=Events.ITERATION_COMPLETED(
                    every=config["log_model_grads_every"]),
            )

    if config["crash_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"]))
        def _(engine):
            raise Exception("STOP at iteration: {}".format(
                engine.state.iteration))

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    if rank == 0:
        tb_logger.close()
Beispiel #11
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="gpt2",
                        help="Path, url or short name of the model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=1,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=1,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=2.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=1000,
                        help="Number of training epochs")
    parser.add_argument("--personality_permutations",
                        type=int,
                        default=1,
                        help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer_class = GPT2Tokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    model_class = GPT2DoubleHeadsModel
    model = model_class.from_pretrained(args.model_checkpoint)
    model.to(args.device)
    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
        (lm_loss), (mc_loss), *_ = model(input_ids,
                                         token_type_ids=token_type_ids,
                                         mc_token_ids=mc_token_ids,
                                         mc_labels=mc_labels,
                                         lm_labels=lm_labels)
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            lm_logits, mc_logits, *_ = model(
                input_ids,
                token_type_ids=token_type_ids,
                mc_token_ids=mc_token_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.model_checkpoint)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #12
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model",
                        type=str,
                        default="",
                        help="Model type, one of: %s" %
                        ', '.join(MODELS.keys()))
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="",
                        help="Path, url or short name of a pretrained model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--adv_coef",
                        type=float,
                        default=1.0,
                        help="Adversarial dataset prediction loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    #parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    #parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument(
        "--max_sequence_length",
        type=int,
        default=-1,
        help="If set, use this to manually restrict the sequence length. "
        "This might be helpful to save resources (memory). "
        "If not set, this is looked up from the model config (n_ctx value).")
    parser.add_argument(
        "--adversarial_dataset_prediction",
        action='store_true',
        help="Set to train with adversarial dataset prediction")
    parser.add_argument("--seed",
                        type=int,
                        default=None,
                        help='set random seed')
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    if args.seed is not None:
        torch.manual_seed(args.seed)

    args.distributed = (args.local_rank != -1)

    logger.info("Prepare tokenizer and data")
    if not args.model:
        logger.warning(
            '"model" parameter is not set! This is deprecated. Please use one of: %s. '
            'To mimic deprecated behaviour, "model_checkpoint" will be used as "model"'
            % ', '.join(MODELS.keys()))
        args.model = args.model_checkpoint
    if args.model not in MODELS:
        raise NotImplementedError(
            'model "%s" not implemented. use one of: %s' %
            (args.model, ', '.join(MODELS.keys())))
    config_class, tokenizer_class, model_class, _ = MODELS[args.model]
    if not args.model_checkpoint:
        args.model_checkpoint = args.model

    model_config = config_class.from_pretrained(args.model_checkpoint)
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    additional_special_tokens = [TYPE_BACKGROUND, TYPE_BOT, TYPE_USER]
    # for adversarial training (dataset prediction)
    dataset_labels = None
    if args.adversarial_dataset_prediction:
        dataset_labels = [
            get_dataset_label(dataset_path)
            for dataset_path in args.dataset_path.split(',')
        ]
        #additional_special_tokens.extend(dataset_labels)
        #if model_class not in ADV_MODELS.values():
        assert model_class in ADV_MODELS, f'no adversarial model implemented for model class: {model_class.__name__}'
        model_class = ADV_MODELS[model_class]
        if not hasattr(model_config, 'cls'):
            model_config.cls = {}
        if 'dataset_labels' in model_config.cls:
            assert all([dl in model_config.cls['dataset_labels']['labels'] for dl in dataset_labels]), \
                f'loaded dataset_labels [{model_config.cls["dataset_labels"]["labels"]}] do not contain all ' \
                f'current dataset_labels [{dataset_labels}]'
            dataset_labels = model_config.cls['dataset_labels']['labels']
        else:
            model_config.cls['dataset_labels'] = {
                'labels': dataset_labels,
                'is_adversarial': True
            }
        model_input_names = [
            "input_ids", "mc_token_ids", "lm_labels", "mc_labels",
            "dataset_labels", "token_type_ids"
        ]
        # not yet used
        model_output_names = [
            "lm_loss", "mc_loss", "cl_loss_0", "lm_logits", "mc_logits",
            "cl_logits_0", "presents"
        ]
    else:
        model_input_names = [
            "input_ids", "mc_token_ids", "lm_labels", "mc_labels",
            "token_type_ids"
        ]
        # not yet used
        model_output_names = [
            "lm_loss", "mc_loss", "lm_logits", "mc_logits", "presents"
        ]

    tokenizer.add_special_tokens({
        'bos_token':
        TYPE_BOS,
        'eos_token':
        TYPE_EOS,
        'pad_token':
        TYPE_PAD,
        'additional_special_tokens':
        additional_special_tokens
    })

    logger.info("Prepare datasets")
    max_sequence_length = model_config.n_ctx if args.max_sequence_length <= 0 else args.max_sequence_length
    assert max_sequence_length <= model_config.n_ctx, 'max_sequence_length [%i] was set to a value higher than ' \
                                                      'supported by the model (config.n_ctx [%i]). Please use a lower ' \
                                                      'value or do not set it [-1] to use the highest supported one.' \
                                                      % (max_sequence_length, model_config.n_ctx)
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args=args,
        tokenizer=tokenizer,
        model_input_names=model_input_names,
        max_sequence_length=max_sequence_length,
        dataset_labels=dataset_labels)

    logger.info(
        "Prepare pretrained model and optimizer - add special tokens for fine-tuning"
    )

    # Initialize distributed training if needed
    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Barrier to make sure only the first process in distributed training download model & vocab

    #model = model_class.from_pretrained(args.model_checkpoint, num_cl_labels=len(dataset_ids))    # for GPT2DoubleHeadsModelwithAdversarial
    model = model_class.from_pretrained(args.model_checkpoint,
                                        config=model_config)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # End of barrier to make sure only the first process in distributed training download model & vocab

    ####################################################################################################################

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    #optimizer = OpenAIAdam(model.parameters(), lr=args.lr)
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr)
    # scheduler is set below (see ignite)
    #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
    #                                            num_training_steps=len(train_loader) // args.train_batch_size + 1)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_checkpoint, 'optimizer.pt')) and os.path.isfile(
                os.path.join(args.model_checkpoint, 'scheduler.pt')):
        # Load in optimizer and scheduler states
        # TODO: this needs to be dumped somewhere
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_checkpoint, 'optimizer.pt')))
        #scheduler.load_state_dict(torch.load(os.path.join(args.model_checkpoint, 'scheduler.pt')))

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = {
            model_input_names[i]: input_tensor.to(args.device)
            for i, input_tensor in enumerate(batch)
        }
        model_output = model(**batch)
        losses = model_output[:
                              3] if args.adversarial_dataset_prediction else model_output[:
                                                                                          2]
        if args.n_gpu > 1:  # mean() to average on multi-gpu.
            losses = list(losses)
            for i in range(len(losses)):
                losses[i] = losses[i].mean()
        lm_loss, mc_loss = losses[0], losses[1]
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps

        # handle adversarial loss
        loss_wo_adv = loss.clone()
        if args.adversarial_dataset_prediction:
            adv_loss = model_output[2]
            loss += (adv_loss *
                     args.adv_coef) / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            #scheduler.step()  # Update learning rate schedule # already DONE below!
            optimizer.zero_grad()
        return loss_wo_adv.item(), loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            if args.adversarial_dataset_prediction:
                input_ids, mc_token_ids, lm_labels, mc_labels, dataset_labels, token_type_ids = batch
            else:
                input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch

            logger.debug(
                tokenizer.decode(input_ids[0, -1, :].tolist()).replace(
                    TYPE_PAD, ''))
            model_outputs = model(input_ids=input_ids,
                                  mc_token_ids=mc_token_ids,
                                  token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero (scheduler)
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
    if args.adversarial_dataset_prediction:
        RunningAverage(output_transform=lambda x: x[1]).attach(
            trainer, "loss_w/_adv")
        RunningAverage(output_transform=lambda x: x[1] - x[0]).attach(
            trainer, "loss_only_adv")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        if args.adversarial_dataset_prediction:
            tb_logger.attach(trainer,
                             log_handler=OutputHandler(
                                 tag="training", metric_names=["loss_w/_adv"]),
                             event_name=Events.ITERATION_COMPLETED)
            tb_logger.attach(trainer,
                             log_handler=OutputHandler(
                                 tag="training",
                                 metric_names=["loss_only_adv"]),
                             event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        logger.info('save checkpoints to: %s' % tb_logger.writer.log_dir)
        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(tb_logger.writer.log_dir)

        #logger.debug("Saving optimizer and scheduler states to %s", tb_logger.writer.log_dir)
        #torch.save(optimizer.state_dict(), os.path.join(tb_logger.writer.log_dir, 'optimizer.pt'))
        #torch.save(scheduler.state_dict(), os.path.join(tb_logger.writer.log_dir, 'scheduler.pt'))

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #13
0
def test_stopping_criterion_is_max_epochs():
    engine = Engine(MagicMock(return_value=1))
    max_epochs = 5
    state = engine.run([1], max_epochs=max_epochs)
    assert state.epoch == max_epochs
Beispiel #14
0
def test_default_exception_handler():
    update_function = MagicMock(side_effect=ValueError())
    engine = Engine(update_function)

    with raises(ValueError):
        engine.run([1])
Beispiel #15
0
def test_integration():

    n_iters = 100
    batch_size = 10
    n_classes = 10
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    loss_values = iter(range(n_iters))

    def update_fn(engine, batch):
        loss_value = next(loss_values)
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return loss_value, torch.from_numpy(y_pred_batch), torch.from_numpy(
            y_true_batch)

    trainer = Engine(update_fn)
    alpha = 0.98

    acc_metric = RunningAverage(
        Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha)
    acc_metric.attach(trainer, "running_avg_accuracy")

    avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha)
    avg_output.attach(trainer, "running_avg_output")

    running_avg_acc = [
        None,
    ]

    @trainer.on(Events.ITERATION_COMPLETED)
    def manual_running_avg_acc(engine):
        _, y_pred, y = engine.state.output
        indices = torch.max(y_pred, 1)[1]
        correct = torch.eq(indices, y).view(-1)
        num_correct = torch.sum(correct).item()
        num_examples = correct.shape[0]
        batch_acc = num_correct * 1.0 / num_examples
        if running_avg_acc[0] is None:
            running_avg_acc[0] = batch_acc
        else:
            running_avg_acc[0] = running_avg_acc[0] * alpha + (
                1.0 - alpha) * batch_acc
        engine.state.running_avg_acc = running_avg_acc[0]

    @trainer.on(Events.EPOCH_STARTED)
    def running_avg_output_init(engine):
        engine.state.running_avg_output = None

    @trainer.on(Events.ITERATION_COMPLETED)
    def running_avg_output_update(engine):
        if engine.state.running_avg_output is None:
            engine.state.running_avg_output = engine.state.output[0]
        else:
            engine.state.running_avg_output = (
                engine.state.running_avg_output * alpha +
                (1.0 - alpha) * engine.state.output[0])

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_acc_values(engine):
        assert engine.state.running_avg_acc == engine.state.metrics[
            "running_avg_accuracy"], "{} vs {}".format(
                engine.state.running_avg_acc,
                engine.state.metrics["running_avg_accuracy"])

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_output_values(engine):
        assert engine.state.running_avg_output == engine.state.metrics[
            "running_avg_output"], "{} vs {}".format(
                engine.state.running_avg_output,
                engine.state.metrics["running_avg_output"])

    np.random.seed(10)
    running_avg_acc = [
        None,
    ]
    n_iters = 10
    batch_size = 10
    n_classes = 10
    data = list(range(n_iters))
    loss_values = iter(range(n_iters))
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    trainer.run(data, max_epochs=1)

    running_avg_acc = [
        None,
    ]
    n_iters = 10
    batch_size = 10
    n_classes = 10
    data = list(range(n_iters))
    loss_values = iter(range(n_iters))
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    trainer.run(data, max_epochs=1)
Beispiel #16
0
def main(dataset, dataroot, z_dim, g_filters, d_filters, batch_size, epochs,
         learning_rate, beta_1, saved_G, saved_D, seed, n_workers, device,
         alpha, output_dir):

    # seed
    check_manual_seed(seed)

    # netowrks
    netG = Generator(z_dim, g_filters).to(device)
    netD = Discriminator(d_filters).to(device)

    # criterion
    bce = nn.BCELoss()

    # optimizers
    optimizerG = optim.Adam(netG.parameters(),
                            lr=learning_rate,
                            betas=(beta_1, 0.999))
    optimizerD = optim.Adam(netD.parameters(),
                            lr=learning_rate,
                            betas=(beta_1, 0.999))

    # data
    dataset = check_dataset(dataset, dataroot)
    loader = data.DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=n_workers,
                             drop_last=True)

    # load pre-trained models
    if saved_G:
        netG.load_state_dict(torch.load(saved_G))

    if saved_D:
        netD.load_state_dict(torch.load(saved_D))

    # misc
    real_labels = torch.ones(batch_size, device=device)
    fake_labels = torch.zeros(batch_size, device=device)
    fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)

    def get_noise():
        return torch.randn(batch_size, z_dim, 1, 1, device=device)

    # The main function, processing a batch of examples
    def step(engine, batch):

        # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
        real, _ = batch
        real = real.to(device)

        # -----------------------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        # train with real
        output = netD(real)
        errD_real = bce(output, real_labels)
        D_x = output.mean().item()

        errD_real.backward()

        # get fake image from generator
        noise = get_noise()
        fake = netG(noise)

        # train with fake
        output = netD(fake.detach())
        errD_fake = bce(output, fake_labels)
        D_G_z1 = output.mean().item()

        errD_fake.backward()

        # gradient update
        errD = errD_real + errD_fake
        optimizerD.step()

        # -----------------------------------------------------------
        # (2) Update G network: maximize log(D(G(z)))
        netG.zero_grad()

        # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
        output = netD(fake)
        errG = bce(output, real_labels)
        D_G_z2 = output.mean().item()

        errG.backward()

        # gradient update
        optimizerG.step()

        return {
            'errD': errD.item(),
            'errG': errG.item(),
            'D_x': D_x,
            'D_G_z1': D_G_z1,
            'D_G_z2': D_G_z2
        }

    # ignite objects
    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         CKPT_PREFIX,
                                         save_interval=1,
                                         n_saved=10,
                                         require_empty=False)
    timer = Timer(average=True)

    # attach running average metrics
    monitoring_metrics = ['errD', 'errG', 'D_x', 'D_G_z1', 'D_G_z2']
    for metric in monitoring_metrics:
        RunningAverage(alpha=alpha,
                       output_transform=lambda x: x[metric]).attach(
                           trainer, metric)

    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_logs(engine):
        if (engine.state.iteration - 1) % PRINT_FREQ == 0:
            fname = os.path.join(output_dir, LOGS_FNAME)
            columns = engine.state.metrics.keys()
            values = [
                str(round(value, 5))
                for value in engine.state.metrics.values()
            ]

            with open(fname, 'a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

            message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(
                epoch=engine.state.epoch,
                max_epoch=epochs,
                i=(engine.state.iteration % len(loader)),
                max_i=len(loader))
            for name, value in zip(columns, values):
                message += ' | {name}: {value}'.format(name=name, value=value)

            pbar.log_message(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = os.path.join(output_dir,
                            FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = os.path.join(output_dir,
                            REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # adding handlers using `trainer.add_event_handler` method API
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'netG': netG,
                                  'netD': netD
                              })

    # automatically adding handlers via a special `attach` method of `Timer` handler
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl
            mpl.use('agg')

            import numpy as np
            import pandas as pd
            import matplotlib.pyplot as plt

        except ImportError:
            warnings.warn(
                'Loss plots will not be generated -- pandas or matplotlib not found'
            )

        else:
            df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME),
                             delimiter='\t')
            x = np.arange(1, engine.state.iteration + 1, PRINT_FREQ)
            _ = df.plot(x=x, subplots=True, figsize=(20, 20))
            _ = plt.xlabel('Iteration number')
            fig = plt.gcf()
            path = os.path.join(output_dir, PLOT_FNAME)

            fig.savefig(path)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')

            create_plots(engine)
            checkpoint_handler(engine, {
                'netG_exception': netG,
                'netD_exception': netD
            })

        else:
            raise e

    # Setup is done. Now let's run the training
    trainer.run(loader, epochs)
Beispiel #17
0
class UnisstBaseExperiment(BaseExperiment):
    def __init__(
        self, 
        gen, dis_img, dis_vid, 
        corruption,
        train, test=None, val=None,
        optim_gen=None, optim_dis=None,
        sacred_run=None, writers=None, root=None, nepoch=10, niter=300, 
        display_frequency=1, num_dis_step:int = 1, device='cuda:0', fid_fvd: bool = True,
        colorize=False, **kwargs):
        super().__init__(**kwargs)

        self.train = train
        self.test = test
        self.val = val
        self.sacred_run = sacred_run
        self.sacred_run.result = float('Inf')
        self.device = device
        self.nepoch = nepoch
        self.display_frequency = display_frequency
        self.colorize = colorize

        if isinstance(niter, str) and niter.find('epoch') > 0:
            nepoch = int(niter.split(' ')[0])
            niter = nepoch * len(train)
        self.niter = niter

        self.fid_fvd = fid_fvd
        if root is not None:
            self.basedir = os.path.join(root, str(sacred_run._id))
        else:
            writers = None
            checkpoint = None

        if writers is not None:
            self.writers = init_writers(*writers, sacred_run=sacred_run, dirname=self.basedir)
        else:
            self.writers = None

        if checkpoint is not None:
            self.checkpoint = init_checkpoint_handler(dirname=self.basedir, **checkpoint)

        self.trainer = Engine(self.train_step)
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self.evaluate)
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self.log)

        self.tester = Engine(self.test_step)
        self.evaluator = Engine(self.test_step)

        self.gen = gen.to(self.device)
        self.dis_img = dis_img.to(self.device)
        self.dis_vid = dis_vid.to(self.device)

        self.optim_gen = optim_gen
        self.optim_dis = optim_dis

        self.scheduler_gen = ExponentialLR(self.optim_gen, gamma=0.99)
        self.scheduler_dis = ExponentialLR(self.optim_dis, gamma=0.99)

        self.corruption = corruption

        self.num_dis_step = num_dis_step

        RunningAverage(alpha=0.9, output_transform=lambda x: x['loss_gen'].item()).attach(self.trainer, 'loss_gen')
        RunningAverage(alpha=0.9, output_transform=lambda x: x['loss_dis_img'].item()).attach(self.trainer, 'loss_dis_img')
        RunningAverage(alpha=0.9, output_transform=lambda x: x['loss_dis_vid'].item()).attach(self.trainer, 'loss_dis_vid')

        if self.fid_fvd:
            self.dims = 2048
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
            self.model = InceptionV3([block_idx]).to('cuda')
            self.fid_score = float('inf')

    def train_step(self, engine, batch):
        self.training()
        batch = convert_tensor(batch, self.device)
        loss_gen, loss_dis_img, loss_dis_vid, output = self.forward_backward(**batch)
        metric = self.metric(**output, **batch)

        loss = {
            'loss_gen': loss_gen,
            'loss_dis_img': loss_dis_img,
            'loss_dis_vid': loss_dis_vid,
        }
        return {
            **batch,
            **output,
            **loss,
            **metric,
        }

    def test_step(self, engine, batch):
        self.evaluating()
        with torch.no_grad():
            batch = convert_tensor(batch, self.device)
            _, _, _, output = self.forward_backward(backward=False, **batch)
            metric = self.metric(**output, **batch)
            
            return {
                **batch,
                **output,
                **metric,
            }

    def evaluate(self, engine):
        iteration = engine.state.iteration
        if iteration % self.niter == 0:
            self.step(engine, iteration)
            
            if self.val is not None:
                self.evaluator.run(self.val, max_epochs=1)
                self.step(self.evaluator, iteration, dataset_name='val')
                columns = self.evaluator.state.metrics.keys()
                values = [value for value in self.evaluator.state.metrics.values()]
                message = 'Val: '
                for name, value in zip(columns, values):
                    message += ' | {name}: {value:.4f}, {std:.4f}'.format(name=name, value=statistics.mean(value), std=statistics.stdev(value))
                print(message)

            if self.test is not None:
                self.tester.run(self.test, max_epochs=1)
                self.step(self.tester, iteration, dataset_name='test')
                columns = self.tester.state.metrics.keys()
                values = [value for value in self.tester.state.metrics.values()]
                message = 'Test: '
                for name, value in zip(columns, values):
                    message += ' | {name}: {value:.4f}, {std:.4f}'.format(name=name, value=statistics.mean(value), std=statistics.stdev(value))
                print(message)

    def log(self, engine):
        iter = engine.state.iteration
        if iter % self.display_frequency == 0:
            columns = engine.state.metrics.keys()
            values = [value for value in engine.state.metrics.values()]
            message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(epoch=engine.state.epoch,
                                                                  max_epoch=self.nepoch,
                                                                  i=(engine.state.iteration % len(self.train)),
                                                                  max_i=len(self.train))
            for name, value in zip(columns, values):
                message += ' | {name}: {value:.4f}'.format(name=name, value=value)
            print(message)

    def write(self, engine, dataset_name):
        iteration = self.trainer.state.iteration

        # Logging Images
        o = engine.state.output
        b = engine.state.batch

        img_tensor, nrow = self.get_tensor(o, b)

        img = make_grid(
            img_tensor,
                nrow=nrow, normalize=True, range=(-1, 1)
        )

        try:
            self.writers.add_image(dataset_name, img, iteration)
        except:
            print('IMPOSSIBLE TO SAVE')

    def forward(self, **kwargs):
        raise NotImplementedError

    def backward_gen(self, **kwargs):
        raise NotImplementedError

    def backward_dis(self, **kwargs):
        raise NotImplementedError

    def get_tensor(self, o, b, limit=2):

        batch_size, nc, seq_len, height, width = b['x'].shape
        if batch_size < limit:
            limit = batch_size

        x       = b['x'][:limit].cpu().permute(0, 2, 1, 3, 4).contiguous().view(-1, nc, height, width)
        y       = o['y'][:limit].cpu().permute(0, 2, 1, 3, 4).contiguous().view(-1, nc, height, width)
        x_hat   = o['x_hat'][:limit].cpu().permute(0, 2, 1, 3, 4).contiguous().view(-1, nc, height, width)

        if self.colorize:
            mask  = b['mask'][:limit].cpu().permute(0, 2, 1, 3, 4).contiguous().view(-1, nc, height, width)
            
            x = colorize(x)
            y = colorize(y)
            x_hat = colorize(x_hat)

            y[mask.expand_as(y)] = 1

        list_tensor = [x, y, x_hat]
        nrow = seq_len
        return torch.cat(list_tensor), nrow

    def step(self, engine, iteration, dataset_name='train'):
        values = {c: value for value, c in zip(engine.state.metrics.values(), engine.state.metrics.keys())}

        if dataset_name == 'train':
            self.scheduler_gen.step()
            self.scheduler_dis.step()

            if values['loss_dis_img'] + values['loss_dis_vid'] < 0.001:
                raise CustomInterrupt('DIS_TOO_SMALL')
            if values['loss_gen'] < 0.001:
                raise CustomInterrupt('GEN_TOO_SMALL')
            if values['recon_mae'] > 1.9:
                raise CustomInterrupt('RECON_TOO_HIGH')

        metrics = engine.state.metrics
        if self.writers is not None:
            for name, value in metrics.items():
                metric_name = dataset_name + '/' + name
                if dataset_name in ['test', 'val']:
                    m = statistics.mean(value)
                    s = statistics.stdev(value)
                    self.writers.add_scalar(metric_name, m, iteration)
                    self.writers.add_scalar(metric_name + '_std', s, iteration)
                else:
                    self.writers.add_scalar(metric_name, value, iteration)

        print(f"saving {iteration}")
        self.write(engine, dataset_name)
        if iteration % 2 * self.niter == 0 and self.fid_fvd:
            self.compute_fid(iteration, dataset_name)
            self.compute_fvd(iteration, dataset_name)


    def compute_fvd(self, iteration, dataset_name):
        self.evaluating()
        fake_list, real_list = [], []

        if dataset_name == 'train':
            dataset = self.train
        elif dataset_name == 'val':
            dataset = self.val
        else:
            dataset = self.test

        with torch.no_grad():
            for i, batch in enumerate(dataset):
                batch = convert_tensor(batch, self.device)

                output = self.forward_backward(**batch, backward=False)[-1]
                real_seq_len = batch['seq_len']
                batch_size, nc, _, _, _ = batch['x'].shape
                x_hat        = output['x_hat']
                x            = batch['x']
                # B x C x T x H x W
                if nc != 3:
                    fake = x_hat.repeat(1, 3, 1, 1, 1)
                    true = x.repeat(1, 3, 1, 1, 1)

                fake_list.append(fake.cpu())
                real_list.append(true.cpu())
                if i == 15:
                    break

        fake_vid = torch.cat(fake_list, dim=0)
        real_vid = torch.cat(real_list, dim=0)
        
        fvd_score = fvd(real_vid, fake_vid)

        print(f"FVD_{dataset_name} : {fvd_score}")
        if self.writers is not None:
            self.writers.add_scalar(f'FVD_{dataset_name}', fvd_score, iteration)

    def compute_fid(self, iteration, dataset_name):
        self.evaluating()
        fake_list, real_list = [], []

        if dataset_name == 'train':
            dataset = self.train
        elif dataset_name == 'val':
            dataset = self.val
        else:
            dataset = self.test

        with torch.no_grad():
            for i, batch in enumerate(dataset):
                batch = convert_tensor(batch, self.device)

                output = self.forward_backward(**batch, backward=False)[-1]
                real_seq_len = batch['seq_len']
                batch_size, nc, _, _, _ = batch['x'].shape
                x_hat        = output['x_hat']
                x            = batch['x']
                
                fake = []
                true = []
                for bi in range(batch_size):
                    fake.append(torch.narrow(x_hat[bi], 1, 0, real_seq_len[bi]).permute(1, 0, 2, 3))
                    true.append(torch.narrow(    x[bi], 1, 0, real_seq_len[bi]).permute(1, 0, 2, 3))

                fake = torch.cat(fake, dim=0)
                true = torch.cat(true, dim=0)

                if nc != 3:
                    fake = fake.repeat(1, 3, 1, 1)
                    true = true.repeat(1, 3, 1, 1)

                fake_list.append((fake.cpu().numpy() + 1.0) / 2.0)
                real_list.append((true.cpu().numpy() + 1.0) / 2.0)

                if i == 15:
                    break

        fake_images = np.concatenate(fake_list)
        real_images = np.concatenate(real_list)
        mu_fake, sigma_fake = metrics.calculate_activation_statistics(
            fake_images, self.model, self.train.batch_size, device=self.device
        )
        mu_real, sigma_real = metrics.calculate_activation_statistics(
            real_images, self.model, self.train.batch_size, device=self.device
        )
        fid_score = metrics.calculate_frechet_distance(
            mu_fake, sigma_fake, mu_real, sigma_real
        )
        print(f"FID_{dataset_name} : {fid_score}")
        if self.writers is not None:
            self.writers.add_scalar(f'FID_{dataset_name}', fid_score, iteration)

    def run(self):
        self.trainer.run(self.train, max_epochs=self.nepoch)
Beispiel #18
0
    @engine.on(PeriodEvents.ITERS_10000_COMPLETED)
    def test_network(engine: Engine):
        dqn_model.train(False)
        test_reward: float
        test_steps: float
        test_reward, test_steps, test_deers = test_model(
            dqn_model, device, configuration)
        dqn_model.train(True)

        engine.state.metrics[TEST_REWARD_METRIC] = test_reward
        engine.state.metrics[TEST_STEPS_METRTIC] = test_steps
        engine.state.metrics[TEST_DEERS_METRIC] = test_deers

        print(
            "Test done: got %.3f reward after %.2f steps. Deer survival %.3f "
            % (test_reward, test_steps, test_deers))

        global best_test_reward
        if best_test_reward is None:
            best_test_reward = test_reward
        elif best_test_reward < test_reward:
            print("Best test reward updated %.3f <- %.3f, save model" %
                  (best_test_reward, test_reward))
            best_test_reward = test_reward
            torch.save(dqn_model.state_dict(),
                       os.path.join(saves_path, "best_%.3f.dat" % test_reward))

    engine.run(
        batch_generator(replay_buffer, PARAMETERS.replay_initial,
                        PARAMETERS.batch_size))
    best_test_reward = None

    @engine.on(ptan_ignite.PeriodEvents.ITERS_1000_COMPLETED)
    def test_network(engine):
        net.train(False)
        a_reward, a_steps, b_reward, b_steps = test_model(net, device, config)
        net.train(True)
        engine.state.metrics['test_reward_a'] = a_reward
        engine.state.metrics['test_steps_a'] = a_steps
        engine.state.metrics['test_reward_b'] = b_reward
        engine.state.metrics['test_steps_b'] = b_steps
        print(
            "Test done: A got %.3f reward after %.2f steps, B %.3f reward after %.2f steps"
            % (a_reward, a_steps, b_reward, b_steps))

        global best_test_reward
        reward = max(a_reward, b_reward)

        if best_test_reward is None:
            best_test_reward = reward
        elif best_test_reward < reward:
            print("Best test reward updated %.3f <- %.3f, save model" %
                  (best_test_reward, reward))
            best_test_reward = reward
            torch.save(net.state_dict(),
                       os.path.join(saves_path, "best_%.3f.dat" % reward))

    engine.run(
        batch_generator(a_exp_source, b_exp_source, buffer,
                        PARAMS.replay_initial, PARAMS.batch_size))
class Evaluator:
    """
    Class which setups the evaluation logic which mainly involves defining callback handlers and attaching them to
    the evaluation loop.
    """
    def __init__(self, model, config, data_loaders, tb_writer, run_info,
                 logger, checkpoint_dir):
        """
        Creates a new evaluator object for evaluating a model.
        :param model: model to train. Needs to inherit from the BaseModel class.
        :param config: dictionary containing the whole configuration of the experiment
        :param data_loaders: (dictionary) the keys represent the name and each value contains
         a pytorch data loader providing the validation data
        :param tb_writer: tensorboardX summary writer
        :param run_info: sacred run info for loging training progress
        :param logger: python logger object
        :param checkpoint_dir: directory path for storing checkpoints
        """
        self.run_info = run_info
        self.logger = logger
        self.data_loaders = data_loaders
        self.config = config
        self.engine = Engine(self._step)
        self.model = model
        self.tb_writer = tb_writer
        self.trainer = None

        # Using custom metric wrapper which retrieves metrics from dictionary instead of separately calculating them.
        self.metrics = {k: LossFromDict(k) for k in self.model.metric_names}
        self.non_scalar_metrics = {
            k: LossFromDict(k, reduce=False)
            for k in self.model.non_scalar_metrics_names
        }

        if 'external_metrics' in config['val_data']:
            for idx, name in enumerate(config['val_data']['external_metrics']):
                if 'external_metrics_kw_args' in config['val_data']:
                    self.metrics[name] = get_subclass(name, Metric)(
                        config['devices'][0],
                        **config['val_data']['external_metrics_kw_args'][idx])
                else:
                    self.metrics[name] = get_subclass(name, Metric)()

        self._handle_save_best_checkpoint_handler = \
            ModelCheckpoint(checkpoint_dir, 'best',
                            score_function=lambda engine: -self.model.main_metric(engine.state.metrics),
                            score_name=self.model.name_main_metric,
                            n_saved=1,
                            require_empty=False)

        self.add_handler()
        self.best_loss = None
        self.current_data_loader = None
        self.main_data_loader = config['val_data']['main_dataset']

    def run(self):
        """
        Start the evaluation run which will run through one epoch for each validation dataset
        :return:
        """
        for name, data_loader in self.data_loaders.items():
            self.current_data_loader = name
            self.engine.run(data_loader)

    def set_trainer(self, trainer):
        """
        Setter method for setting the trainer object which is mainly needed for getting information on the current
        training iteration.
        :param trainer: Trainer object
        :return:
        """
        self.trainer = trainer

    def add_handler(self):
        """
        Adds all the callback handlers to the trainer engine. Should be called in the end of the init.
        :return:
        """
        for name, metric in self.metrics.items():
            metric.attach(self.engine, name)

        for name, non_scalar_metric in self.non_scalar_metrics.items():
            non_scalar_metric.attach(self.engine, name)

        # on epoch complete
        self.engine.add_event_handler(
            Events.EPOCH_COMPLETED, self._handle_save_best_checkpoint_handler,
            self.model.networks)
        self.engine.add_event_handler(Events.EPOCH_COMPLETED,
                                      self._handle_log_validation_results)

        # on iteration complete
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_log_val_images)

    def _step(self, engine, batch):
        """
        Definition of a single evaluation step. This function gets automatically called by the engine every iteration.
        :param engine: evaluator engine
        :param batch: one batch provided by the data loader
        :return:
        """
        self.model.eval()
        self.model.set_input(batch)
        self.model.test()
        return self.model.state

    def _handle_log_validation_results(self, engine):
        """
        Handler for writing the losses to tensorboard and sacred.
        :param engine: evaluation engine
        :return:
        """
        metrics = self.engine.state.metrics

        loss = self.model.main_metric(metrics)
        metrics[self.model.name_main_metric] = loss

        for name, m in metrics.items():
            if 'non_scalar_metric_' not in name:  # Only add scalars
                # log to sacred
                self.run_info.log_scalar(
                    f"val_{self.current_data_loader}.{name}.", m,
                    self.trainer.engine.state.iteration)
                self.tb_writer.add_scalar(
                    f"val_{self.current_data_loader}/{name}.", m,
                    self.trainer.engine.state.iteration)

        self.logger.info(
            "Validation Results for {} - Epoch: {}  Avg loss: {:.6f}".format(
                self.current_data_loader, self.trainer.engine.state.epoch,
                loss))
        if self.current_data_loader == self.main_data_loader and \
                (self.best_loss is None or loss < self.best_loss):
            self.best_loss = loss
        self.run_info.result = self.best_loss

        self._handle_complete_val_dataset_figure(engine)

    def _handle_log_val_images(self, engine):
        """
        Handler for writing visual samples to tensorboard.
        :param engine: evaluation engine
        :return:
        """
        if engine.state.iteration == 1:
            for name, visual in self.model.visuals.items():
                self.tb_writer.add_image(
                    f"val_{self.current_data_loader}/{name}.",
                    visual.transpose(2, 0, 1),
                    self.trainer.engine.state.iteration)

    def _score_function(self, engine):
        """
        Helper method use in ModelCheckpoint to save the best model. Need to change the sign because it saves the
        ModelCheckpoint saves the best scores.
        :param engine: evaluation engine
        :return:
        """
        val_loss = engine.state.metrics[self.model.name_main_metric]
        return -val_loss

    def _handle_complete_val_dataset_figure(self, engine):
        """
        Adds complete validation dataset metric figure to tensorboard.
        :param engine: evaluation engine
        :return:
        """
        figures = self.model.get_validation_figures(engine.state)
        for name, figure in figures.items():
            self.tb_writer.add_figure(
                f"val_{self.current_data_loader}_metrics/{name}", figure,
                self.trainer.engine.state.iteration)
Beispiel #21
0
    def _run(self, tempdir):
        my_rank = dist.get_rank()
        fnames = ["aaa" * 300, "bbb" * 301, "ccc" * 302]

        metrics_saver = MetricsSaver(
            save_dir=tempdir,
            metrics=["metric1", "metric2"],
            metric_details=["metric3", "metric4"],
            batch_transform=lambda x: x[PostFix.meta("image")],
            summary_ops="*",
            delimiter="\t",
        )

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)

        if my_rank == 0:
            data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}]

            @engine.on(Events.EPOCH_COMPLETED)
            def _save_metrics0(engine):
                engine.state.metrics = {"metric1": 1, "metric2": 2}
                engine.state.metric_details = {
                    "metric3": torch.tensor([[1, 2]]),
                    "metric4": torch.tensor([[5, 6]])
                }

        if my_rank == 1:
            # different ranks have different data length
            data = [
                {
                    PostFix.meta("image"): {
                        "filename_or_obj": [fnames[1]]
                    }
                },
                {
                    PostFix.meta("image"): {
                        "filename_or_obj": [fnames[2]]
                    }
                },
            ]

            @engine.on(Events.EPOCH_COMPLETED)
            def _save_metrics1(engine):
                engine.state.metrics = {"metric1": 1, "metric2": 2}
                engine.state.metric_details = {
                    "metric3": torch.tensor([[2, 3], [3, 4]]),
                    "metric4": torch.tensor([[6, 7], [7, 8]]),
                }

        @engine.on(Events.EPOCH_COMPLETED)
        def _all_gather(engine):
            scores = engine.state.metric_details["metric3"]
            engine.state.metric_details[
                "metric3"] = evenly_divisible_all_gather(data=scores,
                                                         concat=True)
            scores = engine.state.metric_details["metric4"]
            engine.state.metric_details[
                "metric4"] = evenly_divisible_all_gather(data=scores,
                                                         concat=True)

        metrics_saver.attach(engine)
        engine.run(data, max_epochs=1)

        if my_rank == 0:
            # check the metrics.csv and content
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metrics.csv")))
            with open(os.path.join(tempdir, "metrics.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_raw.csv")))
            # check the metric_raw.csv and content
            with open(os.path.join(tempdir, "metric3_raw.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i > 0:
                        expected = [
                            f"{fnames[i-1]}\t{float(i):.4f}\t{float(i + 1):.4f}\t{i + 0.5:.4f}"
                        ]
                        self.assertEqual(row, expected)
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric3_summary.csv")))
            # check the metric_summary.csv and content
            with open(os.path.join(tempdir, "metric3_summary.csv")) as f:
                f_csv = csv.reader(f)
                for i, row in enumerate(f_csv):
                    if i == 1:
                        self.assertEqual(row, [
                            "class0\t2.0000\t2.0000\t3.0000\t1.0000\t2.8000\t0.8165\t3.0000"
                        ])
                    elif i == 2:
                        self.assertEqual(row, [
                            "class1\t3.0000\t3.0000\t4.0000\t2.0000\t3.8000\t0.8165\t3.0000"
                        ])
                    elif i == 3:
                        self.assertEqual(row, [
                            "mean\t2.5000\t2.5000\t3.5000\t1.5000\t3.3000\t0.8165\t3.0000"
                        ])
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_raw.csv")))
            self.assertTrue(
                os.path.exists(os.path.join(tempdir, "metric4_summary.csv")))
        dist.barrier()
class Trainer:
    """
    Class which setups the training logic which mainly involves defining callback handlers and attaching them to
    the training loop.
    """
    def __init__(self, model, config, evaluator, data_loader, tb_writer,
                 run_info, logger, checkpoint_dir):
        """
        Creates a new trainer object for training a model.
        :param model: model to train. Needs to inherit from the BaseModel class.
        :param config: dictionary containing the whole configuration of the experiment
        :param evaluator: Instance of the evaluator class, used to run evaluation on a specified schedule
        :param data_loader: pytorch data loader providing the training data
        :param tb_writer: tensorboardX summary writer
        :param run_info: sacred run info for loging training progress
        :param logger: python logger object
        :param checkpoint_dir: directory path for storing checkpoints
        """
        self.run_info = run_info
        self.logger = logger
        self.data_loader = data_loader
        self.evaluator = evaluator
        self.engine = Engine(self._step)
        self.model = model
        self.config = config
        self.train_cfg = config['train']
        self.tb_writer = tb_writer

        self.pbar = ProgressBar(ascii=True, desc='* Epoch')
        self.timer = Timer(average=True)
        self.save_last_checkpoint_handler = ModelCheckpoint(
            checkpoint_dir,
            'last',
            save_interval=self.train_cfg['save_interval'],
            n_saved=self.train_cfg['save_n_last'],
            require_empty=False)

        self.add_handler()

    def run(self):
        """
        Start the training loop which will run until all epochs are complete
        :return:
        """
        self.engine.run(self.data_loader,
                        max_epochs=self.train_cfg['n_epochs'])

    def add_handler(self):
        """
        Adds all the callback handlers to the trainer engine. Should be called in the end of the init.
        :return:
        """
        # Learning rate decay
        for lr_s in self.model.schedulers:
            self.engine.add_event_handler(Events.ITERATION_STARTED, lr_s)

        # Checkpoint saving
        self.engine.add_event_handler(Events.EPOCH_STARTED,
                                      self.save_last_checkpoint_handler,
                                      self.model.networks)

        # Progbar
        monitoring_metrics = self.model.metric_names
        for mm in monitoring_metrics:
            RunningAverage(output_transform=self._extract_loss(mm)).attach(
                self.engine, mm)
        self.pbar.attach(self.engine, metric_names=monitoring_metrics)

        # Timer
        self.timer.attach(self.engine,
                          start=Events.EPOCH_STARTED,
                          resume=Events.ITERATION_STARTED,
                          pause=Events.ITERATION_COMPLETED,
                          step=Events.ITERATION_COMPLETED)

        # Logging
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_log_train_results)
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_log_train_images)
        self.engine.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._handle_run_evaluation)
        self.engine.add_event_handler(Events.EPOCH_COMPLETED,
                                      self._handle_print_times)

        # Exception handling
        self.engine.add_event_handler(Events.EXCEPTION_RAISED,
                                      self._handle_exception)

    def _step(self, engine, batch):
        """
        Definition of a single training step. This function gets automatically called by the engine every iteration.
        :param engine: trainer engine
        :param batch: one batch provided by the dataloader
        :return:
        """
        self.model.train()
        self.model.set_input(batch)
        self.model.optimize_parameters()
        return self.model.state

    def _handle_log_train_results(self, engine):
        """
        Handler for writing the losses to tensorboard and sacred.
        :param engine: train engine
        :return:
        """
        if (engine.state.iteration - 1) % self.train_cfg['log_interval'] == 0:
            metrics = engine.state.metrics  # does not include non scalar metrics, since loggers can not handle this

            for m_name, m_val in metrics.items():
                if m_val is None:
                    raise ValueError(f'Value for {m_name} is None')
                self.run_info.log_scalar("train.%s" % m_name, m_val,
                                         engine.state.iteration)
                self.tb_writer.add_scalar("train/%s" % m_name, m_val,
                                          engine.state.iteration)

            for lr_name, lr_val in self.model.learning_rates.items():
                if lr_val is None:
                    raise ValueError(f'Value for {lr_name} is None')
                self.run_info.log_scalar("train.%s" % lr_name, lr_val,
                                         engine.state.iteration)
                self.tb_writer.add_scalar("train/%s" % lr_name, lr_val,
                                          engine.state.iteration)

    def _handle_log_train_images(self, engine):
        """
        Handler for writing visual samples to tensorboard.
        :param engine: train engine
        :return:
        """
        if (engine.state.iteration -
                1) % self.train_cfg['img_log_interval'] == 0:
            for name, visual in self.model.visuals.items():
                # TODO remove the visual.transpose here and put it in the visualization function of the models
                self.tb_writer.add_image('train/%s' % name,
                                         visual.transpose(2, 0, 1),
                                         engine.state.iteration)
            for name, figure in self.model.figures.items():
                self.tb_writer.add_figure('train_metrics/%s' % name, figure,
                                          engine.state.iteration)

    def _handle_run_evaluation(self, engine):
        """
        Handler which will execute evaluation by running the evaluator object.
        :param engine: train engine
        :return:
        """
        if (engine.state.iteration - 1) % self.train_cfg['eval_interval'] == 0:
            self.evaluator.run()

    def _handle_exception(self, engine, e):
        """
        Exception handler which ensures that the model gets saved when stopped through a keyboard interruption.
        :param engine: train engine
        :param e: the exception which caused the training to stop
        :return:
        """
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            self.logger.warning(
                'KeyboardInterrupt caught. Exiting gracefully.')
            self.save_last_checkpoint_handler(engine, self.model.networks)
        else:
            raise e

    def _handle_print_times(self, engine):
        """
        Handler for logging timer information for different training and evaluation steps.
        :param engine: train engine
        :return:
        """
        self.logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, self.timer.value()))
        self.timer.reset()

    @staticmethod
    def _extract_loss(key):
        """
        Helper method to return losses for the RunningAverage
        :param key: (str) loss name
        :return: (fn) for the corresponding key
        """
        def _func(losses):
            return losses[key]

        return _func
def test_linear_scheduler():

    with pytest.raises(TypeError, match=r"Argument optimizer should be torch.optim.Optimizer"):
        LinearCyclicalScheduler({}, "lr", 1, 0, cycle_size=0)

    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0.0)

    with pytest.raises(ValueError, match=r"Argument cycle_size should be positive and larger than 1"):
        LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=0)

    with pytest.raises(ValueError, match=r"Argument cycle_size should be positive and larger than 1"):
        LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=1)

    scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10)
    state_dict = scheduler.state_dict()

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]["lr"])

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    for _ in range(2):
        lrs = []
        trainer.run([0] * 9, max_epochs=2)

        assert lrs == list(
            map(
                pytest.approx,
                [
                    # Cycle 1
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.0,
                    0.2,
                    0.4,
                    0.6,
                    0.8,
                    # Cycle 2
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.0,
                    0.2,
                    0.4,  # 0.6, 0.8,
                ],
            )
        )
        scheduler.load_state_dict(state_dict)

    optimizer = torch.optim.SGD([tensor], lr=0)
    scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, cycle_mult=2)
    state_dict = scheduler.state_dict()

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    for _ in range(2):
        lrs = []
        trainer.run([0] * 10, max_epochs=3)

        assert lrs == list(
            map(
                pytest.approx,
                [
                    # Cycle 1
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.0,
                    0.2,
                    0.4,
                    0.6,
                    0.8,
                    # Cycle 2
                    1.0,
                    0.9,
                    0.8,
                    0.7,
                    0.6,
                    0.5,
                    0.4,
                    0.3,
                    0.2,
                    0.1,
                    0.0,
                    0.1,
                    0.2,
                    0.3,
                    0.4,
                    0.5,
                    0.6,
                    0.7,
                    0.8,
                    0.9,
                ],
            )
        )
        scheduler.load_state_dict(state_dict)

    # With float cycle_size
    optimizer = torch.optim.SGD([tensor], lr=0)
    scheduler = LinearCyclicalScheduler(
        optimizer, "lr", start_value=1.2, end_value=0.2, cycle_size=10.00000012, cycle_mult=1.0
    )
    state_dict = scheduler.state_dict()

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    for _ in range(2):
        lrs = []
        trainer.run([0] * 9, max_epochs=2)
        assert lrs == list(
            map(
                pytest.approx,
                [
                    # Cycle 1
                    1.2,
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.4,
                    0.6,
                    0.8,
                    1.0,
                    # Cycle 2
                    1.2,
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.4,
                    0.6,  # 0.8, 1.0,
                ],
            )
        )
        scheduler.load_state_dict(state_dict)
Beispiel #24
0
def adv_train_loop(model,
                   params,
                   ds,
                   min_y,
                   base_data,
                   model_id,
                   attack_type,
                   device,
                   batch_size,
                   max_epochs=5):
    print('training adversarial:', attack_type)
    ds_train, ds_valid = ds
    min_y_train, min_y_val = min_y
    original_model = copy.deepcopy(
        model)  # used to generate adv images for the trained model
    original_model.eval()
    model = copy.deepcopy(
        model)  # making a copy so that original model is not changed
    model = model.to(device)
    model_id = f'{model_id}_{attack_type}'

    with create_summary_writer(model,
                               ds_train,
                               base_data,
                               model_id,
                               device=device) as writer:
        lr = params['lr']
        mom = params['momentum']
        wd = params['l2_wd']
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    momentum=mom,
                                    weight_decay=wd)
        sched = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
        funcs = {'accuracy': Accuracy(), 'loss': Loss(F.cross_entropy)}
        loss = funcs['loss']._loss_fn

        acc_metric = Accuracy(device=device)
        loss_metric = Loss(F.cross_entropy, device=device)

        acc_val_metric = Accuracy(device=device)
        loss_val_metric = Loss(F.cross_entropy, device=device)

        classifier = PyTorchClassifier(
            model=original_model,
            clip_values=(0, 1),
            loss=nn.CrossEntropyLoss(),
            optimizer=optimizer,
            input_shape=(3, 64, 64),
            nb_classes=200,
        )

        attack = None

        #         if attack_type == "fgsm":
        #             attack = FastGradientMethod(estimator=classifier, eps=0.2)
        #         elif attack_type == "bim":
        #             attack = BasicIterativeMethod(estimator=classifier, eps=0.2)
        #         elif attack_type == "carlini":
        #             attack = CarliniLInfMethod(classifier=classifier)
        #         elif attack_type == "deepfool":
        #             attack = DeepFool(classifier=classifier)
        if attack_type == "fgsm":
            attack = GradientSignAttack(model, loss_fn=loss, eps=0.2)
        elif attack_type == "ffa":
            attack = FastFeatureAttack(model, loss_fn=loss, eps=0.3)
        elif attack_type == "carlini":
            attack = CarliniWagnerL2Attack(model, 200, max_iterations=1000)
        elif attack_type == "lbfgs":
            attack = DeepFool(classifier=classifier)

        def train_step(engine, batch):
            model.train()
            x, y = batch
            x = x.to(device)
            y = y.to(device) - min_y_train
            with ctx_noparamgrad_and_eval(model):
                x_adv = attack.perturb(x, y)
            optimizer.zero_grad()
            x = torch.cat((x, x_adv))
            y = torch.cat((y, y))
            ans = model.forward(x)
            l = loss(ans, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            #             return ans, y
            return l.item()

        trainer = Engine(train_step)

        #         acc_metric.attach(trainer, "accuracy")
        #         loss_metric.attach(trainer, 'loss')

        def train_eval_step(engine, batch):
            model.eval()
            x, y = batch
            x = x.to(device)
            y = y.to(device) - min_y_train
            x_adv = attack.perturb(x, y)
            x = torch.cat((x, x_adv))
            y = torch.cat((y, y))
            with torch.no_grad():
                ans = model.forward(x)
            return ans, y

        train_evaluator = Engine(train_eval_step)
        acc_metric.attach(train_evaluator, "accuracy")
        loss_metric.attach(train_evaluator, 'loss')

        def validation_step(engine, batch):
            model.eval()
            x, y = batch
            x = x.to(device)
            y = y.to(device) - min_y_train
            x_adv = attack.perturb(x, y)
            x = torch.cat((x, x_adv))
            y = torch.cat((y, y))
            with torch.no_grad():
                ans = model.forward(x)
            return ans, y

        valid_evaluator = Engine(validation_step)
        acc_val_metric.attach(valid_evaluator, "accuracy")
        loss_val_metric.attach(valid_evaluator, 'loss')

        @trainer.on(
            Events.ITERATION_COMPLETED(every=200 * 5000 // batch_size // 10))
        def log_validation_results(engine):
            valid_evaluator.run(ds_valid)
            metrics = valid_evaluator.state.metrics
            valid_avg_accuracy = metrics['accuracy']
            avg_nll = metrics['loss']
            print(
                "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
                .format(engine.state.epoch, valid_avg_accuracy, avg_nll))
            writer.add_scalar("validation/avg_loss", avg_nll,
                              engine.state.epoch)
            writer.add_scalar("validation/avg_accuracy", valid_avg_accuracy,
                              engine.state.epoch)
            writer.add_scalar("validation/avg_error", 1. - valid_avg_accuracy,
                              engine.state.epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def lr_scheduler(engine):
            metrics = valid_evaluator.state.metrics
            avg_nll = metrics['accuracy']
            sched.step(avg_nll)

        @trainer.on(Events.ITERATION_COMPLETED(every=50))
        def log_training_loss(engine):
            batch = engine.state.batch
            ds = DataLoader(TensorDataset(*batch), batch_size=batch_size)
            train_evaluator.run(ds)
            metrics = train_evaluator.state.metrics
            # metrics = engine.state.metrics
            accuracy = metrics['accuracy']
            nll = metrics['loss']
            iter = (engine.state.iteration - 1) % len(ds_train) + 1
            if (iter % 50) == 0:
                print("Epoch[{}] Iter[{}/{}] Accuracy: {:.2f} Loss: {:.2f}".
                      format(engine.state.epoch, iter, len(ds_train), accuracy,
                             nll))
            writer.add_scalar("batchtraining/detloss", nll, engine.state.epoch)
            writer.add_scalar("batchtraining/accuracy", accuracy,
                              engine.state.iteration)
            writer.add_scalar("batchtraining/error", 1. - accuracy,
                              engine.state.iteration)
            writer.add_scalar("batchtraining/loss", engine.state.output,
                              engine.state.iteration)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_lr(engine):
            writer.add_scalar("lr", optimizer.param_groups[0]['lr'],
                              engine.state.epoch)

#         @trainer.on(Events.EPOCH_COMPLETED)
#         def log_training_results(engine):
#             train_evaluator.run(ds_train)
#             metrics = train_evaluator.state.metrics
#             # metrics = engine.state.metrics
#             avg_accuracy = metrics['accuracy']
#             avg_nll = metrics['loss']
#             print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
#                   .format(engine.state.epoch, avg_accuracy, avg_nll))
#             writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
#             writer.add_scalar("training/avg_accuracy",
#                               avg_accuracy, engine.state.epoch)
#             writer.add_scalar("training/avg_error", 1. -
#                               avg_accuracy, engine.state.epoch)

        @trainer.on(
            Events.ITERATION_COMPLETED(every=200 * 5000 // batch_size // 10))
        def validation_value(engine):
            metrics = valid_evaluator.state.metrics
            valid_avg_accuracy = metrics['accuracy']
            return valid_avg_accuracy

        to_save = {'model': model}
        handler = Checkpoint(
            to_save,
            DiskSaver(os.path.join(base_data, model_id), create_dir=True),
            score_function=validation_value,
            score_name="val_acc",
            global_step_transform=global_step_from_engine(trainer),
            n_saved=None)

        # kick everything off
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=200 * 5000 // batch_size // 10),
            handler)
        trainer.run(ds_train, max_epochs=max_epochs)
    def _test(milestones_as_np_int):
        tensor = torch.zeros([1], requires_grad=True)
        optimizer = torch.optim.SGD([tensor], lr=0)

        milestones_values = [(5, 0.5), (15, 1.0), (25, 0.0), (35, 1.0), (40, 0.5)]
        if milestones_as_np_int:
            milestones_values = [(np.int64(t), v) for t, v in milestones_values]

        scheduler = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values)
        state_dict = scheduler.state_dict()

        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]["lr"])

        trainer = Engine(lambda engine, batch: None)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

        for _ in range(2):
            lrs = []
            trainer.run([0] * 25, max_epochs=2)

            assert lrs == list(
                map(
                    pytest.approx,
                    [
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                        0.55,
                        0.6,
                        0.65,
                        0.7,
                        0.75,
                        0.8,
                        0.85,
                        0.9,
                        0.95,
                        1.0,
                        0.9,
                        0.8,
                        0.7,
                        0.6,
                        0.5,
                        0.4,
                        0.3,
                        0.2,
                        0.1,
                        0.0,
                        0.1,
                        0.2,
                        0.3,
                        0.4,
                        0.5,
                        0.6,
                        0.7,
                        0.8,
                        0.9,
                        1.0,
                        0.9,
                        0.8,
                        0.7,
                        0.6,
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                        0.5,
                    ],
                )
            )
            scheduler.load_state_dict(state_dict)
Beispiel #26
0
def train(run_name, forward_func, model, train_set, val_set, n_epochs,
          batch_size, lr):

    # Make the run directory
    save_dir = os.path.join('training/simple/saved_runs', run_name)
    if run_name == 'debug':
        shutil.rmtree(save_dir, ignore_errors=True)
    os.mkdir(save_dir)

    model = model.to(device)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Training step
    def step(engine, batch):
        model.train()

        if isinstance(batch, list):
            batch = [tensor.to(device) for tensor in batch]
        else:
            batch = batch.to(device)
        x_gen, x_q, _ = forward_func(model, batch)

        loss = F.l1_loss(x_gen, x_q)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        return {'L1': loss}

    # Trainer and metrics
    trainer = Engine(step)
    metric_names = ['L1']
    RunningAverage(output_transform=lambda x: x['L1']).attach(trainer, 'L1')
    ProgressBar().attach(trainer, metric_names=metric_names)
    Timer(average=True).attach(trainer,
                               start=Events.EPOCH_STARTED,
                               resume=Events.ITERATION_STARTED,
                               pause=Events.ITERATION_COMPLETED,
                               step=Events.ITERATION_COMPLETED)

    # Model checkpointing
    checkpoint_handler = ModelCheckpoint(os.path.join(save_dir, 'checkpoints'),
                                         type(model).__name__,
                                         save_interval=1,
                                         n_saved=3,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'model': model,
                                  'optimizer': optimizer
                              })

    # Tensorbard writer
    writer = SummaryWriter(log_dir=os.path.join(save_dir, 'logs'))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_metrics(engine):
        if engine.state.iteration % 100 == 0:
            for metric, value in engine.state.metrics.items():
                writer.add_scalar('training/{}'.format(metric), value,
                                  engine.state.iteration)

    def save_images(engine, batch):
        x_gen, x_q, r = forward_func(model, batch)
        r_dim = r.shape[1]
        if isinstance(model, SimpleVVGQN):
            r = (r + 1) / 2
        r = r.view(-1, 1, int(math.sqrt(r_dim)), int(math.sqrt(r_dim)))

        x_gen = x_gen.detach().cpu().float()
        r = r.detach().cpu().float()

        writer.add_image('representation', make_grid(r), engine.state.epoch)
        writer.add_image('generation', make_grid(x_gen), engine.state.epoch)
        writer.add_image('query', make_grid(x_q), engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        model.eval()
        with torch.no_grad():
            batch = next(iter(val_loader))
            if isinstance(batch, list):
                batch = [tensor.to(device) for tensor in batch]
            else:
                batch = batch.to(device)
            x_gen, x_q, r = forward_func(model, batch)

            loss = F.l1_loss(x_gen, x_q)

            writer.add_scalar('validation/L1', loss.item(), engine.state.epoch)

            save_images(engine, batch)

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        writer.close()
        engine.terminate()
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            import warnings
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
            checkpoint_handler(engine, {'model_exception': model})
        else:
            raise e

    start_time = time.time()
    trainer.run(train_loader, n_epochs)
    writer.close()
    end_time = time.time()
    print('Total training time: {}'.format(
        timedelta(seconds=end_time - start_time)))
def main(
    batch_size,
    epochs,
    length_scale,
    centroid_size,
    model_output_size,
    learning_rate,
    l_gradient_penalty,
    gamma,
    weight_decay,
    final_model,
):
    name = f"DUQ_{length_scale}__{l_gradient_penalty}_{gamma}_{centroid_size}"
    writer = SummaryWriter(comment=name)

    ds = all_datasets["CIFAR10"]()
    input_size, num_classes, dataset, test_dataset = ds

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        val_size = int(len(dataset) * 0.8)
        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

        val_dataset.transform = (test_dataset.transform
                                 )  # Test time preprocessing for validation

    model = ResNet_DUQ(input_size, num_classes, centroid_size,
                       model_output_size, length_scale, gamma)
    model = model.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[25, 50, 75],
                                                     gamma=0.2)

    def bce_loss_fn(y_pred, y):
        bce = F.binary_cross_entropy(y_pred, y, reduction="sum").div(
            num_classes * y_pred.shape[0])
        return bce

    def output_transform_bce(output):
        y_pred, y, x = output

        y = F.one_hot(y, num_classes).float()

        return y_pred, y

    def output_transform_acc(output):
        y_pred, y, x = output

        return y_pred, y

    def output_transform_gp(output):
        y_pred, y, x = output

        return x, y_pred

    def calc_gradients_input(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        return gradients

    def calc_gradient_penalty(x, y_pred):
        gradients = calc_gradients_input(x, y_pred)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1)**2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        if l_gradient_penalty > 0:
            x.requires_grad_(True)

        z, y_pred = model(x)
        y = F.one_hot(y, num_classes).float()

        loss = bce_loss_fn(y_pred, y)

        if l_gradient_penalty > 0:
            loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred)

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        z, y_pred = model(x)

        return y_pred, y, x

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "loss")

    metric = Accuracy(output_transform=output_transform_acc)
    metric.attach(evaluator, "accuracy")

    metric = Loss(F.binary_cross_entropy,
                  output_transform=output_transform_bce)
    metric.attach(evaluator, "bce")

    metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp)
    metric.attach(evaluator, "gradient_penalty")

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1000,
                                             shuffle=False,
                                             **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1000,
                                              shuffle=False,
                                              **kwargs)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        loss = metrics["loss"]

        print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f} ")

        writer.add_scalar("Loss/train", loss, trainer.state.epoch)

        if trainer.state.epoch % 5 == 0 or trainer.state.epoch > 65:
            accuracy, auroc = get_cifar_svhn_ood(model)
            print(f"Test Accuracy: {accuracy}, AUROC: {auroc}")
            writer.add_scalar("OoD/test_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch)

            accuracy, auroc = get_auroc_classification(val_dataset, model)
            print(f"AUROC - uncertainty: {auroc}")
            writer.add_scalar("OoD/val_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc_classification", auroc,
                              trainer.state.epoch)

        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        bce = metrics["bce"]
        GP = metrics["gradient_penalty"]
        loss = bce + l_gradient_penalty * GP

        print((f"Valid - Epoch: {trainer.state.epoch} "
               f"Acc: {acc:.4f} "
               f"Loss: {loss:.2f} "
               f"BCE: {bce:.2f} "
               f"GP: {GP:.2f} "))

        writer.add_scalar("Loss/valid", loss, trainer.state.epoch)
        writer.add_scalar("BCE/valid", bce, trainer.state.epoch)
        writer.add_scalar("GP/valid", GP, trainer.state.epoch)
        writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch)

        print(f"Centroid norm: {torch.norm(model.m / model.N, dim=0)}")

        scheduler.step()

        if trainer.state.epoch > 65:
            torch.save(model.state_dict(),
                       f"saved_models/{name}_{trainer.state.epoch}.pt")

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    trainer.run(train_loader, max_epochs=epochs)

    evaluator.run(test_loader)
    acc = evaluator.state.metrics["accuracy"]

    print(f"Test - Accuracy {acc:.4f}")

    writer.close()
Beispiel #28
0
    def _test(metric_device):
        data = list(range(n_iters))
        np.random.seed(12)
        all_y_true_batch_values = np.random.randint(
            0,
            n_classes,
            size=(idist.get_world_size(), n_epochs * n_iters, batch_size))
        all_y_pred_batch_values = np.random.rand(idist.get_world_size(),
                                                 n_epochs * n_iters,
                                                 batch_size, n_classes)

        y_true_batch_values = iter(all_y_true_batch_values[rank, ...])
        y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...])

        def update_fn(engine, batch):
            y_true_batch = next(y_true_batch_values)
            y_pred_batch = next(y_pred_batch_values)
            return torch.from_numpy(y_pred_batch), torch.from_numpy(
                y_true_batch)

        trainer = Engine(update_fn)
        alpha = 0.98

        acc_metric = RunningAverage(Accuracy(
            output_transform=lambda x: [x[0], x[1]], device=metric_device),
                                    alpha=alpha,
                                    epoch_bound=False)
        acc_metric.attach(trainer, "running_avg_accuracy")

        running_avg_acc = [
            None,
        ]
        true_acc_metric = Accuracy(device=metric_device)

        @trainer.on(Events.ITERATION_COMPLETED)
        def manual_running_avg_acc(engine):
            i = engine.state.iteration - 1

            true_acc_metric.reset()
            for j in range(idist.get_world_size()):
                output = (
                    torch.from_numpy(all_y_pred_batch_values[j, i, :, :]),
                    torch.from_numpy(all_y_true_batch_values[j, i, :]),
                )
                true_acc_metric.update(output)

            batch_acc = true_acc_metric._num_correct.item(
            ) * 1.0 / true_acc_metric._num_examples

            if running_avg_acc[0] is None:
                running_avg_acc[0] = batch_acc
            else:
                running_avg_acc[0] = running_avg_acc[0] * alpha + (
                    1.0 - alpha) * batch_acc
            engine.state.running_avg_acc = running_avg_acc[0]

        @trainer.on(Events.ITERATION_COMPLETED)
        def assert_equal_running_avg_acc_values(engine):
            assert engine.state.running_avg_acc == engine.state.metrics[
                "running_avg_accuracy"], "{} vs {}".format(
                    engine.state.running_avg_acc,
                    engine.state.metrics["running_avg_accuracy"])

        trainer.run(data, max_epochs=3)
Beispiel #29
0
def train_model(
    name="",
    resume="",
    base_dir=utils.BASE_DIR,
    model_name="v0",
    chosen_diseases=None,
    n_epochs=10,
    batch_size=4,
    oversample=False,
    max_os=None,
    shuffle=False,
    opt="sgd",
    opt_params={},
    loss_name="wbce",
    loss_params={},
    train_resnet=False,
    log_metrics=None,
    flush_secs=120,
    train_max_images=None,
    val_max_images=None,
    test_max_images=None,
    experiment_mode="debug",
    save=True,
    save_cms=True,  # Note that in this case, save_cms (to disk) includes write_cms (to TB)
    write_graph=False,
    write_emb=False,
    write_emb_img=False,
    write_img=False,
    image_format="RGB",
    multiple_gpu=False,
):

    # Choose GPU
    device = utilsT.get_torch_device()
    print("Using device: ", device)

    # Common folders
    dataset_dir = os.path.join(base_dir, "dataset")

    # Dataset handling
    print("Loading train dataset...")
    train_dataset, train_dataloader = utilsT.prepare_data(
        dataset_dir,
        "train",
        chosen_diseases,
        batch_size,
        oversample=oversample,
        max_os=max_os,
        shuffle=shuffle,
        max_images=train_max_images,
        image_format=image_format,
    )
    train_samples, _ = train_dataset.size()

    print("Loading val dataset...")
    val_dataset, val_dataloader = utilsT.prepare_data(
        dataset_dir,
        "val",
        chosen_diseases,
        batch_size,
        max_images=val_max_images,
        image_format=image_format,
    )
    val_samples, _ = val_dataset.size()

    # Should be the same than chosen_diseases
    chosen_diseases = list(train_dataset.classes)
    print("Chosen diseases: ", chosen_diseases)

    if resume:
        # Load model and optimizer
        model, model_name, optimizer, opt, loss_name, loss_params, chosen_diseases = models.load_model(
            base_dir, resume, experiment_mode="", device=device)
        model.train(True)
    else:
        # Create model
        model = models.init_empty_model(model_name,
                                        chosen_diseases,
                                        train_resnet=train_resnet).to(device)

        # Create optimizer
        OptClass = optimizers.get_optimizer_class(opt)
        optimizer = OptClass(model.parameters(), **opt_params)
        # print("OPT: ", opt_params)

    # Allow multiple GPUs
    if multiple_gpu:
        model = DataParallel(model)

    # Tensorboard log options
    run_name = utils.get_timestamp()
    if name:
        run_name += "_{}".format(name)

    if len(chosen_diseases) == 1:
        run_name += "_{}".format(chosen_diseases[0])
    elif len(chosen_diseases) == 14:
        run_name += "_all"

    log_dir = get_log_dir(base_dir, run_name, experiment_mode=experiment_mode)

    print("Run name: ", run_name)
    print("Saved TB in: ", log_dir)

    writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)

    # Create validator engine
    validator = Engine(
        utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params,
                           False))

    val_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
    val_loss.attach(validator, loss_name)

    utilsT.attach_metrics(validator, chosen_diseases, "prec", Precision, True)
    utilsT.attach_metrics(validator, chosen_diseases, "recall", Recall, True)
    utilsT.attach_metrics(validator, chosen_diseases, "acc", Accuracy, True)
    utilsT.attach_metrics(validator, chosen_diseases, "roc_auc",
                          utilsT.RocAucMetric, False)
    utilsT.attach_metrics(validator,
                          chosen_diseases,
                          "cm",
                          ConfusionMatrix,
                          get_transform_fn=utilsT.get_transform_cm,
                          metric_args=(2, ))
    utilsT.attach_metrics(validator,
                          chosen_diseases,
                          "positives",
                          RunningAverage,
                          get_transform_fn=utilsT.get_count_positives)

    # Create trainer engine
    trainer = Engine(
        utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params,
                           True))

    train_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
    train_loss.attach(trainer, loss_name)

    utilsT.attach_metrics(trainer, chosen_diseases, "acc", Accuracy, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "prec", Precision, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "recall", Recall, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "roc_auc",
                          utilsT.RocAucMetric, False)
    utilsT.attach_metrics(trainer,
                          chosen_diseases,
                          "cm",
                          ConfusionMatrix,
                          get_transform_fn=utilsT.get_transform_cm,
                          metric_args=(2, ))
    utilsT.attach_metrics(trainer,
                          chosen_diseases,
                          "positives",
                          RunningAverage,
                          get_transform_fn=utilsT.get_count_positives)

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 step=Events.EPOCH_COMPLETED)

    # TODO: Early stopping
    #     def score_function(engine):
    #         val_loss = engine.state.metrics[loss_name]
    #         return -val_loss

    #     handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
    #     validator.add_event_handler(Events.COMPLETED, handler)

    # Metrics callbacks
    if log_metrics is None:
        log_metrics = list(ALL_METRICS)

    def _write_metrics(run_type, metrics, epoch, wall_time):
        loss = metrics.get(loss_name, 0)

        writer.add_scalar("Loss/" + run_type, loss, epoch, wall_time)

        for metric_base_name in log_metrics:
            for disease in chosen_diseases:
                metric_value = metrics.get(
                    "{}_{}".format(metric_base_name, disease), -1)
                writer.add_scalar(
                    "{}_{}/{}".format(metric_base_name, disease, run_type),
                    metric_value, epoch, wall_time)

    @trainer.on(Events.EPOCH_COMPLETED)
    def tb_write_metrics(trainer):
        epoch = trainer.state.epoch
        max_epochs = trainer.state.max_epochs

        # Run on evaluation
        validator.run(val_dataloader, 1)

        # Common time
        wall_time = time.time()

        # Log all metrics to TB
        _write_metrics("train", trainer.state.metrics, epoch, wall_time)
        _write_metrics("val", validator.state.metrics, epoch, wall_time)

        train_loss = trainer.state.metrics.get(loss_name, 0)
        val_loss = validator.state.metrics.get(loss_name, 0)

        tb_write_histogram(writer, model, epoch, wall_time)

        print("Finished epoch {}/{}, loss {:.3f}, val loss {:.3f} (took {})".
              format(epoch, max_epochs, train_loss, val_loss,
                     utils.duration_to_str(int(timer._elapsed()))))

    # Hparam dict
    hparam_dict = {
        "resume": resume,
        "n_diseases": len(chosen_diseases),
        "diseases": ",".join(chosen_diseases),
        "n_epochs": n_epochs,
        "batch_size": batch_size,
        "shuffle": shuffle,
        "model_name": model_name,
        "opt": opt,
        "loss": loss_name,
        "samples (train, val)": "{},{}".format(train_samples, val_samples),
        "train_resnet": train_resnet,
        "multiple_gpu": multiple_gpu,
    }

    def copy_params(params_dict, base_name):
        for name, value in params_dict.items():
            hparam_dict["{}_{}".format(base_name, name)] = value

    copy_params(loss_params, "loss")
    copy_params(opt_params, "opt")
    print("HPARAM: ", hparam_dict)

    # Train
    print("-" * 50)
    print("Training...")
    trainer.run(train_dataloader, n_epochs)

    # Capture time
    secs_per_epoch = timer.value()
    duration_per_epoch = utils.duration_to_str(int(secs_per_epoch))
    print("Average time per epoch: ", duration_per_epoch)
    print("-" * 50)

    ## Write all hparams
    hparam_dict["duration_per_epoch"] = duration_per_epoch

    # FIXME: this is commented to avoid having too many hparams in TB frontend
    # metrics
    #     def copy_metrics(engine, engine_name):
    #         for metric_name, metric_value in engine.state.metrics.items():
    #             hparam_dict["{}_{}".format(engine_name, metric_name)] = metric_value
    #     copy_metrics(trainer, "train")
    #     copy_metrics(validator, "val")

    print("Writing TB hparams")
    writer.add_hparams(hparam_dict, {})

    # Save model to disk
    if save:
        print("Saving model...")
        models.save_model(base_dir, run_name, model_name, experiment_mode,
                          hparam_dict, trainer, model, optimizer)

    # Write graph to TB
    if write_graph:
        print("Writing TB graph...")
        tb_write_graph(writer, model, train_dataloader, device)

    # Write embeddings to TB
    if write_emb:
        print("Writing TB embeddings...")
        image_size = 256 if write_emb_img else 0

        # FIXME: be able to select images (balanced, train vs val, etc)
        image_list = list(train_dataset.label_index["FileName"])[:1000]
        # disease = chosen_diseases[0]
        # positive = train_dataset.label_index[train_dataset.label_index[disease] == 1]
        # negative = train_dataset.label_index[train_dataset.label_index[disease] == 0]
        # positive_images = list(positive["FileName"])[:25]
        # negative_images = list(negative["FileName"])[:25]
        # image_list = positive_images + negative_images

        all_images, all_embeddings, all_predictions, all_ground_truths = gen_embeddings(
            model,
            train_dataset,
            device,
            image_list=image_list,
            image_size=image_size)
        tb_write_embeddings(
            writer,
            chosen_diseases,
            all_images,
            all_embeddings,
            all_predictions,
            all_ground_truths,
            global_step=n_epochs,
            use_images=write_emb_img,
            tag="1000_{}".format("img" if write_emb_img else "no_img"),
        )

    # Save confusion matrices (is expensive to calculate them afterwards)
    if save_cms:
        print("Saving confusion matrices...")
        # Assure folder
        cms_dir = os.path.join(base_dir, "cms", experiment_mode)
        os.makedirs(cms_dir, exist_ok=True)
        base_fname = os.path.join(cms_dir, run_name)

        n_diseases = len(chosen_diseases)

        def extract_cms(metrics):
            """Extract confusion matrices from a metrics dict."""
            cms = []
            for disease in chosen_diseases:
                key = "cm_" + disease
                if key not in metrics:
                    cm = np.array([[-1, -1], [-1, -1]])
                else:
                    cm = metrics[key].numpy()

                cms.append(cm)
            return np.array(cms)

        # Train confusion matrix
        train_cms = extract_cms(trainer.state.metrics)
        np.save(base_fname + "_train", train_cms)
        tb_write_cms(writer, "train", chosen_diseases, train_cms)

        # Validation confusion matrix
        val_cms = extract_cms(validator.state.metrics)
        np.save(base_fname + "_val", val_cms)
        tb_write_cms(writer, "val", chosen_diseases, val_cms)

        # All confusion matrix (train + val)
        all_cms = train_cms + val_cms
        np.save(base_fname + "_all", all_cms)

        # Print to console
        if len(chosen_diseases) == 1:
            print("Train CM: ")
            print(train_cms[0])
            print("Val CM: ")
            print(val_cms[0])


#             print("Train CM 2: ")
#             print(trainer.state.metrics["cm_" + chosen_diseases[0]])
#             print("Val CM 2: ")
#             print(validator.state.metrics["cm_" + chosen_diseases[0]])

    if write_img:
        # NOTE: this option is not recommended, use Testing notebook to plot and analyze images

        print("Writing images to TB...")

        test_dataset, test_dataloader = utilsT.prepare_data(
            dataset_dir,
            "test",
            chosen_diseases,
            batch_size,
            max_images=test_max_images,
        )

        # TODO: add a way to select images?
        # image_list = list(test_dataset.label_index["FileName"])[:3]

        # Examples in test_dataset (with bboxes available):
        image_list = [
            # "00010277_000.png", # (Effusion, Infiltrate, Mass, Pneumonia)
            # "00018427_004.png", # (Atelectasis, Effusion, Mass)
            # "00021703_001.png", # (Atelectasis, Effusion, Infiltrate)
            # "00028640_008.png", # (Effusion, Infiltrate)
            # "00019124_104.png", # (Pneumothorax)
            # "00019124_090.png", # (Nodule)
            # "00020318_007.png", # (Pneumothorax)
            "00000003_000.png",  # (0)
            # "00000003_001.png", # (0)
            # "00000003_002.png", # (0)
            "00000732_005.png",  # (Cardiomegaly, Pneumothorax)
            # "00012261_001.png", # (Cardiomegaly, Pneumonia)
            # "00013249_033.png", # (Cardiomegaly, Pneumonia)
            # "00029808_003.png", # (Cardiomegaly, Pneumonia)
            # "00022215_012.png", # (Cardiomegaly, Pneumonia)
            # "00011402_007.png", # (Cardiomegaly, Pneumonia)
            # "00019018_007.png", # (Cardiomegaly, Infiltrate)
            # "00021009_001.png", # (Cardiomegaly, Infiltrate)
            # "00013670_151.png", # (Cardiomegaly, Infiltrate)
            # "00005066_030.png", # (Cardiomegaly, Infiltrate, Effusion)
            "00012288_000.png",  # (Cardiomegaly)
            "00008399_007.png",  # (Cardiomegaly)
            "00005532_000.png",  # (Cardiomegaly)
            "00005532_014.png",  # (Cardiomegaly)
            "00005532_016.png",  # (Cardiomegaly)
            "00005827_000.png",  # (Cardiomegaly)
            # "00006912_007.png", # (Cardiomegaly)
            # "00007037_000.png", # (Cardiomegaly)
            # "00007043_000.png", # (Cardiomegaly)
            # "00012741_004.png", # (Cardiomegaly)
            # "00007551_020.png", # (Cardiomegaly)
            # "00007735_040.png", # (Cardiomegaly)
            # "00008339_010.png", # (Cardiomegaly)
            # "00008365_000.png", # (Cardiomegaly)
            # "00012686_003.png", # (Cardiomegaly)
        ]

        tb_write_images(writer, model, test_dataset, chosen_diseases, n_epochs,
                        device, image_list)

    # Close TB writer
    if experiment_mode != "debug":
        writer.close()

    # Run post_train
    print("-" * 50)
    print("Running post_train...")

    print("Loading test dataset...")
    test_dataset, test_dataloader = utilsT.prepare_data(
        dataset_dir,
        "test",
        chosen_diseases,
        batch_size,
        max_images=test_max_images)

    save_cms_with_names(run_name, experiment_mode, model, test_dataset,
                        test_dataloader, chosen_diseases)

    evaluate_model(run_name,
                   model,
                   optimizer,
                   device,
                   loss_name,
                   loss_params,
                   chosen_diseases,
                   test_dataloader,
                   experiment_mode=experiment_mode,
                   base_dir=base_dir)

    # Return values for debugging
    model_run = ModelRun(model, run_name, model_name, chosen_diseases)
    if experiment_mode == "debug":
        model_run.save_debug_data(writer, trainer, validator, train_dataset,
                                  train_dataloader, val_dataset,
                                  val_dataloader)

    return model_run
Beispiel #30
0
def main(config, needs_save, study_name, k, n_splits):
    if config.run.visible_devices:
        os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices

    seed = check_manual_seed(config.run.seed)
    print('Using seed: {}'.format(seed))

    train_data_loader, test_data_loader, data_train = get_k_hold_data_loader(
        config.dataset,
        k=k,
        n_splits=n_splits,
    )

    data_train = torch.from_numpy(data_train).float().cuda(non_blocking=True)
    data_train = torch.t(data_train)

    model = get_model(config.model)
    model.cuda()
    model = nn.DataParallel(model)

    print('count params: ', count_parameters(model.module))

    saved_model_path, _, _ = get_saved_model_path(
        config,
        study_name,
        config.model.checkpoint_epoch,
        k,
        n_splits,
    )

    model.load_state_dict(torch.load(saved_model_path)['model'])
    model.eval()

    if config.model.model_name == 'MLP':
        embedding = model.module.get_embedding()

    elif config.model.model_name == 'ModifiedMLP':
        embedding = model.module.get_embedding()

    elif config.model.model_name == 'DietNetworks':
        embedding = model.module.get_embedding(data_train)

    elif config.model.model_name == 'ModifiedDietNetworks':
        embedding = model.module.get_embedding(data_train)

    embedding = embedding.detach().cpu().numpy()

    emb_pca = PCA(n_components=2)
    emb_pca.fit_transform(embedding)

    if config.run.decomp == '1D':
        print('Approximate by 1D PCA')
        axis_1= torch.from_numpy(emb_pca.components_[0])
        score_1 = np.dot(embedding, axis_1)
        approx = np.outer(score_1, axis_1)

    elif config.run.decomp == '2D':
        print('Approximate by 2D PCA')
        axis_1= torch.from_numpy(emb_pca.components_[0])
        score_1 = np.dot(embedding, axis_1)
        axis_2= torch.from_numpy(emb_pca.components_[1])
        score_2 = np.dot(embedding, axis_2)
        approx = np.outer(score_1, axis_1) + np.outer(score_2, axis_2)
        # approx = np.outer(score_2, axis_2)

    approx = torch.from_numpy(approx).float().cuda(non_blocking=True)

    criterion = nn.CrossEntropyLoss()

    def inference(engine, batch):

        x = batch['data'].float().cuda(non_blocking=True)
        y = batch['label'].long().cuda(non_blocking=True)

        assert config.run.transposed_matrix == 'overall'
        x_t = data_train

        with torch.no_grad():
            out, _ = model.module.approx(x, approx)
            l_discriminative = criterion(out, y)
            l_total = l_discriminative

        metrics = calc_metrics(out, y)

        metrics.update({
            'l_total': l_total.item(),
            'l_discriminative': l_discriminative.item(),
        })

        torch.cuda.synchronize()

        return metrics

    evaluator = Engine(inference)

    monitoring_metrics = ['l_total', 'l_discriminative', 'accuracy']

    for metric in monitoring_metrics:
        RunningAverage(
            alpha=0.98,
            output_transform=partial(lambda x, metric: x[metric], metric=metric)
        ).attach(evaluator, metric)

    pbar = ProgressBar()
    pbar.attach(evaluator, metric_names=monitoring_metrics)

    evaluator.run(test_data_loader, 1)

    columns = ['k', 'n_splits', 'epoch', 'iteration'] + list(evaluator.state.metrics.keys())
    values = [str(k), str(n_splits), str(evaluator.state.epoch), str(evaluator.state.iteration)] \
           + [str(value) for value in evaluator.state.metrics.values()]

    values = {c: v for (c, v) in zip(columns, values)}
    values.update({
        'variance_ratio_1': emb_pca.explained_variance_ratio_[0],
        'variance_ratio_2': emb_pca.explained_variance_ratio_[1],
    })
    return values