Example #1
0
def attach_metrics(engine: Engine):
    Average(output_transform=lambda d: d["loss"]).attach(engine, "loss")
    Average(output_transform=lambda d: d["kl_div"]).attach(engine, "kl_div")
    Average(
        output_transform=lambda d: d["recon_loss"]).attach(
        engine,
        "recon_loss")
    def __init__(self, trainer, res_dir='results', **kwargs):
        self.trainer = trainer
        self.start_datetime = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        self.res_dir = Path(res_dir) / self.start_datetime
        self.res_dir.mkdir(parents=True)

        metric_loss = Average()
        metric_loss.attach(self.trainer, 'loss')
Example #3
0
def attach_pbar_and_metrics(trainer, evaluator):
    loss_metric = Average(output_transform=lambda output: output["loss"])
    accuracy_metric = Accuracy(
        output_transform=lambda output: (output["logit"], output["label"]))
    pbar = ProgressBar()
    loss_metric.attach(trainer, "loss")
    accuracy_metric.attach(trainer, "accuracy")
    accuracy_metric.attach(evaluator, "accuracy")
    pbar.attach(trainer)
    def setup_training(self):
        assert self.batch_size is not None
        trainer = Engine(lambda e, b: self.train_step(b))
        trainer.register_events("EVAL_DONE")
        Average(lambda o: o['loss']).attach(trainer, 'avg_loss')
        state_vars = dict(model=self.model, opt=self.opt, trainer=trainer)
        checkpoint_handler = ModelCheckpoint(self.run_path, '', score_function=lambda e: e.state.metrics['val_accuracy'],
                                             score_name='val_accuracy', n_saved=2, global_step_transform=lambda e, evt_name: e.state.epoch)
        if checkpoint_handler.last_checkpoint:
            checkpoint_handler.load_objects(state_vars, self.run_path / checkpoint_handler.last_checkpoint)
        trainer.add_event_handler("EVAL_DONE", lambda e: checkpoint_handler(e, state_vars))
        if self.use_lr_decay:
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda e: self.lr_decay.step(e.state.iteration * self.batch_size))

        RunningAverage(output_transform=lambda o: o['loss']).attach(trainer, 'running_avg_loss')
        ProgressBar().attach(trainer, ['running_avg_loss'])
        logger.setup_logger(self.run_path, trainer, self.model)

        @trainer.on(Events.EPOCH_COMPLETED)
        def eval_and_log(e: Engine):
            eval_results = self.eval()
            e.state.metrics['val_accuracy'] = eval_results['val'].metrics['accuracy'] 
            e.state.metrics['val_loss'] = eval_results['val'].metrics['avg_loss']
            e.state.eval_results = eval_results
            e.fire_event("EVAL_DONE")

        if self.use_early_stop:
            es = self.make_early_stopper(trainer)
            trainer.add_event_handler("EVAL_DONE", es)

        return trainer
Example #5
0
def create_evaluator(model, tasks, args):
    model.to(args.device)

    @torch.no_grad()
    def evaluator_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(args.device)
        y = [_y.to(device=args.device) for _y in y]

        losses = {}
        preds = model(x)
        for rep_i, task_i in zip(preds, tasks):
            losses[f'loss_{task_i.name}'] = task_i.loss(
                rep_i, y[task_i.index]).mean(dim=0)
            losses[f'metric_{task_i.name}'] = task_i.metric(
                rep_i, y[task_i.index]).mean(dim=0)

        preds = [pred_i.detach().clone() for pred_i in preds]
        return losses, y, preds

    evaluator = Engine(evaluator_step)
    for task_i in tasks:
        for prefix in ['metric', 'loss']:
            name = f'{prefix}_{task_i.name}'
            output_transform = partial(lambda name, x: x[0][name], name)
            Average(output_transform=output_transform).attach(evaluator, name)

    return evaluator
Example #6
0
def create_classification_trainer(
    model,
    optimizer,
    loss_fn,
    device=None,
    non_blocking=False,
    prepare_batch=_prepare_batch,
    output_transform=lambda x, y, y_pred, loss: loss.item()):  # noqa
    """
    Factory function for creating a trainer for supervised models.
    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_y)`.
        output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
            to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
    Note: `engine.state.output` for this engine is defind by `output_transform` parameter and is the loss
        of the processed batch by default.
    Returns:
        Engine: a trainer engine with supervised update function.
    """
    if device:
        model.to(device)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        return output_transform(x, y, y_pred, loss)

    engine = Engine(_update)
    metric_loss = Average()
    metric_loss.attach(engine, 'loss')
    return engine
Example #7
0
def run(model,
        train_loader,
        val_loader,
        optimizer,
        epochs,
        log_interval,
        log_dir,
        val=False,
        log=True):
    writer = create_summary_writer(log_dir)
    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'
    loss_metric = Average()
    trainer = create_train_engine(model, optimizer, device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'loss': loss_metric},
                                            device=device)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        # print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
        #      "".format(engine.state.epoch, engine.state.iteration, len(train_loader), engine.state.output))
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)

    if log:

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_training_results(engine):
            evaluator.run(train_loader)
            metrics = evaluator.state.metrics
            avg_mse = metrics['loss']
            print("Training Results - Epoch: {}   Avg loss: {:.2f}".format(
                engine.state.epoch, avg_mse))
            writer.add_scalar("training/avg_loss", avg_mse, engine.state.epoch)

    if val:

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            evaluator.run(val_loader)
            metrics = evaluator.state.metrics
            avg_mse = metrics['loss']
            # print("Validation Results - Epoch: {}  Avg loss: {:.2f}"
            #      .format(engine.state.epoch, avg_mse))
            writer.add_scalar("validation/avg_loss", avg_mse,
                              engine.state.epoch)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    writer.close()
Example #8
0
def main(ctx, config_file, dataset_root, res_root_dir, debug, device,
         num_workers, **kwargs):
    with open(config_file) as stream:
        config = yaml.safe_load(stream)

    train_transforms = get_transforms(config['train_augment'])
    val_transforms = get_transforms(config['val_augment'])
    train_loader, val_loader = get_loaders(train_transforms=train_transforms,
                                           val_transforms=val_transforms,
                                           dataset_root=dataset_root,
                                           num_workers=num_workers,
                                           **config['dataset'])
    label_names = get_labels(train_loader)
    net, criterion = get_model(n_class=len(label_names), **config['model'])
    optimizer = get_optimizer(net, **config['optimizer'])

    trainer = create_supervised_trainer(net, optimizer, criterion, device,
                                        prepare_batch=prepare_batch)
    metric_loss = Average()
    metric_loss.attach(trainer, 'loss')
    metrics = get_metrics(label_names, config['evaluate'])
    metric_names = list(metrics.keys())
    evaluator = create_supervised_evaluator(net, metrics, device,
                                            prepare_batch=prepare_batch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        evaluator.run(val_loader)

    res_dir = Path(res_root_dir) / config['dataset']['dataset_name']
    train_extend = TrainExtension(trainer, evaluator, res_dir)
    train_extend.print_metrics(metric_names)
    train_extend.set_progressbar()
    train_extend.schedule_lr(optimizer, **config['lr_schedule'])
    if not debug:
        train_extend.copy_configs(config_file)
        train_extend.set_tensorboard(metric_names)
        train_extend.save_model(net, **config['model_checkpoint'])
        train_extend.show_config_on_tensorboard(config)

    trainer.run(train_loader, max_epochs=config['epochs'])
Example #9
0
def create_evaluator(classifier, device):
    classifier.to(device)

    def update_fn(engine, batch):
        classifier.eval()
        with torch.no_grad():
            # batch = [elem.to(device) for elem in batch]
            x, y = [elem.to(device) for elem in batch]
            _, metrics, pred_y = classifier(x, y)
            return metrics, pred_y, y
    evaluator = Engine(update_fn)

    for key in classifier.metrics_keys:
        Average(output_transform=DictOutputTransform(key)).attach(evaluator, key)
    return evaluator
Example #10
0
    def trainer(self):
        if not self._trainer:
            trainer = Engine(self.train_update)
            metrics = {
                "sim_acc":
                SiamSimAccuracy(margin=self.margin,
                                output_transform=lambda x:
                                (x['emb_vecs'], x['targets'])),
                "clsf_acc":
                Accuracy(
                    output_transform=lambda x: (x['cls_pred'], x['cls_true'])),
                "loss":
                Average(output_transform=lambda x: x["loss"]),
                "con_loss":
                Average(output_transform=lambda x: x["con_loss"]),
                "clsf_loss":
                Average(output_transform=lambda x: x["clsf_loss"])
            }
            for name, metric in metrics.items():
                metric.attach(trainer, name)
            self._trainer = trainer
            self.train_metrics = metrics

        return self._trainer
Example #11
0
def get_evaluator(net: nn.Module, device: torch.device) -> Engine:
    def _inference(engine: Engine, batch: Batch) -> Tuple[T, T, T, T]:
        x1, x2, t = prepare_batch(batch, device)
        net.eval()
        s1, s2 = net(x1), net(x2)
        return s1, s2, t, ranknet_loss(s1, s2, t)

    def _acc_output_transform(output: Tuple[T, T, T, T]) -> Tuple[T, T]:
        s1, s2, t, _ = output
        return (s1 - s2 > 0).long(), (t > 0.5).long()

    ev = Engine(_inference)
    Average(lambda t: t[3].item()).attach(ev, "loss")
    Accuracy(_acc_output_transform).attach(ev, "accuracy")

    return ev
Example #12
0
def create_trainer(classifier, optimizer, device):
    classifier.to(device)

    def update_fn(engine, batch):
        classifier.train()
        optimizer.zero_grad()
        # batch = [elem.to(device) for elem in batch]
        x, y = [elem.to(device) for elem in batch]
        loss, metrics, pred_y = classifier(x, y)
        loss.backward()
        optimizer.step()
        return metrics, pred_y, y
    trainer = Engine(update_fn)

    for key in classifier.metrics_keys:
        Average(output_transform=DictOutputTransform(key)).attach(trainer, key)
    return trainer
Example #13
0
def create_trainer(classifier, optimizer, device,w1,w2,w3):
    classifier.to(device)

    def update_fn(engine, batch):
#         print(engine,batch)
        classifier.train()
        optimizer.zero_grad()
        # batch = [elem.to(device) for elem in batch]
        x, y = [elem.to(device) for elem in batch]
        x = x.to(device,dtype = torch.float)
        y = y.to(device,dtype = torch.long)

        preds = classifier(x)
        loss_grapheme = F.cross_entropy(preds[0], y[:,0])
        loss_vowel = F.cross_entropy(preds[1], y[:,1])
        loss_consonant = F.cross_entropy(preds[2], y[:,2])
        loss = (loss_grapheme*w1 + loss_vowel*w2 + loss_consonant*w3)/(w1+w2+w3)

        metrics = {
                'loss': loss.item(),
                'loss_grapheme': loss_grapheme.item(),
                'loss_vowel': loss_vowel.item(),
                'loss_consonant': loss_consonant.item(),
                'acc_grapheme': accuracy(preds[0], y[:,0]),
                'acc_vowel': accuracy(preds[1], y[:,1]),
                'acc_consonant': accuracy(preds[2], y[:,2]),
        }

        loss.backward()
        optimizer.step()
        return metrics, torch.cat(preds,dim=1), y
    trainer = Engine(update_fn)
        
#         loss, metrics, pred_y = classifier(x, y)
#         loss.backward()
#         optimizer.step()
#         return metrics, pred_y, y
#     trainer = Engine(update_fn)

    for key in classifier.metrics_keys:
        Average(output_transform=DictOutputTransform(key)).attach(trainer, key)
    return trainer
Example #14
0
def create_evaluator(classifier, device):
    classifier.to(device)
    metrics_keys = classifier.metrics_keys
    classifier = torch.nn.DataParallel(classifier)

    def update_fn(engine, batch):
        classifier.eval()
        with torch.no_grad():
            x = batch[0].to(device)
            y = batch[1][0].to(device)
            lams = batch[1][1].to(device)
            _, metrics, pred_y = classifier(x, y, lams)
            return metrics, pred_y, get_y_main(y, lams)  # nyan

    evaluator = Engine(update_fn)

    for key in metrics_keys:
        Average(output_transform=DictOutputTransform(key)).attach(
            evaluator, key)
    return evaluator
Example #15
0
def create_evaluator(model, criterion, cfg):
    def _validation_step(_, batch):
        model.eval()
        with torch.no_grad():
            x, y = batch_to_tensor(batch, cfg)
            x, y = x.to(cfg.device), y.to(cfg.device)

            y_pred, hidden = model(x)
            loss = criterion(y_pred, y)

            if cfg.multi_label:
                y_pred = (y_pred > 0).float()

            return y_pred, y, loss, hidden

    evaluator = Engine(_validation_step)

    accuracy = Accuracy(lambda x: x[0:2], is_multilabel=cfg.multi_label)
    accuracy.attach(evaluator, "acc")

    precision = Precision(lambda x: x[0:2],
                          average=False,
                          is_multilabel=cfg.multi_label)
    precision.attach(evaluator, 'precision')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  precision).attach(evaluator, "MP")

    recall = Recall(lambda x: x[0:2],
                    average=False,
                    is_multilabel=cfg.multi_label)
    recall.attach(evaluator, 'recall')
    MetricsLambda(lambda t: torch.mean(t).item(),
                  recall).attach(evaluator, "MR")

    F1 = 2. * precision * recall / (precision + recall + 1e-20)
    f1 = MetricsLambda(lambda t: torch.mean(t).item(), F1)
    f1.attach(evaluator, "F1")

    Average(lambda x: x[2]).attach(evaluator, 'loss')

    return evaluator
