def test_pbar_file(tmp_path):
    n_epochs = 2
    loader = [1, 2]
    engine = Engine(update_fn)

    file_path = tmp_path / "temp.txt"
    file = open(str(file_path), "w+")

    pbar = ProgressBar(file=file)
    pbar.attach(engine, ["a"])
    engine.run(loader, max_epochs=n_epochs)

    file.close()  # Force a flush of the buffer. file.flush() does not work.

    file = open(str(file_path), "r")
    lines = file.readlines()

    if get_tqdm_version() < LooseVersion("4.49.0"):
        expected = "Epoch [2/2]: [1/2]  50%|█████     , a=1 [00:00<00:00]\n"
    else:
        expected = "Epoch [2/2]: [1/2]  50%|█████     , a=1 [00:00<?]\n"
    assert lines[-2] == expected
def test_pbar_wrong_events_order():

    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED)

    with pytest.raises(ValueError, match="should not be a filtered event"):
        pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10))
    def fit(self, train_loader, valid_loader, n_epochs):
        trainer = create_supervised_trainer(self.model,
                                            self.optim,
                                            self.loss_fn,
                                            device=self.device)

        evaluator = create_supervised_evaluator(
            self.model,
            metrics={'loss': Loss(self.loss_fn)},
            device=self.device)

        RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
        pbar = ProgressBar(persist=False)
        pbar.attach(trainer, metric_names="all")

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results():
            train_loss = trainer.state.metrics['loss']
            evaluator.run(valid_loader)
            valid_loss = evaluator.state.metrics['loss']

            self.history['train_loss'].append(train_loss)
            self.history['valid_loss'].append(valid_loss)

            if valid_loss < self.best_loss:
                self.best_loss = valid_loss
                self.best_epoch = trainer.state.epoch
                self.best_model = deepcopy(self.model.state_dict())

            template = "Epoch [%3d/%3d] >> train_loss = %.4f, valid_loss = %.4f, "
            template += "lowest_loss = %.4f @epoch = %d"
            pbar.log_message(template %
                             (trainer.state.epoch, trainer.state.max_epochs,
                              trainer.state.output, valid_loss, self.best_loss,
                              self.best_epoch))

        trainer.run(train_loader, max_epochs=n_epochs)
        self.model.load_state_dict(self.best_model)
Beispiel #4
0
def start_to_learn(trainer, train_loader, tester, test_loader, epochs, model_metrics):
    # ---Log Message Initializing---
    log_msg = ProgressBar(persist=True, bar_format=" ")

    # ---After Training starts Testing---
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        tester.run(test_loader)

        epoch = engine.state.epoch
        metrics = tester.state.metrics

        model_metrics.update({epoch: metrics})

        log_msg.log_message("Epoch: {}  \nAccuracy: {:.3f} \tLoss: {:.3f} \tRecall: {:.3f} \tPrecision: {:.3f}\n\n"
                                .format(epoch, metrics['accuracy'], metrics['loss'], metrics['precision'], metrics['recall']))
        log_msg.n = log_msg.last_print_n = 0

    start = time.time()
    trainer.run(train_loader, epochs)   # Training Starting
    duration = time.time() - start
    print('Duration of execution: ', duration)
    return duration
Beispiel #5
0
def test_pbar_with_state_attrs(capsys):

    n_iters = 2
    data = list(range(n_iters))
    loss_values = iter(range(n_iters))

    def step(engine, batch):
        loss_value = next(loss_values)
        return loss_value

    trainer = Engine(step)
    trainer.state.alpha = 3.899
    trainer.state.beta = torch.tensor(12.21)
    trainer.state.gamma = torch.tensor([21.0, 6.0])

    RunningAverage(alpha=0.5, output_transform=lambda x: x).attach(trainer, "batchloss")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=["batchloss"], state_attributes=["alpha", "beta", "gamma"])

    trainer.run(data=data, max_epochs=1)

    captured = capsys.readouterr()
    err = captured.err.split("\r")
    err = list(map(lambda x: x.strip(), err))
    err = list(filter(None, err))
    actual = err[-1]
    if get_tqdm_version() < Version("4.49.0"):
        expected = (
            "Iteration: [1/2]  50%|█████     , batchloss=0.5, alpha=3.9, beta=12.2, gamma_0=21, gamma_1=6 [00:00<00:00]"
        )
    else:
        expected = (
            "Iteration: [1/2]  50%|█████     , batchloss=0.5, alpha=3.9, beta=12.2, gamma_0=21, gamma_1=6 [00:00<?]"
        )
    assert actual == expected
Beispiel #6
0
    def build_trainer(self) -> Engine:
        loss_fn: callable = F.nll_loss
        optimizer: torch.optim.Adam = torch.optim.Adam(self.parameters(), 1e-3)
        model = self

        def process_function(engine: Engine, batch: Tuple[torch.Tensor, torch.Tensor, List[int]]) -> \
                Tuple[float, torch.Tensor, torch.Tensor]:
            """Single training loop to be attached to trainer Engine"""
            model.train()
            optimizer.zero_grad()
            x, y, lengths = batch
            x, y = x.to(model.device), y.to(model.device)
            y_pred: torch.Tensor = model(x, lengths)
            loss: torch.Tensor = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            return loss.item(), torch.max(y_pred, dim=1)[1], y

        def eval_function(engine: Engine, batch: Tuple[torch.Tensor, torch.Tensor, List[int]]) -> \
                Tuple[torch.Tensor, torch.Tensor]:
            """Single evaluator loop to be attached to trainer and evaluator Engine"""
            model.eval()
            with torch.no_grad():
                x, y, lengths = batch
                x, y = x.to(model.device), y.to(model.device)
                y_pred: torch.Tensor = model(x, lengths)
                return y_pred, y

        trainer: Engine = Engine(process_function)
        train_evaluator: Engine = Engine(eval_function)
        validation_evaluator: Engine = Engine(eval_function)
        ConcatPoolingGRUAdaptive.track_progress(train_evaluator, validation_evaluator, loss_fn, trainer)
        pbar = ProgressBar(persist=True, bar_format="")
        pbar.attach(trainer, ['loss', 'acc'])
        self.log_results(train_evaluator, validation_evaluator, pbar, trainer)
        return trainer
def create_segmentation_evaluator(
        model, device,
        num_classes=19,
        loss_fn=None,
        non_blocking=True):

    cm = partial(ConfusionMatrix, num_classes)

    metrics = {
        'iou': IoU(cm()),
        'miou': mIoU(cm()),
        'accuracy': cmAccuracy(cm()),
        'dice': DiceCoefficient(cm()),
    }
    if loss_fn is not None:
        metrics['loss'] = Loss(loss_fn)

    evaluator = create_supervised_evaluator(
        model, metrics, device, non_blocking)

    ProgressBar(persist=False) \
        .attach(evaluator)

    return evaluator
Beispiel #8
0
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y


RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
# validation
Accuracy(output_transform=thresholded_output_transform).attach(
    train_evaluator, 'accuracy')
Loss(criterion).attach(train_evaluator, 'bce')
# test
Accuracy(output_transform=thresholded_output_transform).attach(
    validation_evaluator, 'accuracy')
Loss(criterion).attach(validation_evaluator, 'bce')

pbar = ProgressBar(persist=True, bar_format="")
pbar.attach(trainer, ['loss'])

trainer.run(train_iterator, max_epochs=N_EPOCH)


def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()  #convert into float for division
    acc = correct.sum() / len(correct)
    return acc
Beispiel #9
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model")
    parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training")
    parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps")
    parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate")
    parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient")
    parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences")
    parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", args.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTLMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(args.device)
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        lm_loss, mc_loss = model(*batch)
        loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics 
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
               "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #10
0
def create_trainer(
    train_step,
    output_names,
    model,
    ema_model,
    optimizer,
    lr_scheduler,
    supervised_train_loader,
    test_loader,
    cfg,
    logger,
    cta=None,
    unsup_train_loader=None,
    cta_probe_loader=None,
):

    trainer = Engine(train_step)
    trainer.logger = logger

    output_path = os.getcwd()

    to_save = {
        "model": model,
        "ema_model": ema_model,
        "optimizer": optimizer,
        "trainer": trainer,
        "lr_scheduler": lr_scheduler,
    }
    if cta is not None:
        to_save["cta"] = cta

    common.setup_common_training_handlers(
        trainer,
        train_sampler=supervised_train_loader.sampler,
        to_save=to_save,
        save_every_iters=cfg.solver.checkpoint_every,
        output_path=output_path,
        output_names=output_names,
        lr_scheduler=lr_scheduler,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    ProgressBar(persist=False).attach(
        trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED
    )

    unsupervised_train_loader_iter = None
    if unsup_train_loader is not None:
        unsupervised_train_loader_iter = cycle(unsup_train_loader)

    cta_probe_loader_iter = None
    if cta_probe_loader is not None:
        cta_probe_loader_iter = cycle(cta_probe_loader)

    # Setup handler to prepare data batches
    @trainer.on(Events.ITERATION_STARTED)
    def prepare_batch(e):
        sup_batch = e.state.batch
        e.state.batch = {
            "sup_batch": sup_batch,
        }
        if unsupervised_train_loader_iter is not None:
            unsup_batch = next(unsupervised_train_loader_iter)
            e.state.batch["unsup_batch"] = unsup_batch

        if cta_probe_loader_iter is not None:
            cta_probe_batch = next(cta_probe_loader_iter)
            cta_probe_batch["policy"] = [
                deserialize(p) for p in cta_probe_batch["policy"]
            ]
            e.state.batch["cta_probe_batch"] = cta_probe_batch

    # Setup handler to update EMA model
    @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay)
    def update_ema_model(ema_decay):
        # EMA on parametes
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)

    # Setup handlers for debugging
    if cfg.debug:

        @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100))
        @idist.one_rank_only()
        def log_weights_norms():
            wn = []
            ema_wn = []
            for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                wn.append(torch.mean(param.data))
                ema_wn.append(torch.mean(ema_param.data))

            msg = "\n\nWeights norms"
            msg += "\n- Raw model: {}".format(
                to_list_str(torch.tensor(wn[:10] + wn[-10:]))
            )
            msg += "\n- EMA model: {}\n".format(
                to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:]))
            )
            logger.info(msg)

            rmn = []
            rvar = []
            ema_rmn = []
            ema_rvar = []
            for m1, m2 in zip(model.modules(), ema_model.modules()):
                if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d):
                    rmn.append(torch.mean(m1.running_mean))
                    rvar.append(torch.mean(m1.running_var))
                    ema_rmn.append(torch.mean(m2.running_mean))
                    ema_rvar.append(torch.mean(m2.running_var))

            msg = "\n\nBN buffers"
            msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10])))
            msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10])))
            msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10])))
            msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10])))
            logger.info(msg)

        # TODO: Need to inspect a bug
        # if idist.get_rank() == 0:
        #     from ignite.contrib.handlers import ProgressBar
        #
        #     profiler = BasicTimeProfiler()
        #     profiler.attach(trainer)
        #
        #     @trainer.on(Events.ITERATION_COMPLETED(every=200))
        #     def log_profiling(_):
        #         results = profiler.get_results()
        #         profiler.print_results(results)

    # Setup validation engine
    metrics = {
        "accuracy": Accuracy(),
    }

    if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU):
        metrics.update({
            "precision": Precision(average=False),
            "recall": Recall(average=False),
        })

    eval_kwargs = dict(
        metrics=metrics,
        prepare_batch=sup_prepare_batch,
        device=idist.device(),
        non_blocking=True,
    )

    evaluator = create_supervised_evaluator(model, **eval_kwargs)
    ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs)

    def log_results(epoch, max_epochs, metrics, ema_metrics):
        msg1 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()]
        )
        msg2 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()]
        )
        logger.info(
            "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2)
        )
        if cta is not None:
            logger.info("\n" + stats(cta))

    @trainer.on(
        Events.EPOCH_COMPLETED(every=cfg.solver.validate_every)
        | Events.STARTED
        | Events.COMPLETED
    )
    def run_evaluation():
        evaluator.run(test_loader)
        ema_evaluator.run(test_loader)
        log_results(
            trainer.state.epoch,
            trainer.state.max_epochs,
            evaluator.state.metrics,
            ema_evaluator.state.metrics,
        )

    # setup TB logging
    if idist.get_rank() == 0:
        tb_logger = common.setup_tb_logging(
            output_path,
            trainer,
            optimizers=optimizer,
            evaluators={"validation": evaluator, "ema validation": ema_evaluator},
            log_every_iters=15,
        )
        if cfg.online_exp_tracking.wandb:
            from ignite.contrib.handlers import WandBLogger

            wb_dir = Path("/tmp/output-fixmatch-wandb")
            if not wb_dir.exists():
                wb_dir.mkdir()

            _ = WandBLogger(
                project="fixmatch-pytorch",
                name=cfg.name,
                config=cfg,
                sync_tensorboard=True,
                dir=wb_dir.as_posix(),
                reinit=True,
            )

    resume_from = cfg.solver.resume_from
    if resume_from is not None:
        resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*"))
        if len(resume_from) > 0:
            # get latest
            checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime)
            assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
                checkpoint_fp.as_posix()
            )
            logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
            checkpoint = torch.load(checkpoint_fp.as_posix())
            Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    @trainer.on(Events.COMPLETED)
    def release_all_resources():
        nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter

        if idist.get_rank() == 0:
            tb_logger.close()

        if unsupervised_train_loader_iter is not None:
            unsupervised_train_loader_iter = None

        if cta_probe_loader_iter is not None:
            cta_probe_loader_iter = None

    return trainer
