def __init__(self, trainer, res_dir='results', **kwargs):
        self.trainer = trainer
        self.start_datetime = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        self.res_dir = Path(res_dir) / self.start_datetime
        self.res_dir.mkdir(parents=True)

        metric_loss = Average()
        metric_loss.attach(self.trainer, 'loss')
Example #2
0
def attach_pbar_and_metrics(trainer, evaluator):
    loss_metric = Average(output_transform=lambda output: output["loss"])
    accuracy_metric = Accuracy(
        output_transform=lambda output: (output["logit"], output["label"]))
    pbar = ProgressBar()
    loss_metric.attach(trainer, "loss")
    accuracy_metric.attach(trainer, "accuracy")
    accuracy_metric.attach(evaluator, "accuracy")
    pbar.attach(trainer)
Example #3
0
def create_classification_trainer(
    model,
    optimizer,
    loss_fn,
    device=None,
    non_blocking=False,
    prepare_batch=_prepare_batch,
    output_transform=lambda x, y, y_pred, loss: loss.item()):  # noqa
    """
    Factory function for creating a trainer for supervised models.
    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_y)`.
        output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
            to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
    Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
        of the processed batch by default.
    Returns:
        Engine: a trainer engine with supervised update function.
    """
    if device:
        model.to(device)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        return output_transform(x, y, y_pred, loss)

    engine = Engine(_update)
    metric_loss = Average()
    metric_loss.attach(engine, 'loss')
    return engine
Example #4
0
def main(ctx, config_file, dataset_root, res_root_dir, debug, device,
         num_workers, **kwargs):
    with open(config_file) as stream:
        config = yaml.safe_load(stream)

    train_transforms = get_transforms(config['train_augment'])
    val_transforms = get_transforms(config['val_augment'])
    train_loader, val_loader = get_loaders(train_transforms=train_transforms,
                                           val_transforms=val_transforms,
                                           dataset_root=dataset_root,
                                           num_workers=num_workers,
                                           **config['dataset'])
    label_names = get_labels(train_loader)
    net, criterion = get_model(n_class=len(label_names), **config['model'])
    optimizer = get_optimizer(net, **config['optimizer'])

    trainer = create_supervised_trainer(net, optimizer, criterion, device,
                                        prepare_batch=prepare_batch)
    metric_loss = Average()
    metric_loss.attach(trainer, 'loss')
    metrics = get_metrics(label_names, config['evaluate'])
    metric_names = list(metrics.keys())
    evaluator = create_supervised_evaluator(net, metrics, device,
                                            prepare_batch=prepare_batch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        evaluator.run(val_loader)

    res_dir = Path(res_root_dir) / config['dataset']['dataset_name']
    train_extend = TrainExtension(trainer, evaluator, res_dir)
    train_extend.print_metrics(metric_names)
    train_extend.set_progressbar()
    train_extend.schedule_lr(optimizer, **config['lr_schedule'])
    if not debug:
        train_extend.copy_configs(config_file)
        train_extend.set_tensorboard(metric_names)
        train_extend.save_model(net, **config['model_checkpoint'])
        train_extend.show_config_on_tensorboard(config)

    trainer.run(train_loader, max_epochs=config['epochs'])
Example #5
0
def create_supervised_trainer2(model,
                               optimizer,
                               loss_fn,
                               replay_buffer,
                               device=None,
                               non_blocking=False,
                               prepare_batch=_prepare_batch):
    "使用 ignite 包,不用 ignite 时,需要修改该函数"
    from ignite.engine.engine import Engine
    from ignite.metrics import Average

    if device:
        model.to(device)

    def _update(engine, batch):
        model.train()

        LogSumExpf = lambda x: LogSumExp(model(x))
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        x = x.detach()
        y_pred = model(x)
        loss_elf = loss_fn(y_pred, y)
        x_sample = Sample(LogSumExpf, x.shape[0], x.shape[1], replay_buffer,
                          device)
        replay_buffer.add(x_sample.cpu())
        loss_gen = -(LogSumExpf(x) - LogSumExpf(x_sample)).mean()
        loss = loss_elf + loss_gen
        loss.backward()
        optimizer.step()
        return {
            'loss': loss.item(),
            'loss_elf': loss_elf.item(),
            'loss_gen': loss_gen.item()
        }

    engine = Engine(_update)
    metric_loss = Average(output_transform=lambda output: output['loss'])
    metric_loss_elf = Average(
        output_transform=lambda output: output['loss_elf'])
    metric_loss_gen = Average(
        output_transform=lambda output: output['loss_gen'])
    metric_loss.attach(engine, "loss")
    metric_loss_elf.attach(engine, "loss_elf")
    metric_loss_gen.attach(engine, "loss_gen")

    return engine
Example #6
0
def main(hparams):
    results_dir = get_results_directory(hparams.output_dir)
    writer = SummaryWriter(log_dir=str(results_dir))

    ds = get_dataset(hparams.dataset, root=hparams.data_root)
    input_size, num_classes, train_dataset, test_dataset = ds

    hparams.seed = set_seed(hparams.seed)

    if hparams.n_inducing_points is None:
        hparams.n_inducing_points = num_classes

    print(f"Training with {hparams}")
    hparams.save(results_dir / "hparams.json")

    if hparams.ard:
        # Hardcoded to WRN output size
        ard = 640
    else:
        ard = None

    feature_extractor = WideResNet(
        spectral_normalization=hparams.spectral_normalization,
        dropout_rate=hparams.dropout_rate,
        coeff=hparams.coeff,
        n_power_iterations=hparams.n_power_iterations,
        batchnorm_momentum=hparams.batchnorm_momentum,
    )

    initial_inducing_points, initial_lengthscale = initial_values_for_GP(
        train_dataset, feature_extractor, hparams.n_inducing_points
    )

    gp = GP(
        num_outputs=num_classes,
        initial_lengthscale=initial_lengthscale,
        initial_inducing_points=initial_inducing_points,
        separate_inducing_points=hparams.separate_inducing_points,
        kernel=hparams.kernel,
        ard=ard,
        lengthscale_prior=hparams.lengthscale_prior,
    )

    model = DKL_GP(feature_extractor, gp)
    model = model.cuda()

    likelihood = SoftmaxLikelihood(num_classes=num_classes, mixing_weights=False)
    likelihood = likelihood.cuda()

    elbo_fn = VariationalELBO(likelihood, gp, num_data=len(train_dataset))

    parameters = [
        {"params": feature_extractor.parameters(), "lr": hparams.learning_rate},
        {"params": gp.parameters(), "lr": hparams.learning_rate},
        {"params": likelihood.parameters(), "lr": hparams.learning_rate},
    ]

    optimizer = torch.optim.SGD(
        parameters, momentum=0.9, weight_decay=hparams.weight_decay
    )

    milestones = [60, 120, 160]

    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=milestones, gamma=0.2
    )

    def step(engine, batch):
        model.train()
        likelihood.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        y_pred = model(x)
        elbo = -elbo_fn(y_pred, y)

        elbo.backward()
        optimizer.step()

        return elbo.item()

    def eval_step(engine, batch):
        model.eval()
        likelihood.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        with torch.no_grad():
            y_pred = model(x)

        return y_pred, y

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "elbo")

    def output_transform(output):
        y_pred, y = output

        # Sample softmax values independently for classification at test time
        y_pred = y_pred.to_data_independent_dist()

        # The mean here is over likelihood samples
        y_pred = likelihood(y_pred).probs.mean(0)

        return y_pred, y

    metric = Accuracy(output_transform=output_transform)
    metric.attach(evaluator, "accuracy")

    metric = Loss(lambda y_pred, y: -elbo_fn(y_pred, y))
    metric.attach(evaluator, "elbo")

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=hparams.batch_size,
        shuffle=True,
        drop_last=True,
        **kwargs,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=512, shuffle=False, **kwargs
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        elbo = metrics["elbo"]

        print(f"Train - Epoch: {trainer.state.epoch} ELBO: {elbo:.2f} ")
        writer.add_scalar("Likelihood/train", elbo, trainer.state.epoch)

        if hparams.spectral_normalization:
            for name, layer in model.feature_extractor.named_modules():
                if isinstance(layer, torch.nn.Conv2d):
                    writer.add_scalar(
                        f"sigma/{name}", layer.weight_sigma, trainer.state.epoch
                    )

        if not hparams.ard:
            # Otherwise it's too much to submit to tensorboard
            length_scales = model.gp.covar_module.base_kernel.lengthscale.squeeze()
            for i in range(length_scales.shape[0]):
                writer.add_scalar(
                    f"length_scale/{i}", length_scales[i], trainer.state.epoch
                )

        if trainer.state.epoch > 150 and trainer.state.epoch % 5 == 0:
            _, auroc, aupr = get_ood_metrics(
                hparams.dataset, "SVHN", model, likelihood, hparams.data_root
            )
            print(f"OoD Metrics - AUROC: {auroc}, AUPR: {aupr}")
            writer.add_scalar("OoD/auroc", auroc, trainer.state.epoch)
            writer.add_scalar("OoD/auprc", aupr, trainer.state.epoch)

        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        elbo = metrics["elbo"]

        print(
            f"Test - Epoch: {trainer.state.epoch} "
            f"Acc: {acc:.4f} "
            f"ELBO: {elbo:.2f} "
        )

        writer.add_scalar("Likelihood/test", elbo, trainer.state.epoch)
        writer.add_scalar("Accuracy/test", acc, trainer.state.epoch)

        scheduler.step()

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    trainer.run(train_loader, max_epochs=200)

    # Done training - time to evaluate
    results = {}

    evaluator.run(train_loader)
    train_acc = evaluator.state.metrics["accuracy"]
    train_elbo = evaluator.state.metrics["elbo"]
    results["train_accuracy"] = train_acc
    results["train_elbo"] = train_elbo

    evaluator.run(test_loader)
    test_acc = evaluator.state.metrics["accuracy"]
    test_elbo = evaluator.state.metrics["elbo"]
    results["test_accuracy"] = test_acc
    results["test_elbo"] = test_elbo

    _, auroc, aupr = get_ood_metrics(
        hparams.dataset, "SVHN", model, likelihood, hparams.data_root
    )
    results["auroc_ood_svhn"] = auroc
    results["aupr_ood_svhn"] = aupr

    print(f"Test - Accuracy {results['test_accuracy']:.4f}")

    results_json = json.dumps(results, indent=4, sort_keys=True)
    (results_dir / "results.json").write_text(results_json)

    torch.save(model.state_dict(), results_dir / "model.pt")
    torch.save(likelihood.state_dict(), results_dir / "likelihood.pt")

    writer.close()