Example #16
0
    def setup_evaluator(self):
        validator = Engine(lambda e, b: self.pred_calc_loss(b))
        self.model.get_probabilities = True

        def get_y_pred_y(o: dict):
            """
            :param o: output of forward method
            :return: tuple (y_pred, y) both with shape [batch_size * query_size, num_classes] in OHE
            """
            y_pred, y = o['prob_query'].argmax(1).flatten(), o['batch']['test'][1].flatten()
            num_classes = int(y.max().item())+1
            y_pred = one_hot(y_pred, num_classes)
            y = one_hot(y, num_classes)
            assert y_pred.shape == y.shape, f"{y_pred.shape} != {y.shape}"
            return y_pred, y

        RunningAverage(output_transform=lambda o: o['loss']).attach(validator, 'running_avg_loss')
        Average(lambda o: o['loss']).attach(validator, 'avg_loss')
        ProgressBar().attach(validator, ['running_avg_loss'])
        Accuracy(output_transform=get_y_pred_y, is_multilabel=True).attach(validator, 'accuracy')
        return validator
Example #17
0
def create_trainer(classifier, optimizer, device):
    classifier.to(device)
    metrics_keys = classifier.metrics_keys
    classifier = torch.nn.DataParallel(classifier)

    def update_fn(engine, batch):
        classifier.train()
        optimizer.zero_grad()
        x = batch[0].to(device)
        y = batch[1][0].to(device)
        lams = batch[1][1].to(device)
        loss, metrics, pred_y = classifier(x, y, lams)
        loss.sum().backward()
        optimizer.step()
        return metrics, pred_y, get_y_main(y, lams)

    trainer = Engine(update_fn)

    for key in metrics_keys:
        Average(output_transform=DictOutputTransform(key)).attach(trainer, key)
    return trainer
Example #18
0
def create_evaluator(classifier, device):
    classifier.to(device)

    def update_fn(engine, batch):
        classifier.eval()

        with torch.no_grad():
            # batch = [elem.to(device) for elem in batch]
            x, y = [elem.to(device) for elem in batch]
            x = x.to(device,dtype = torch.float)
            y = y.to(device,dtype = torch.long)

            preds = classifier(x)
            loss_grapheme = F.cross_entropy(preds[0], y[:,0])
            loss_vowel = F.cross_entropy(preds[1], y[:,1])
            loss_consonant = F.cross_entropy(preds[2], y[:,2])
            loss = loss_grapheme + loss_vowel + loss_consonant

            metrics = {
                'loss': loss.item(),
                'loss_grapheme': loss_grapheme.item(),
                'loss_vowel': loss_vowel.item(),
                'loss_consonant': loss_consonant.item(),
                'acc_grapheme': accuracy(preds[0], y[:,0]),
                'acc_vowel': accuracy(preds[1], y[:,1]),
                'acc_consonant': accuracy(preds[2], y[:,2]),
            }
            return metrics, torch.cat(preds,dim=1), y
    evaluator = Engine(update_fn)  
#             _, metrics, pred_y = classifier(x, y)
#             return metrics, pred_y, y
    # evaluator = Engine(update_fn)

    for key in classifier.metrics_keys:
        Average(output_transform=DictOutputTransform(key)).attach(evaluator, key)
    return evaluator