Beispiel #11
0
def train():
    config_file = "configs/train_daily_dialog_emotion_action_config.json"
    config = Config.from_json_file(config_file)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", config.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(config.device)
    optimizer = OpenAIAdam(model.parameters(), lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.fp16)
    if config.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[config.local_rank],
                                        output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        config, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple(
            input_tensor.to(config.device) for input_tensor in batch)
        lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels,
                                 token_type_ids, token_emotion_ids,
                                 token_action_ids)
        loss = (lm_loss * config.lm_coef +
                mc_loss * config.mc_coef) / config.gradient_accumulation_steps
        if config.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           config.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
        if engine.state.iteration % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch
            #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids,
                                  token_emotion_ids=token_emotion_ids,
                                  token_action_ids=token_action_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if config.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if config.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if config.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, config.lr),
                                 (config.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], config),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], config)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if config.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=config.log_dir)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(config,
                   tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=config.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if config.local_rank in [-1, 0] and config.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #12
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every,
         ld_on_samples, weight_gan, weight_prior, weight_logdet,
         jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every,
         eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale,
         no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split,
         disc_arch, weight_entropy_reg, db):

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)
    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, logittransform, sn, affine_eps,
                 no_actnorm, affine_scale_eps, actnorm_max_scale,
                 no_conv_actnorm, affine_max_scale, actnorm_eps, no_split)

    model = model.to(device)

    if disc_arch == 'mine':
        discriminator = mine.Discriminator(image_shape[-1])
    elif disc_arch == 'biggan':
        discriminator = cgan_models.Discriminator(
            image_channels=image_shape[-1], conditional_D=False)
    elif disc_arch == 'dcgan':
        discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1])
    elif disc_arch == 'inv':
        discriminator = InvDiscriminator(
            image_shape, hidden_channels, K, L, actnorm_scale,
            flow_permutation, flow_coupling, LU_decomposed, num_classes,
            learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm,
            affine_scale_eps, actnorm_max_scale, no_conv_actnorm,
            affine_max_scale, actnorm_eps, no_split)

    discriminator = discriminator.to(device)
    D_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                    discriminator.parameters()),
                             lr=disc_lr,
                             betas=(.5, .99),
                             weight_decay=0)
    if optim_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               betas=(.5, .99),
                               weight_decay=0)
    elif optim_name == 'adamax':
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    if not no_warm_up:
        lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lr_lambda)

    iteration_fieldnames = [
        'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd',
        'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc'
    ]
    iteration_logger = CSVLogger(fieldnames=iteration_fieldnames,
                                 filename=os.path.join(output_dir,
                                                       'iteration_log.csv'))
    iteration_fieldnames = [
        'global_iteration', 'condition_num', 'max_sv', 'min_sv',
        'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv'
    ]
    svd_logger = CSVLogger(fieldnames=iteration_fieldnames,
                           filename=os.path.join(output_dir, 'svd_log.csv'))

    #
    test_iter = test_loader.__iter__()
    N_inception = 1000
    x_real_inception = torch.cat([
        test_iter.__next__()[0].to(device)
        for _ in range(N_inception // args.batch_size + 1)
    ], 0)[:N_inception]
    x_real_inception = x_real_inception + .5
    x_for_recon = test_iter.__next__()[0].to(device)

    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        real_acc = fake_acc = acc = 0
        if weight_gan > 0:
            fake = generate_from_noise(model, x.size(0), clamp=clamp)

            D_real_scores = run_noised_disc(discriminator, x.detach())
            D_fake_scores = run_noised_disc(discriminator, fake.detach())

            ones_target = torch.ones((x.size(0), 1), device=x.device)
            zeros_target = torch.zeros((x.size(0), 1), device=x.device)

            D_real_accuracy = torch.sum(
                torch.round(F.sigmoid(D_real_scores)) ==
                ones_target).float() / ones_target.size(0)
            D_fake_accuracy = torch.sum(
                torch.round(F.sigmoid(D_fake_scores)) ==
                zeros_target).float() / zeros_target.size(0)

            D_real_loss = F.binary_cross_entropy_with_logits(
                D_real_scores, ones_target)
            D_fake_loss = F.binary_cross_entropy_with_logits(
                D_fake_scores, zeros_target)

            D_loss = (D_real_loss + D_fake_loss) / 2
            gp = gradient_penalty(
                x.detach(), fake.detach(),
                lambda _x: run_noised_disc(discriminator, _x))
            D_loss_plus_gp = D_loss + 10 * gp
            D_optimizer.zero_grad()
            D_loss_plus_gp.backward()
            D_optimizer.step()

            # Train generator
            fake = generate_from_noise(model,
                                       x.size(0),
                                       clamp=clamp,
                                       guard_nans=False)
            G_loss = F.binary_cross_entropy_with_logits(
                run_noised_disc(discriminator, fake),
                torch.ones((x.size(0), 1), device=x.device))

            # Trace
            real_acc = D_real_accuracy.item()
            fake_acc = D_fake_accuracy.item()
            acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item())

        z, nll, y_logits, (prior, logdet) = model.forward(x,
                                                          None,
                                                          return_details=True)
        train_bpd = nll.mean().item()

        loss = 0
        if weight_gan > 0:
            loss = loss + weight_gan * G_loss
        if weight_prior > 0:
            loss = loss + weight_prior * -prior.mean()
        if weight_logdet > 0:
            loss = loss + weight_logdet * -logdet.mean()

        if weight_entropy_reg > 0:
            _, _, _, (sample_prior,
                      sample_logdet) = model.forward(fake,
                                                     None,
                                                     return_details=True)
            # notice this is actually "decreasing" sample likelihood.
            loss = loss + weight_entropy_reg * (sample_prior.mean() +
                                                sample_logdet.mean())
        # Jac Reg
        if jac_reg_lambda > 0:
            # Sample
            x_samples = generate_from_noise(model,
                                            args.batch_size,
                                            clamp=clamp).detach()
            x_samples.requires_grad_()
            z = model.forward(x_samples, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            sample_foward_jac = compute_jacobian_regularizer(x_samples,
                                                             all_z,
                                                             n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            randz = torch.randn(zshape).to(device)
            randz = torch.autograd.Variable(randz, requires_grad=True)
            images = model(z=randz,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [randz] + other_zs
            sample_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # Data
            x.requires_grad_()
            z = model.forward(x, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            z.requires_grad_()
            images = model(z=z,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [z] + other_zs
            data_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac )
            loss = loss + jac_reg_lambda * (sample_foward_jac +
                                            sample_inverse_jac +
                                            data_foward_jac + data_inverse_jac)

        if not eval_only:
            optimizer.zero_grad()
            loss.backward()
            if not db:
                assert max_grad_clip == max_grad_norm == 0
            if max_grad_clip > 0:
                torch.nn.utils.clip_grad_value_(model.parameters(),
                                                max_grad_clip)
            if max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)

            # Replace NaN gradient with 0
            for p in model.parameters():
                if p.requires_grad and p.grad is not None:
                    g = p.grad.data
                    g[g != g] = 0

            optimizer.step()

        if engine.iter_ind % 100 == 0:
            with torch.no_grad():
                fake = generate_from_noise(model, x.size(0), clamp=clamp)
                z = model.forward(fake, None, return_details=True)[0]
            print("Z max min")
            print(z.max().item(), z.min().item())
            if (fake != fake).float().sum() > 0:
                title = 'NaNs'
            else:
                title = "Good"
            grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.title(title)
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

        if engine.iter_ind % eval_every == 0:

            def check_all_zero_except_leading(x):
                return x % 10**np.floor(np.log10(x)) == 0

            if engine.iter_ind == 0 or check_all_zero_except_leading(
                    engine.iter_ind):
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt'))

            model.eval()

            with torch.no_grad():
                # Plot recon
                fpath = os.path.join(output_dir, '_recon',
                                     f'recon_{engine.iter_ind}.png')
                sample_pad = run_recon_evolution(
                    model,
                    generate_from_noise(model, args.batch_size,
                                        clamp=clamp).detach(), fpath)
                print(
                    f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}")

                pad = run_recon_evolution(model, x_for_recon, fpath)
                print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}")
                pad = pad.item()
                sample_pad = sample_pad.item()

                # Inception score
                sample = torch.cat([
                    generate_from_noise(model, args.batch_size, clamp=clamp)
                    for _ in range(N_inception // args.batch_size + 1)
                ], 0)[:N_inception]
                sample = sample + .5

                if (sample != sample).float().sum() > 0:
                    print("Sample NaNs")
                    raise
                else:
                    fid = run_fid(x_real_inception.clamp_(0, 1),
                                  sample.clamp_(0, 1))
                    print(f'fid: {fid}, global_iter: {engine.iter_ind}')

                # Eval BPD
                eval_bpd = np.mean([
                    model.forward(x.to(device), None,
                                  return_details=True)[1].mean().item()
                    for x, _ in test_loader
                ])

                stats_dict = {
                    'global_iteration': engine.iter_ind,
                    'fid': fid,
                    'train_bpd': train_bpd,
                    'pad': pad,
                    'eval_bpd': eval_bpd,
                    'sample_pad': sample_pad,
                    'batch_real_acc': real_acc,
                    'batch_fake_acc': fake_acc,
                    'batch_acc': acc
                }
                iteration_logger.writerow(stats_dict)
                plot_csv(iteration_logger.filename)
            model.train()

        if engine.iter_ind + 2 % svd_every == 0:
            model.eval()
            svd_dict = {}
            ret = utils.computeSVDjacobian(x_for_recon, model)
            D_for, D_inv = ret['D_for'], ret['D_inv']
            cn = float(D_for.max() / D_for.min())
            cn_inv = float(D_inv.max() / D_inv.min())
            svd_dict['global_iteration'] = engine.iter_ind
            svd_dict['condition_num'] = cn
            svd_dict['max_sv'] = float(D_for.max())
            svd_dict['min_sv'] = float(D_for.min())
            svd_dict['inverse_condition_num'] = cn_inv
            svd_dict['inverse_max_sv'] = float(D_inv.max())
            svd_dict['inverse_min_sv'] = float(D_inv.min())
            svd_logger.writerow(svd_dict)
            # plot_utils.plot_stability_stats(output_dir)
            # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv')
            model.train()
            if eval_only:
                sys.exit()

        # Dummy
        losses['total_loss'] = torch.mean(nll).item()
        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(gan_step)
    # else:
    #     trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=5,
                                         n_saved=1,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        print("Loading...")
        print(saved_model)
        loaded = torch.load(saved_model)
        # if 'Glow' in str(type(loaded)):
        #     model  = loaded
        # else:
        #     raise
        # # if 'Glow' in str(type(loaded)):
        # #     loaded  = loaded.state_dict()
        model.load_state_dict(loaded)
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        if saved_model:
            return
        model.train()
        print("Initializing Actnorm...")
        init_batches = []
        init_targets = []

        if n_init_batches == 0:
            model.set_actnorm_init()
            return
        with torch.no_grad():
            if init_sample:
                generate_from_noise(model,
                                    args.batch_size * args.n_init_batches)
            else:
                for batch, target in islice(train_loader, None,
                                            n_init_batches):
                    init_batches.append(batch)
                    init_targets.append(target)

                init_batches = torch.cat(init_batches).to(device)

                assert init_batches.shape[0] == n_init_batches * batch_size

                if y_condition:
                    init_targets = torch.cat(init_targets).to(device)
                else:
                    init_targets = None

                model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        if not no_warm_up:
            scheduler.step()
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
def main(config, needs_save):
    os.environ['CUDA_VISIBLE_DEVICES'] = config.training.visible_devices
    seed = check_manual_seed(config.training.seed)
    print('Using manual seed: {}'.format(seed))

    if config.dataset.patient_ids == 'TRAIN_PATIENT_IDS':
        patient_ids = TRAIN_PATIENT_IDS
    elif config.dataset.patient_ids == 'TEST_PATIENT_IDS':
        patient_ids = TEST_PATIENT_IDS
    else:
        raise NotImplementedError

    data_loader = get_data_loader(
        mode=config.dataset.mode,
        dataset_name=config.dataset.name,
        patient_ids=patient_ids,
        root_dir_path=config.dataset.root_dir_path,
        use_augmentation=config.dataset.use_augmentation,
        batch_size=config.dataset.batch_size,
        num_workers=config.dataset.num_workers,
        image_size=config.dataset.image_size)

    E = Encoder(input_dim=config.model.input_dim,
                z_dim=config.model.z_dim,
                filters=config.model.enc_filters,
                activation=config.model.enc_activation).float()

    D = Decoder(input_dim=config.model.input_dim,
                z_dim=config.model.z_dim,
                filters=config.model.dec_filters,
                activation=config.model.dec_activation,
                final_activation=config.model.dec_final_activation).float()

    if config.model.enc_spectral_norm:
        apply_spectral_norm(E)

    if config.model.dec_spectral_norm:
        apply_spectral_norm(D)

    if config.training.use_cuda:
        E.cuda()
        D.cuda()
        E = nn.DataParallel(E)
        D = nn.DataParallel(D)

    if config.model.saved_E:
        print(config.model.saved_E)
        E.load_state_dict(torch.load(config.model.saved_E))

    if config.model.saved_D:
        print(config.model.saved_D)
        D.load_state_dict(torch.load(config.model.saved_D))

    print(E)
    print(D)

    e_optim = optim.Adam(filter(lambda p: p.requires_grad, E.parameters()),
                         config.optimizer.enc_lr, [0.9, 0.9999])

    d_optim = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()),
                         config.optimizer.dec_lr, [0.9, 0.9999])

    alpha = config.training.alpha
    beta = config.training.beta
    margin = config.training.margin

    batch_size = config.dataset.batch_size
    fixed_z = torch.randn(calc_latent_dim(config))

    if 'ssim' in config.training.loss:
        ssim_loss = pytorch_ssim.SSIM(window_size=11)

    def l_recon(recon: torch.Tensor, target: torch.Tensor):
        if config.training.loss == 'l2':
            loss = F.mse_loss(recon, target, reduction='sum')

        elif config.training.loss == 'l1':
            loss = F.l1_loss(recon, target, reduction='sum')

        elif config.training.loss == 'ssim':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon)

        elif config.training.loss == 'ssim+l1':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \
                 + F.l1_loss(recon, target, reduction='sum')

        elif config.training.loss == 'ssim+l2':
            loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \
                 + F.mse_loss(recon, target, reduction='sum')

        else:
            raise NotImplementedError

        return beta * loss / batch_size

    def l_reg(mu: torch.Tensor, log_var: torch.Tensor):
        loss = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var))
        return loss / batch_size

    def update(engine, batch):
        E.train()
        D.train()

        image = norm(batch['image'])

        if config.training.use_cuda:
            image = image.cuda(non_blocking=True).float()
        else:
            image = image.float()

        e_optim.zero_grad()
        d_optim.zero_grad()

        z, z_mu, z_logvar = E(image)
        x_r = D(z)

        l_vae_reg = l_reg(z_mu, z_logvar)
        l_vae_recon = l_recon(x_r, image)
        l_vae_total = l_vae_reg + l_vae_recon

        l_vae_total.backward()

        e_optim.step()
        d_optim.step()

        if config.training.use_cuda:
            torch.cuda.synchronize()

        return {
            'TotalLoss': l_vae_total.item(),
            'EncodeLoss': l_vae_reg.item(),
            'ReconLoss': l_vae_recon.item(),
        }

    output_dir = get_output_dir_path(config)
    trainer = Engine(update)
    timer = Timer(average=True)

    monitoring_metrics = ['TotalLoss', 'EncodeLoss', 'ReconLoss']

    for metric in monitoring_metrics:
        RunningAverage(alpha=0.98,
                       output_transform=partial(lambda x, metric: x[metric],
                                                metric=metric)).attach(
                                                    trainer, metric)

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.STARTED)
    def save_config(engine):
        config_to_save = defaultdict(dict)

        for key, child in config._asdict().items():
            for k, v in child._asdict().items():
                config_to_save[key][k] = v

        config_to_save['seed'] = seed
        config_to_save['output_dir'] = output_dir

        print('Training starts by the following configuration: ',
              config_to_save)

        if needs_save:
            save_path = os.path.join(output_dir, 'config.json')
            with open(save_path, 'w') as f:
                json.dump(config_to_save, f)

    @trainer.on(Events.ITERATION_COMPLETED)
    def show_logs(engine):
        if (engine.state.iteration - 1) % config.save.log_iter_interval == 0:
            columns = ['epoch', 'iteration'] + list(
                engine.state.metrics.keys())
            values = [str(engine.state.epoch), str(engine.state.iteration)] \
                   + [str(value) for value in engine.state.metrics.values()]

            message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(
                epoch=engine.state.epoch,
                max_epoch=config.training.n_epochs,
                i=engine.state.iteration,
                max_i=len(data_loader))

            for name, value in zip(columns, values):
                message += ' | {name}: {value}'.format(name=name, value=value)

            pbar.log_message(message)

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_logs(engine):
        if needs_save:
            fname = os.path.join(output_dir, 'logs.tsv')
            columns = ['epoch', 'iteration'] + list(
                engine.state.metrics.keys())
            values = [str(engine.state.epoch), str(engine.state.iteration)] \
                   + [str(value) for value in engine.state.metrics.values()]

            with open(fname, 'a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_images(engine):
        if needs_save:
            if engine.state.epoch % config.save.save_epoch_interval == 0:
                image = norm(engine.state.batch['image'])

                with torch.no_grad():
                    z, _, _ = E(image)
                    x_r = D(z)
                    x_p = D(fixed_z)

                image = denorm(image).detach().cpu()
                x_r = denorm(x_r).detach().cpu()
                x_p = denorm(x_p).detach().cpu()

                image = image[:config.save.n_save_images, ...]
                x_r = x_r[:config.save.n_save_images, ...]
                x_p = x_p[:config.save.n_save_images, ...]

                save_path = os.path.join(
                    output_dir, 'result_{}.png'.format(engine.state.epoch))
                save_image(torch.cat([image, x_r, x_p]).data, save_path)

    if needs_save:
        checkpoint_handler = ModelCheckpoint(
            output_dir,
            config.save.study_name,
            save_interval=config.save.save_epoch_interval,
            n_saved=config.save.n_saved,
            create_dir=True,
        )
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'E': E,
                                      'D': D
                                  })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    print('Training starts: [max_epochs] {}, [max_iterations] {}'.format(
        config.training.n_epochs, config.training.n_epochs * len(data_loader)))

    trainer.run(data_loader, config.training.n_epochs)