Example #7
0
 def get_engine():
     engine = Engine(sum_data)
     average = Average()
     average.attach(engine, "average")
     return engine
Example #8
0
def train(model,
          train_loader,
          eval_loaders,
          optimizer,
          loss_fn,
          n_it_max,
          patience,
          split_names,
          select_metric='Val accuracy_0',
          select_mode='max',
          viz=None,
          device='cpu',
          lr_scheduler=None,
          name=None,
          log_steps=None,
          log_epoch=False,
          _run=None,
          prepare_batch=_prepare_batch,
          single_pass=False,
          n_ep_max=None):

    # print(model)

    if not log_steps and not log_epoch:
        logger.warning('/!\\ No logging during training /!\\')

    if log_steps is None:
        log_steps = []

    epoch_steps = len(train_loader)
    if log_epoch:
        log_steps.append(epoch_steps)

    if single_pass:
        max_epoch = 1
    elif n_ep_max is None:
        assert n_it_max is not None
        max_epoch = int(n_it_max / epoch_steps) + 1
    else:
        assert n_it_max is None
        max_epoch = n_ep_max

    all_metrics = defaultdict(dict)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device,
                                        prepare_batch=prepare_batch)

    if hasattr(model, 'new_epoch_hook'):
        trainer.add_event_handler(Events.EPOCH_STARTED, model.new_epoch_hook)
    if hasattr(model, 'new_iter_hook'):
        trainer.add_event_handler(Events.ITERATION_STARTED,
                                  model.new_iter_hook)

    trainer._logger.setLevel(logging.WARNING)

    # trainer output is in the format (x, y, y_pred, loss, optionals)
    train_loss = RunningAverage(output_transform=lambda out: out[3].item(),
                                epoch_bound=True)
    train_loss.attach(trainer, 'Trainer loss')
    if hasattr(model, 's'):
        met = Average(output_transform=lambda _: float('nan')
                      if model.s is None else model.s)
        met.attach(trainer, 'cur_s')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed,
                                  'cur_s')

    if hasattr(model, 'arch_sampler') and model.arch_sampler.distrib_dim > 0:
        met = Average(output_transform=lambda _: float('nan')
                      if model.cur_split is None else model.cur_split)
        met.attach(trainer, 'Trainer split')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed,
                                  'Trainer split')
        # trainer.add_event_handler(Events.EPOCH_STARTED, met.started)
        all_ent = Average(
            output_transform=lambda out: out[-1]['arch_entropy_avg'].item())
        all_ent.attach(trainer, 'Trainer all entropy')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  all_ent.completed, 'Trainer all entropy')
        train_ent = Average(
            output_transform=lambda out: out[-1]['arch_entropy_sample'].item())
        train_ent.attach(trainer, 'Trainer sampling entropy')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  train_ent.completed,
                                  'Trainer sampling entropy')
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, lambda engine: model.check_arch_freezing(
                ent=train_ent.compute(),
                epoch=engine.state.iteration / (epoch_steps * max_epoch)))

        def log_always(engine, name):
            val = engine.state.output[-1][name]
            all_metrics[name][engine.state.iteration /
                              epoch_steps] = val.mean().item()

        def log_always_dict(engine, name):
            for node, val in engine.state.output[-1][name].items():
                all_metrics['node {} {}'.format(
                    node, name)][engine.state.iteration /
                                 epoch_steps] = val.mean().item()

        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always_dict,
                                  name='arch_grads')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always_dict,
                                  name='arch_probas')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always_dict,
                                  name='node_grads')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always,
                                  name='task all_loss')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always,
                                  name='arch all_loss')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always,
                                  name='entropy all_loss')

    if n_it_max is not None:
        StopAfterIterations([n_it_max]).attach(trainer)
    # epoch_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name,
    #                          persist=True, disable=not (_run or viz))
    # epoch_pbar.attach(trainer, metric_names=['Train loss'])
    #
    # training_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name,
    #                             persist=True, disable=not (_run or viz))
    # training_pbar.attach(trainer, event_name=Events.EPOCH_COMPLETED,
    #                      closing_event_name=Events.COMPLETED)
    total_time = Timer(average=False)
    eval_time = Timer(average=False)
    eval_time.pause()
    data_time = Timer(average=False)
    forward_time = Timer(average=False)
    forward_time.attach(trainer,
                        start=Events.EPOCH_STARTED,
                        pause=Events.ITERATION_COMPLETED,
                        resume=Events.ITERATION_STARTED,
                        step=Events.ITERATION_COMPLETED)
    epoch_time = Timer(average=False)
    epoch_time.attach(trainer,
                      start=Events.EPOCH_STARTED,
                      pause=Events.EPOCH_COMPLETED,
                      resume=Events.EPOCH_STARTED,
                      step=Events.EPOCH_COMPLETED)

    def get_loss(y_pred, y):
        l = loss_fn(y_pred, y)
        if not torch.is_tensor(l):
            l, *l_details = l
        return l.mean()

    def get_member(x, n=0):
        if isinstance(x, (list, tuple)):
            return x[n]
        return x

    eval_metrics = {'loss': Loss(get_loss)}

    for i in range(model.n_out):
        out_trans = get_attr_transform(i)

        def extract_ys(out):
            x, y, y_pred, loss, _ = out
            return out_trans((y_pred, y))

        train_acc = Accuracy(extract_ys)
        train_acc.attach(trainer, 'Trainer accuracy_{}'.format(i))
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  train_acc.completed,
                                  'Trainer accuracy_{}'.format(i))
        eval_metrics['accuracy_{}'.format(i)] = \
            Accuracy(output_transform=out_trans)
        # if isinstance(model, SSNWrapper):
        #     model.arch_sampler.entropy().mean()

    evaluator = create_supervised_evaluator(model,
                                            metrics=eval_metrics,
                                            device=device,
                                            prepare_batch=prepare_batch)
    last_iteration = 0
    patience_counter = 0

    best = {
        'value': float('inf') * 1 if select_mode == 'min' else -1,
        'iter': -1,
        'state_dict': None
    }

    def is_better(new, old):
        if select_mode == 'min':
            return new < old
        else:
            return new > old

    def log_results(evaluator, data_loader, iteration, split_name):
        evaluator.run(data_loader)
        metrics = evaluator.state.metrics

        log_metrics = {}

        for metric_name, metric_val in metrics.items():
            log_name = '{} {}'.format(split_name, metric_name)
            if viz:
                first = iteration == 0 and split_name == split_names[0]
                viz.line(
                    [metric_val],
                    X=[iteration],
                    win=metric_name,
                    name=log_name,
                    update=None if first else 'append',
                    opts={
                        'title': metric_name,
                        'showlegend': True,
                        'width': 500,
                        'xlabel': 'iterations'
                    })
                viz.line(
                    [metric_val],
                    X=[iteration / epoch_steps],
                    win='{}epoch'.format(metric_name),
                    name=log_name,
                    update=None if first else 'append',
                    opts={
                        'title': metric_name,
                        'showlegend': True,
                        'width': 500,
                        'xlabel': 'epoch'
                    })
            if _run:
                _run.log_scalar(log_name, metric_val, iteration)
            log_metrics[log_name] = metric_val
            all_metrics[log_name][iteration] = metric_val

        return log_metrics

    if lr_scheduler is not None:

        @trainer.on(Events.EPOCH_COMPLETED)
        def step(_):
            lr_scheduler.step()
            # logger.warning('current lr {:.5e}'.format(
            #     optimizer.param_groups[0]['lr']))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_event(trainer):
        iteration = trainer.state.iteration if trainer.state else 0
        nonlocal last_iteration, patience_counter, best

        if not log_steps or not \
                (iteration in log_steps or iteration % log_steps[-1] == 0):
            return
        epoch_time.pause()
        eval_time.resume()
        all_metrics['training_epoch'][iteration] = iteration / epoch_steps
        all_metrics['training_iteration'][iteration] = iteration
        if hasattr(model, 'arch_sampler'):
            all_metrics['training_archs'][iteration] = \
                model.arch_sampler().squeeze().detach()
        # if hasattr(model, 'distrib_gen'):
        #     entropy = model.distrib_gen.entropy()
        #     all_metrics['entropy'][iteration] = entropy.mean().item()
        # if trainer.state and len(trainer.state.metrics) > 1:
        #     raise ValueError(trainer.state.metrics)
        all_metrics['data time'][iteration] = data_time.value()
        all_metrics['data time_ps'][iteration] = data_time.value() / max(
            data_time.step_count, 1.)
        all_metrics['forward time'][iteration] = forward_time.value()
        all_metrics['forward time_ps'][iteration] = forward_time.value() / max(
            forward_time.step_count, 1.)
        all_metrics['epoch time'][iteration] = epoch_time.value()
        all_metrics['epoch time_ps'][iteration] = epoch_time.value() / max(
            epoch_time.step_count, 1.)

        if trainer.state:
            # logger.warning(trainer.state.metrics)
            for metric, value in trainer.state.metrics.items():
                all_metrics[metric][iteration] = value
                if viz:
                    viz.line(
                        [value],
                        X=[iteration],
                        win=metric.split()[-1],
                        name=metric,
                        update=None if iteration == 0 else 'append',
                        opts={
                            'title': metric,
                            'showlegend': True,
                            'width': 500,
                            'xlabel': 'iterations'
                        })

        iter_this_step = iteration - last_iteration
        for d_loader, name in zip(eval_loaders, split_names):
            if name == 'Train':
                if iteration == 0:
                    all_metrics['Trainer loss'][iteration] = float('nan')
                    all_metrics['Trainer accuracy_0'][iteration] = float('nan')
                    if hasattr(model, 'arch_sampler'):
                        all_metrics['Trainer all entropy'][iteration] = float(
                            'nan')
                        all_metrics['Trainer sampling entropy'][
                            iteration] = float('nan')
                        # if hasattr(model, 'cur_split'):
                        all_metrics['Trainer split'][iteration] = float('nan')
                continue
            split_metrics = log_results(evaluator, d_loader, iteration, name)
            if select_metric not in split_metrics:
                continue
            if is_better(split_metrics[select_metric], best['value']):
                best['value'] = split_metrics[select_metric]
                best['iter'] = iteration
                best['state_dict'] = copy.deepcopy(model.state_dict())
                if patience > 0:
                    patience_counter = 0
            elif patience > 0:
                patience_counter += iter_this_step
                if patience_counter >= patience:
                    logger.info('#####')
                    logger.info('# Early stopping Run')
                    logger.info('#####')
                    trainer.terminate()
        last_iteration = iteration
        eval_time.pause()
        eval_time.step()
        all_metrics['eval time'][iteration] = eval_time.value()
        all_metrics['eval time_ps'][iteration] = eval_time.value(
        ) / eval_time.step_count
        all_metrics['total time'][iteration] = total_time.value()
        epoch_time.resume()

    log_event(trainer)

    #
    # @trainer.on(Events.EPOCH_COMPLETED)
    # def log_epoch(trainer):
    #     iteration = trainer.state.iteration if trainer.state else 0
    #     epoch = iteration/epoch_steps
    #     fw_t = forward_time.value()
    #     fw_t_ps = fw_t / forward_time.step_count
    #     d_t = data_time.value()
    #     d_t_ps = d_t / data_time.step_count
    #     e_t = epoch_time.value()
    #     e_t_ps = e_t / epoch_time.step_count
    #     ev_t = eval_time.value()
    #     ev_t_ps = ev_t / eval_time.step_count
    #     logger.warning('<{}> Epoch {}/{} finished (Forward: {:.3f}s({:.3f}), '
    #                    'data: {:.3f}s({:.3f}), epoch: {:.3f}s({:.3f}),'
    #                    ' Eval: {:.3f}s({:.3f}), Total: '
    #                    '{:.3f}s)'.format(type(model).__name__, epoch,
    #                                      max_epoch, fw_t, fw_t_ps, d_t, d_t_ps,
    #                                      e_t, e_t_ps, ev_t, ev_t_ps,
    #                                      total_time.value()))

    data_time.attach(trainer,
                     start=Events.STARTED,
                     pause=Events.ITERATION_STARTED,
                     resume=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_STARTED)

    if hasattr(model, 'iter_per_epoch'):
        model.iter_per_epoch = len(train_loader)
    trainer.run(train_loader, max_epochs=max_epoch)
    return trainer.state.iteration, all_metrics, best
    def run_training(self):

        ########## init wandb ###########
        print(GREEN + "*************** START TRAINING *******************")

        # fixme check if this works
        for (key, val) in self.config.items():
            print(GREEN + f"{key}: {val}")  # print to console
            wandb.config.update({key: val})  # update wandb config

        print(GREEN + "**************************************************" +
              ENDC)

        ########## checkpoints ##########
        if self.config["general"]["restart"]:
            mod_ckpt, op_ckpt = self._load_ckpt("reg_ckpt")
            # flow_ckpt, flow_op_ckpt = self._load_ckpt("flow_ckpt")

        else:
            mod_ckpt = op_ckpt = None

        dataset, image_transforms = get_dataset(self.config["data"])
        transforms = tt.Compose([tt.ToTensor()])
        train_dataset = dataset(transforms,
                                data_keys=self.data_keys,
                                mode="train",
                                label_transfer=True,
                                debug=self.config["general"]["debug"],
                                crop_app=True,
                                **self.config["data"])

        # if seq_length is pruned, use min seq_length, such that the seq_length of test_dataset lower or equal than that of the train dataset
        # collect_len = train_dataset.seq_length
        # self.collect_recon_loss_seq = {
        #     k: np.zeros(shape=[k])
        #     for k in range(collect_len[0], collect_len[-1])
        # }
        # self.collect_count_seq_lens = np.zeros(shape=[collect_len[-1]])
        # # adapt sequence_length
        # self.config["data"]["seq_length"] = (
        #     min(self.config["data"]["seq_length"][0], train_dataset.seq_length[0]),
        #     min(self.config["data"]["seq_length"][1], train_dataset.seq_length[1]),
        # )

        train_sampler = RandomSampler(data_source=train_dataset)

        seq_sampler_train = SequenceSampler(
            train_dataset,
            sampler=train_sampler,
            batch_size=self.config["training"]["batch_size"],
            drop_last=True,
        )

        train_loader = DataLoader(
            train_dataset,
            num_workers=0 if self.config["general"]["debug"] else
            self.config["data"]["n_data_workers"],
            batch_sampler=seq_sampler_train,
        )

        # test data
        t_datakeys = [key for key in self.data_keys] + [
            "action",
            "sample_ids",
            "intrinsics",
            "intrinsics_paired",
            "extrinsics",
            "extrinsics_paired",
        ]
        test_dataset = dataset(image_transforms,
                               data_keys=t_datakeys,
                               mode="test",
                               debug=self.config["general"]["debug"],
                               label_transfer=True,
                               **self.config["data"])
        assert (test_dataset.action_id_to_action is not None)
        rand_sampler_test = RandomSampler(data_source=test_dataset)
        seq_sampler_test = SequenceSampler(
            test_dataset,
            rand_sampler_test,
            batch_size=self.config["training"]["batch_size"],
            drop_last=True,
        )
        test_loader = DataLoader(
            test_dataset,
            num_workers=0 if self.config["general"]["debug"] else
            self.config["data"]["n_data_workers"],
            batch_sampler=seq_sampler_test,
        )
        #
        rand_sampler_transfer = RandomSampler(data_source=test_dataset)
        seq_sampler_transfer = SequenceSampler(
            test_dataset,
            rand_sampler_transfer,
            batch_size=1,
            drop_last=True,
        )
        transfer_loader = DataLoader(
            test_dataset,
            batch_sampler=seq_sampler_transfer,
            num_workers=0 if self.config["general"]["debug"] else
            self.config["data"]["n_data_workers"],
        )
        #
        # compare_dataset = dataset(
        #     transforms,
        #     data_keys=t_datakeys,
        #     mode="train",
        #     label_transfer=True,
        #     debug=self.config["general"]["debug"],
        #     crop_app=True,
        #     **self.config["data"]
        # )

        ## Classifier action
        # n_actions = len(train_dataset.action_id_to_action)
        # classifier_action = Classifier_action(len(train_dataset.dim_to_use), n_actions, dropout=0, dim=512).to(self.device)
        # optimizer_classifier = Adam(classifier_action.parameters(), lr=0.0001, weight_decay=1e-4)
        # print("Number of parameters in classifier action", sum(p.numel() for p in classifier_action.parameters()))
        #
        n_actions = len(train_dataset.action_id_to_action)
        # # classifier_action2 = Classifier_action(len(train_dataset.dim_to_use), n_actions, dropout=0, dim=512).to(self.device)
        # classifier_action2 = Sequence_disc_michael([2, 1, 1, 1], len(train_dataset.dim_to_use), out_dim=n_actions).to(self.device)
        # optimizer_classifier2 = Adam(classifier_action2.parameters(), lr=0.0001, weight_decay=1e-5)
        # print("Number of parameters in classifier action", sum(p.numel() for p in classifier_action2.parameters()))

        # Classifier beta
        classifier_beta = Classifier_action_beta(512,
                                                 n_actions).to(self.device)
        optimizer_classifier_beta = Adam(classifier_beta.parameters(),
                                         lr=0.001)
        print("Number of parameters in classifier on beta",
              sum(p.numel() for p in classifier_beta.parameters()))
        # # Regressor
        # regressor = Regressor_fly(self.config["architecture"]["dim_hidden_b"], len(train_dataset.dim_to_use)).to(self.device)
        # optimizer_regressor = Adam(regressor.parameters(), lr=0.0001)
        # print("Number of parameters in regressor", sum(p.numel() for p in regressor.parameters()))

        ########## load network and optimizer ##########
        net = MTVAE(self.config["architecture"], len(train_dataset.dim_to_use),
                    self.device)

        print(
            "Number of parameters in VAE model",
            sum(p.numel() for p in net.parameters()),
        )
        if self.config["general"]["restart"]:
            if mod_ckpt is not None:
                print(BLUE + f"***** Initializing VAE from checkpoint! *****" +
                      ENDC)
                net.load_state_dict(mod_ckpt)
        net.to(self.device)

        optimizer = Adam(net.parameters(),
                         lr=self.config["training"]["lr_init"],
                         weight_decay=self.config["training"]["weight_decay"])
        wandb.watch(net, log="all", log_freq=len(train_loader))
        if self.config["general"]["restart"]:
            if op_ckpt is not None:
                optimizer.load_state_dict(op_ckpt)
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(
        #     optimizer, milestones=self.config["training"]["tau"], gamma=self.config["training"]["gamma"]
        # )
        # rec_loss = nn.MSELoss(reduction="none")

        ############## DISCRIMINATOR ##############################
        n_kps = len(train_dataset.dim_to_use)

        # make gan loss weights
        print(
            f"len of train_dataset: {len(train_dataset)}, len of train_loader: {len(train_loader)}"
        )

        # 10 epochs of fine tuning
        total_steps = (self.config["training"]["n_epochs"] -
                       10) * len(train_loader)

        get_kl_weight = partial(
            linear_var,
            start_it=0,
            end_it=total_steps,
            start_val=1e-5,
            end_val=1,
            clip_min=0,
            clip_max=1,
        )

        def train_fn(engine, batch):
            net.train()
            # reference keypoints with label #1
            kps = batch["keypoints"].to(torch.float).to(self.device)

            # keypoints for cross label transfer, label #2
            kps_cross = batch["paired_keypoints"].to(torch.float).to(
                self.device)

            p_id = batch["paired_sample_ids"].to(torch.int)
            # reconstruct second sequence with inferred b

            labels = batch['action'][:, 0] - 2

            out_seq, mu, logstd, out_cycle = net(kps, kps_cross)

            ps = torch.randn_like(out_cycle, requires_grad=False)

            cycle_loss = torch.mean(torch.abs(out_cycle - ps))
            kps_loss = torch.mean(torch.abs(out_seq - kps[:, net.div:]))
            l_kl = kl_loss(mu, logstd)

            k_vel = self.config["training"]["k_vel"]
            vel_tgt = kps[:, net.div:net.div +
                          k_vel] - kps[:, net.div - 1:net.div + k_vel - 1]

            vel_pred = out_seq[:, :k_vel] - torch.cat(
                [kps[:, net.div - 1].unsqueeze(1), out_seq[:, :k_vel - 1]],
                dim=1)
            motion_loss = torch.mean(torch.abs(vel_tgt - vel_pred))

            kl_weight = get_kl_weight(engine.state.iteration)

            loss = kps_loss + kl_weight * l_kl + self.config["training"]["weight_motion"] * motion_loss \
                   + self.config["training"]["weight_cycle"] * cycle_loss

            #
            #
            if engine.state.epoch < self.config["training"]["n_epochs"] - 10:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            out_dict = {
                "loss": loss.detach().item(),
                "motion_loss": motion_loss.detach().item(),
                "rec_loss": kps_loss.detach().item(),
                "cycle_loss": cycle_loss.detach().item(),
                "kl_loss": l_kl.detach().item(),
                "kl_weight": kl_weight
            }
            #
            #
            # ## Train classifier on action
            # predict = classifier_action(seq_b)[0]
            # loss_classifier_action = nn.CrossEntropyLoss()(predict, labels.to(self.device))
            # optimizer_classifier.zero_grad()
            # loss_classifier_action.backward()
            # optimizer_classifier.step()
            # _, labels_pred = torch.max(nn.Sigmoid()(predict), dim=1)
            # acc_action = torch.sum(labels_pred.cpu() == labels).float() / labels_pred.shape[0]
            #
            # predict = classifier_action2((seq_b[:, 1:] - seq_b[:, :-1]).transpose(1, 2))[0]
            # loss_classifier_action2 = nn.CrossEntropyLoss()(predict, labels.to(self.device))
            # optimizer_classifier2.zero_grad()
            # loss_classifier_action2.backward()
            # optimizer_classifier2.step()
            # _, labels_pred = torch.max(nn.Sigmoid()(predict), dim=1)
            # acc_action2 = torch.sum(labels_pred.cpu() == labels).float() / labels_pred.shape[0]
            #
            # ## Train classifier on beta
            # if engine.state.epoch >= self.config["training"]["n_epochs"] - 10:
            net.eval()
            with torch.no_grad():
                _, mu, *_ = net(kps, kps_cross)
            predict = classifier_beta(mu)
            loss_classifier_action_beta = nn.CrossEntropyLoss()(
                predict, labels.to(self.device))
            optimizer_classifier_beta.zero_grad()
            loss_classifier_action_beta.backward()
            optimizer_classifier_beta.step()
            _, labels_pred = torch.max(nn.Sigmoid()(predict), dim=1)
            acc_action_beta = torch.sum(
                labels_pred.cpu() == labels).float() / labels_pred.shape[0]
            #
            # out_dict = {}
            # # this is only run if flow training is enable
            #
            # # add info to out_dict
            # out_dict['loss_classifier_action'] = loss_classifier_action.detach().item()
            # out_dict['acc_classifier_action'] = acc_action.item()
            # out_dict['loss_classifier_action2'] = loss_classifier_action2.detach().item()
            # out_dict['acc_classifier_action2'] = acc_action2.item()
            #
            # # if engine.state.epoch >= self.config["training"]["n_epochs"] - 10:
            out_dict[
                'loss_classifier_action_beta'] = loss_classifier_action_beta.detach(
                ).item()
            out_dict['acc_action_beta'] = acc_action_beta.item()
            # out_dict["loss"] = loss.detach().item()
            # out_dict["kl_loss"] = kl_loss_avg.detach().item()
            #
            # out_dict["mu_s"] = torch.mean(mu_s).item()
            # out_dict["logstd_s"] = torch.mean(logstd_s).item()
            # # if self.config["training"]["use_regressor"]:
            # #     out_dict["loss_regressor"] = torch.mean(loss_regressor).item()
            # out_dict["loss_recon"] = recon_loss.detach().item()
            # out_dict["loss_per_seq_recon"] = (
            #     recon_loss_per_seq.detach().cpu().numpy()
            # )
            # out_dict["seq_len"] = seq_len
            #
            return out_dict

        ##### CREATE TRAINING RUN #####
        trainer = Engine(train_fn)
        pbar = ProgressBar()
        pbar.attach(
            trainer,
            output_transform=lambda x:
            {key: x[key]
             for key in x if "per_seq" not in key},
        )

        # compute averages for all outputs of train function which are specified in the list
        # fixme this can be used to log as soon as losses for mtvae are defined and named
        loss_avg = Average(output_transform=lambda x: x["loss"])
        loss_avg.attach(trainer, "loss")
        recon_loss_avg = Average(output_transform=lambda x: x["rec_loss"])
        recon_loss_avg.attach(trainer, "rec_loss")
        kl_loss_avg = Average(output_transform=lambda x: x["kl_loss"])
        kl_loss_avg.attach(trainer, "kl_loss")
        kl_loss_avg = Average(output_transform=lambda x: x["motion_loss"])
        kl_loss_avg.attach(trainer, "motion_loss")
        kl_loss_avg = Average(output_transform=lambda x: x["cycle_loss"])
        kl_loss_avg.attach(trainer, "cycle_loss")
        # mu_s_avg = Average(output_transform=lambda x: x["mu_s"])
        # mu_s_avg.attach(trainer, "mu_s")
        # logstd_s_avg = Average(output_transform=lambda x: x["logstd_s"])
        # logstd_s_avg.attach(trainer, "logstd_s")
        #
        # loss_classifier = Average(output_transform=lambda x: x["loss_classifier_action"] if "loss_classifier_action" in x else 0)
        # loss_classifier.attach(trainer, "loss_classifier_action")
        # acc_classifier = Average(output_transform=lambda x: x["acc_classifier_action"] if "acc_classifier_action" in x else 0)
        # acc_classifier.attach(trainer, "acc_classifier_action")
        #
        # loss_classifier_action2 = Average(output_transform=lambda x: x["loss_classifier_action2"] if "loss_classifier_action2" in x else 0)
        # loss_classifier_action2.attach(trainer, "loss_classifier_action2")
        # acc_classifier_action2 = Average(output_transform=lambda x: x["acc_classifier_action2"] if "acc_classifier_action2" in x else 0)
        # acc_classifier_action2.attach(trainer, "acc_classifier_action2")
        #
        loss_classifier_action_beta = Average(
            output_transform=lambda x: x["loss_classifier_action_beta"]
            if "loss_classifier_action_beta" in x else 0)
        loss_classifier_action_beta.attach(trainer,
                                           "loss_classifier_action_beta")
        acc_action_beta = Average(output_transform=lambda x: x[
            "acc_action_beta"] if "acc_action_beta" in x else 0)
        acc_action_beta.attach(trainer, "acc_action_beta")

        # loss_avg = Average(output_transform=lambda x: x["loss"])
        # loss_avg.attach(trainer, "loss")

        ##### TRAINING HOOKS ######
        # @trainer.on(Events.ITERATION_COMPLETED)
        # def collect_training_info(engine):
        #     it = engine.state.iteration
        #
        #     self.collect_recon_loss_seq[seq_len] += engine.state.output[
        #         "loss_per_seq_recon"
        #     ]
        #     self.collect_count_seq_lens[seq_len] += self.config["training"]["batch_size"]

        # @trainer.on(Events.EPOCH_COMPLETED)
        # def update_optimizer_params(engine):
        #     scheduler.step()

        def log_wandb(engine):

            wandb.log({
                "epoch": engine.state.epoch,
                "iteration": engine.state.iteration,
            })

            print(
                f"Logging metrics: Currently, the following metrics are tracked: {list(engine.state.metrics.keys())}"
            )
            for key in engine.state.metrics:
                val = engine.state.metrics[key]
                wandb.log({key + "-epoch-avg": val})
                print(ENDC + f" [metrics] {key}:{val}")

            # reset
            # self.collect_recon_loss_seq = {
            #     k: np.zeros(shape=[k])
            #     for k in range(collect_len[0], collect_len[-1])
            # }
            # self.collect_count_seq_lens = np.zeros(shape=[collect_len[-1]])

            loss_avg = engine.state.metrics["loss"]

            print(GREEN + f"Epoch {engine.state.epoch} summary:")
            print(ENDC + f" [losses] loss overall:{loss_avg}")

        def eval_model(engine):
            eval_nets(net,
                      test_loader,
                      self.device,
                      engine.state.epoch,
                      cf_action_beta=classifier_beta,
                      debug=self.config["general"]["debug"])

        #
        #
        def transfer_behavior_test(engine):
            visualize_transfer3d(
                net,
                transfer_loader,
                self.device,
                name="Test-Set: ",
                dirs=self.dirs,
                revert_coord_space=False,
                epoch=engine.state.epoch,
                n_vid_to_generate=self.config["logging"]["n_vid_to_generate"])

        # # compare predictions on train and test set
        # def eval_grid(engine):
        #     if self.config["data"]["dataset"] != "HumanEva":
        #         make_eval_grid(
        #             net,
        #             transfer_loader,
        #             self.device,
        #             dirs=self.dirs,
        #             revert_coord_space=False,
        #             epoch=engine.state.epoch,
        #             synth_ckpt=self.synth_ckpt,
        #             synth_params=self.synth_params,
        #         )

        # def latent_interpolations(engine):
        #     latent_interpolate(
        #         net,
        #         transfer_loader,
        #         self.device,
        #         dirs=self.dirs,
        #         epoch=engine.state.epoch,
        #         synth_params=self.synth_params,
        #         synth_ckpt=self.synth_ckpt,
        #         n_vid_to_generate=self.config["logging"]["n_vid_to_generate"]
        #     )

        ckpt_handler_reg = ModelCheckpoint(self.dirs["ckpt"],
                                           "reg_ckpt",
                                           n_saved=100,
                                           require_empty=False)
        save_dict = {"model": net, "optimizer": optimizer}
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  ckpt_handler_reg, save_dict)

        trainer.add_event_handler(Events.EPOCH_COMPLETED, log_wandb)

        def log_outputs(engine):
            for key in engine.state.output:
                val = engine.state.output[key]
                wandb.log({key + "-epoch-step": val})

        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(
                every=10 if self.config["general"]["debug"] else 1000),
            log_outputs)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, eval_model)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED(
                every=1 if self.config["general"]["debug"] else 3),
            transfer_behavior_test,
        )
        # trainer.add_event_handler(
        #     Events.EPOCH_COMPLETED(
        #         every=10
        #     ),
        #     latent_interpolations,
        # )
        # trainer.add_event_handler(
        #     Events.EPOCH_COMPLETED(
        #         every=3
        #     ),
        #     eval_grid,
        # )

        ####### RUN TRAINING ##############
        print(BLUE + "*************** Train VAE *******************" + ENDC)
        trainer.run(
            train_loader,
            max_epochs=self.config["training"]["n_epochs"],
            epoch_length=10
            if self.config["general"]["debug"] else len(train_loader),
        )
        print(BLUE + "*************** VAE training ends *******************" +
              ENDC)