def train(): 
    parser = ArgumentParser()
    parser.add_argument("--train_path", type=str, default='data/spolin-train-acl.json', help="Set data path")    
    parser.add_argument("--valid_path", type=str, default='data/spolin-valid.json', help="Set data path")     

    parser.add_argument("--correct_bias", type=bool, default=False, help="Set to true to correct bias for Adam optimizer")
    parser.add_argument("--lr", type=float, default=2e-5, help="Set learning rate")
    parser.add_argument("--n_epochs", type=int, default=4, help="Set number of epochs")
    parser.add_argument("--num_warmup_steps", type=float, default=1000, help="Set number of warm-up steps")
    parser.add_argument("--num_total_steps", type=float, default=10000, help="Set number of total steps")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Set maximum gradient normalization.")
    parser.add_argument("--pretrained_path", type=str, default='bert-base-uncased', help="Choose which pretrained model to use (bert-base-uncased, roberta-base, roberta-large, roberta-large-mnli)")    
    parser.add_argument("--batch_size", type=int, default=32, help="Provide the batch size")    
    parser.add_argument("--random_seed", type=int, default=42, help="Set the random seed")
    parser.add_argument("--test", action='store_true', help="If true, run with small dataset for testing code")
    parser.add_argument("--base", action='store_true', help="If true, run with base experiment configuration (training with spont only) for comparison")

    args = parser.parse_args() 

    logging.basicConfig(level=logging.INFO)
    logger.info("Arguments: {}".format(pformat(args)))

    if 'roberta' in args.pretrained_path: 
        # initialize tokenizer and model 
        logger.info("Initialize model and tokenizer.")
        tokenizer = RobertaTokenizer.from_pretrained(args.pretrained_path, cache_dir = '../pretrained_models')
        model = RobertaForSequenceClassification.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')

        ### START MODEL MODIFICATION
        # Pretrained model was not trained with token type ids. 
        # fix token type embeddings for finetuning. Without this, the model can only take 0s as valid input for token_type_ids 
        model.config.type_vocab_size = 2 
        model.roberta.embeddings.token_type_embeddings = torch.nn.Embedding(2, model.config.hidden_size)
        model.roberta.embeddings.token_type_embeddings.weight.data.normal_(mean=0.0, std=model.config.initializer_range)

        ### END MOD
    elif 'bert' in args.pretrained_path: 
        model = BertForSequenceClassification.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')
        tokenizer = BertTokenizer.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')

    model.to(args.device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                        lr=args.lr,
                        correct_bias = args.correct_bias)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.num_warmup_steps, t_total=args.num_total_steps) 

    logger.info("Prepare datasets")
    logger.info("Loading train set...")

    train_data = get_data(args.train_path)
    valid_data = get_data(args.valid_path)

    cornell_valid_data = {k: {'cornell': valid_data[k]['cornell']} for k in valid_data.keys()}
    spont_valid_data = {k: {'spont': valid_data[k]['spont']} for k in valid_data.keys()}

    train_loader, train_sampler = get_data_loaders(args, train_data, args.train_path, tokenizer)
    logger.info("Loading validation set...")
    valid_p = Path(args.valid_path)
    cornell_valid_loader, cornell_valid_sampler = get_data_loaders(args, cornell_valid_data, f"{str(valid_p.parent)}/cornell_{valid_p.name}",  tokenizer)
    spont_valid_loader, spont_valid_sampler = get_data_loaders(args, spont_valid_data, f"{str(valid_p.parent)}/spont_{valid_p.name}", tokenizer)


    # Training function and trainer 
    def update(engine, batch): 
        model.train() 

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        b_input_ids, b_input_mask, b_input_segment, b_labels = batch

        optimizer.zero_grad()
        #roberta has issues with token_type_ids 
        loss, logits = model(b_input_ids, token_type_ids=b_input_segment, attention_mask=b_input_mask, labels=b_labels)
        # loss, logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)


        loss.backward() 
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        
        optimizer.step() 
        scheduler.step() 

        return loss.item(), logits, b_labels

    trainer = Engine(update)     

    # Evaluation function and evaluator 
    def inference(engine, batch): 
        model.eval() 

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        b_input_ids, b_input_mask, b_input_segment, b_labels = batch
        
        with torch.no_grad(): 
            #roberta has issues with token_type_ids 
            # loss, logits = model(b_input_ids, token_type_ids = None, attention_mask=b_input_mask, labels=b_labels)
            loss, logits = model(b_input_ids, token_type_ids = b_input_segment, attention_mask=b_input_mask, labels=b_labels)
            label_ids = b_labels

        return logits, label_ids, loss.item()
    cornell_evaluator = Engine(inference)
    spont_evaluator = Engine(inference)


    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: cornell_evaluator.run(cornell_valid_loader))
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: spont_evaluator.run(spont_valid_loader))


    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss") 
    RunningAverage(Accuracy(output_transform=lambda x: (x[1], x[2]))).attach(trainer, "accuracy")
    if torch.cuda.is_available(): 
        GpuInfo().attach(trainer, name='gpu')

    recall = Recall(output_transform=lambda x: (x[0], x[1]))
    precision = Precision(output_transform=lambda x: (x[0], x[1]))
    F1 = (precision * recall * 2 / (precision + recall)).mean()
    accuracy = Accuracy(output_transform=lambda x: (x[0], x[1]))
    metrics = {"recall": recall, "precision": precision, "f1": F1, "accuracy": accuracy, "loss": Average(output_transform=lambda x: x[2])}

    for name, metric in metrics.items(): 
        metric.attach(cornell_evaluator, name) 
        metric.attach(spont_evaluator, name) 


    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=['loss', 'accuracy'])
    pbar.attach(trainer, metric_names=['gpu:0 mem(%)', 'gpu:0 util(%)'])
    
    cornell_evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Cornell validation metrics:\n %s" % pformat(cornell_evaluator.state.metrics)))
    spont_evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Spont validation metrics:\n %s" % pformat(spont_evaluator.state.metrics)))


    tb_logger = TensorboardLogger(log_dir=None)
    tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
    tb_logger.attach(cornell_evaluator, log_handler=OutputHandler(tag="valid", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(spont_evaluator, log_handler=OutputHandler(tag="valid", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)


    # tb_logger.writer.log_dir -> tb_logger.writer.logdir (this is the correct attribute name as seen in: https://tensorboardx.readthedocs.io/en/latest/_modules/tensorboardX/writer.html#SummaryWriter)
    checkpoint_handler = ModelCheckpoint(tb_logger.writer.logdir, 'checkpoint', save_interval=1, n_saved=5)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

    torch.save(args, tb_logger.writer.logdir + '/model_training_args.bin')
    getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.logdir, CONFIG_NAME))
    tokenizer.save_vocabulary(tb_logger.writer.logdir)

    trainer.run(train_loader, max_epochs = args.n_epochs)

    if args.n_epochs > 0: 
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.logdir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Example #20
0
def train(model,
          train_loader,
          eval_loaders,
          optimizer,
          loss_fn,
          n_it_max,
          patience,
          split_names,
          select_metric='Val accuracy_0',
          select_mode='max',
          viz=None,
          device='cpu',
          lr_scheduler=None,
          name=None,
          log_steps=None,
          log_epoch=False,
          _run=None,
          prepare_batch=_prepare_batch,
          single_pass=False,
          n_ep_max=None):

    # print(model)

    if not log_steps and not log_epoch:
        logger.warning('/!\\ No logging during training /!\\')

    if log_steps is None:
        log_steps = []

    epoch_steps = len(train_loader)
    if log_epoch:
        log_steps.append(epoch_steps)

    if single_pass:
        max_epoch = 1
    elif n_ep_max is None:
        assert n_it_max is not None
        max_epoch = int(n_it_max / epoch_steps) + 1
    else:
        assert n_it_max is None
        max_epoch = n_ep_max

    all_metrics = defaultdict(dict)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device,
                                        prepare_batch=prepare_batch)

    if hasattr(model, 'new_epoch_hook'):
        trainer.add_event_handler(Events.EPOCH_STARTED, model.new_epoch_hook)
    if hasattr(model, 'new_iter_hook'):
        trainer.add_event_handler(Events.ITERATION_STARTED,
                                  model.new_iter_hook)

    trainer._logger.setLevel(logging.WARNING)

    # trainer output is in the format (x, y, y_pred, loss, optionals)
    train_loss = RunningAverage(output_transform=lambda out: out[3].item(),
                                epoch_bound=True)
    train_loss.attach(trainer, 'Trainer loss')
    if hasattr(model, 's'):
        met = Average(output_transform=lambda _: float('nan')
                      if model.s is None else model.s)
        met.attach(trainer, 'cur_s')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed,
                                  'cur_s')

    if hasattr(model, 'arch_sampler') and model.arch_sampler.distrib_dim > 0:
        met = Average(output_transform=lambda _: float('nan')
                      if model.cur_split is None else model.cur_split)
        met.attach(trainer, 'Trainer split')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed,
                                  'Trainer split')
        # trainer.add_event_handler(Events.EPOCH_STARTED, met.started)
        all_ent = Average(
            output_transform=lambda out: out[-1]['arch_entropy_avg'].item())
        all_ent.attach(trainer, 'Trainer all entropy')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  all_ent.completed, 'Trainer all entropy')
        train_ent = Average(
            output_transform=lambda out: out[-1]['arch_entropy_sample'].item())
        train_ent.attach(trainer, 'Trainer sampling entropy')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  train_ent.completed,
                                  'Trainer sampling entropy')
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, lambda engine: model.check_arch_freezing(
                ent=train_ent.compute(),
                epoch=engine.state.iteration / (epoch_steps * max_epoch)))

        def log_always(engine, name):
            val = engine.state.output[-1][name]
            all_metrics[name][engine.state.iteration /
                              epoch_steps] = val.mean().item()

        def log_always_dict(engine, name):
            for node, val in engine.state.output[-1][name].items():
                all_metrics['node {} {}'.format(
                    node, name)][engine.state.iteration /
                                 epoch_steps] = val.mean().item()

        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always_dict,
                                  name='arch_grads')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always_dict,
                                  name='arch_probas')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always_dict,
                                  name='node_grads')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always,
                                  name='task all_loss')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always,
                                  name='arch all_loss')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always,
                                  name='entropy all_loss')

    if n_it_max is not None:
        StopAfterIterations([n_it_max]).attach(trainer)
    # epoch_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name,
    #                          persist=True, disable=not (_run or viz))
    # epoch_pbar.attach(trainer, metric_names=['Train loss'])
    #
    # training_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name,
    #                             persist=True, disable=not (_run or viz))
    # training_pbar.attach(trainer, event_name=Events.EPOCH_COMPLETED,
    #                      closing_event_name=Events.COMPLETED)
    total_time = Timer(average=False)
    eval_time = Timer(average=False)
    eval_time.pause()
    data_time = Timer(average=False)
    forward_time = Timer(average=False)
    forward_time.attach(trainer,
                        start=Events.EPOCH_STARTED,
                        pause=Events.ITERATION_COMPLETED,
                        resume=Events.ITERATION_STARTED,
                        step=Events.ITERATION_COMPLETED)
    epoch_time = Timer(average=False)
    epoch_time.attach(trainer,
                      start=Events.EPOCH_STARTED,
                      pause=Events.EPOCH_COMPLETED,
                      resume=Events.EPOCH_STARTED,
                      step=Events.EPOCH_COMPLETED)

    def get_loss(y_pred, y):
        l = loss_fn(y_pred, y)
        if not torch.is_tensor(l):
            l, *l_details = l
        return l.mean()

    def get_member(x, n=0):
        if isinstance(x, (list, tuple)):
            return x[n]
        return x

    eval_metrics = {'loss': Loss(get_loss)}

    for i in range(model.n_out):
        out_trans = get_attr_transform(i)

        def extract_ys(out):
            x, y, y_pred, loss, _ = out
            return out_trans((y_pred, y))

        train_acc = Accuracy(extract_ys)
        train_acc.attach(trainer, 'Trainer accuracy_{}'.format(i))
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  train_acc.completed,
                                  'Trainer accuracy_{}'.format(i))
        eval_metrics['accuracy_{}'.format(i)] = \
            Accuracy(output_transform=out_trans)
        # if isinstance(model, SSNWrapper):
        #     model.arch_sampler.entropy().mean()

    evaluator = create_supervised_evaluator(model,
                                            metrics=eval_metrics,
                                            device=device,
                                            prepare_batch=prepare_batch)
    last_iteration = 0
    patience_counter = 0

    best = {
        'value': float('inf') * 1 if select_mode == 'min' else -1,
        'iter': -1,
        'state_dict': None
    }

    def is_better(new, old):
        if select_mode == 'min':
            return new < old
        else:
            return new > old

    def log_results(evaluator, data_loader, iteration, split_name):
        evaluator.run(data_loader)
        metrics = evaluator.state.metrics

        log_metrics = {}

        for metric_name, metric_val in metrics.items():
            log_name = '{} {}'.format(split_name, metric_name)
            if viz:
                first = iteration == 0 and split_name == split_names[0]
                viz.line(
                    [metric_val],
                    X=[iteration],
                    win=metric_name,
                    name=log_name,
                    update=None if first else 'append',
                    opts={
                        'title': metric_name,
                        'showlegend': True,
                        'width': 500,
                        'xlabel': 'iterations'
                    })
                viz.line(
                    [metric_val],
                    X=[iteration / epoch_steps],
                    win='{}epoch'.format(metric_name),
                    name=log_name,
                    update=None if first else 'append',
                    opts={
                        'title': metric_name,
                        'showlegend': True,
                        'width': 500,
                        'xlabel': 'epoch'
                    })
            if _run:
                _run.log_scalar(log_name, metric_val, iteration)
            log_metrics[log_name] = metric_val
            all_metrics[log_name][iteration] = metric_val

        return log_metrics

    if lr_scheduler is not None:

        @trainer.on(Events.EPOCH_COMPLETED)
        def step(_):
            lr_scheduler.step()
            # logger.warning('current lr {:.5e}'.format(
            #     optimizer.param_groups[0]['lr']))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_event(trainer):
        iteration = trainer.state.iteration if trainer.state else 0
        nonlocal last_iteration, patience_counter, best

        if not log_steps or not \
                (iteration in log_steps or iteration % log_steps[-1] == 0):
            return
        epoch_time.pause()
        eval_time.resume()
        all_metrics['training_epoch'][iteration] = iteration / epoch_steps
        all_metrics['training_iteration'][iteration] = iteration
        if hasattr(model, 'arch_sampler'):
            all_metrics['training_archs'][iteration] = \
                model.arch_sampler().squeeze().detach()
        # if hasattr(model, 'distrib_gen'):
        #     entropy = model.distrib_gen.entropy()
        #     all_metrics['entropy'][iteration] = entropy.mean().item()
        # if trainer.state and len(trainer.state.metrics) > 1:
        #     raise ValueError(trainer.state.metrics)
        all_metrics['data time'][iteration] = data_time.value()
        all_metrics['data time_ps'][iteration] = data_time.value() / max(
            data_time.step_count, 1.)
        all_metrics['forward time'][iteration] = forward_time.value()
        all_metrics['forward time_ps'][iteration] = forward_time.value() / max(
            forward_time.step_count, 1.)
        all_metrics['epoch time'][iteration] = epoch_time.value()
        all_metrics['epoch time_ps'][iteration] = epoch_time.value() / max(
            epoch_time.step_count, 1.)

        if trainer.state:
            # logger.warning(trainer.state.metrics)
            for metric, value in trainer.state.metrics.items():
                all_metrics[metric][iteration] = value
                if viz:
                    viz.line(
                        [value],
                        X=[iteration],
                        win=metric.split()[-1],
                        name=metric,
                        update=None if iteration == 0 else 'append',
                        opts={
                            'title': metric,
                            'showlegend': True,
                            'width': 500,
                            'xlabel': 'iterations'
                        })

        iter_this_step = iteration - last_iteration
        for d_loader, name in zip(eval_loaders, split_names):
            if name == 'Train':
                if iteration == 0:
                    all_metrics['Trainer loss'][iteration] = float('nan')
                    all_metrics['Trainer accuracy_0'][iteration] = float('nan')
                    if hasattr(model, 'arch_sampler'):
                        all_metrics['Trainer all entropy'][iteration] = float(
                            'nan')
                        all_metrics['Trainer sampling entropy'][
                            iteration] = float('nan')
                        # if hasattr(model, 'cur_split'):
                        all_metrics['Trainer split'][iteration] = float('nan')
                continue
            split_metrics = log_results(evaluator, d_loader, iteration, name)
            if select_metric not in split_metrics:
                continue
            if is_better(split_metrics[select_metric], best['value']):
                best['value'] = split_metrics[select_metric]
                best['iter'] = iteration
                best['state_dict'] = copy.deepcopy(model.state_dict())
                if patience > 0:
                    patience_counter = 0
            elif patience > 0:
                patience_counter += iter_this_step
                if patience_counter >= patience:
                    logger.info('#####')
                    logger.info('# Early stopping Run')
                    logger.info('#####')
                    trainer.terminate()
        last_iteration = iteration
        eval_time.pause()
        eval_time.step()
        all_metrics['eval time'][iteration] = eval_time.value()
        all_metrics['eval time_ps'][iteration] = eval_time.value(
        ) / eval_time.step_count
        all_metrics['total time'][iteration] = total_time.value()
        epoch_time.resume()

    log_event(trainer)

    #
    # @trainer.on(Events.EPOCH_COMPLETED)
    # def log_epoch(trainer):
    #     iteration = trainer.state.iteration if trainer.state else 0
    #     epoch = iteration/epoch_steps
    #     fw_t = forward_time.value()
    #     fw_t_ps = fw_t / forward_time.step_count
    #     d_t = data_time.value()
    #     d_t_ps = d_t / data_time.step_count
    #     e_t = epoch_time.value()
    #     e_t_ps = e_t / epoch_time.step_count
    #     ev_t = eval_time.value()
    #     ev_t_ps = ev_t / eval_time.step_count
    #     logger.warning('<{}> Epoch {}/{} finished (Forward: {:.3f}s({:.3f}), '
    #                    'data: {:.3f}s({:.3f}), epoch: {:.3f}s({:.3f}),'
    #                    ' Eval: {:.3f}s({:.3f}), Total: '
    #                    '{:.3f}s)'.format(type(model).__name__, epoch,
    #                                      max_epoch, fw_t, fw_t_ps, d_t, d_t_ps,
    #                                      e_t, e_t_ps, ev_t, ev_t_ps,
    #                                      total_time.value()))

    data_time.attach(trainer,
                     start=Events.STARTED,
                     pause=Events.ITERATION_STARTED,
                     resume=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_STARTED)

    if hasattr(model, 'iter_per_epoch'):
        model.iter_per_epoch = len(train_loader)
    trainer.run(train_loader, max_epochs=max_epoch)
    return trainer.state.iteration, all_metrics, best
    def run_training(self):

        ########## init wandb ###########
        print(GREEN + "*************** START TRAINING *******************")

        # fixme check if this works
        for (key, val) in self.config.items():
            print(GREEN + f"{key}: {val}")  # print to console
            wandb.config.update({key: val})  # update wandb config

        print(GREEN + "**************************************************" +
              ENDC)

        ########## checkpoints ##########
        if self.config["general"]["restart"]:
            mod_ckpt, op_ckpt = self._load_ckpt("reg_ckpt")
            # flow_ckpt, flow_op_ckpt = self._load_ckpt("flow_ckpt")

        else:
            mod_ckpt = op_ckpt = None

        dataset, image_transforms = get_dataset(self.config["data"])
        transforms = tt.Compose([tt.ToTensor()])
        train_dataset = dataset(transforms,
                                data_keys=self.data_keys,
                                mode="train",
                                label_transfer=True,
                                debug=self.config["general"]["debug"],
                                crop_app=True,
                                **self.config["data"])

        # if seq_length is pruned, use min seq_length, such that the seq_length of test_dataset lower or equal than that of the train dataset
        # collect_len = train_dataset.seq_length
        # self.collect_recon_loss_seq = {
        #     k: np.zeros(shape=[k])
        #     for k in range(collect_len[0], collect_len[-1])
        # }
        # self.collect_count_seq_lens = np.zeros(shape=[collect_len[-1]])
        # # adapt sequence_length
        # self.config["data"]["seq_length"] = (
        #     min(self.config["data"]["seq_length"][0], train_dataset.seq_length[0]),
        #     min(self.config["data"]["seq_length"][1], train_dataset.seq_length[1]),
        # )

        train_sampler = RandomSampler(data_source=train_dataset)

        seq_sampler_train = SequenceSampler(
            train_dataset,
            sampler=train_sampler,
            batch_size=self.config["training"]["batch_size"],
            drop_last=True,
        )

        train_loader = DataLoader(
            train_dataset,
            num_workers=0 if self.config["general"]["debug"] else
            self.config["data"]["n_data_workers"],
            batch_sampler=seq_sampler_train,
        )

        # test data
        t_datakeys = [key for key in self.data_keys] + [
            "action",
            "sample_ids",
            "intrinsics",
            "intrinsics_paired",
            "extrinsics",
            "extrinsics_paired",
        ]
        test_dataset = dataset(image_transforms,
                               data_keys=t_datakeys,
                               mode="test",
                               debug=self.config["general"]["debug"],
                               label_transfer=True,
                               **self.config["data"])
        assert (test_dataset.action_id_to_action is not None)
        rand_sampler_test = RandomSampler(data_source=test_dataset)
        seq_sampler_test = SequenceSampler(
            test_dataset,
            rand_sampler_test,
            batch_size=self.config["training"]["batch_size"],
            drop_last=True,
        )
        test_loader = DataLoader(
            test_dataset,
            num_workers=0 if self.config["general"]["debug"] else
            self.config["data"]["n_data_workers"],
            batch_sampler=seq_sampler_test,
        )
        #
        rand_sampler_transfer = RandomSampler(data_source=test_dataset)
        seq_sampler_transfer = SequenceSampler(
            test_dataset,
            rand_sampler_transfer,
            batch_size=1,
            drop_last=True,
        )
        transfer_loader = DataLoader(
            test_dataset,
            batch_sampler=seq_sampler_transfer,
            num_workers=0 if self.config["general"]["debug"] else
            self.config["data"]["n_data_workers"],
        )
        #
        # compare_dataset = dataset(
        #     transforms,
        #     data_keys=t_datakeys,
        #     mode="train",
        #     label_transfer=True,
        #     debug=self.config["general"]["debug"],
        #     crop_app=True,
        #     **self.config["data"]
        # )

        ## Classifier action
        # n_actions = len(train_dataset.action_id_to_action)
        # classifier_action = Classifier_action(len(train_dataset.dim_to_use), n_actions, dropout=0, dim=512).to(self.device)
        # optimizer_classifier = Adam(classifier_action.parameters(), lr=0.0001, weight_decay=1e-4)
        # print("Number of parameters in classifier action", sum(p.numel() for p in classifier_action.parameters()))
        #
        n_actions = len(train_dataset.action_id_to_action)
        # # classifier_action2 = Classifier_action(len(train_dataset.dim_to_use), n_actions, dropout=0, dim=512).to(self.device)
        # classifier_action2 = Sequence_disc_michael([2, 1, 1, 1], len(train_dataset.dim_to_use), out_dim=n_actions).to(self.device)
        # optimizer_classifier2 = Adam(classifier_action2.parameters(), lr=0.0001, weight_decay=1e-5)
        # print("Number of parameters in classifier action", sum(p.numel() for p in classifier_action2.parameters()))

        # Classifier beta
        classifier_beta = Classifier_action_beta(512,
                                                 n_actions).to(self.device)
        optimizer_classifier_beta = Adam(classifier_beta.parameters(),
                                         lr=0.001)
        print("Number of parameters in classifier on beta",
              sum(p.numel() for p in classifier_beta.parameters()))
        # # Regressor
        # regressor = Regressor_fly(self.config["architecture"]["dim_hidden_b"], len(train_dataset.dim_to_use)).to(self.device)
        # optimizer_regressor = Adam(regressor.parameters(), lr=0.0001)
        # print("Number of parameters in regressor", sum(p.numel() for p in regressor.parameters()))

        ########## load network and optimizer ##########
        net = MTVAE(self.config["architecture"], len(train_dataset.dim_to_use),
                    self.device)

        print(
            "Number of parameters in VAE model",
            sum(p.numel() for p in net.parameters()),
        )
        if self.config["general"]["restart"]:
            if mod_ckpt is not None:
                print(BLUE + f"***** Initializing VAE from checkpoint! *****" +
                      ENDC)
                net.load_state_dict(mod_ckpt)
        net.to(self.device)

        optimizer = Adam(net.parameters(),
                         lr=self.config["training"]["lr_init"],
                         weight_decay=self.config["training"]["weight_decay"])
        wandb.watch(net, log="all", log_freq=len(train_loader))
        if self.config["general"]["restart"]:
            if op_ckpt is not None:
                optimizer.load_state_dict(op_ckpt)
        # scheduler = torch.optim.lr_scheduler.MultiStepLR(
        #     optimizer, milestones=self.config["training"]["tau"], gamma=self.config["training"]["gamma"]
        # )
        # rec_loss = nn.MSELoss(reduction="none")

        ############## DISCRIMINATOR ##############################
        n_kps = len(train_dataset.dim_to_use)

        # make gan loss weights
        print(
            f"len of train_dataset: {len(train_dataset)}, len of train_loader: {len(train_loader)}"
        )

        # 10 epochs of fine tuning
        total_steps = (self.config["training"]["n_epochs"] -
                       10) * len(train_loader)

        get_kl_weight = partial(
            linear_var,
            start_it=0,
            end_it=total_steps,
            start_val=1e-5,
            end_val=1,
            clip_min=0,
            clip_max=1,
        )

        def train_fn(engine, batch):
            net.train()
            # reference keypoints with label #1
            kps = batch["keypoints"].to(torch.float).to(self.device)

            # keypoints for cross label transfer, label #2
            kps_cross = batch["paired_keypoints"].to(torch.float).to(
                self.device)

            p_id = batch["paired_sample_ids"].to(torch.int)
            # reconstruct second sequence with inferred b

            labels = batch['action'][:, 0] - 2

            out_seq, mu, logstd, out_cycle = net(kps, kps_cross)

            ps = torch.randn_like(out_cycle, requires_grad=False)

            cycle_loss = torch.mean(torch.abs(out_cycle - ps))
            kps_loss = torch.mean(torch.abs(out_seq - kps[:, net.div:]))
            l_kl = kl_loss(mu, logstd)

            k_vel = self.config["training"]["k_vel"]
            vel_tgt = kps[:, net.div:net.div +
                          k_vel] - kps[:, net.div - 1:net.div + k_vel - 1]

            vel_pred = out_seq[:, :k_vel] - torch.cat(
                [kps[:, net.div - 1].unsqueeze(1), out_seq[:, :k_vel - 1]],
                dim=1)
            motion_loss = torch.mean(torch.abs(vel_tgt - vel_pred))

            kl_weight = get_kl_weight(engine.state.iteration)

            loss = kps_loss + kl_weight * l_kl + self.config["training"]["weight_motion"] * motion_loss \
                   + self.config["training"]["weight_cycle"] * cycle_loss

            #
            #
            if engine.state.epoch < self.config["training"]["n_epochs"] - 10:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            out_dict = {
                "loss": loss.detach().item(),
                "motion_loss": motion_loss.detach().item(),
                "rec_loss": kps_loss.detach().item(),
                "cycle_loss": cycle_loss.detach().item(),
                "kl_loss": l_kl.detach().item(),
                "kl_weight": kl_weight
            }
            #
            #
            # ## Train classifier on action
            # predict = classifier_action(seq_b)[0]
            # loss_classifier_action = nn.CrossEntropyLoss()(predict, labels.to(self.device))
            # optimizer_classifier.zero_grad()
            # loss_classifier_action.backward()
            # optimizer_classifier.step()
            # _, labels_pred = torch.max(nn.Sigmoid()(predict), dim=1)
            # acc_action = torch.sum(labels_pred.cpu() == labels).float() / labels_pred.shape[0]
            #
            # predict = classifier_action2((seq_b[:, 1:] - seq_b[:, :-1]).transpose(1, 2))[0]
            # loss_classifier_action2 = nn.CrossEntropyLoss()(predict, labels.to(self.device))
            # optimizer_classifier2.zero_grad()
            # loss_classifier_action2.backward()
            # optimizer_classifier2.step()
            # _, labels_pred = torch.max(nn.Sigmoid()(predict), dim=1)
            # acc_action2 = torch.sum(labels_pred.cpu() == labels).float() / labels_pred.shape[0]
            #
            # ## Train classifier on beta
            # if engine.state.epoch >= self.config["training"]["n_epochs"] - 10:
            net.eval()
            with torch.no_grad():
                _, mu, *_ = net(kps, kps_cross)
            predict = classifier_beta(mu)
            loss_classifier_action_beta = nn.CrossEntropyLoss()(
                predict, labels.to(self.device))
            optimizer_classifier_beta.zero_grad()
            loss_classifier_action_beta.backward()
            optimizer_classifier_beta.step()
            _, labels_pred = torch.max(nn.Sigmoid()(predict), dim=1)
            acc_action_beta = torch.sum(
                labels_pred.cpu() == labels).float() / labels_pred.shape[0]
            #
            # out_dict = {}
            # # this is only run if flow training is enable
            #
            # # add info to out_dict
            # out_dict['loss_classifier_action'] = loss_classifier_action.detach().item()
            # out_dict['acc_classifier_action'] = acc_action.item()
            # out_dict['loss_classifier_action2'] = loss_classifier_action2.detach().item()
            # out_dict['acc_classifier_action2'] = acc_action2.item()
            #
            # # if engine.state.epoch >= self.config["training"]["n_epochs"] - 10:
            out_dict[
                'loss_classifier_action_beta'] = loss_classifier_action_beta.detach(
                ).item()
            out_dict['acc_action_beta'] = acc_action_beta.item()
            # out_dict["loss"] = loss.detach().item()
            # out_dict["kl_loss"] = kl_loss_avg.detach().item()
            #
            # out_dict["mu_s"] = torch.mean(mu_s).item()
            # out_dict["logstd_s"] = torch.mean(logstd_s).item()
            # # if self.config["training"]["use_regressor"]:
            # #     out_dict["loss_regressor"] = torch.mean(loss_regressor).item()
            # out_dict["loss_recon"] = recon_loss.detach().item()
            # out_dict["loss_per_seq_recon"] = (
            #     recon_loss_per_seq.detach().cpu().numpy()
            # )
            # out_dict["seq_len"] = seq_len
            #
            return out_dict

        ##### CREATE TRAINING RUN #####
        trainer = Engine(train_fn)
        pbar = ProgressBar()
        pbar.attach(
            trainer,
            output_transform=lambda x:
            {key: x[key]
             for key in x if "per_seq" not in key},
        )

        # compute averages for all outputs of train function which are specified in the list
        # fixme this can be used to log as soon as losses for mtvae are defined and named
        loss_avg = Average(output_transform=lambda x: x["loss"])
        loss_avg.attach(trainer, "loss")
        recon_loss_avg = Average(output_transform=lambda x: x["rec_loss"])
        recon_loss_avg.attach(trainer, "rec_loss")
        kl_loss_avg = Average(output_transform=lambda x: x["kl_loss"])
        kl_loss_avg.attach(trainer, "kl_loss")
        kl_loss_avg = Average(output_transform=lambda x: x["motion_loss"])
        kl_loss_avg.attach(trainer, "motion_loss")
        kl_loss_avg = Average(output_transform=lambda x: x["cycle_loss"])
        kl_loss_avg.attach(trainer, "cycle_loss")
        # mu_s_avg = Average(output_transform=lambda x: x["mu_s"])
        # mu_s_avg.attach(trainer, "mu_s")
        # logstd_s_avg = Average(output_transform=lambda x: x["logstd_s"])
        # logstd_s_avg.attach(trainer, "logstd_s")
        #
        # loss_classifier = Average(output_transform=lambda x: x["loss_classifier_action"] if "loss_classifier_action" in x else 0)
        # loss_classifier.attach(trainer, "loss_classifier_action")
        # acc_classifier = Average(output_transform=lambda x: x["acc_classifier_action"] if "acc_classifier_action" in x else 0)
        # acc_classifier.attach(trainer, "acc_classifier_action")
        #
        # loss_classifier_action2 = Average(output_transform=lambda x: x["loss_classifier_action2"] if "loss_classifier_action2" in x else 0)
        # loss_classifier_action2.attach(trainer, "loss_classifier_action2")
        # acc_classifier_action2 = Average(output_transform=lambda x: x["acc_classifier_action2"] if "acc_classifier_action2" in x else 0)
        # acc_classifier_action2.attach(trainer, "acc_classifier_action2")
        #
        loss_classifier_action_beta = Average(
            output_transform=lambda x: x["loss_classifier_action_beta"]
            if "loss_classifier_action_beta" in x else 0)
        loss_classifier_action_beta.attach(trainer,
                                           "loss_classifier_action_beta")
        acc_action_beta = Average(output_transform=lambda x: x[
            "acc_action_beta"] if "acc_action_beta" in x else 0)
        acc_action_beta.attach(trainer, "acc_action_beta")

        # loss_avg = Average(output_transform=lambda x: x["loss"])
        # loss_avg.attach(trainer, "loss")

        ##### TRAINING HOOKS ######
        # @trainer.on(Events.ITERATION_COMPLETED)
        # def collect_training_info(engine):
        #     it = engine.state.iteration
        #
        #     self.collect_recon_loss_seq[seq_len] += engine.state.output[
        #         "loss_per_seq_recon"
        #     ]
        #     self.collect_count_seq_lens[seq_len] += self.config["training"]["batch_size"]

        # @trainer.on(Events.EPOCH_COMPLETED)
        # def update_optimizer_params(engine):
        #     scheduler.step()

        def log_wandb(engine):

            wandb.log({
                "epoch": engine.state.epoch,
                "iteration": engine.state.iteration,
            })

            print(
                f"Logging metrics: Currently, the following metrics are tracked: {list(engine.state.metrics.keys())}"
            )
            for key in engine.state.metrics:
                val = engine.state.metrics[key]
                wandb.log({key + "-epoch-avg": val})
                print(ENDC + f" [metrics] {key}:{val}")

            # reset
            # self.collect_recon_loss_seq = {
            #     k: np.zeros(shape=[k])
            #     for k in range(collect_len[0], collect_len[-1])
            # }
            # self.collect_count_seq_lens = np.zeros(shape=[collect_len[-1]])

            loss_avg = engine.state.metrics["loss"]

            print(GREEN + f"Epoch {engine.state.epoch} summary:")
            print(ENDC + f" [losses] loss overall:{loss_avg}")

        def eval_model(engine):
            eval_nets(net,
                      test_loader,
                      self.device,
                      engine.state.epoch,
                      cf_action_beta=classifier_beta,
                      debug=self.config["general"]["debug"])

        #
        #
        def transfer_behavior_test(engine):
            visualize_transfer3d(
                net,
                transfer_loader,
                self.device,
                name="Test-Set: ",
                dirs=self.dirs,
                revert_coord_space=False,
                epoch=engine.state.epoch,
                n_vid_to_generate=self.config["logging"]["n_vid_to_generate"])

        # # compare predictions on train and test set
        # def eval_grid(engine):
        #     if self.config["data"]["dataset"] != "HumanEva":
        #         make_eval_grid(
        #             net,
        #             transfer_loader,
        #             self.device,
        #             dirs=self.dirs,
        #             revert_coord_space=False,
        #             epoch=engine.state.epoch,
        #             synth_ckpt=self.synth_ckpt,
        #             synth_params=self.synth_params,
        #         )

        # def latent_interpolations(engine):
        #     latent_interpolate(
        #         net,
        #         transfer_loader,
        #         self.device,
        #         dirs=self.dirs,
        #         epoch=engine.state.epoch,
        #         synth_params=self.synth_params,
        #         synth_ckpt=self.synth_ckpt,
        #         n_vid_to_generate=self.config["logging"]["n_vid_to_generate"]
        #     )

        ckpt_handler_reg = ModelCheckpoint(self.dirs["ckpt"],
                                           "reg_ckpt",
                                           n_saved=100,
                                           require_empty=False)
        save_dict = {"model": net, "optimizer": optimizer}
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  ckpt_handler_reg, save_dict)

        trainer.add_event_handler(Events.EPOCH_COMPLETED, log_wandb)

        def log_outputs(engine):
            for key in engine.state.output:
                val = engine.state.output[key]
                wandb.log({key + "-epoch-step": val})

        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(
                every=10 if self.config["general"]["debug"] else 1000),
            log_outputs)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, eval_model)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED(
                every=1 if self.config["general"]["debug"] else 3),
            transfer_behavior_test,
        )
        # trainer.add_event_handler(
        #     Events.EPOCH_COMPLETED(
        #         every=10
        #     ),
        #     latent_interpolations,
        # )
        # trainer.add_event_handler(
        #     Events.EPOCH_COMPLETED(
        #         every=3
        #     ),
        #     eval_grid,
        # )

        ####### RUN TRAINING ##############
        print(BLUE + "*************** Train VAE *******************" + ENDC)
        trainer.run(
            train_loader,
            max_epochs=self.config["training"]["n_epochs"],
            epoch_length=10
            if self.config["general"]["debug"] else len(train_loader),
        )
        print(BLUE + "*************** VAE training ends *******************" +
              ENDC)