Beispiel #14
0
def train(device, net, dataloader, val_loader, args, logger, experiment):
    def update(engine, data):
        input_left, input_right, label = data['left_image'], data['right_image'], data['winner']
        input_left, input_right, label = input_left.to(device), input_right.to(device), label.to(device)
        rank_label = label.clone()
        inverse_label = label.clone()
        label[label==-1] = 0
        # zero the parameter gradients
        optimizer.zero_grad()
        rank_label = rank_label.float()

        start = timer()
        output_clf,output_rank_left, output_rank_right = net(input_left,input_right)
        end = timer()
        logger.info(f'FORWARD,{end-start:.4f}')

        #compute clf loss
        start = timer()
        loss_clf = clf_crit(output_clf,label)

        #compute ranking loss
        loss_rank = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit)
        loss = loss_clf + loss_rank

        end = timer()
        logger.info(f'LOSS,{end-start:.4f}')

        #compute ranking accuracy
        start = timer()
        rank_acc = compute_ranking_accuracy(output_rank_left, output_rank_right, label)
        end = timer()
        logger.info(f'RANK-ACC,{end-start:.4f}')

        # backward step
        start = timer()
        loss.backward()
        optimizer.step()
        end = timer()
        logger.info(f'BACKWARD,{end-start:.4f}')

        #swapped forward
        start = timer()
        inverse_label*=-1 #swap label
        inverse_rank_label = inverse_label.clone()
        inverse_rank_label = inverse_rank_label.float()
        inverse_label[inverse_label==-1] = 0
        end = timer()
        logger.info(f'SWAPPED-SETUP,{end-start:.4f}')
        start = timer()
        outputs, output_rank_left, output_rank_right = net(input_right,input_left) #pass swapped input
        end = timer()
        logger.info(f'SWAPPED-FORWARD,{end-start:.4f}')
        start = timer()
        inverse_loss_clf = clf_crit(outputs, inverse_label)
        #compute ranking loss
        inverse_loss_rank = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit)
        #swapped backward
        inverse_loss = inverse_loss_clf + inverse_loss_rank
        end = timer()
        logger.info(f'SWAPPED-LOSS,{end-start:.4f}')
        start = timer()
        inverse_loss.backward()
        optimizer.step()
        end = timer()
        logger.info(f'SWAPPED-BACKWARD,{end-start:.4f}')

        return  { 'loss':loss.item(),
                'loss_clf':loss_clf.item(),
                'loss_rank':loss_rank.item(),
                'y':label,
                'y_pred': output_clf,
                'rank_acc': rank_acc
                }

    def inference(engine,data):
        with torch.no_grad():
            start = timer()
            input_left, input_right, label = data['left_image'], data['right_image'], data['winner']
            input_left, input_right, label = input_left.to(device), input_right.to(device), label.to(device)
            rank_label = label.clone()
            label[label==-1] = 0
            rank_label = rank_label.float()
            # forward
            output_clf,output_rank_left, output_rank_right = net(input_left,input_right)
            loss_clf = clf_crit(output_clf,label)
            loss_rank = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit)
            rank_acc = compute_ranking_accuracy(output_rank_left, output_rank_right, label)
            loss = loss_clf + loss_rank
            end = timer()
            logger.info(f'INFERENCE,{end-start:.4f}')
            return  { 'loss':loss.item(),
                'loss_clf':loss_clf.item(),
                'loss_rank':loss_rank.item(),
                'y':label,
                'y_pred': output_clf,
                'rank_acc': rank_acc
                }
    net = net.to(device)

    clf_crit = nn.NLLLoss()
    rank_crit = nn.MarginRankingLoss(reduction='mean', margin=1)
    optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
    lamb = Variable(torch.FloatTensor([1]),requires_grad = False).cuda()[0]

    trainer = Engine(update)
    evaluator = Engine(inference)

    writer = SummaryWriter()
    RunningAverage(output_transform=lambda x: x['loss']).attach(trainer, 'loss')
    RunningAverage(output_transform=lambda x: x['loss_clf']).attach(trainer, 'loss_clf')
    RunningAverage(output_transform=lambda x: x['loss_rank']).attach(trainer, 'loss_rank')
    RunningAverage(output_transform=lambda x: x['rank_acc']).attach(trainer, 'rank_acc')
    RunningAverage(Accuracy(output_transform=lambda x: (x['y_pred'],x['y']))).attach(trainer,'avg_acc')

    RunningAverage(output_transform=lambda x: x['loss']).attach(evaluator, 'loss')
    RunningAverage(output_transform=lambda x: x['loss_clf']).attach(evaluator, 'loss_clf')
    RunningAverage(output_transform=lambda x: x['loss_rank']).attach(evaluator, 'loss_rank')
    RunningAverage(output_transform=lambda x: x['rank_acc']).attach(evaluator, 'rank_acc')
    RunningAverage(Accuracy(output_transform=lambda x: (x['y_pred'],x['y']))).attach(evaluator,'avg_acc')

    if args.pbar:
        pbar = ProgressBar(persist=False)
        pbar.attach(trainer,['loss','avg_acc', 'rank_acc'])

        pbar = ProgressBar(persist=False)
        pbar.attach(evaluator,['loss','loss_clf', 'loss_rank','avg_acc'])

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        net.eval()
        evaluator.run(val_loader)
        trainer.state.metrics['val_acc'] = evaluator.state.metrics['rank_acc']
        net.train()

        tb_log(
            {
                "accuracy":{
                    'accuracy':trainer.state.metrics['avg_acc'],
                    'rank_accuracy':trainer.state.metrics['rank_acc']
                },
                "loss": {
                    'total':trainer.state.metrics['loss'],
                    'clf':trainer.state.metrics['loss_clf'],
                    'rank':trainer.state.metrics['loss_rank']
                }
            },
            {
                "accuracy":{
                    'accuracy':evaluator.state.metrics['avg_acc'],
                    'rank_accuracy':evaluator.state.metrics['rank_acc']
                },
                "loss": {
                    'total':evaluator.state.metrics['loss'],
                    'clf':evaluator.state.metrics['loss_clf'],
                    'rank':evaluator.state.metrics['loss_rank']
                }
            },
            writer,
            args.attribute,
            trainer.state.epoch
        )

    handler = ModelCheckpoint(args.model_dir, '{}_{}_{}'.format(args.model, args.premodel, args.attribute),
                                n_saved=1,
                                create_dir=True,
                                save_as_state_dict=True,
                                require_empty=False,
                                score_function=lambda engine: engine.state.metrics['val_acc'])
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {
                'model': net
                })

    if (args.resume):
        def start_epoch(engine):
            engine.state.epoch = args.epoch
        trainer.add_event_handler(Events.STARTED, start_epoch)
        evaluator.add_event_handler(Events.STARTED, start_epoch)

    trainer.run(dataloader,max_epochs=args.max_epochs)
    def run_once(self):
        
        log_dir = self.log_dir

        misc.check_manual_seed(self.seed)
        train_pairs, valid_pairs = dataset.prepare_data_VIABLE_2048()
        print(len(train_pairs))
        # --------------------------- Dataloader

        train_augmentors = self.train_augmentors()
        train_dataset = dataset.DatasetSerial(train_pairs[:],
                        shape_augs=iaa.Sequential(train_augmentors[0]),
                        input_augs=iaa.Sequential(train_augmentors[1]))

        infer_augmentors = self.infer_augmentors()
        infer_dataset = dataset.DatasetSerial(valid_pairs[:],
                        shape_augs=iaa.Sequential(infer_augmentors))

        train_loader = data.DataLoader(train_dataset, 
                                num_workers=self.nr_procs_train, 
                                batch_size=self.train_batch_size, 
                                shuffle=True, drop_last=True)

        valid_loader = data.DataLoader(infer_dataset, 
                                num_workers=self.nr_procs_valid, 
                                batch_size=self.infer_batch_size, 
                                shuffle=True, drop_last=False)

        # --------------------------- Training Sequence

        if self.logging:
            misc.check_log_dir(log_dir)

        device = 'cuda'

        # networks
        input_chs = 3    
        net = DenseNet(input_chs, self.nr_classes)
        net = torch.nn.DataParallel(net).to(device)
        # print(net)

        # optimizers
        optimizer = optim.Adam(net.parameters(), lr=self.init_lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, self.lr_steps)

        # load pre-trained models
        if self.load_network:
            saved_state = torch.load(self.save_net_path)
            net.load_state_dict(saved_state)
        #
        trainer = Engine(lambda engine, batch: self.train_step(net, batch, optimizer, 'cuda'))
        inferer = Engine(lambda engine, batch: self.infer_step(net, batch, 'cuda'))

        train_output = ['loss', 'acc']
        infer_output = ['prob', 'true']
        ##

        if self.logging:
            checkpoint_handler = ModelCheckpoint(log_dir, self.chkpts_prefix, 
                                            save_interval=1, n_saved=120, require_empty=False)
            # adding handlers using `trainer.add_event_handler` method API
            trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
                                    to_save={'net': net}) 

        timer = Timer(average=True)
        timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                            pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
        timer.attach(inferer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                            pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

        # attach running average metrics computation
        # decay of EMA to 0.95 to match tensorpack default
        RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach(trainer, 'loss')
        RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(trainer, 'acc')

        # attach progress bar
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=['loss'])
        pbar.attach(inferer)

        # adding handlers using `trainer.on` decorator API
        @trainer.on(Events.EXCEPTION_RAISED)
        def handle_exception(engine, e):
            if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
                engine.terminate()
                warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
                checkpoint_handler(engine, {'net_exception': net})
            else:
                raise e

        # writer for tensorboard logging
        if self.logging:
            writer = SummaryWriter(log_dir=log_dir)
            json_log_file = log_dir + '/stats.json'
            with open(json_log_file, 'w') as json_file:
                json.dump({}, json_file) # create empty file

        @trainer.on(Events.EPOCH_STARTED)
        def log_lrs(engine):
            if self.logging:
                lr = float(optimizer.param_groups[0]['lr'])
                writer.add_scalar("lr", lr, engine.state.epoch)
            # advance scheduler clock
            scheduler.step()

        ####
        def update_logs(output, epoch, prefix, color):
            # print values and convert
            max_length = len(max(output.keys(), key=len))
            for metric in output:
                key = colored(prefix + '-' + metric.ljust(max_length), color)
                print('------%s : ' % key, end='')
                print('%0.7f' % output[metric])
            if 'train' in prefix:
                lr = float(optimizer.param_groups[0]['lr'])
                key = colored(prefix + '-' + 'lr'.ljust(max_length), color)
                print('------%s : %0.7f' % (key, lr))

            if not self.logging:
                return

            # create stat dicts
            stat_dict = {}
            for metric in output:
                metric_value = output[metric] 
                stat_dict['%s-%s' % (prefix, metric)] = metric_value

            # json stat log file, update and overwrite
            with open(json_log_file) as json_file:
                json_data = json.load(json_file)

            current_epoch = str(epoch)
            if current_epoch in json_data:
                old_stat_dict = json_data[current_epoch]
                stat_dict.update(old_stat_dict)
            current_epoch_dict = {current_epoch : stat_dict}
            json_data.update(current_epoch_dict)

            with open(json_log_file, 'w') as json_file:
                json.dump(json_data, json_file)

            # log values to tensorboard
            for metric in output:
                writer.add_scalar(prefix + '-' + metric, output[metric], current_epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_train_running_results(engine):
            """
            running training measurement
            """
            training_ema_output = engine.state.metrics #
            update_logs(training_ema_output, engine.state.epoch, prefix='train-ema', color='green')

        ####
        def get_init_accumulator(output_names):
            return {metric : [] for metric in output_names}

        import cv2
        def process_accumulated_output(output):
            def uneven_seq_to_np(seq, batch_size=self.infer_batch_size):
                if self.infer_batch_size == 1:
                    return np.squeeze(seq)
                    
                item_count = batch_size * (len(seq) - 1) + len(seq[-1])
                cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype)
                for idx in range(0, len(seq)-1):
                    cat_array[idx   * batch_size : 
                            (idx+1) * batch_size] = seq[idx] 
                cat_array[(idx+1) * batch_size:] = seq[-1]
                return cat_array
            #
            prob = uneven_seq_to_np(output['prob'])
            true = uneven_seq_to_np(output['true'])

            # cmap = plt.get_cmap('jet')
            # epi = prob[...,1]
            # epi = (cmap(epi) * 255.0).astype('uint8')
            # cv2.imwrite('sample.png', cv2.cvtColor(epi, cv2.COLOR_RGB2BGR))

            pred = np.argmax(prob, axis=-1)
            true = np.squeeze(true)

            # deal with ignore index
            pred = pred.flatten()
            true = true.flatten()
            pred = pred[true != 0] - 1
            true = true[true != 0] - 1

            acc = np.mean(pred == true)
            inter = (pred * true).sum()
            total = (pred + true).sum()
            dice = 2 * inter / total
            #
            proc_output = dict(acc=acc, dice=dice)
            return proc_output

        @trainer.on(Events.EPOCH_COMPLETED)
        def infer_valid(engine):
            """
            inference measurement
            """
            inferer.accumulator = get_init_accumulator(infer_output)
            inferer.run(valid_loader)
            output_stat = process_accumulated_output(inferer.accumulator)
            update_logs(output_stat, engine.state.epoch, prefix='valid', color='red')

        @inferer.on(Events.ITERATION_COMPLETED)
        def accumulate_outputs(engine):
            batch_output = engine.state.output
            for key, item in batch_output.items():
                engine.accumulator[key].extend([item])
        ###
        #Setup is done. Now let's run the training
        trainer.run(train_loader, self.nr_epochs)
        return