Example #10
0
def main(
    batch_size,
    epochs,
    length_scale,
    centroid_size,
    model_output_size,
    learning_rate,
    l_gradient_penalty,
    gamma,
    weight_decay,
    final_model,
):
    name = f"DUQ_{length_scale}__{l_gradient_penalty}_{gamma}_{centroid_size}"
    writer = SummaryWriter(comment=name)

    ds = all_datasets["CIFAR10"]()
    input_size, num_classes, dataset, test_dataset = ds

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        val_size = int(len(dataset) * 0.8)
        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

        val_dataset.transform = (test_dataset.transform
                                 )  # Test time preprocessing for validation

    model = ResNet_DUQ(
        input_size,
        num_classes,
        centroid_size,
        model_output_size,
        length_scale,
        gamma,
    )
    model = model.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[25, 50, 75],
                                                     gamma=0.2)

    def bce_loss_fn(y_pred, y):
        bce = F.binary_cross_entropy(y_pred, y, reduction="sum").div(
            num_classes * y_pred.shape[0])
        return bce

    def output_transform_bce(output):
        y_pred, y, x = output

        y = F.one_hot(y, num_classes).float()

        return y_pred, y

    def output_transform_acc(output):
        y_pred, y, x = output

        return y_pred, y

    def output_transform_gp(output):
        y_pred, y, x = output

        return x, y_pred

    def calc_gradients_input(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        return gradients

    def calc_gradient_penalty(x, y_pred):
        gradients = calc_gradients_input(x, y_pred)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1)**2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        if l_gradient_penalty > 0:
            x.requires_grad_(True)

        z, y_pred = model(x)
        y = F.one_hot(y, num_classes).float()

        loss = bce_loss_fn(y_pred, y)

        if l_gradient_penalty > 0:
            loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred)

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        z, y_pred = model(x)

        return y_pred, y, x

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "loss")

    metric = Accuracy(output_transform=output_transform_acc)
    metric.attach(evaluator, "accuracy")

    metric = Loss(F.binary_cross_entropy,
                  output_transform=output_transform_bce)
    metric.attach(evaluator, "bce")

    metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp)
    metric.attach(evaluator, "gradient_penalty")

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1000,
                                             shuffle=False,
                                             **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1000,
                                              shuffle=False,
                                              **kwargs)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        loss = metrics["loss"]

        print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f} ")

        writer.add_scalar("Loss/train", loss, trainer.state.epoch)

        if trainer.state.epoch % 5 == 0 or trainer.state.epoch > 65:
            accuracy, auroc = get_cifar_svhn_ood(model)
            print(f"Test Accuracy: {accuracy}, AUROC: {auroc}")
            writer.add_scalar("OoD/test_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch)

            accuracy, auroc = get_auroc_classification(val_dataset, model)
            print(f"AUROC - uncertainty: {auroc}")
            writer.add_scalar("OoD/val_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc_classification", auroc,
                              trainer.state.epoch)

        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        bce = metrics["bce"]
        GP = metrics["gradient_penalty"]
        loss = bce + l_gradient_penalty * GP

        print((f"Valid - Epoch: {trainer.state.epoch} "
               f"Acc: {acc:.4f} "
               f"Loss: {loss:.2f} "
               f"BCE: {bce:.2f} "
               f"GP: {GP:.2f} "))

        writer.add_scalar("Loss/valid", loss, trainer.state.epoch)
        writer.add_scalar("BCE/valid", bce, trainer.state.epoch)
        writer.add_scalar("GP/valid", GP, trainer.state.epoch)
        writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch)

        print(f"Centroid norm: {torch.norm(model.m / model.N, dim=0)}")

        scheduler.step()

        if trainer.state.epoch > 65:
            torch.save(model.state_dict(),
                       f"saved_models/{name}_{trainer.state.epoch}.pt")

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    trainer.run(train_loader, max_epochs=epochs)

    evaluator.run(test_loader)
    acc = evaluator.state.metrics["accuracy"]

    print(f"Test - Accuracy {acc:.4f}")

    writer.close()