Example #22
0
def main(
    batch_size,
    epochs,
    length_scale,
    centroid_size,
    model_output_size,
    learning_rate,
    l_gradient_penalty,
    gamma,
    weight_decay,
    final_model,
):
    name = f"DUQ_{length_scale}__{l_gradient_penalty}_{gamma}_{centroid_size}"
    writer = SummaryWriter(comment=name)

    ds = all_datasets["CIFAR10"]()
    input_size, num_classes, dataset, test_dataset = ds

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        val_size = int(len(dataset) * 0.8)
        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

        val_dataset.transform = (test_dataset.transform
                                 )  # Test time preprocessing for validation

    model = ResNet_DUQ(
        input_size,
        num_classes,
        centroid_size,
        model_output_size,
        length_scale,
        gamma,
    )
    model = model.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[25, 50, 75],
                                                     gamma=0.2)

    def bce_loss_fn(y_pred, y):
        bce = F.binary_cross_entropy(y_pred, y, reduction="sum").div(
            num_classes * y_pred.shape[0])
        return bce

    def output_transform_bce(output):
        y_pred, y, x = output

        y = F.one_hot(y, num_classes).float()

        return y_pred, y

    def output_transform_acc(output):
        y_pred, y, x = output

        return y_pred, y

    def output_transform_gp(output):
        y_pred, y, x = output

        return x, y_pred

    def calc_gradients_input(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        return gradients

    def calc_gradient_penalty(x, y_pred):
        gradients = calc_gradients_input(x, y_pred)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1)**2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        if l_gradient_penalty > 0:
            x.requires_grad_(True)

        z, y_pred = model(x)
        y = F.one_hot(y, num_classes).float()

        loss = bce_loss_fn(y_pred, y)

        if l_gradient_penalty > 0:
            loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred)

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        z, y_pred = model(x)

        return y_pred, y, x

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "loss")

    metric = Accuracy(output_transform=output_transform_acc)
    metric.attach(evaluator, "accuracy")

    metric = Loss(F.binary_cross_entropy,
                  output_transform=output_transform_bce)
    metric.attach(evaluator, "bce")

    metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp)
    metric.attach(evaluator, "gradient_penalty")

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1000,
                                             shuffle=False,
                                             **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1000,
                                              shuffle=False,
                                              **kwargs)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        loss = metrics["loss"]

        print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f} ")

        writer.add_scalar("Loss/train", loss, trainer.state.epoch)

        if trainer.state.epoch % 5 == 0 or trainer.state.epoch > 65:
            accuracy, auroc = get_cifar_svhn_ood(model)
            print(f"Test Accuracy: {accuracy}, AUROC: {auroc}")
            writer.add_scalar("OoD/test_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch)

            accuracy, auroc = get_auroc_classification(val_dataset, model)
            print(f"AUROC - uncertainty: {auroc}")
            writer.add_scalar("OoD/val_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc_classification", auroc,
                              trainer.state.epoch)

        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        bce = metrics["bce"]
        GP = metrics["gradient_penalty"]
        loss = bce + l_gradient_penalty * GP

        print((f"Valid - Epoch: {trainer.state.epoch} "
               f"Acc: {acc:.4f} "
               f"Loss: {loss:.2f} "
               f"BCE: {bce:.2f} "
               f"GP: {GP:.2f} "))

        writer.add_scalar("Loss/valid", loss, trainer.state.epoch)
        writer.add_scalar("BCE/valid", bce, trainer.state.epoch)
        writer.add_scalar("GP/valid", GP, trainer.state.epoch)
        writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch)

        print(f"Centroid norm: {torch.norm(model.m / model.N, dim=0)}")

        scheduler.step()

        if trainer.state.epoch > 65:
            torch.save(model.state_dict(),
                       f"saved_models/{name}_{trainer.state.epoch}.pt")

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    trainer.run(train_loader, max_epochs=epochs)

    evaluator.run(test_loader)
    acc = evaluator.state.metrics["accuracy"]

    print(f"Test - Accuracy {acc:.4f}")

    writer.close()