Beispiel #16
0
def main(hparams):
    results_dir = get_results_directory(hparams.output_dir)
    writer = SummaryWriter(log_dir=str(results_dir))

    ds = get_dataset(hparams.dataset, root=hparams.data_root)
    input_size, num_classes, train_dataset, test_dataset = ds

    hparams.seed = set_seed(hparams.seed)

    if hparams.n_inducing_points is None:
        hparams.n_inducing_points = num_classes

    print(f"Training with {hparams}")
    hparams.save(results_dir / "hparams.json")

    if hparams.ard:
        # Hardcoded to WRN output size
        ard = 640
    else:
        ard = None

    feature_extractor = WideResNet(
        spectral_normalization=hparams.spectral_normalization,
        dropout_rate=hparams.dropout_rate,
        coeff=hparams.coeff,
        n_power_iterations=hparams.n_power_iterations,
        batchnorm_momentum=hparams.batchnorm_momentum,
    )

    initial_inducing_points, initial_lengthscale = initial_values_for_GP(
        train_dataset, feature_extractor, hparams.n_inducing_points
    )

    gp = GP(
        num_outputs=num_classes,
        initial_lengthscale=initial_lengthscale,
        initial_inducing_points=initial_inducing_points,
        separate_inducing_points=hparams.separate_inducing_points,
        kernel=hparams.kernel,
        ard=ard,
        lengthscale_prior=hparams.lengthscale_prior,
    )

    model = DKL_GP(feature_extractor, gp)
    model = model.cuda()

    likelihood = SoftmaxLikelihood(num_classes=num_classes, mixing_weights=False)
    likelihood = likelihood.cuda()

    elbo_fn = VariationalELBO(likelihood, gp, num_data=len(train_dataset))

    parameters = [
        {"params": feature_extractor.parameters(), "lr": hparams.learning_rate},
        {"params": gp.parameters(), "lr": hparams.learning_rate},
        {"params": likelihood.parameters(), "lr": hparams.learning_rate},
    ]

    optimizer = torch.optim.SGD(
        parameters, momentum=0.9, weight_decay=hparams.weight_decay
    )

    milestones = [60, 120, 160]

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.2
    )

    def step(engine, batch):
        model.train()
        likelihood.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        y_pred = model(x)
        elbo = -elbo_fn(y_pred, y)

        elbo.backward()
        optimizer.step()

        return elbo.item()

    def eval_step(engine, batch):
        model.eval()
        likelihood.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        with torch.no_grad():
            y_pred = model(x)

        return y_pred, y

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "elbo")

    def output_transform(output):
        y_pred, y = output

        # Sample softmax values independently for classification at test time
        y_pred = y_pred.to_data_independent_dist()

        # The mean here is over likelihood samples
        y_pred = likelihood(y_pred).probs.mean(0)

        return y_pred, y

    metric = Accuracy(output_transform=output_transform)
    metric.attach(evaluator, "accuracy")

    metric = Loss(lambda y_pred, y: -elbo_fn(y_pred, y))
    metric.attach(evaluator, "elbo")

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=hparams.batch_size,
        shuffle=True,
        drop_last=True,
        **kwargs,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=512, shuffle=False, **kwargs
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        elbo = metrics["elbo"]

        print(f"Train - Epoch: {trainer.state.epoch} ELBO: {elbo:.2f} ")
        writer.add_scalar("Likelihood/train", elbo, trainer.state.epoch)

        if hparams.spectral_normalization:
            for name, layer in model.feature_extractor.named_modules():
                if isinstance(layer, torch.nn.Conv2d):
                    writer.add_scalar(
                        f"sigma/{name}", layer.weight_sigma, trainer.state.epoch
                    )

        if not hparams.ard:
            # Otherwise it's too much to submit to tensorboard
            length_scales = model.gp.covar_module.base_kernel.lengthscale.squeeze()
            for i in range(length_scales.shape[0]):
                writer.add_scalar(
                    f"length_scale/{i}", length_scales[i], trainer.state.epoch
                )

        if trainer.state.epoch > 150 and trainer.state.epoch % 5 == 0:
            _, auroc, aupr = get_ood_metrics(
                hparams.dataset, "SVHN", model, likelihood, hparams.data_root
            )
            print(f"OoD Metrics - AUROC: {auroc}, AUPR: {aupr}")
            writer.add_scalar("OoD/auroc", auroc, trainer.state.epoch)
            writer.add_scalar("OoD/auprc", aupr, trainer.state.epoch)

        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        elbo = metrics["elbo"]

        print(
            f"Test - Epoch: {trainer.state.epoch} "
            f"Acc: {acc:.4f} "
            f"ELBO: {elbo:.2f} "
        )

        writer.add_scalar("Likelihood/test", elbo, trainer.state.epoch)
        writer.add_scalar("Accuracy/test", acc, trainer.state.epoch)

        scheduler.step()

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    trainer.run(train_loader, max_epochs=200)

    # Done training - time to evaluate
    results = {}

    evaluator.run(train_loader)
    train_acc = evaluator.state.metrics["accuracy"]
    train_elbo = evaluator.state.metrics["elbo"]
    results["train_accuracy"] = train_acc
    results["train_elbo"] = train_elbo

    evaluator.run(test_loader)
    test_acc = evaluator.state.metrics["accuracy"]
    test_elbo = evaluator.state.metrics["elbo"]
    results["test_accuracy"] = test_acc
    results["test_elbo"] = test_elbo

    _, auroc, aupr = get_ood_metrics(
        hparams.dataset, "SVHN", model, likelihood, hparams.data_root
    )
    results["auroc_ood_svhn"] = auroc
    results["aupr_ood_svhn"] = aupr

    print(f"Test - Accuracy {results['test_accuracy']:.4f}")

    results_json = json.dumps(results, indent=4, sort_keys=True)
    (results_dir / "results.json").write_text(results_json)

    torch.save(model.state_dict(), results_dir / "model.pt")
    torch.save(likelihood.state_dict(), results_dir / "likelihood.pt")

    writer.close()