Example #11
0
def main(
    architecture,
    batch_size,
    length_scale,
    centroid_size,
    learning_rate,
    l_gradient_penalty,
    gamma,
    weight_decay,
    final_model,
    output_dir,
):
    writer = SummaryWriter(log_dir=f"runs/{output_dir}")

    ds = all_datasets["CIFAR10"]()
    input_size, num_classes, dataset, test_dataset = ds

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        val_size = int(len(dataset) * 0.8)
        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

        val_dataset.transform = (test_dataset.transform
                                 )  # Test time preprocessing for validation

    if architecture == "WRN":
        model_output_size = 640
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = WideResNet()
    elif architecture == "ResNet18":
        model_output_size = 512
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet18()
    elif architecture == "ResNet50":
        model_output_size = 2048
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet50()
    elif architecture == "ResNet110":
        model_output_size = 2048
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet110()
    elif architecture == "DenseNet121":
        model_output_size = 1024
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = densenet121()

        # Adapted resnet from:
        # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
        feature_extractor.conv1 = torch.nn.Conv2d(3,
                                                  64,
                                                  kernel_size=3,
                                                  stride=1,
                                                  padding=1,
                                                  bias=False)
        feature_extractor.maxpool = torch.nn.Identity()
        feature_extractor.fc = torch.nn.Identity()

    if centroid_size is None:
        centroid_size = model_output_size

    model = ResNet_DUQ(
        feature_extractor,
        num_classes,
        centroid_size,
        model_output_size,
        length_scale,
        gamma,
    )
    model = model.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=milestones,
                                                     gamma=0.2)

    def calc_gradients_input(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        return gradients

    def calc_gradient_penalty(x, y_pred):
        gradients = calc_gradients_input(x, y_pred)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1)**2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        y = F.one_hot(y, num_classes).float()

        loss = F.binary_cross_entropy(y_pred, y, reduction="mean")

        if l_gradient_penalty > 0:
            gp = calc_gradient_penalty(x, y_pred)
            loss += l_gradient_penalty * gp

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        return {"x": x, "y": y, "y_pred": y_pred}

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "loss")

    metric = Accuracy(output_transform=lambda out: (out["y_pred"], out["y"]))
    metric.attach(evaluator, "accuracy")

    def bce_output_transform(out):
        return (out["y_pred"], F.one_hot(out["y"], num_classes).float())

    metric = Loss(F.binary_cross_entropy,
                  output_transform=bce_output_transform)
    metric.attach(evaluator, "bce")

    metric = Loss(calc_gradient_penalty,
                  output_transform=lambda out: (out["x"], out["y_pred"]))
    metric.attach(evaluator, "gradient_penalty")

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              **kwargs)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        loss = metrics["loss"]

        print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f}")

        writer.add_scalar("Loss/train", loss, trainer.state.epoch)

        if trainer.state.epoch > (epochs - 5):
            accuracy, auroc = get_cifar_svhn_ood(model)
            print(f"Test Accuracy: {accuracy}, AUROC: {auroc}")
            writer.add_scalar("OoD/test_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch)

            accuracy, auroc = get_auroc_classification(val_dataset, model)
            print(f"AUROC - uncertainty: {auroc}")
            writer.add_scalar("OoD/val_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc_classification", auroc,
                              trainer.state.epoch)

        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        bce = metrics["bce"]
        GP = metrics["gradient_penalty"]
        loss = bce + l_gradient_penalty * GP

        print((f"Valid - Epoch: {trainer.state.epoch} "
               f"Acc: {acc:.4f} "
               f"Loss: {loss:.2f} "
               f"BCE: {bce:.2f} "
               f"GP: {GP:.2f} "))

        writer.add_scalar("Loss/valid", loss, trainer.state.epoch)
        writer.add_scalar("BCE/valid", bce, trainer.state.epoch)
        writer.add_scalar("GP/valid", GP, trainer.state.epoch)
        writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch)

        scheduler.step()

    trainer.run(train_loader, max_epochs=epochs)
    evaluator.run(test_loader)
    acc = evaluator.state.metrics["accuracy"]

    print(f"Test - Accuracy {acc:.4f}")

    torch.save(model.state_dict(), f"runs/{output_dir}/model.pt")
    writer.close()