Example #23
0
        return {
            'preds': preds,
            'targets': targets,
            'cross_entropy': loss.item()
        }

    train_metrics = {
        'Loss':
        RunningAverage(output_transform=lambda x: x['cross_entropy']),
        'Accuracy':
        RunningAverage(
            Accuracy(output_transform=lambda x: (x['preds'], x['targets'])))
    }

    eval_metrics = {
        'Loss': Average(output_transform=lambda x: x['cross_entropy']),
        'Accuracy':
        Accuracy(output_transform=lambda x: (x['preds'], x['targets']))
    }

    train(args.run_name, model, train_set, test_set, train_step, eval_step,
          train_metrics, eval_metrics, args.n_iterations, args.batch_size)

    predictions = []
    truths = []
    test_loader = DataLoader(test_set,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=4)
    model.eval()
    for images, targets in test_loader:
Example #24
0
def main(
    architecture,
    batch_size,
    length_scale,
    centroid_size,
    learning_rate,
    l_gradient_penalty,
    gamma,
    weight_decay,
    final_model,
    output_dir,
):
    writer = SummaryWriter(log_dir=f"runs/{output_dir}")

    ds = all_datasets["CIFAR10"]()
    input_size, num_classes, dataset, test_dataset = ds

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        val_size = int(len(dataset) * 0.8)
        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

        val_dataset.transform = (test_dataset.transform
                                 )  # Test time preprocessing for validation

    if architecture == "WRN":
        model_output_size = 640
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = WideResNet()
    elif architecture == "ResNet18":
        model_output_size = 512
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet18()
    elif architecture == "ResNet50":
        model_output_size = 2048
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet50()
    elif architecture == "ResNet110":
        model_output_size = 2048
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet110()
    elif architecture == "DenseNet121":
        model_output_size = 1024
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = densenet121()

        # Adapted resnet from:
        # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
        feature_extractor.conv1 = torch.nn.Conv2d(3,
                                                  64,
                                                  kernel_size=3,
                                                  stride=1,
                                                  padding=1,
                                                  bias=False)
        feature_extractor.maxpool = torch.nn.Identity()
        feature_extractor.fc = torch.nn.Identity()

    if centroid_size is None:
        centroid_size = model_output_size

    model = ResNet_DUQ(
        feature_extractor,
        num_classes,
        centroid_size,
        model_output_size,
        length_scale,
        gamma,
    )
    model = model.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=milestones,
                                                     gamma=0.2)

    def calc_gradients_input(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        return gradients

    def calc_gradient_penalty(x, y_pred):
        gradients = calc_gradients_input(x, y_pred)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1)**2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        y = F.one_hot(y, num_classes).float()

        loss = F.binary_cross_entropy(y_pred, y, reduction="mean")

        if l_gradient_penalty > 0:
            gp = calc_gradient_penalty(x, y_pred)
            loss += l_gradient_penalty * gp

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        return {"x": x, "y": y, "y_pred": y_pred}

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "loss")

    metric = Accuracy(output_transform=lambda out: (out["y_pred"], out["y"]))
    metric.attach(evaluator, "accuracy")

    def bce_output_transform(out):
        return (out["y_pred"], F.one_hot(out["y"], num_classes).float())

    metric = Loss(F.binary_cross_entropy,
                  output_transform=bce_output_transform)
    metric.attach(evaluator, "bce")

    metric = Loss(calc_gradient_penalty,
                  output_transform=lambda out: (out["x"], out["y_pred"]))
    metric.attach(evaluator, "gradient_penalty")

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              **kwargs)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        loss = metrics["loss"]

        print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f}")

        writer.add_scalar("Loss/train", loss, trainer.state.epoch)

        if trainer.state.epoch > (epochs - 5):
            accuracy, auroc = get_cifar_svhn_ood(model)
            print(f"Test Accuracy: {accuracy}, AUROC: {auroc}")
            writer.add_scalar("OoD/test_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch)

            accuracy, auroc = get_auroc_classification(val_dataset, model)
            print(f"AUROC - uncertainty: {auroc}")
            writer.add_scalar("OoD/val_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc_classification", auroc,
                              trainer.state.epoch)

        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        bce = metrics["bce"]
        GP = metrics["gradient_penalty"]
        loss = bce + l_gradient_penalty * GP

        print((f"Valid - Epoch: {trainer.state.epoch} "
               f"Acc: {acc:.4f} "
               f"Loss: {loss:.2f} "
               f"BCE: {bce:.2f} "
               f"GP: {GP:.2f} "))

        writer.add_scalar("Loss/valid", loss, trainer.state.epoch)
        writer.add_scalar("BCE/valid", bce, trainer.state.epoch)
        writer.add_scalar("GP/valid", GP, trainer.state.epoch)
        writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch)

        scheduler.step()

    trainer.run(train_loader, max_epochs=epochs)
    evaluator.run(test_loader)
    acc = evaluator.state.metrics["accuracy"]

    print(f"Test - Accuracy {acc:.4f}")

    torch.save(model.state_dict(), f"runs/{output_dir}/model.pt")
    writer.close()