Beispiel #17
0
def test_attach_fail_with_string():
    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(TypeError):
        pbar.attach(engine, "a")
Beispiel #18
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         fresh):

    device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0'

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Beispiel #19
0
def test_pbar_fail_with_non_callable_transform():
    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(TypeError):
        pbar.attach(engine, output_transform=1)
Beispiel #20
0
def main():

    SEED = 1234

    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    TEXT = data.Field(lower=True, batch_first=True, tokenize='spacy')
    LABEL = data.LabelField(dtype=torch.float)

    train_data, test_data = datasets.IMDB.splits(TEXT,
                                                 LABEL,
                                                 root='/tmp/imdb/')
    train_data, valid_data = train_data.split(split_ratio=0.8,
                                              random_state=random.seed(SEED))

    TEXT.build_vocab(train_data,
                     vectors=GloVe(name='6B', dim=100, cache='/tmp/glove/'),
                     unk_init=torch.Tensor.normal_)

    LABEL.build_vocab(train_data)

    BATCH_SIZE = 64

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
        (train_data, valid_data, test_data),
        batch_size=BATCH_SIZE,
        device=device)
    vocab_size, embedding_dim = TEXT.vocab.vectors.shape

    class SentimentAnalysisCNN(nn.Module):
        def __init__(self,
                     vocab_size,
                     embedding_dim,
                     kernel_sizes,
                     num_filters,
                     num_classes,
                     d_prob,
                     mode,
                     use_drop=False):
            """
            Args:
                vocab_size : int - size of vocabulary in dictionary
                embedding_dim : int - the dimension of word embedding vector
                kernel_sizes : list of int - sequence of sizes of kernels in this architecture
                num_filters : how many filters used for each layers
                num_classes : int - number of classes to classify
                d_prob: probability for dropout layer
                mode:  one of :
                        static      : pretrained weights, non-trainable
                        nonstatic   : pretrained weights, trainable
                        rand        : random init weights
                use_drop : use drop or not in this class
            """
            super(SentimentAnalysisCNN, self).__init__()
            self.vocab_size = vocab_size
            self.embedding_dim = embedding_dim
            self.kernel_sizes = kernel_sizes
            self.num_filters = num_filters
            self.num_classes = num_classes
            self.d_prob = d_prob
            self.mode = mode
            self.embedding = nn.Embedding(vocab_size,
                                          embedding_dim,
                                          padding_idx=1)
            self.load_embeddings()
            self.conv = nn.ModuleList([
                nn.Sequential(
                    nn.Conv1d(in_channels=embedding_dim,
                              out_channels=num_filters,
                              kernel_size=k,
                              stride=1), nn.Dropout(p=0.5, inplace=True))
                for k in kernel_sizes
            ])
            self.use_drop = use_drop
            if self.use_drop:
                self.dropout = nn.Dropout(d_prob)
            self.fc = nn.Linear(len(kernel_sizes) * num_filters, num_classes)

        def forward(self, x):
            batch_size, sequence_length = x.shape
            x = self.embedding(x).transpose(1, 2)
            x = [F.relu(conv(x)) for conv in self.conv]
            x = [F.max_pool1d(c, c.size(-1)).squeeze(dim=-1) for c in x]
            x = torch.cat(x, dim=1)
            if self.use_drop:
                x = self.fc(self.dropout(x))
            x = self.fc(x)
            return torch.sigmoid(x).squeeze()

        def load_embeddings(self):
            if 'static' in self.mode:
                self.embedding.weight.data.copy_(TEXT.vocab.vectors)
                if 'non' not in self.mode:
                    self.embedding.weight.data.requires_grad = False
                    print(
                        'Loaded pretrained embeddings, weights are not trainable.'
                    )
                else:
                    self.embedding.weight.data.requires_grad = True
                    print(
                        'Loaded pretrained embeddings, weights are trainable.')
            elif self.mode == 'rand':
                print('Randomly initialized embeddings are used.')
            else:
                raise ValueError(
                    'Unexpected value of mode. Please choose from static, nonstatic, rand.'
                )

    model = SentimentAnalysisCNN(
        vocab_size=vocab_size,  #pkgmodel
        embedding_dim=embedding_dim,
        kernel_sizes=[3, 4, 5],
        num_filters=100,
        num_classes=1,
        d_prob=0.5,
        mode='static')
    model.to(device)
    ## switch back and forth the two optimizers
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)

    ## optimizer provide better performance but get overfitting quickly
    optimizer = Ranger(model.parameters(), weight_decay=0.1)

    criterion = nn.BCELoss()

    def process_function(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = batch.text, batch.label
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        return loss.item()

    def eval_function(engine, batch):
        model.eval()
        with torch.no_grad():
            x, y = batch.text, batch.label
            y_pred = model(x)
            return y_pred, y

    trainer = Engine(process_function)
    train_evaluator = Engine(eval_function)
    validation_evaluator = Engine(eval_function)

    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    def thresholded_output_transform(output):
        y_pred, y = output
        y_pred = torch.round(y_pred)
        return y_pred, y

    Accuracy(output_transform=thresholded_output_transform).attach(
        train_evaluator, 'accuracy')
    Loss(criterion).attach(train_evaluator, 'bce')

    Accuracy(output_transform=thresholded_output_transform).attach(
        validation_evaluator, 'accuracy')
    Loss(criterion).attach(validation_evaluator, 'bce')

    pbar = ProgressBar(persist=True, bar_format="")
    pbar.attach(trainer, ['loss'])

    def score_function(engine):
        val_loss = engine.state.metrics['bce']
        return -val_loss

    handler = EarlyStopping(patience=5,
                            score_function=score_function,
                            trainer=trainer)
    validation_evaluator.add_event_handler(Events.COMPLETED, handler)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_iterator)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_bce = metrics['bce']
        pbar.log_message(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_bce))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        validation_evaluator.run(valid_iterator)
        metrics = validation_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_bce = metrics['bce']
        pbar.log_message(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_bce))
        pbar.n = pbar.last_print_n = 0

    checkpointer = ModelCheckpoint('/tmp/models',
                                   'textcnn_ranger_wd_0_1',
                                   save_interval=1,
                                   n_saved=2,
                                   create_dir=True,
                                   save_as_state_dict=True,
                                   require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer,
                              {'textcnn_ranger_wd_0_1': model})
    # trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)

    trainer.run(train_iterator, max_epochs=20)
Beispiel #21
0
    def __init__(self: TrainerType,
                 model: nn.Module,
                 optimizer: Optimizer,
                 checkpoint_dir: str = '../../checkpoints',
                 experiment_name: str = 'experiment',
                 model_checkpoint: Optional[str] = None,
                 optimizer_checkpoint: Optional[str] = None,
                 metrics: types.GenericDict = None,
                 patience: int = 10,
                 validate_every: int = 1,
                 accumulation_steps: int = 1,
                 loss_fn: Union[_Loss, DataParallelCriterion] = None,
                 non_blocking: bool = True,
                 retain_graph: bool = False,
                 dtype: torch.dtype = torch.float,
                 device: str = 'cpu',
                 parallel: bool = False) -> None:
        self.dtype = dtype
        self.retain_graph = retain_graph
        self.non_blocking = non_blocking
        self.device = device
        self.loss_fn = loss_fn
        self.validate_every = validate_every
        self.patience = patience
        self.accumulation_steps = accumulation_steps
        self.checkpoint_dir = checkpoint_dir

        model_checkpoint = self._check_checkpoint(model_checkpoint)
        optimizer_checkpoint = self._check_checkpoint(optimizer_checkpoint)

        self.model = cast(
            nn.Module,
            from_checkpoint(model_checkpoint,
                            model,
                            map_location=torch.device('cpu')))
        self.model = self.model.type(dtype).to(device)
        self.optimizer = from_checkpoint(optimizer_checkpoint, optimizer)
        self.parallel = parallel
        if parallel:
            if device == 'cpu':
                raise ValueError("parallel can be used only with cuda device")
            self.model = DataParallelModel(self.model).to(device)
            self.loss_fn = DataParallelCriterion(self.loss_fn)  # type: ignore
        if metrics is None:
            metrics = {}
        if 'loss' not in metrics:
            if self.parallel:
                metrics['loss'] = Loss(
                    lambda x, y: self.loss_fn(x, y).mean())  # type: ignore
            else:
                metrics['loss'] = Loss(self.loss_fn)
        self.trainer = Engine(self.train_step)
        self.train_evaluator = Engine(self.eval_step)
        self.valid_evaluator = Engine(self.eval_step)
        for name, metric in metrics.items():
            metric.attach(self.train_evaluator, name)
            metric.attach(self.valid_evaluator, name)

        self.pbar = ProgressBar()
        self.val_pbar = ProgressBar(desc='Validation')

        if checkpoint_dir is not None:
            self.checkpoint = CheckpointHandler(checkpoint_dir,
                                                experiment_name,
                                                score_name='validation_loss',
                                                score_function=self._score_fn,
                                                n_saved=2,
                                                require_empty=False,
                                                save_as_state_dict=True)

        self.early_stop = EarlyStopping(patience, self._score_fn, self.trainer)

        self.val_handler = EvaluationHandler(pbar=self.pbar,
                                             validate_every=1,
                                             early_stopping=self.early_stop)
        self.attach()
        log.info(
            f'Trainer configured to run {experiment_name}\n'
            f'\tpretrained model: {model_checkpoint} {optimizer_checkpoint}\n'
            f'\tcheckpoint directory: {checkpoint_dir}\n'
            f'\tpatience: {patience}\n'
            f'\taccumulation steps: {accumulation_steps}\n'
            f'\tnon blocking: {non_blocking}\n'
            f'\tretain graph: {retain_graph}\n'
            f'\tdevice: {device}\n'
            f'\tmodel dtype: {dtype}\n'
            f'\tparallel: {parallel}')