Example #12
0
def get_train_mean_std(train_dataset, unique_id="", cache_dir="/tmp/unosat/"):
    # # Ensure that only process 0 in distributed performs the computation, and the others will use the cache
    # if dist.get_rank() > 0:
    #     torch.distributed.barrier()  # synchronization point for all processes > 0

    cache_dir = Path(cache_dir)
    if not cache_dir.exists():
        cache_dir.mkdir(parents=True)

    if len(unique_id) > 0:
        unique_id += "_"

    fp = cache_dir / "train_mean_std_{}{}.pth".format(len(train_dataset),
                                                      unique_id)

    if fp.exists():
        mean_std = torch.load(fp.as_posix())
    else:
        if dist.is_available() and dist.is_initialized():
            raise RuntimeError(
                "Current implementation of Mean/Std computation is not working in distrib config"
            )

        from ignite.engine import Engine
        from ignite.metrics import Average
        from ignite.contrib.handlers import ProgressBar
        from albumentations.pytorch import ToTensorV2

        train_dataset = TransformedDataset(train_dataset,
                                           transform_fn=ToTensorV2())
        train_loader = DataLoader(train_dataset,
                                  shuffle=False,
                                  drop_last=False,
                                  batch_size=16,
                                  num_workers=10,
                                  pin_memory=False)

        def compute_mean_std(engine, batch):
            b, c, *_ = batch['image'].shape
            data = batch['image'].reshape(b, c, -1).to(dtype=torch.float64)
            mean = torch.mean(data, dim=-1)
            mean2 = torch.mean(data**2, dim=-1)

            return {
                "mean": mean,
                "mean^2": mean2,
            }

        compute_engine = Engine(compute_mean_std)
        ProgressBar(desc="Compute Mean/Std").attach(compute_engine)
        img_mean = Average(output_transform=lambda output: output['mean'])
        img_mean2 = Average(output_transform=lambda output: output['mean^2'])
        img_mean.attach(compute_engine, 'mean')
        img_mean2.attach(compute_engine, 'mean2')
        state = compute_engine.run(train_loader)
        state.metrics['std'] = torch.sqrt(state.metrics['mean2'] -
                                          state.metrics['mean']**2)
        mean_std = {'mean': state.metrics['mean'], 'std': state.metrics['std']}

        # if dist.get_rank() < 1:
        torch.save(mean_std, fp.as_posix())

    # if dist.get_rank() < 1:
    #     torch.distributed.barrier()  # synchronization point for process 0

    return mean_std['mean'].tolist(), mean_std['std'].tolist()