Example #25
0
    def train(self, config, **kwargs):
        """Trains a model on the given configurations.
        :param config:str: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE
        :param **kwargs: parameters to overwrite yaml config
        """

        from pycocoevalcap.cider.cider import Cider

        config_parameters = train_util.parse_config_or_kwargs(config, **kwargs)
        config_parameters["seed"] = self.seed
        zh = config_parameters["zh"]
        outputdir = os.path.join(
            config_parameters["outputpath"], config_parameters["model"],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))

        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            "run",
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: -engine.state.metrics["loss"],
            score_name="loss")

        logger = train_util.genlogger(os.path.join(outputdir, "train.log"))
        # print passed config parameters
        logger.info("Storing files in: {}".format(outputdir))
        train_util.pprint_dict(config_parameters, logger.info)

        vocabulary = torch.load(config_parameters["vocab_file"])
        trainloader, cvloader, info = self._get_dataloaders(config_parameters, vocabulary)
        config_parameters["inputdim"] = info["inputdim"]
        logger.info("<== Estimating Scaler ({}) ==>".format(info["scaler"].__class__.__name__))
        logger.info(
                "Stream: {} Input dimension: {} Vocab Size: {}".format(
                config_parameters["feature_stream"], info["inputdim"], len(vocabulary)))
        train_key2refs = info["train_key2refs"]
        # train_scorer = BatchCider(train_key2refs)
        cv_key2refs = info["cv_key2refs"]
        # cv_scorer = BatchCider(cv_key2refs)

        model = self._get_model(config_parameters, vocabulary)
        model = model.to(device)
        train_util.pprint_dict(model, logger.info, formatter="pretty")
        optimizer = getattr(
            torch.optim, config_parameters["optimizer"]
        )(model.parameters(), **config_parameters["optimizer_args"])
        train_util.pprint_dict(optimizer, logger.info, formatter="pretty")

        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            # optimizer, **config_parameters["scheduler_args"])
        crtrn_imprvd = train_util.criterion_improver(config_parameters["improvecriterion"])

        def _train_batch(engine, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                train_scorer = Cider(zh=zh)
                output = self._forward(model, batch, "train", train_mode="scst", 
                                       key2refs=train_key2refs, scorer=train_scorer)
                output["loss"].backward()
                optimizer.step()
                return output

        trainer = Engine(_train_batch)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "running_loss")
        pbar = ProgressBar(persist=False, ascii=True)
        pbar.attach(trainer, ["running_loss"])

        key2pred = {}

        def _inference(engine, batch):
            model.eval()
            keys = batch[2]
            with torch.no_grad():
                cv_scorer = Cider(zh=zh)
                output = self._forward(model, batch, "train", train_mode="scst",
                                       key2refs=cv_key2refs, scorer=cv_scorer)
                seqs = output["sampled_seqs"].cpu().numpy()
                for idx, seq in enumerate(seqs):
                    if keys[idx] in key2pred:
                        continue
                    candidate = self._convert_idx2sentence(seq, vocabulary, zh=zh)
                    key2pred[keys[idx]] = [candidate,]
                return output

        evaluator = Engine(_inference)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "running_loss")

        metrics = {
            "loss": Average(output_transform=lambda x: x["loss"]),
            "reward": Average(output_transform=lambda x: x["reward"].reshape(-1, 1)),
        }

        for name, metric in metrics.items():
            metric.attach(trainer, name)
            metric.attach(evaluator, name)

        RunningAverage(output_transform=lambda x: x["loss"]).attach(evaluator, "running_loss")
        pbar.attach(evaluator, ["running_loss"])

        # @trainer.on(Events.STARTED)
        # def log_initial_result(engine):
            # evaluator.run(cvloader, max_epochs=1)
            # logger.info("Initial Results - loss: {:<5.2f}\tscore: {:<5.2f}".format(evaluator.state.metrics["loss"], evaluator.state.metrics["score"].item()))


        trainer.add_event_handler(
              Events.EPOCH_COMPLETED, train_util.log_results, evaluator, cvloader,
              logger.info, metrics.keys(), ["loss", "reward", "score"])

        def eval_cv(engine, key2pred, key2refs, scorer):
            # if len(cv_key2refs) == 0:
                # for key, _ in key2pred.items():
                    # cv_key2refs[key] = key2refs[key]
            score, scores = scorer.compute_score(key2refs, key2pred)
            engine.state.metrics["score"] = score
            key2pred.clear()

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, eval_cv, key2pred, cv_key2refs, Cider(zh=zh))

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, train_util.save_model_on_improved, crtrn_imprvd,
            "score", {
                "model": model,
                "config": config_parameters,
                "scaler": info["scaler"]
            }, os.path.join(outputdir, "saved.pth"))

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler, {
                "model": model,
            }
        )

        trainer.run(trainloader, max_epochs=config_parameters["epochs"])
        return outputdir