Beispiel #22
0
def run(args):
    train_loader, val_loader = get_data_loaders(args.dir, args.batch_size,
                                                args.num_workers)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    num_classes = CityscapesDataset.num_instance_classes() + 1
    model = models.box2pix(num_classes=num_classes)
    model.init_from_googlenet()

    writer = create_summary_writer(model, train_loader, args.log_dir)

    if torch.cuda.device_count() > 1:
        print("Using %d GPU(s)" % torch.cuda.device_count())
        model = nn.DataParallel(model)

    model = model.to(device)

    semantics_criterion = nn.CrossEntropyLoss(ignore_index=255)
    offsets_criterion = nn.MSELoss()
    box_criterion = BoxLoss(num_classes, gamma=2)
    multitask_criterion = MultiTaskLoss().to(device)

    box_coder = BoxCoder()
    optimizer = optim.Adam([{
        'params': model.parameters(),
        'weight_decay': 5e-4
    }, {
        'params': multitask_criterion.parameters()
    }],
                           lr=args.lr)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            multitask_criterion.load_state_dict(checkpoint['multitask'])
            print("Loaded checkpoint '{}' (Epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    def _prepare_batch(batch, non_blocking=True):
        x, instance, boxes, labels = batch

        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(instance,
                               device=device,
                               non_blocking=non_blocking),
                convert_tensor(boxes, device=device,
                               non_blocking=non_blocking),
                convert_tensor(labels,
                               device=device,
                               non_blocking=non_blocking))

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, instance, boxes, labels = _prepare_batch(batch)
        boxes, labels = box_coder.encode(boxes, labels)

        loc_preds, conf_preds, semantics_pred, offsets_pred = model(x)

        semantics_loss = semantics_criterion(semantics_pred, instance)
        offsets_loss = offsets_criterion(offsets_pred, instance)
        box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds,
                                            labels)

        loss = multitask_criterion(semantics_loss, offsets_loss, box_loss,
                                   conf_loss)

        loss.backward()
        optimizer.step()

        return {
            'loss': loss.item(),
            'loss_semantics': semantics_loss.item(),
            'loss_offsets': offsets_loss.item(),
            'loss_ssdbox': box_loss.item(),
            'loss_ssdclass': conf_loss.item()
        }

    trainer = Engine(_update)

    checkpoint_handler = ModelCheckpoint(args.output_dir,
                                         'checkpoint',
                                         save_interval=1,
                                         n_saved=10,
                                         require_empty=False,
                                         create_dir=True,
                                         save_as_state_dict=False)
    timer = Timer(average=True)

    # attach running average metrics
    train_metrics = [
        'loss', 'loss_semantics', 'loss_offsets', 'loss_ssdbox',
        'loss_ssdclass'
    ]
    for m in train_metrics:
        transform = partial(lambda x, metric: x[metric], metric=m)
        RunningAverage(output_transform=transform).attach(trainer, m)

    # attach progress bar
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=train_metrics)

    checkpoint = {
        'model': model.state_dict(),
        'epoch': trainer.state.epoch,
        'optimizer': optimizer.state_dict(),
        'multitask': multitask_criterion.state_dict()
    }
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={'checkpoint': checkpoint})

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            x, instance, boxes, labels = _prepare_batch(batch)
            loc_preds, conf_preds, semantics, offsets_pred = model(x)
            boxes_preds, labels_preds, scores_preds = box_coder.decode(
                loc_preds, F.softmax(conf_preds, dim=1), score_thresh=0.01)

            semantics_loss = semantics_criterion(semantics, instance)
            offsets_loss = offsets_criterion(offsets_pred, instance)
            box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds,
                                                labels)

            semantics_pred = semantics.argmax(dim=1)
            instances = helper.assign_pix2box(semantics_pred, offsets_pred,
                                              boxes_preds, labels_preds)

        return {
            'loss': (semantics_loss, offsets_loss, {
                'box_loss': box_loss,
                'conf_loss': conf_loss
            }),
            'objects':
            (boxes_preds, labels_preds, scores_preds, boxes, labels),
            'semantics':
            semantics_pred,
            'instances':
            instances
        }

    train_evaluator = Engine(_inference)
    Loss(multitask_criterion,
         output_transform=lambda x: x['loss']).attach(train_evaluator, 'loss')
    MeanAveragePrecision(num_classes,
                         output_transform=lambda x: x['objects']).attach(
                             train_evaluator, 'objects')
    IntersectionOverUnion(num_classes,
                          output_transform=lambda x: x['semantics']).attach(
                              train_evaluator, 'semantics')

    evaluator = Engine(_inference)
    Loss(multitask_criterion,
         output_transform=lambda x: x['loss']).attach(evaluator, 'loss')
    MeanAveragePrecision(num_classes,
                         output_transform=lambda x: x['objects']).attach(
                             evaluator, 'objects')
    IntersectionOverUnion(num_classes,
                          output_transform=lambda x: x['semantics']).attach(
                              evaluator, 'semantics')

    @trainer.on(Events.STARTED)
    def initialize(engine):
        if args.resume:
            engine.state.epoch = args.start_epoch

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            "Epoch [{}/{}] done. Time per batch: {:.3f}[s]".format(
                engine.state.epoch, engine.state.max_epochs, timer.value()))
        timer.reset()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iteration = (engine.state.iteration - 1) % len(train_loader) + 1
        if iteration % args.log_interval == 0:
            writer.add_scalar("training/loss", engine.state.output['loss'],
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        loss = metrics['loss']
        mean_ap = metrics['objects']
        iou = metrics['semantics']

        pbar.log_message(
            'Training results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
            .format(loss, evaluator.state.epochs, evaluator.state.max_epochs,
                    mean_ap, iou * 100.0))

        writer.add_scalar("train-val/loss", loss, engine.state.epoch)
        writer.add_scalar("train-val/mAP", mean_ap, engine.state.epoch)
        writer.add_scalar("train-val/IoU", iou, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        mean_ap = metrics['objects']
        iou = metrics['semantics']

        pbar.log_message(
            'Validation results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
            .format(loss, evaluator.state.epochs, evaluator.state.max_epochs,
                    mean_ap, iou * 100.0))

        writer.add_scalar("validation/loss", loss, engine.state.epoch)
        writer.add_scalar("validation/mAP", mean_ap, engine.state.epoch)
        writer.add_scalar("validation/IoU", iou, engine.state.epoch)

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            checkpoint_handler(engine, {'model_exception': model})
        else:
            raise e

    @trainer.on(Events.COMPLETED)
    def save_final_model(engine):
        checkpoint_handler(engine, {'final': model})

    trainer.run(train_loader, max_epochs=args.epochs)
    writer.close()
Beispiel #23
0
def train_gan(logger: Logger,
              experiment_dir: Path,
              data_dir: Path,
              batch_size: int,
              z_dim: int,
              g_filters: int,
              d_filters: int,
              learning_rate: float,
              beta_1: float,
              epochs: int,
              saved_g: bool = False,
              saved_d: bool = False,
              seed: Optional[int] = None,
              g_extra_layers: int = 0,
              d_extra_layers: int = 0,
              scheduler: bool = False) -> None:
    seed = fix_random_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Train started with seed: {seed}")
    dataset = HDF5ImageDataset(image_dir=data_dir)
    desired_minkowski = pickle.load(
        (data_dir / 'minkowski.pkl').open(mode='rb'))

    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        drop_last=True,
                        pin_memory=True)
    iterations = epochs * len(loader)
    img_size = dataset.shape[-1]
    num_channels = dataset.shape[0]

    # networks
    net_g = Generator(img_size=img_size,
                      z_dim=z_dim,
                      num_channels=num_channels,
                      num_filters=g_filters,
                      num_extra_layers=g_extra_layers).to(device)
    net_d = Discriminator(img_size=img_size,
                          num_channels=num_channels,
                          num_filters=d_filters,
                          num_extra_layers=d_extra_layers).to(device)
    summary(net_g, (z_dim, 1, 1, 1))
    summary(net_d, (num_channels, img_size, img_size, img_size))

    if saved_g:
        net_g.load_state_dict(torch.load(experiment_dir / G_CHECKPOINT_NAME))
        logger.info("Loaded generator checkpoint")
    if saved_d:
        net_d.load_state_dict(torch.load(experiment_dir / D_CHECKPOINT_NAME))
        logger.info("Loaded discriminator checkpoint")

    # criterion
    criterion = nn.BCELoss()

    optimizer_g = optim.Adam(net_g.parameters(),
                             lr=learning_rate,
                             betas=(beta_1, 0.999))
    optimizer_d = optim.Adam(net_d.parameters(),
                             lr=learning_rate,
                             betas=(beta_1, 0.999))

    patience = int(3000 / len(loader))
    scheduler_g = optim.lr_scheduler.ReduceLROnPlateau(optimizer_g,
                                                       min_lr=1e-6,
                                                       verbose=True,
                                                       patience=patience)
    scheduler_d = optim.lr_scheduler.ReduceLROnPlateau(optimizer_d,
                                                       min_lr=1e-6,
                                                       verbose=True,
                                                       patience=patience)

    # labels smoothing
    real_labels = torch.full((batch_size, ), fill_value=0.9, device=device)
    fake_labels = torch.zeros((batch_size, ), device=device)
    fixed_noise = torch.randn(1, z_dim, 1, 1, 1, device=device)

    def step(engine: Engine, batch: torch.Tensor) -> Dict[str, float]:
        """
        Train step function

        :param engine: pytorch ignite train engine
        :param batch: batch to process
        :return batch metrics
        """
        # get batch of fake images from generator
        fake_batch = net_g(
            torch.randn(batch_size, z_dim, 1, 1, 1, device=device))
        # 1. Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        batch = batch.to(device)
        optimizer_d.zero_grad()
        # train D with real and fake batches
        d_out_real = net_d(batch)
        d_out_fake = net_d(fake_batch.detach())
        loss_d_real = criterion(d_out_real, real_labels)
        loss_d_fake = criterion(d_out_fake, fake_labels)
        # mean probabilities
        p_real = d_out_real.mean().item()
        p_fake = d_out_fake.mean().item()

        loss_d = (loss_d_real + loss_d_fake) / 2
        loss_d.backward()
        optimizer_d.step()

        # 2. Update G network: maximize log(D(G(z)))
        loss_g = None
        p_gen = None
        for _ in range(1):
            fake_batch = net_g(
                torch.randn(batch_size, z_dim, 1, 1, 1, device=device))
            optimizer_g.zero_grad()
            d_out_fake = net_d(fake_batch)
            loss_g = criterion(d_out_fake, real_labels)
            # mean fake generator probability
            p_gen = d_out_fake.mean().item()
            loss_g.backward()
            optimizer_g.step()

        # minkowski functional measures
        cube = net_g(fixed_noise).detach().squeeze().cpu()
        cube = cube.mul(0.5).add(0.5).numpy()
        cube = postprocess_cube(cube)
        cube = np.pad(cube, ((1, 1), (1, 1), (1, 1)),
                      mode='constant',
                      constant_values=0)
        v, s, b, xi = compute_minkowski(cube)
        return {
            'loss_d': loss_d.item(),
            'loss_g': loss_g.item(),
            'p_real': p_real,
            'p_fake': p_fake,
            'p_gen': p_gen,
            'V': v,
            'S': s,
            'B': b,
            'Xi': xi
        }

    # ignite objects
    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(dirname=str(experiment_dir),
                                         filename_prefix=CKPT_PREFIX,
                                         save_interval=5,
                                         n_saved=50,
                                         require_empty=False)

    # attach running average metrics
    monitoring_metrics = [
        'loss_d', 'loss_g', 'p_real', 'p_fake', 'p_gen', 'V', 'S', 'B', 'Xi'
    ]
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['loss_d']).attach(
        trainer, 'loss_d')
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['loss_g']).attach(
        trainer, 'loss_g')
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_real']).attach(
        trainer, 'p_real')
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_fake']).attach(
        trainer, 'p_fake')
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_gen']).attach(
        trainer, 'p_gen')
    RunningAverage(alpha=ALPHA,
                   output_transform=lambda x: x['V']).attach(trainer, 'V')
    RunningAverage(alpha=ALPHA,
                   output_transform=lambda x: x['S']).attach(trainer, 'S')
    RunningAverage(alpha=ALPHA,
                   output_transform=lambda x: x['B']).attach(trainer, 'B')
    RunningAverage(alpha=ALPHA,
                   output_transform=lambda x: x['Xi']).attach(trainer, 'Xi')

    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_logs(engine):
        if (engine.state.iteration - 1) % PRINT_FREQ == 0:
            fname = experiment_dir / LOGS_FNAME
            columns = ['iter'] + list(engine.state.metrics.keys())
            values = [str(engine.state.iteration)] + [
                str(round(value, 7))
                for value in engine.state.metrics.values()
            ]

            with fname.open(mode='a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

            message = f"[{engine.state.epoch}/{epochs}][{engine.state.iteration:04d}/{iterations}]"
            for name, value in zip(engine.state.metrics.keys(),
                                   engine.state.metrics.values()):
                message += f" | {name}: {value:0.5f}"

            pbar.log_message(message)

    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'net_g': net_g,
                                  'net_d': net_d
                              })

    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        df = pd.read_csv(experiment_dir / LOGS_FNAME, delimiter='\t')

        fig_1 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['loss_d'], label='loss_d', linestyle='dashed')
        plt.plot(df['iter'], df['loss_g'], label='loss_g')
        plt.xlabel('Iteration number')
        plt.legend()
        fig_1.savefig(experiment_dir / ('loss_' + PLOT_FNAME))
        plt.close(fig_1)

        fig_2 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['p_real'], label='p_real', linestyle='dashed')
        plt.plot(df['iter'], df['p_fake'], label='p_fake', linestyle='dashdot')
        plt.plot(df['iter'], df['p_gen'], label='p_gen')
        plt.xlabel('Iteration number')
        plt.legend()
        fig_2.savefig(experiment_dir / PLOT_FNAME)
        plt.close(fig_2)

        desired_v = [desired_minkowski[0]] * len(df['iter'])
        desired_s = [desired_minkowski[1]] * len(df['iter'])
        desired_b = [desired_minkowski[2]] * len(df['iter'])
        desired_xi = [desired_minkowski[3]] * len(df['iter'])

        fig_3 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['V'], label='V', color='b')
        plt.plot(df['iter'], desired_v, color='b', linestyle='dashed')
        plt.xlabel('Iteration number')
        plt.ylabel('Minkowski functional V')
        plt.legend()
        fig_3.savefig(experiment_dir / ('minkowski_V_' + PLOT_FNAME))
        plt.close(fig_3)

        fig_4 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['S'], label='S', color='r')
        plt.plot(df['iter'], desired_s, color='r', linestyle='dashed')
        plt.xlabel('Iteration number')
        plt.ylabel('Minkowski functional S')
        plt.legend()
        fig_4.savefig(experiment_dir / ('minkowski_S_' + PLOT_FNAME))
        plt.close(fig_4)

        fig_5 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['B'], label='B', color='g')
        plt.plot(df['iter'], desired_b, color='g', linestyle='dashed')
        plt.xlabel('Iteration number')
        plt.ylabel('Minkowski functional B')
        plt.legend()
        fig_5.savefig(experiment_dir / ('minkowski_B_' + PLOT_FNAME))
        plt.close(fig_5)

        fig_6 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['Xi'], label='Xi', color='y')
        plt.plot(df['iter'], desired_xi, color='y', linestyle='dashed')
        plt.xlabel('Iteration number')
        plt.ylabel('Minkowski functional Xi')
        plt.legend()
        fig_6.savefig(experiment_dir / ('minkowski_Xi_' + PLOT_FNAME))
        plt.close(fig_6)

    if scheduler:

        @trainer.on(Events.EPOCH_COMPLETED)
        def lr_scheduler(engine):
            desired_b = desired_minkowski[2]
            desired_xi = desired_minkowski[3]

            current_b = engine.state.metrics['B']
            current_xi = engine.state.metrics['Xi']

            delta = abs(desired_b - current_b) + abs(desired_xi - current_xi)

            scheduler_d.step(delta)
            scheduler_g.step(delta)

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')

            create_plots(engine)
            checkpoint_handler(engine, {
                'net_g_exception': net_g,
                'net_d_exception': net_d
            })
        else:
            raise e

    trainer.run(loader, epochs)