Example #26
0
def main():
    # region Setup
    conf = parse_args()
    setup_seeds(conf.session.seed)
    tb_logger, tb_img_logger, json_logger = setup_all_loggers(conf)
    logger.info("Parsed configuration:\n" +
                pyaml.dump(OmegaConf.to_container(conf),
                           safe=True,
                           sort_dicts=False,
                           force_embed=True))

    # region Predicate classification engines
    datasets, dataset_metadata = build_datasets(conf.dataset,
                                                seed=conf.session.seed)
    dataloaders = build_dataloaders(conf, datasets)

    model = build_model(conf.model,
                        dataset_metadata["train_gt"]).to(conf.session.device)
    criterion = PredicateClassificationCriterion(conf.losses)

    pred_class_trainer = Trainer(pred_class_training_step, conf)
    pred_class_trainer.model = model
    pred_class_trainer.criterion = criterion
    pred_class_trainer.optimizer, scheduler = build_optimizer_and_scheduler(
        conf.optimizer, pred_class_trainer.model)

    pred_class_validator = Validator(pred_class_validation_step, conf)
    pred_class_validator.model = model
    pred_class_validator.criterion = criterion

    pred_class_tester = Validator(pred_class_validation_step, conf)
    pred_class_tester.model = model
    pred_class_tester.criterion = criterion
    # endregion

    # region Visual Relations engines
    vr_model = VisualRelationExplainer(model, dataset_metadata["train_gt"],
                                       **conf.visual_relations)

    vr_predicate_validator = Validator(vr_validation_step, conf)
    vr_predicate_validator.model = vr_model

    vr_predicate_tester = Validator(vr_validation_step, conf)
    vr_predicate_tester.model = vr_model

    vr_phrase_and_relation_validator = Validator(vr_validation_step, conf)
    vr_phrase_and_relation_validator.model = vr_model

    vr_phrase_and_relation_tester = Validator(vr_validation_step, conf)
    vr_phrase_and_relation_tester.model = vr_model
    # endregion

    if "resume" in conf:
        checkpoint = Path(conf.resume.checkpoint).expanduser().resolve()
        logger.debug(f"Resuming checkpoint from {checkpoint}")
        Checkpoint.load_objects(
            {
                "model": pred_class_trainer.model,
                "optimizer": pred_class_trainer.optimizer,
                "scheduler": scheduler,
                "trainer": pred_class_trainer,
            },
            checkpoint=torch.load(checkpoint,
                                  map_location=conf.session.device),
        )
        logger.info(f"Resumed from {checkpoint}, "
                    f"epoch {pred_class_trainer.state.epoch}, "
                    f"samples {pred_class_trainer.global_step()}")
    # endregion

    # region Predicate classification training callbacks
    def increment_samples(trainer: Trainer):
        graphs: Batch = trainer.state.batch[0]
        trainer.state.samples += graphs.num_graphs

    pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                         increment_samples)

    ProgressBar(persist=True, desc="[train] Predicate classification").attach(
        pred_class_trainer, output_transform=itemgetter("losses"))

    tb_logger.attach(
        pred_class_trainer,
        OptimizerParamsHandler(
            pred_class_trainer.optimizer,
            param_name="lr",
            tag="z",
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.EPOCH_STARTED,
    )

    pred_class_trainer.add_event_handler(
        Events.ITERATION_COMPLETED,
        PredicateClassificationMeanAveragePrecisionBatch())
    pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                         RecallAtBatch(sizes=(5, 10)))

    tb_logger.attach(
        pred_class_trainer,
        OutputHandler(
            "train_gt",
            output_transform=lambda o: {
                **o["losses"],
                "pc/mAP": o["pc/mAP"].mean().item(),
                **{k: r.mean().item()
                   for k, r in o["recalls"].items()},
            },
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.ITERATION_COMPLETED,
    )

    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Predicate classification training",
        "train_gt",
        json_logger=None,
        tb_logger=tb_logger,
        global_step_fn=pred_class_trainer.global_step,
    )
    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        PredicateClassificationLogger(
            grid=(2, 3),
            tag="train_gt",
            logger=tb_img_logger.writer,
            metadata=dataset_metadata["train_gt"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    tb_logger.attach(
        pred_class_trainer,
        EpochHandler(
            pred_class_trainer,
            tag="z",
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.EPOCH_COMPLETED,
    )

    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        lambda _: pred_class_validator.run(dataloaders["val_gt"]),
    )
    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED(epoch_filter(5)),
        lambda _: vr_predicate_validator.run(dataloaders["val_gt"]),
    )
    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED(epoch_filter(2)),
        lambda _: vr_phrase_and_relation_validator.run(dataloaders["val_d2"]),
    )
    # endregion

    # region Predicate classification validation callbacks
    ProgressBar(
        persist=True,
        desc="[val] Predicate classification").attach(pred_class_validator)

    if conf.losses["bce"]["weight"] > 0:
        Average(output_transform=lambda o: o["losses"]["loss/bce"]).attach(
            pred_class_validator, "loss/bce")
    if conf.losses["rank"]["weight"] > 0:
        Average(output_transform=lambda o: o["losses"]["loss/rank"]).attach(
            pred_class_validator, "loss/rank")
    Average(output_transform=lambda o: o["losses"]["loss/total"]).attach(
        pred_class_validator, "loss/total")

    PredicateClassificationMeanAveragePrecisionEpoch(
        itemgetter("target", "output")).attach(pred_class_validator, "pc/mAP")
    RecallAtEpoch((5, 10),
                  itemgetter("target",
                             "output")).attach(pred_class_validator,
                                               "pc/recall_at")

    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        lambda val_engine: scheduler.step(val_engine.state.metrics["loss/total"
                                                                   ]),
    )
    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Predicate classification validation",
        "val_gt",
        json_logger,
        tb_logger,
        pred_class_trainer.global_step,
    )
    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        PredicateClassificationLogger(
            grid=(2, 3),
            tag="val_gt",
            logger=tb_img_logger.writer,
            metadata=dataset_metadata["val_gt"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    pred_class_validator.add_event_handler(
        Events.COMPLETED,
        EarlyStopping(
            patience=conf.session.early_stopping.patience,
            score_function=lambda val_engine: -val_engine.state.metrics[
                "loss/total"],
            trainer=pred_class_trainer,
        ),
    )
    pred_class_validator.add_event_handler(
        Events.COMPLETED,
        Checkpoint(
            {
                "model": pred_class_trainer.model,
                "optimizer": pred_class_trainer.optimizer,
                "scheduler": scheduler,
                "trainer": pred_class_trainer,
            },
            DiskSaver(
                Path(conf.checkpoint.folder).expanduser().resolve() /
                conf.fullname),
            score_function=lambda val_engine: val_engine.state.metrics[
                "pc/recall_at_5"],
            score_name="pc_recall_at_5",
            n_saved=conf.checkpoint.keep,
            global_step_transform=pred_class_trainer.global_step,
        ),
    )
    # endregion

    # region Predicate detection validation callbacks
    vr_predicate_validator.add_event_handler(
        Events.ITERATION_COMPLETED,
        VisualRelationRecallAt(type="predicate", steps=(50, 100)),
    )
    for step in (50, 100):
        Average(output_transform=itemgetter(
            f"vr/predicate/recall_at_{step}")).attach(
                vr_predicate_validator, f"vr/predicate/recall_at_{step}")

    if conf.dataset.name in {"unrel", "unrel_vrd"}:
        UnRelDetectionMeanAvgPrecision("GT").attach(vr_predicate_validator,
                                                    "vr/unrel/mAP")

    ProgressBar(
        persist=True,
        desc="[val] Predicate detection").attach(vr_predicate_validator)

    vr_predicate_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Predicate detection validation",
        "val_gt",
        json_logger,
        tb_logger,
        pred_class_trainer.global_step,
    )
    vr_predicate_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        VisualRelationPredictionLogger(
            grid=(2, 3),
            data_root=conf.dataset.folder,
            tag="with GT boxes",
            logger=tb_img_logger.writer,
            top_x_relations=conf.visual_relations.top_x_relations,
            metadata=dataset_metadata["val_gt"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    # endregion

    # region Phrase and relationship detection validation callbacks
    ProgressBar(persist=True,
                desc="[val] Phrase and relation detection").attach(
                    vr_phrase_and_relation_validator)
    vr_phrase_and_relation_validator.add_event_handler(
        Events.ITERATION_COMPLETED,
        VisualRelationRecallAt(type="phrase", steps=(50, 100)),
    )
    vr_phrase_and_relation_validator.add_event_handler(
        Events.ITERATION_COMPLETED,
        VisualRelationRecallAt(type="relationship", steps=(50, 100)),
    )
    for name in ["phrase", "relationship"]:
        for step in (50, 100):
            Average(output_transform=itemgetter(
                f"vr/{name}/recall_at_{step}")).attach(
                    vr_phrase_and_relation_validator,
                    f"vr/{name}/recall_at_{step}")
    if conf.dataset.name == "hico":
        HoiClassificationMeanAvgPrecision().attach(
            vr_phrase_and_relation_validator, "pc/hoi/mAP")
        HoiDetectionMeanAvgPrecision().attach(vr_phrase_and_relation_validator,
                                              "vr/hoi/mAP")
    elif conf.dataset.name in {"unrel", "unrel_vrd"}:
        UnRelDetectionMeanAvgPrecision("D2").attach(
            vr_phrase_and_relation_validator, "vr/unrel/mAP")

    vr_phrase_and_relation_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Phrase and relationship detection validation",
        "val_d2",
        json_logger,
        tb_logger,
        pred_class_trainer.global_step,
    )
    vr_phrase_and_relation_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        VisualRelationPredictionLogger(
            grid=(2, 3),
            data_root=conf.dataset.folder,
            tag="with D2 boxes",
            logger=tb_img_logger.writer,
            top_x_relations=conf.visual_relations.top_x_relations,
            metadata=dataset_metadata["val_d2"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    # endregion

    if "test" in conf.dataset:
        # region Predicate classification testing callbacks
        if conf.losses["bce"]["weight"] > 0:
            Average(
                output_transform=lambda o: o["losses"]["loss/bce"],
                device=conf.session.device,
            ).attach(pred_class_tester, "loss/bce")
        if conf.losses["rank"]["weight"] > 0:
            Average(
                output_transform=lambda o: o["losses"]["loss/rank"],
                device=conf.session.device,
            ).attach(pred_class_tester, "loss/rank")
        Average(
            output_transform=lambda o: o["losses"]["loss/total"],
            device=conf.session.device,
        ).attach(pred_class_tester, "loss/total")

        PredicateClassificationMeanAveragePrecisionEpoch(
            itemgetter("target", "output")).attach(pred_class_tester, "pc/mAP")
        RecallAtEpoch((5, 10),
                      itemgetter("target",
                                 "output")).attach(pred_class_tester,
                                                   "pc/recall_at")

        ProgressBar(
            persist=True,
            desc="[test] Predicate classification").attach(pred_class_tester)

        pred_class_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            log_metrics,
            "Predicate classification test",
            "test_gt",
            json_logger,
            tb_logger,
            pred_class_trainer.global_step,
        )
        pred_class_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            PredicateClassificationLogger(
                grid=(2, 3),
                tag="test_gt",
                logger=tb_img_logger.writer,
                metadata=dataset_metadata["test_gt"],
                global_step_fn=pred_class_trainer.global_step,
            ),
        )
        # endregion

        # region Predicate detection testing callbacks
        vr_predicate_tester.add_event_handler(
            Events.ITERATION_COMPLETED,
            VisualRelationRecallAt(type="predicate", steps=(50, 100)),
        )
        for step in (50, 100):
            Average(output_transform=itemgetter(
                f"vr/predicate/recall_at_{step}")).attach(
                    vr_predicate_tester, f"vr/predicate/recall_at_{step}")

        if ("test" in conf.dataset and conf.dataset.test.name
                == "vrd_relationship_detection_zero_shot"):
            vr_predicate_tester.add_event_handler(
                Events.ITERATION_COMPLETED,
                VrdZeroShotVisualRelationRecallAt(type="predicate",
                                                  steps=(50, 100)),
            )
            for step in (50, 100):
                Average(output_transform=itemgetter(
                    f"vr/predicate/zero_shot/recall_at_{step}")).attach(
                        vr_predicate_tester,
                        f"vr/predicate/zero_shot/recall_at_{step}")

        if conf.dataset.name in {"unrel", "unrel_vrd"}:
            UnRelDetectionMeanAvgPrecision("GT").attach(
                vr_predicate_tester, "vr/unrel/mAP")

        ProgressBar(
            persist=True,
            desc="[test] Predicate detection").attach(vr_predicate_tester)

        vr_predicate_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            log_metrics,
            "Predicate detection test",
            "test_gt",
            json_logger,
            tb_logger,
            pred_class_trainer.global_step,
        )
        vr_predicate_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            VisualRelationPredictionLogger(
                grid=(2, 3),
                data_root=conf.dataset.folder,
                tag="with GT boxes",
                logger=tb_img_logger.writer,
                top_x_relations=conf.visual_relations.top_x_relations,
                metadata=dataset_metadata["test_gt"],
                global_step_fn=pred_class_trainer.global_step,
            ),
        )
        vr_predicate_tester.add_event_handler(
            Events.ITERATION_COMPLETED,
            VisualRelationPredictionExporter(
                mode="GT", dest=Path(conf.checkpoint.folder) / conf.fullname),
        )
        # endregion

        # region Phrase and relationship detection testing callbacks
        ProgressBar(persist=True,
                    desc="[test] Phrase and relation detection").attach(
                        vr_phrase_and_relation_tester)
        vr_phrase_and_relation_tester.add_event_handler(
            Events.ITERATION_COMPLETED,
            VisualRelationRecallAt(type="phrase", steps=(50, 100)),
        )
        vr_phrase_and_relation_tester.add_event_handler(
            Events.ITERATION_COMPLETED,
            VisualRelationRecallAt(type="relationship", steps=(50, 100)),
        )
        for name in ["phrase", "relationship"]:
            for step in (50, 100):
                Average(output_transform=itemgetter(
                    f"vr/{name}/recall_at_{step}")).attach(
                        vr_phrase_and_relation_tester,
                        f"vr/{name}/recall_at_{step}")

        if conf.dataset.name == "hico":
            HoiClassificationMeanAvgPrecision().attach(
                vr_phrase_and_relation_tester, "pc/hoi/mAP")
            HoiDetectionMeanAvgPrecision().attach(
                vr_phrase_and_relation_tester, "vr/hoi/mAP")
        elif conf.dataset.name in {"unrel", "unrel_vrd"}:
            UnRelDetectionMeanAvgPrecision("D2").attach(
                vr_phrase_and_relation_tester, "vr/unrel/mAP")
        if ("test" in conf.dataset and conf.dataset.test.name
                == "vrd_relationship_detection_zero_shot"):
            vr_phrase_and_relation_tester.add_event_handler(
                Events.ITERATION_COMPLETED,
                VrdZeroShotVisualRelationRecallAt(type="phrase",
                                                  steps=(50, 100)),
            )
            vr_phrase_and_relation_tester.add_event_handler(
                Events.ITERATION_COMPLETED,
                VrdZeroShotVisualRelationRecallAt(type="relationship",
                                                  steps=(50, 100)),
            )
            for name in ["phrase", "relationship"]:
                for step in (50, 100):
                    Average(output_transform=itemgetter(
                        f"vr/{name}/zero_shot/recall_at_{step}")).attach(
                            vr_phrase_and_relation_tester,
                            f"vr/{name}/zero_shot/recall_at_{step}",
                        )

        vr_phrase_and_relation_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            log_metrics,
            "Phrase and relationship detection test",
            "test_d2",
            json_logger,
            tb_logger,
            pred_class_trainer.global_step,
        )
        vr_phrase_and_relation_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            VisualRelationPredictionLogger(
                grid=(2, 3),
                data_root=conf.dataset.folder,
                tag="with D2 boxes",
                logger=tb_img_logger.writer,
                top_x_relations=conf.visual_relations.top_x_relations,
                metadata=dataset_metadata["test_d2"],
                global_step_fn=pred_class_trainer.global_step,
            ),
        )
        vr_phrase_and_relation_tester.add_event_handler(
            Events.ITERATION_COMPLETED,
            VisualRelationPredictionExporter(
                mode="D2", dest=Path(conf.checkpoint.folder) / conf.fullname),
        )
        # endregion

    # region Run
    log_effective_config(conf, pred_class_trainer, tb_logger)
    if not ("resume" in conf and conf.resume.test_only):
        if conf.dataset.eager:
            for n, d in datasets.items():
                if n.startswith(("train", "val")):
                    d.load_eager()
        max_epochs = conf.session.max_epochs
        if "resume" in conf:
            max_epochs += pred_class_trainer.state.epoch
        pred_class_trainer.run(
            dataloaders["train_gt"],
            max_epochs=max_epochs,
            seed=conf.session.seed,
            epoch_length=len(dataloaders["train_gt"]),
        )

    if "test" in conf.dataset:
        if conf.dataset.eager:
            datasets["test_gt"].load_eager()
            datasets["test_d2"].load_eager()
        pred_class_tester.run(dataloaders["test_gt"])
        vr_predicate_tester.run(dataloaders["test_gt"])
        vr_phrase_and_relation_tester.run(dataloaders["test_d2"])

    add_session_end(tb_logger.writer, "SUCCESS")
    tb_logger.close()
    tb_img_logger.close()