Beispiel #24
0
    def train(self, config, **kwargs):
        """Trains a model on the given configurations.
        :param config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE
        :param **kwargs: parameters to overwrite yaml config
        """
        from pycocoevalcap.cider.cider import Cider

        config_parameters = train_util.parse_config_or_kwargs(config, **kwargs)
        config_parameters["seed"] = self.seed
        outputdir = os.path.join(
            config_parameters["outputpath"], config_parameters["model"],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))

        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            "run",
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: engine.state.metrics["score"],
            score_name="score")

        logger = train_util.genlogger(os.path.join(outputdir, "train.log"))
        # print passed config parameters
        logger.info("Storing files in: {}".format(outputdir))
        train_util.pprint_dict(config_parameters, logger.info)

        zh = config_parameters["zh"]
        vocabulary = torch.load(config_parameters["vocab_file"])
        train_loader, cv_loader, info = self._get_dataloaders(
            config_parameters, vocabulary)
        config_parameters["inputdim"] = info["inputdim"]
        cv_key2refs = info["cv_key2refs"]
        logger.info("<== Estimating Scaler ({}) ==>".format(
            info["scaler"].__class__.__name__))
        logger.info("Feature: {} Input dimension: {} Vocab Size: {}".format(
            config_parameters["feature_file"], info["inputdim"],
            len(vocabulary)))

        model = self._get_model(config_parameters, len(vocabulary))
        if "pretrained_word_embedding" in config_parameters:
            embeddings = np.load(
                config_parameters["pretrained_word_embedding"])
            model.load_word_embeddings(
                embeddings,
                tune=config_parameters["tune_word_embedding"],
                projection=True)
        model = model.to(self.device)
        train_util.pprint_dict(model, logger.info, formatter="pretty")
        optimizer = getattr(torch.optim, config_parameters["optimizer"])(
            model.parameters(), **config_parameters["optimizer_args"])
        train_util.pprint_dict(optimizer, logger.info, formatter="pretty")

        criterion = torch.nn.CrossEntropyLoss().to(self.device)
        crtrn_imprvd = train_util.criterion_improver(
            config_parameters['improvecriterion'])

        def _train_batch(engine, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(model, batch, "train")
                loss = criterion(output["packed_logits"],
                                 output["targets"]).to(self.device)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()
                output["loss"] = loss.item()
                return output

        trainer = Engine(_train_batch)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(
            trainer, "running_loss")
        pbar = ProgressBar(persist=False, ascii=True, ncols=100)
        pbar.attach(trainer, ["running_loss"])

        key2pred = {}

        def _inference(engine, batch):
            model.eval()
            keys = batch[2]
            with torch.no_grad():
                output = self._forward(model, batch, "validation")
                seqs = output["seqs"].cpu().numpy()
                for (idx, seq) in enumerate(seqs):
                    if keys[idx] in key2pred:
                        continue
                    candidate = self._convert_idx2sentence(seq, vocabulary, zh)
                    key2pred[keys[idx]] = [
                        candidate,
                    ]
                return output

        metrics = {
            "loss":
            Loss(criterion,
                 output_transform=lambda x: (x["packed_logits"], x["targets"]))
        }

        evaluator = Engine(_inference)

        def eval_cv(engine, key2pred, key2refs):
            scorer = Cider(zh=zh)
            score, scores = scorer.compute_score(key2refs, key2pred)
            engine.state.metrics["score"] = score
            key2pred.clear()

        evaluator.add_event_handler(Events.EPOCH_COMPLETED, eval_cv, key2pred,
                                    cv_key2refs)

        for name, metric in metrics.items():
            metric.attach(evaluator, name)

        trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                  train_util.log_results, evaluator, cv_loader,
                                  logger.info, ["loss", "score"])

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, train_util.save_model_on_improved,
            crtrn_imprvd, "score", {
                "model": model.state_dict(),
                "config": config_parameters,
                "scaler": info["scaler"]
            }, os.path.join(outputdir, "saved.pth"))

        scheduler = getattr(torch.optim.lr_scheduler,
                            config_parameters["scheduler"])(
                                optimizer,
                                **config_parameters["scheduler_args"])
        evaluator.add_event_handler(Events.EPOCH_COMPLETED,
                                    train_util.update_lr, scheduler, "score")

        evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                    {
                                        "model": model,
                                    })

        trainer.run(train_loader, max_epochs=config_parameters["epochs"])
        return outputdir
Beispiel #25
0
def main():
    # region Setup
    conf = parse_args()
    setup_seeds(conf.session.seed)
    tb_logger, tb_img_logger, json_logger = setup_all_loggers(conf)
    logger.info("Parsed configuration:\n" +
                pyaml.dump(OmegaConf.to_container(conf),
                           safe=True,
                           sort_dicts=False,
                           force_embed=True))

    # region Predicate classification engines
    datasets, dataset_metadata = build_datasets(conf.dataset)
    dataloaders = build_dataloaders(conf, datasets)

    model = build_model(conf.model,
                        dataset_metadata["train"]).to(conf.session.device)
    criterion = PredicateClassificationCriterion(conf.losses)

    pred_class_trainer = Trainer(pred_class_training_step, conf)
    pred_class_trainer.model = model
    pred_class_trainer.criterion = criterion
    pred_class_trainer.optimizer, scheduler = build_optimizer_and_scheduler(
        conf.optimizer, pred_class_trainer.model)

    pred_class_validator = Validator(pred_class_validation_step, conf)
    pred_class_validator.model = model
    pred_class_validator.criterion = criterion

    pred_class_tester = Validator(pred_class_validation_step, conf)
    pred_class_tester.model = model
    pred_class_tester.criterion = criterion
    # endregion

    if "resume" in conf:
        checkpoint = Path(conf.resume.checkpoint).expanduser().resolve()
        logger.debug(f"Resuming checkpoint from {checkpoint}")
        Checkpoint.load_objects(
            {
                "model": pred_class_trainer.model,
                "optimizer": pred_class_trainer.optimizer,
                "scheduler": scheduler,
                "trainer": pred_class_trainer,
            },
            checkpoint=torch.load(checkpoint,
                                  map_location=conf.session.device),
        )
        logger.info(f"Resumed from {checkpoint}, "
                    f"epoch {pred_class_trainer.state.epoch}, "
                    f"samples {pred_class_trainer.global_step()}")
    # endregion

    # region Predicate classification training callbacks
    def increment_samples(trainer: Trainer):
        images = trainer.state.batch[0]
        trainer.state.samples += len(images)

    pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                         increment_samples)

    ProgressBar(persist=True, desc="Pred class train").attach(
        pred_class_trainer, output_transform=itemgetter("losses"))

    tb_logger.attach(
        pred_class_trainer,
        OptimizerParamsHandler(
            pred_class_trainer.optimizer,
            param_name="lr",
            tag="z",
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.EPOCH_STARTED,
    )

    pred_class_trainer.add_event_handler(
        Events.ITERATION_COMPLETED,
        PredicateClassificationMeanAveragePrecisionBatch())
    pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                         RecallAtBatch(sizes=(5, 10)))

    tb_logger.attach(
        pred_class_trainer,
        OutputHandler(
            "train",
            output_transform=lambda o: {
                **o["losses"],
                "pc/mAP": o["pc/mAP"].mean().item(),
                **{k: r.mean().item()
                   for k, r in o["recalls"].items()},
            },
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.ITERATION_COMPLETED,
    )

    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Predicate classification training",
        "train",
        json_logger=None,
        tb_logger=tb_logger,
        global_step_fn=pred_class_trainer.global_step,
    )
    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        PredicateClassificationLogger(
            grid=(2, 3),
            tag="train",
            logger=tb_img_logger.writer,
            metadata=dataset_metadata["train"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    tb_logger.attach(
        pred_class_trainer,
        EpochHandler(
            pred_class_trainer,
            tag="z",
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.EPOCH_COMPLETED,
    )

    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        lambda _: pred_class_validator.run(dataloaders["val"]))
    # endregion

    # region Predicate classification validation callbacks
    ProgressBar(persist=True,
                desc="Pred class val").attach(pred_class_validator)

    if conf.losses["bce"]["weight"] > 0:
        Average(output_transform=lambda o: o["losses"]["loss/bce"]).attach(
            pred_class_validator, "loss/bce")
    if conf.losses["rank"]["weight"] > 0:
        Average(output_transform=lambda o: o["losses"]["loss/rank"]).attach(
            pred_class_validator, "loss/rank")
    Average(output_transform=lambda o: o["losses"]["loss/total"]).attach(
        pred_class_validator, "loss/total")

    PredicateClassificationMeanAveragePrecisionEpoch(
        itemgetter("target", "output")).attach(pred_class_validator, "pc/mAP")
    RecallAtEpoch((5, 10),
                  itemgetter("target",
                             "output")).attach(pred_class_validator,
                                               "pc/recall_at")

    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        lambda val_engine: scheduler.step(val_engine.state.metrics["loss/total"
                                                                   ]),
    )
    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Predicate classification validation",
        "val",
        json_logger,
        tb_logger,
        pred_class_trainer.global_step,
    )
    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        PredicateClassificationLogger(
            grid=(2, 3),
            tag="val",
            logger=tb_img_logger.writer,
            metadata=dataset_metadata["val"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    pred_class_validator.add_event_handler(
        Events.COMPLETED,
        EarlyStopping(
            patience=conf.session.early_stopping.patience,
            score_function=lambda val_engine: -val_engine.state.metrics[
                "loss/total"],
            trainer=pred_class_trainer,
        ),
    )
    pred_class_validator.add_event_handler(
        Events.COMPLETED,
        Checkpoint(
            {
                "model": pred_class_trainer.model,
                "optimizer": pred_class_trainer.optimizer,
                "scheduler": scheduler,
                "trainer": pred_class_trainer,
            },
            DiskSaver(
                Path(conf.checkpoint.folder).expanduser().resolve() /
                conf.fullname),
            score_function=lambda val_engine: val_engine.state.metrics[
                "pc/recall_at_5"],
            score_name="pc_recall_at_5",
            n_saved=conf.checkpoint.keep,
            global_step_transform=pred_class_trainer.global_step,
        ),
    )
    # endregion

    if "test" in conf.dataset:
        # region Predicate classification testing callbacks
        if conf.losses["bce"]["weight"] > 0:
            Average(
                output_transform=lambda o: o["losses"]["loss/bce"],
                device=conf.session.device,
            ).attach(pred_class_tester, "loss/bce")
        if conf.losses["rank"]["weight"] > 0:
            Average(
                output_transform=lambda o: o["losses"]["loss/rank"],
                device=conf.session.device,
            ).attach(pred_class_tester, "loss/rank")
        Average(
            output_transform=lambda o: o["losses"]["loss/total"],
            device=conf.session.device,
        ).attach(pred_class_tester, "loss/total")

        PredicateClassificationMeanAveragePrecisionEpoch(
            itemgetter("target", "output")).attach(pred_class_tester, "pc/mAP")
        RecallAtEpoch((5, 10),
                      itemgetter("target",
                                 "output")).attach(pred_class_tester,
                                                   "pc/recall_at")

        ProgressBar(persist=True,
                    desc="Pred class test").attach(pred_class_tester)

        pred_class_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            log_metrics,
            "Predicate classification test",
            "test",
            json_logger,
            tb_logger,
            pred_class_trainer.global_step,
        )
        pred_class_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            PredicateClassificationLogger(
                grid=(2, 3),
                tag="test",
                logger=tb_img_logger.writer,
                metadata=dataset_metadata["test"],
                global_step_fn=pred_class_trainer.global_step,
            ),
        )
        # endregion

    # region Run
    log_effective_config(conf, pred_class_trainer, tb_logger)
    if not ("resume" in conf and conf.resume.test_only):
        max_epochs = conf.session.max_epochs
        if "resume" in conf:
            max_epochs += pred_class_trainer.state.epoch
        pred_class_trainer.run(
            dataloaders["train"],
            max_epochs=max_epochs,
            seed=conf.session.seed,
            epoch_length=len(dataloaders["train"]),
        )

    if "test" in conf.dataset:
        pred_class_tester.run(dataloaders["test"])

    add_session_end(tb_logger.writer, "SUCCESS")
    tb_logger.close()
    tb_img_logger.close()
Beispiel #26
0
def run(config, logger):
    
    plx_logger = PolyaxonLogger()

    set_seed(config.seed)

    plx_logger.log_params(**{
        "seed": config.seed,
        "batch_size": config.batch_size,

        "pytorch version": torch.__version__,
        "ignite version": ignite.__version__,
        "cuda version": torch.version.cuda
    })

    device = config.device
    non_blocking = config.non_blocking
    prepare_batch = config.prepare_batch

    def stats_collect_function(engine, batch):

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)

        y_ohe = to_onehot(y.reshape(-1), config.num_classes)
        
        class_distrib = y_ohe.mean(dim=0).cpu()
        class_presence = (class_distrib > 1e-3).cpu().float()
        num_classes = (class_distrib > 1e-3).sum().item() 

        engine.state.class_presence += class_presence
        engine.state.class_presence -= (1 - class_presence)

        return {
            "class_distrib": class_distrib,
            "class_presence": engine.state.class_presence,
            "num_classes": num_classes
        }

    stats_collector = Engine(stats_collect_function)
    ProgressBar(persist=True).attach(stats_collector)

    @stats_collector.on(Events.STARTED)
    def init_vars(engine):
        engine.state.class_presence = torch.zeros(config.num_classes)

    log_dir = get_outputs_path()
    if log_dir is None:
        log_dir = "output"

    tb_logger = TensorboardLogger(log_dir=log_dir)

    tb_handler = tb_output_handler(tag="training", output_transform=lambda x: x)
    tb_logger.attach(stats_collector,
                     log_handler=tb_handler,
                     event_name=Events.ITERATION_COMPLETED)

    stats_collector.run(config.train_loader, max_epochs=1)

    remove_handler(stats_collector, tb_handler, Events.ITERATION_COMPLETED)
    tb_logger.attach(stats_collector,
                     log_handler=tb_output_handler(tag="validation", output_transform=lambda x: x),
                     event_name=Events.ITERATION_COMPLETED)

    stats_collector.run(config.val_loader, max_epochs=1)
Beispiel #27
0
def _setup_common_training_handlers(
    trainer: Engine,
    to_save: Optional[Mapping] = None,
    save_every_iters: int = 1000,
    output_path: Optional[str] = None,
    lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
    with_gpu_stats: bool = False,
    output_names: Optional[Iterable[str]] = None,
    with_pbars: bool = True,
    with_pbar_on_iters: bool = True,
    log_every_iters: int = 100,
    stop_on_nan: bool = True,
    clear_cuda_cache: bool = True,
    save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
    **kwargs: Any,
) -> None:
    if output_path is not None and save_handler is not None:
        raise ValueError(
            "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
        )

    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED,
                lambda engine: cast(_LRScheduler, lr_scheduler).step())
        elif isinstance(lr_scheduler, LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:

        if output_path is None and save_handler is None:
            raise ValueError(
                "If to_save argument is provided then output_path or save_handler arguments should be also defined"
            )
        if output_path is not None:
            save_handler = DiskSaver(dirname=output_path, require_empty=False)

        checkpoint_handler = Checkpoint(to_save,
                                        cast(Union[Callable, BaseSaveHandler],
                                             save_handler),
                                        filename_prefix="training",
                                        **kwargs)
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=save_every_iters),
            checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(
            trainer,
            name="gpu",
            event_name=Events.ITERATION_COMPLETED(
                every=log_every_iters)  # type: ignore[arg-type]
        )

    if output_names is not None:

        def output_transform(x: Any, index: int, name: str) -> Any:
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise TypeError(
                    "Unhandled type of update_function's output. "
                    f"It should either mapping or sequence, but given {type(x)}"
                )

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform,
                                                    index=i,
                                                    name=n),
                           epoch_bound=False).attach(trainer, n)

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer,
                metric_names="all",
                event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
Beispiel #28
0
def run(output_path, config):
    device = "cuda"

    local_rank = config['local_rank']
    distributed = backend is not None
    if distributed:
        torch.cuda.set_device(local_rank)
        device = "cuda"
    rank = dist.get_rank() if distributed else 0

    # Rescale batch_size and num_workers
    ngpus_per_node = torch.cuda.device_count()
    ngpus = dist.get_world_size() if distributed else 1
    batch_size = config['batch_size'] // ngpus
    num_workers = int(
        (config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)

    train_labelled_loader, test_loader = \
        get_train_test_loaders(path=config['data_path'],
                               batch_size=batch_size,
                               distributed=distributed,
                               num_workers=num_workers)

    model = get_model(config['model'])
    model = model.to(device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[
                local_rank,
            ], output_device=local_rank)

    optimizer = optim.SGD(model.parameters(),
                          lr=config['learning_rate'],
                          momentum=config['momentum'],
                          weight_decay=config['weight_decay'],
                          nesterov=True)

    criterion = nn.CrossEntropyLoss().to(device)

    le = len(train_labelled_loader)
    milestones_values = [(0, 0.0),
                         (le * config['num_warmup_epochs'],
                          config['learning_rate']),
                         (le * config['num_epochs'], 0.0)]
    lr_scheduler = PiecewiseLinear(optimizer,
                                   param_name="lr",
                                   milestones_values=milestones_values)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    def process_function(engine, labelled_batch):

        x, y = _prepare_batch(labelled_batch, device=device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            'batch loss': loss.item(),
        }

    trainer = Engine(process_function)

    if not hasattr(lr_scheduler, "step"):
        trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
    else:
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  lambda engine: lr_scheduler.step())

    metric_names = [
        'batch loss',
    ]

    def output_transform(x, name):
        return x[name]

    for n in metric_names:
        # We compute running average values on the output (batch loss) across all devices
        RunningAverage(output_transform=partial(output_transform, name=n),
                       epoch_bound=False,
                       device=device).attach(trainer, n)

    if rank == 0:
        checkpoint_handler = ModelCheckpoint(dirname=output_path,
                                             filename_prefix="checkpoint")
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000),
                                  checkpoint_handler, {
                                      'model': model,
                                      'optimizer': optimizer
                                  })

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
        if config['display_iters']:
            ProgressBar(persist=False,
                        bar_format="").attach(trainer,
                                              metric_names=metric_names)

        tb_logger = TensorboardLogger(log_dir=output_path)
        tb_logger.attach(trainer,
                         log_handler=tbOutputHandler(
                             tag="train", metric_names=metric_names),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=tbOptimizerParamsHandler(optimizer,
                                                              param_name="lr"),
                         event_name=Events.ITERATION_STARTED)

    metrics = {
        "accuracy": Accuracy(device=device if distributed else None),
        "loss": Loss(criterion, device=device if distributed else None)
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        torch.cuda.synchronize()
        train_evaluator.run(train_labelled_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(Events.EPOCH_STARTED(every=3), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        if config['display_iters']:
            ProgressBar(persist=False,
                        desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False,
                        desc="Test evaluation").attach(evaluator)

        tb_logger.attach(train_evaluator,
                         log_handler=tbOutputHandler(tag="train",
                                                     metric_names=list(
                                                         metrics.keys()),
                                                     another_engine=trainer),
                         event_name=Events.COMPLETED)

        tb_logger.attach(evaluator,
                         log_handler=tbOutputHandler(tag="test",
                                                     metric_names=list(
                                                         metrics.keys()),
                                                     another_engine=trainer),
                         event_name=Events.COMPLETED)

        # Store the best model
        def default_score_fn(engine):
            score = engine.state.metrics['accuracy']
            return score

        score_function = default_score_fn if not hasattr(
            config, "score_function") else config.score_function

        best_model_handler = ModelCheckpoint(
            dirname=output_path,
            filename_prefix="best",
            n_saved=3,
            global_step_transform=global_step_from_engine(trainer),
            score_name="val_accuracy",
            score_function=score_function)
        evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {
            'model': model,
        })

    trainer.run(train_labelled_loader, max_epochs=config['num_epochs'])

    if rank == 0:
        tb_logger.close()
def main(batch_size, epochs, length_scale, centroid_size, model_output_size,
         learning_rate, l_gradient_penalty, gamma, weight_decay, final_model,
         input_dep_ls, use_grad_norm):

    # Dataset prep
    ds = all_datasets["CIFAR10"]()
    input_size, num_classes, dataset, test_dataset = ds

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        val_size = int(len(dataset) * 0.8)
        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

        val_dataset.transform = (test_dataset.transform)
    kwargs = {"num_workers": 4, "pin_memory": True}
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)

    # Model
    global model
    model = ResNet_DUQ(input_size, num_classes, centroid_size,
                       model_output_size, length_scale, gamma)

    model = model.cuda()
    #model.load_state_dict(torch.load("DUQ_CIFAR_75.pt"))

    # Optimiser
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[25, 50, 75],
                                                     gamma=0.2)

    def bce_loss_fn(y_pred, y):
        bce = F.binary_cross_entropy(y_pred, y, reduction="sum").div(
            num_classes * y_pred.shape[0])
        return bce

    def output_transform_bce(output):
        y_pred, y, x = output

        y = F.one_hot(y, num_classes).float()

        return y_pred, y

    def output_transform_acc(output):
        y_pred, y, x = output

        return y_pred, y

    def output_transform_gp(output):
        y_pred, y, x = output

        return x, y_pred

    def calc_gradients_input(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        return gradients

    def calc_gradient_penalty(x, y_pred):
        gradients = calc_gradients_input(x, y_pred)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1)**2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        if l_gradient_penalty > 0:
            x.requires_grad_(True)

        z, y_pred = model(x)
        y = F.one_hot(y, num_classes).float()

        loss = bce_loss_fn(y_pred, y)

        # Avoid calc of computing
        if l_gradient_penalty > 0:
            loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred)

        if use_grad_norm:
            #gradient normalization
            loss /= (1 + l_gradient_penalty)

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    trainer = Engine(step)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):

        # logging every 10 epoch or last epochs
        if trainer.state.epoch % 10 == 0 or trainer.state.epoch > epochs - 5:

            #acc on cifar test set and auroc on cifar+svhn testsets
            testacc, auroc_cifsv = get_cifar_svhn_ood(model)

            #acc on cifar val set and self auroc on cifar valset
            val_acc, self_auroc = get_auroc_classification(val_dataset, model)

            print(f"Test Accuracy: {testacc}, AUROC: {auroc_cifsv}")
            print(
                f"AUROC - uncertainty: {self_auroc}, Val Accuracy : {val_acc}")

        scheduler.step()

        # save
        if trainer.state.epoch == epochs - 1:
            torch.save(model.state_dict(), f"model_{trainer.state.epoch}.pt")

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)
    trainer.run(train_loader, max_epochs=epochs)

    testacc, auroc_cifsv = get_cifar_svhn_ood(model)
    val_acc, self_auroc = get_auroc_classification(val_dataset, model)

    return testacc, auroc_cifsv, val_acc, self_auroc
Beispiel #30
0
        epoch = engine.state.epoch
        evaluator.run(val_ld)
        val_wra_vle = round(evaluator.state.metrics['WRA'], 3)
        print(f"EPOCH:[{epoch}] VAL WRA:{val_wra_vle}")

    @trainer.on(Events.COMPLETED)
    def test(engine):
        print("TEST EVAL")
        evaluator.run(test_ld)
        test_wra_vle = round(evaluator.state.metrics["WRA"], 3)
        report = f"{RUN_NAME};{test_wra_vle}\n"
        with EVALUATION_RESULTS_FILE_PATH.open(mode='a') as f:
            f.writelines(report)
        print(f"TRAINING IS DONE FOR {RUN_NAME} RUN.")

    pbar = ProgressBar()

    checkpointer = ModelCheckpoint(
        CHECKPOINTS_RUN_DIR_PATH,
        filename_prefix=RUN_NAME.lower(),
        n_saved=None,
        score_function=lambda engine: round(engine.state.metrics['WRA'], 3),
        score_name='WRA',
        atomic=True,
        require_empty=True,
        create_dir=True,
        archived=False,
        global_step_transform=global_step_from_engine(trainer))
    nan_handler = TerminateOnNan()
    coslr = CosineAnnealingScheduler(opt,
                                     "lr",