Beispiel #1
0
def prune_train_loop(model,
                     params,
                     ds,
                     dset,
                     min_y,
                     base_data,
                     model_id,
                     prune_type,
                     device,
                     batch_size,
                     tpa,
                     max_epochs=2):
    assert prune_type in ['global_unstructured', 'structured']
    total_prune_amount = tpa
    ds_train, ds_valid = ds
    train_set, valid_set = dset
    min_y_train, min_y_val = min_y
    model_id = f'{model_id}_{prune_type}_pruning_{tpa}'
    valid_freq = 200 * 500 // batch_size // 3

    conv_layers = [model.conv1]

    def prune_model(model):
        #         remove_amount = total_prune_amount // (max_epochs)
        remove_amount = total_prune_amount
        print(f'pruned model by {remove_amount}')
        worst = select_filters(model, ds_valid, valid_set, remove_amount,
                               device)
        worst = [
            k
            for k in Counter(torch.stack(worst).view(-1).cpu().numpy()).keys()
        ]
        worst.sort(reverse=True)
        print(worst)
        for layer in conv_layers:
            for d in worst:
                TuckerStructured(layer, name='weight', amount=0, dim=0, filt=d)
        return worst

    bad = prune_model(model)
    zeros = []
    wrong = []
    for i in range(len(model.conv1.weight_mask)):
        if torch.sum(model.conv1.weight_mask[i]) == 0.0:
            zeros.append(i)
    zeros.sort(reverse=True)
    if zeros == bad:
        print("correctly zero'd filters")
    else:
        if len(zeros) == len(bad):
            for i in range(len(zeros)):
                if zeros[i] != bad[i]:
                    wrong.append((bad[i], zeros[i]))
            print(wrong)
        else:
            print("diff number filters zero'd", zeros)
    with create_summary_writer(model,
                               ds_train,
                               base_data,
                               model_id,
                               device=device) as writer:
        lr = params['lr']
        mom = params['momentum']
        wd = params['l2_wd']
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    momentum=mom,
                                    weight_decay=wd)
        sched = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
        funcs = {'accuracy': Accuracy(), 'loss': Loss(F.cross_entropy)}
        loss = funcs['loss']._loss_fn

        acc_metric = Accuracy(device=device)
        loss_metric = Loss(F.cross_entropy, device=device)

        acc_val_metric = Accuracy(device=device)
        loss_val_metric = Loss(F.cross_entropy, device=device)

        def train_step(engine, batch):
            model.train()
            x, y = batch
            x = x.to(device)
            y = y.to(device) - min_y_train
            optimizer.zero_grad()
            ans = model.forward(x)
            l = loss(ans, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            with torch.no_grad():
                for layer in conv_layers:
                    layer.weight *= layer.weight_mask  # make sure pruned weights stay 0
            return l.item()

        trainer = Engine(train_step)

        def train_eval_step(engine, batch):
            model.eval()
            x, y = batch
            x = x.to(device)
            y = y.to(device) - min_y_train
            with torch.no_grad():
                ans = model.forward(x)
            return ans, y

        train_evaluator = Engine(train_eval_step)
        acc_metric.attach(train_evaluator, "accuracy")
        loss_metric.attach(train_evaluator, 'loss')

        def validation_step(engine, batch):
            model.eval()
            x, y = batch
            x = x.to(device)
            y = y.to(device) - min_y_val
            with torch.no_grad():
                ans = model.forward(x)
            return ans, y

        valid_evaluator = Engine(validation_step)
        acc_val_metric.attach(valid_evaluator, "accuracy")
        loss_val_metric.attach(valid_evaluator, 'loss')

        @trainer.on(Events.ITERATION_COMPLETED(every=valid_freq))
        #         @trainer.on(Events.ITERATION_COMPLETED)
        def log_validation_results(engine):
            valid_evaluator.run(ds_valid)
            metrics = valid_evaluator.state.metrics
            valid_avg_accuracy = metrics['accuracy']
            avg_nll = metrics['loss']
            print(
                "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
                .format(engine.state.epoch, valid_avg_accuracy, avg_nll))
            writer.add_scalar("validation/avg_loss", avg_nll,
                              engine.state.epoch)
            writer.add_scalar("validation/avg_accuracy", valid_avg_accuracy,
                              engine.state.epoch)
            writer.add_scalar("validation/avg_error", 1. - valid_avg_accuracy,
                              engine.state.epoch)

#             prune_model(model)

        @trainer.on(Events.EPOCH_COMPLETED)
        def lr_scheduler(engine):
            metrics = valid_evaluator.state.metrics
            avg_nll = metrics['accuracy']
            sched.step(avg_nll)

        @trainer.on(Events.ITERATION_COMPLETED(every=100))
        def log_training_loss(engine):
            batch = engine.state.batch
            ds = DataLoader(TensorDataset(*batch), batch_size=batch_size)
            train_evaluator.run(ds)
            metrics = train_evaluator.state.metrics
            accuracy = metrics['accuracy']
            nll = metrics['loss']
            iter = (engine.state.iteration - 1) % len(ds_train) + 1
            if (iter % 100) == 0:
                print("Epoch[{}] Iter[{}/{}] Accuracy: {:.2f} Loss: {:.2f}".
                      format(engine.state.epoch, iter, len(ds_train), accuracy,
                             nll))
            writer.add_scalar("batchtraining/detloss", nll, engine.state.epoch)
            writer.add_scalar("batchtraining/accuracy", accuracy,
                              engine.state.iteration)
            writer.add_scalar("batchtraining/error", 1. - accuracy,
                              engine.state.iteration)
            writer.add_scalar("batchtraining/loss", engine.state.output,
                              engine.state.iteration)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_lr(engine):
            writer.add_scalar("lr", optimizer.param_groups[0]['lr'],
                              engine.state.epoch)

        @trainer.on(Events.ITERATION_COMPLETED(every=valid_freq))
        def validation_value(engine):
            metrics = valid_evaluator.state.metrics
            valid_avg_accuracy = metrics['accuracy']
            return valid_avg_accuracy

        to_save = {'model': model}
        handler = Checkpoint(
            to_save,
            DiskSaver(os.path.join(base_data, model_id), create_dir=True),
            score_function=validation_value,
            score_name="val_acc",
            global_step_transform=global_step_from_engine(trainer),
            n_saved=None)

        # kick everything off
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=valid_freq),
                                  handler)
        trainer.run(ds_train, max_epochs=max_epochs)
Beispiel #2
0
        loss_v.backward()
        optimizer.step()

        epsilon_tracker.frame(engine.state.iteration)

        if engine.state.iteration % EVAL_EVERY_FRAME == 0:
            eval_states = getattr(engine.state, "eval_states", None)

            if eval_states is None:
                eval_states = buffer.sample(STATES_TO_EVALUATE)
                eval_states = [
                    np.array(transition.state, copy=False)
                    for transition in eval_states
                ]
                eval_states = np.array(eval_states, copy=False)
                engine.state.eval_state = eval_states

            evaluate_states(eval_states, net, device, engine)

        return {"loss": loss_v.item(), "epsilon": selector.epsilon}

    engine = Engine(process_batch)
    common.setup_ignite(engine,
                        params,
                        exp_source,
                        NAME,
                        extra_metrics=("adv", "val"))
    engine.run(
        common.batch_generator(buffer, params.replay_initial,
                               params.batch_size))
Beispiel #3
0
def test_attach_fail_with_string():
    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(TypeError):
        pbar.attach(engine, "a")
Beispiel #4
0
    def train(self,
              model,
              crit,
              optimizer,
              train_loader,
              valid_loader,
              src_vocab,
              tgt_vocab,
              n_epochs,
              lr_scheduler=None):
        trainer = Engine(self.step)
        trainer.config = self.config
        trainer.model, trainer.crit = model, crit
        trainer.optimizer, trainer.lr_scheduler = optimizer, lr_scheduler
        trainer.epoch_idx = 0

        evaluator = Engine(self.validate)
        evaluator.config = self.config
        evaluator.model, evaluator.crit = model, crit
        evaluator.best_loss = np.inf

        self.attach(trainer, evaluator, verbose=self.config.verbose)

        def run_validation(engine, evaluator, valid_loader):
            evaluator.run(valid_loader, max_epochs=1)

            if engine.lr_scheduler is not None and not engine.config.use_noam_decay:
                engine.lr_scheduler.step()

        trainer.add_event_handler(Events.EPOCH_COMPLETED, run_validation,
                                  evaluator, valid_loader)
        evaluator.add_event_handler(Events.EPOCH_COMPLETED, self.check_best)
        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED,
            self.save_model,
            trainer,
            self.config,
            src_vocab,
            tgt_vocab,
        )

        trainer.run(train_loader, max_epochs=n_epochs)

        return model
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model")
    parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training")
    parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps")
    parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate")
    parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient")
    parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences")
    parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", args.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(args.device)
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        lm_loss, mc_loss = model(*batch)
        loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics 
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
               "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.logdir, 'checkpoint', save_interval=1, n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.logdir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.logdir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.logdir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.logdir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #6
0
        eps_tracker.frame(engine.state.iteration)

        if getattr(engine.state, "eval_states", None) is None:
            eval_states = buffer.sample(STATES_TO_EVALUATE)
            eval_states = [
                np.array(transition.state, copy=False)
                for transition in eval_states
            ]
            engine.state.eval_states = np.array(eval_states, copy=False)

        return {
            "loss": loss_v.item(),
            "epsilon": selector.epsilon,
        }

    engine = Engine(process_batch)
    tb = common.setup_ignite(engine,
                             exp_source,
                             f"conv-{args.run}",
                             extra_metrics=("values_mean", ))

    @engine.on(ptan.ignite.PeriodEvents.ITERS_1000_COMPLETED)
    def sync_eval(engine: Engine):
        tgt_net.sync()

        mean_val = common.calc_values_of_states(engine.state.eval_states,
                                                net,
                                                device=device)
        engine.state.metrics["values_mean"] = mean_val
        if getattr(engine.state, "best_mean_val", None) is None:
            engine.state.best_mean_val = mean_val
Beispiel #7
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    cfg = load_config(args.config)

    path, config_name = os.path.split(args.config)
    copyfile(args.config, os.path.join(cfg.workdir, config_name))
    copyfile(os.path.join(path, "model.py"), os.path.join(cfg.workdir, "model.py"))

    pbar = ProgressBar()
    tb_logger = TensorboardLogger(log_dir=os.path.join(cfg.workdir, "tb_logs"))
    checkpointer = ModelCheckpoint(os.path.join(cfg.workdir, "checkpoints"), '', 
                                   save_interval=1, n_saved=cfg.n_epochs, create_dir=True, atomic=True)

    def _update(engine, batch):
        cfg.model.train()
        cfg.optimizer.zero_grad()
        x, y = cfg.prepare_train_batch(batch)
        y_pred = cfg.model(**x)
        loss = cfg.loss_fn(y_pred, y)
        loss['loss'].backward()
        cfg.optimizer.step()
        for k in loss:
            loss[k] = loss[k].item()
        return loss
    trainer = Engine(_update)

    pbar.attach(trainer, output_transform=lambda x: {k: "{:.5f}".format(v) for k, v in x.items()})
    trainer.add_event_handler(Events.ITERATION_STARTED, cfg.scheduler)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': cfg.model, 'optimizer': cfg.optimizer})
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag="training", output_transform=lambda x: x),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(cfg.optimizer),
                     event_name=Events.ITERATION_STARTED)
    # tb_logger.attach(trainer,
    #                  log_handler=WeightsScalarHandler(cfg.model),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer,
    #                  log_handler=WeightsHistHandler(cfg.model),
    #                  event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer,
    #                  log_handler=GradsScalarHandler(cfg.model),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer,
    #                  log_handler=GradsHistHandler(cfg.model),
    #                  event_name=Events.EPOCH_COMPLETED)

    def _evaluate(engine, batch):
        cfg.model.eval()
        x, y = cfg.prepare_train_batch(batch)
        batch_size = len(batch[list(batch.keys())[0]])
        with torch.no_grad():
            y_pred = cfg.model(**x)
            loss = cfg.loss_fn(y_pred, y)
        for k in loss:
            loss[k] = loss[k].item()
            if k not in engine.state.metrics:
                engine.state.metrics[k] = 0.0
            engine.state.metrics[k] += loss[k] * batch_size / len(cfg.valid_ds)
        return loss
    evaluator = Engine(_evaluate)

    pbar.attach(evaluator, output_transform=lambda x: {k: "{:.5f}".format(v) for k, v in x.items()})

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate_on_valid_dl(engine):
        evaluator.run(cfg.valid_dl)

    tb_logger.attach(evaluator,
                     log_handler=OutputHandler(tag="validation",
                                               metric_names=['loss', 'rot_loss_cos', 'rot_loss_l1', 'trans_loss', 'true_distance', 'cls_loss'],
                                               global_step_transform=global_step_from_engine(trainer)),
                     event_name=Events.EPOCH_COMPLETED)

    trainer.run(cfg.train_dl, cfg.n_epochs)
    tb_logger.close()
Beispiel #8
0
            res["loss_distill"] = dist_loss_t.item()

        res.update(
            {
                "loss": loss_t.item(),
                "loss_value": loss_value_t.item(),
                "loss_policy": loss_policy_t.item(),
                "adv": adv_t.mean().item(),
                "loss_entropy": loss_entropy_t.item(),
                "time_batch": time.time() - start_ts,
            }
        )

        return res

    engine = Engine(process_batch)
    common.setup_ignite(
        engine,
        params,
        exp_source,
        NAME + "_" + args.name,
        extra_metrics=("test_reward", "avg_test_reward", "test_steps"),
    )

    @engine.on(ptan_ignite.PeriodEvents.ITERS_10000_COMPLETED)
    def test_network(engine):
        net.actor.train(False)
        obs = test_env.reset()
        reward = 0.0
        steps = 0
Beispiel #9
0
def test_add_event_handler_raises_with_invalid_signature():
    engine = Engine(MagicMock())

    def handler(engine):
        pass

    engine.add_event_handler(Events.STARTED, handler)
    with pytest.raises(ValueError):
        engine.add_event_handler(Events.STARTED, handler, 1)

    def handler_with_args(engine, a):
        pass

    engine.add_event_handler(Events.STARTED, handler_with_args, 1)
    with pytest.raises(ValueError):
        engine.add_event_handler(Events.STARTED, handler_with_args)

    def handler_with_kwargs(engine, b=42):
        pass

    engine.add_event_handler(Events.STARTED, handler_with_kwargs, b=2)
    with pytest.raises(ValueError):
        engine.add_event_handler(Events.STARTED, handler_with_kwargs, c=3)
    with pytest.raises(ValueError):
        engine.add_event_handler(Events.STARTED, handler_with_kwargs, 1, b=2)

    def handler_with_args_and_kwargs(engine, a, b=42):
        pass

    engine.add_event_handler(Events.STARTED,
                             handler_with_args_and_kwargs,
                             1,
                             b=2)
    with pytest.raises(ValueError):
        engine.add_event_handler(Events.STARTED,
                                 handler_with_args_and_kwargs,
                                 1,
                                 2,
                                 b=2)
    with pytest.raises(ValueError):
        engine.add_event_handler(Events.STARTED,
                                 handler_with_args_and_kwargs,
                                 1,
                                 b=2,
                                 c=3)
Beispiel #10
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh):

    device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0'

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: lr * min(1., epoch / warmup)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Beispiel #11
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        output_path=config["output_path"],
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        logger.info("Resume from a checkpoint: {}".format(
            checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Beispiel #12
0
def _setup_common_training_handlers(
    trainer: Engine,
    to_save: Optional[Mapping] = None,
    save_every_iters: int = 1000,
    output_path: Optional[str] = None,
    lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
    with_gpu_stats: bool = False,
    output_names: Optional[Iterable[str]] = None,
    with_pbars: bool = True,
    with_pbar_on_iters: bool = True,
    log_every_iters: int = 100,
    stop_on_nan: bool = True,
    clear_cuda_cache: bool = True,
    save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
    **kwargs: Any,
) -> None:
    if output_path is not None and save_handler is not None:
        raise ValueError(
            "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
        )

    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()
            )
        elif isinstance(lr_scheduler, LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:

        if output_path is None and save_handler is None:
            raise ValueError(
                "If to_save argument is provided then output_path or save_handler arguments should be also defined"
            )
        if output_path is not None:
            save_handler = DiskSaver(dirname=output_path, require_empty=False)

        checkpoint_handler = Checkpoint(
            to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs
        )
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(
            trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)  # type: ignore[arg-type]
        )

    if output_names is not None:

        def output_transform(x: Any, index: int, name: str) -> Any:
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise TypeError(
                    "Unhandled type of update_function's output. "
                    f"It should either mapping or sequence, but given {type(x)}"
                )

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(
                trainer, n
            )

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)
            )

        ProgressBar(persist=True, bar_format="").attach(
            trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED
        )
Beispiel #13
0
def run_trainer(data_loader: dict,
                model: models,
                optimizer: optim,
                lr_scheduler: optim.lr_scheduler,
                criterion: nn,
                train_epochs: int,
                log_training_progress_every: int,
                log_val_progress_every: int,
                checkpoint_every: int,
                tb_summaries_dir: str,
                chkpt_dir: str,
                resume_from: str,
                to_device: object,
                to_cpu: object,
                attackers: object = None,
                train_adv_periodic_ops: int = None,
                *args,
                **kwargs):
    def mk_lr_step(loss):
        lr_scheduler.step(loss)

    def train_step(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = map(lambda _: to_device(_), batch)
        if (train_adv_periodic_ops is not None) and (
                engine.state.iteration % train_adv_periodic_ops == 0):
            random_attacker = random.choice(list(attackers))
            x = attackers[random_attacker].perturb(x, y)
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        return loss.item()

    def eval_step(engine, batch):
        model.eval()
        with torch.no_grad():
            x, y = map(lambda _: to_device(_), batch)
            if random.choice(range(2)) % 2 == 0:
                random_attacker = random.choice(list(attackers))
                x = attackers[random_attacker].perturb(x, y)
            y_pred = model(x)
            return y_pred, y

    def chkpt_score_func(engine):
        val_eval.run(data_loader['val'])
        y_pred, y = val_eval.state.output
        loss = criterion(y_pred, y)
        return np.mean(to_cpu(loss, convert_to_np=True))

    # set up ignite engines
    trainer = Engine(train_step)
    train_eval = Engine(eval_step)
    val_eval = Engine(eval_step)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_training_progress_every))
    def log_training_results(engine):
        step = True
        run_type = 'train'
        train_eval.run(data_loader['train'])
        y_pred, y = train_eval.state.output
        loss = criterion(y_pred, y)
        log_results(to_cpu(y_pred, convert_to_np=True),
                    to_cpu(y, convert_to_np=True),
                    to_cpu(loss, convert_to_np=True), run_type, step,
                    engine.state.iteration, total_train_steps, writer)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_val_progress_every))
    def log_val_results(engine):
        step = True
        run_type = 'val'
        val_eval.run(data_loader['val'])
        y_pred, y = val_eval.state.output
        loss = criterion(y_pred, y)
        mk_lr_step(loss)
        log_results(to_cpu(y_pred, convert_to_np=True),
                    to_cpu(y, convert_to_np=True),
                    to_cpu(loss, convert_to_np=True), run_type, step,
                    engine.state.iteration, total_train_steps, writer)

    # set up vars
    total_train_steps = len(data_loader['train']) * train_epochs

    # reporter to identify memory usage
    # bottlenecks throughout network
    reporter = MemReporter()
    print_model(model, reporter)

    # set up tensorboard summary writer
    writer = create_summary_writer(model, data_loader['train'],
                                   tb_summaries_dir)
    # move model to device
    model = to_device(model)

    # set up progress bar
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
    pbar = ProgressBar(persist=True, bar_format="")
    pbar.attach(trainer, ['loss'])

    # set up checkpoint
    objects_to_checkpoint = {
        'trainer': trainer,
        'model': model,
        'optimizer': optimizer,
        'lr_scheduler': lr_scheduler
    }
    training_checkpoint = Checkpoint(to_save=objects_to_checkpoint,
                                     save_handler=DiskSaver(
                                         chkpt_dir, require_empty=False),
                                     n_saved=3,
                                     filename_prefix='best',
                                     score_function=chkpt_score_func,
                                     score_name='val_loss')

    # register events
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(every=checkpoint_every),
        training_checkpoint)

    # if resuming
    if resume_from and os.path.exists(resume_from):
        print(f'resume model from: {resume_from}')
        checkpoint = torch.load(resume_from)
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    # fire training engine
    trainer.run(data_loader['train'], max_epochs=train_epochs)
Beispiel #14
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    train_loader = check_dataset(args)
    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(requires_grad=False).to(device)
    style_transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))])

    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    running_avgs = OrderedDict()

    def step(engine, batch):

        x, _ = batch
        x = x.to(device)

        n_batch = len(x)

        optimizer.zero_grad()

        y = transformer(x)

        x = utils.normalize_batch(x)
        y = utils.normalize_batch(y)

        features_x = vgg(x)
        features_y = vgg(y)

        content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2)

        style_loss = 0.0
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = utils.gram_matrix(ft_y)
            style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :])
        style_loss *= args.style_weight

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        return {"content_loss": content_loss.item(), "style_loss": style_loss.item(), "total_loss": total_loss.item()}

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(
        args.checkpoint_model_dir, "checkpoint", n_saved=10, require_empty=False, create_dir=True
    )
    progress_bar = Progbar(loader=train_loader, metrics=running_avgs)

    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED(every=args.checkpoint_interval),
        handler=checkpoint_handler,
        to_save={"net": transformer},
    )
    trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=progress_bar)
    trainer.run(train_loader, max_epochs=args.epochs)
Beispiel #15
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default='wikitext-2',
        help="One of ('wikitext-103', 'wikitext-2') or a dict of splits paths."
    )
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")

    parser.add_argument("--embed_dim",
                        type=int,
                        default=410,
                        help="Embeddings dim")
    parser.add_argument("--hidden_dim",
                        type=int,
                        default=2100,
                        help="Hidden dimension")
    parser.add_argument("--num_max_positions",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--num_heads",
                        type=int,
                        default=10,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=16,
                        help="NUmber of layers")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--initializer_range",
                        type=float,
                        default=0.02,
                        help="Normal initialization standard deviation")
    parser.add_argument("--sinusoidal_embeddings",
                        action="store_true",
                        help="Use sinusoidal embeddings")

    parser.add_argument(
        "--mlm",
        action="store_true",
        help=
        "Train with masked-language modeling loss instead of language modeling"
    )
    parser.add_argument(
        "--mlm_probability",
        type=float,
        default=0.15,
        help="Ratio of tokens to mask for masked language modeling loss")

    parser.add_argument("--train_batch_size",
                        type=int,
                        default=8,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=8,
                        help="Batch size for validation")
    parser.add_argument("--lr",
                        type=float,
                        default=2.5e-4,
                        help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=0.25,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.0,
                        help="Weight decay")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=200,
                        help="Number of training epochs")
    parser.add_argument("--n_warmup",
                        type=int,
                        default=1000,
                        help="Number of warmup iterations")
    parser.add_argument("--eval_every",
                        type=int,
                        default=-1,
                        help="Evaluate every X steps (-1 => end of epoch)")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=1,
                        help="Accumulate gradient")

    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log on main process only, logger.warning => log on all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(
        args))  # This is a logger.info: only printed on the first process

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, model and optimizer")
    tokenizer = BertTokenizer.from_pretrained(
        'bert-base-cased',
        do_lower_case=False)  # Let's use a pre-defined tokenizer
    args.num_embeddings = len(
        tokenizer.vocab
    )  # We need this to create the model at next line (number of embeddings to use)
    model = TransformerWithLMHead(args)
    model.to(args.device)
    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     weight_decay=args.weight_decay)
    logger.info("Model has %s parameters",
                sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Prepare model for distributed training if needed
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler, train_num_words, valid_num_words = get_data_loaders(
        args, tokenizer)

    # Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original
    def mask_tokens(inputs):
        labels = inputs.clone()
        masked_indices = torch.bernoulli(
            torch.full(labels.shape, args.mlm_probability)).byte()
        labels[~masked_indices] = -1  # We only compute loss on masked tokens
        indices_replaced = torch.bernoulli(torch.full(
            labels.shape, 0.8)).byte() & masked_indices
        inputs[indices_replaced] = tokenizer.vocab[
            "[MASK]"]  # 80% of the time, replace masked input tokens with [MASK]
        indices_random = torch.bernoulli(torch.full(
            labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
        random_words = torch.randint(args.num_embeddings,
                                     labels.shape,
                                     dtype=torch.long,
                                     device=args.device)
        inputs[indices_random] = random_words[
            indices_random]  # 10% of the time, replace masked input tokens with random word
        return inputs, labels

    # Training function and trainer
    def update(engine, batch):
        model.train()
        inputs = batch.transpose(0, 1).contiguous().to(
            args.device)  # to shape [seq length, batch]
        inputs, labels = mask_tokens(inputs) if args.mlm else (
            inputs, inputs)  # Prepare masked input/labels if we use masked LM
        logits, loss = model(inputs, labels=labels)
        loss = loss / args.gradient_accumulation_steps
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            inputs = batch.transpose(0, 1).contiguous().to(
                args.device)  # to shape [seq length, batch]
            inputs, labels = mask_tokens(inputs) if args.mlm else (
                inputs,
                inputs)  # Prepare masked input/labels if we use masked LM
            logits = model(inputs)
            shift_logits = logits[:-1] if not args.mlm else logits
            shift_labels = labels[1:] if not args.mlm else labels
            return shift_logits.view(-1,
                                     logits.size(-1)), shift_labels.view(-1)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate at the end of each epoch and every 'eval_every' iterations if needed
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.eval_every > 0:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            lambda engine: evaluator.run(val_loader)
            if engine.state.iteration % args.eval_every == 0 else None)
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Learning rate schedule: linearly warm-up to lr and then decrease the learning rate to zero with cosine schedule
    cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0,
                                             len(train_loader) * args.n_epochs)
    scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr,
                                                args.n_warmup)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we average distributed metrics using average_distributed_scalar
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))}
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    # Let's convert sub-word perplexities in word perplexities. If you need details: http://sjmielke.com/comparing-perplexities.htm
    metrics["average_word_ppl"] = MetricsLambda(
        lambda x: math.exp(x * val_loader.dataset.numel() / valid_num_words),
        metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model and configuration before we start to train
    if args.local_rank in [-1, 0]:
        checkpoint_handler, tb_logger = add_logging_and_checkpoint_saving(
            trainer, evaluator, metrics, model, optimizer, args)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint for easy re-loading
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        tb_logger.close()
Beispiel #16
0
def test_default_exception_handler():
    update_function = MagicMock(side_effect=ValueError())
    engine = Engine(update_function)

    with raises(ValueError):
        engine.run([1])
Beispiel #17
0
def main():
    """
    Basic UNet as implemented in MONAI for Fetal Brain Segmentation, but using
    ignite to manage training and validation loop and checkpointing
    :return:
    """
    """
    Read input and configuration parameters
    """
    parser = argparse.ArgumentParser(
        description='Run basic UNet with MONAI - Ignite version.')
    parser.add_argument('--config',
                        dest='config',
                        metavar='config',
                        type=str,
                        help='config file')
    args = parser.parse_args()

    with open(args.config) as f:
        config_info = yaml.load(f, Loader=yaml.FullLoader)

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # GPU params
    cuda_device = config_info['device']['cuda_device']
    num_workers = config_info['device']['num_workers']
    # training and validation params
    loss_type = config_info['training']['loss_type']
    batch_size_train = config_info['training']['batch_size_train']
    batch_size_valid = config_info['training']['batch_size_valid']
    lr = float(config_info['training']['lr'])
    lr_decay = config_info['training']['lr_decay']
    if lr_decay is not None:
        lr_decay = float(lr_decay)
    nr_train_epochs = config_info['training']['nr_train_epochs']
    validation_every_n_epochs = config_info['training'][
        'validation_every_n_epochs']
    sliding_window_validation = config_info['training'][
        'sliding_window_validation']
    if 'model_to_load' in config_info['training'].keys():
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise BlockingIOError(
                "cannot find model: {}".format(model_to_load))
    else:
        model_to_load = None
    if 'manual_seed' in config_info['training'].keys():
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    # data params
    data_root = config_info['data']['data_root']
    training_list = config_info['data']['training_list']
    validation_list = config_info['data']['validation_list']
    # model saving
    out_model_dir = os.path.join(
        config_info['output']['out_model_dir'],
        datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
        config_info['output']['output_subfix'])
    print("Saving to directory ", out_model_dir)
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    max_nr_models_saved = config_info['output']['max_nr_models_saved']
    val_image_to_tensorboad = config_info['output']['val_image_to_tensorboad']

    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    torch.cuda.set_device(cuda_device)
    if seed is not None:
        # set manual seed if required (both numpy and torch)
        set_determinism(seed=seed)
        # # set torch only seed
        # torch.manual_seed(seed)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False
    """
    Data Preparation
    """
    # create cache directory to store results for Persistent Dataset
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    # create training and validation data lists
    train_files = create_data_list(data_folder_list=data_root,
                                   subject_list=training_list,
                                   img_postfix='_Image',
                                   label_postfix='_Label')

    print(len(train_files))
    print(train_files[0])
    print(train_files[-1])

    val_files = create_data_list(data_folder_list=data_root,
                                 subject_list=validation_list,
                                 img_postfix='_Image',
                                 label_postfix='_Label')
    print(len(val_files))
    print(val_files[0])
    print(val_files[-1])

    # data preprocessing for training:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    # - define 2D patches to be extracted
    # - add data augmentation (random rotation and random flip)
    # - squeeze to 2D
    train_transforms = Compose([
        LoadNiftid(keys=['img', 'seg']),
        AddChanneld(keys=['img', 'seg']),
        NormalizeIntensityd(keys=['img']),
        Resized(keys=['img', 'seg'],
                spatial_size=[96, 96],
                interp_order=[1, 0],
                anti_aliasing=[True, False]),
        RandSpatialCropd(keys=['img', 'seg'],
                         roi_size=[96, 96, 1],
                         random_size=False),
        RandRotated(keys=['img', 'seg'],
                    degrees=90,
                    prob=0.2,
                    spatial_axes=[0, 1],
                    interp_order=[1, 0],
                    reshape=False),
        RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]),
        SqueezeDimd(keys=['img', 'seg'], dim=-1),
        ToTensord(keys=['img', 'seg'])
    ])
    # create a training data loader
    # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0,
    #                                    num_workers=num_workers)
    train_ds = monai.data.PersistentDataset(data=train_files,
                                            transform=train_transforms,
                                            cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=batch_size_train,
                              shuffle=True,
                              num_workers=num_workers,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    # check_train_data = monai.utils.misc.first(train_loader)
    # print("Training data tensor shapes")
    # print(check_train_data['img'].shape, check_train_data['seg'].shape)

    # data preprocessing for validation:
    # - convert data to right format [batch, channel, dim, dim, dim]
    # - apply whitening
    # - resize to (96, 96) in-plane (preserve z-direction)
    if sliding_window_validation:
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'],
                    spatial_size=[96, 96],
                    interp_order=[1, 0],
                    anti_aliasing=[True, False]),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = False
        collate_fn_to_use = None
    else:
        # - add extraction of 2D slices from validation set to emulate how loss is computed at training
        val_transforms = Compose([
            LoadNiftid(keys=['img', 'seg']),
            AddChanneld(keys=['img', 'seg']),
            NormalizeIntensityd(keys=['img']),
            Resized(keys=['img', 'seg'],
                    spatial_size=[96, 96],
                    interp_order=[1, 0],
                    anti_aliasing=[True, False]),
            RandSpatialCropd(keys=['img', 'seg'],
                             roi_size=[96, 96, 1],
                             random_size=False),
            SqueezeDimd(keys=['img', 'seg'], dim=-1),
            ToTensord(keys=['img', 'seg'])
        ])
        do_shuffle = True
        collate_fn_to_use = list_data_collate
    # create a validation data loader
    # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    # val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0,
    #                                    num_workers=num_workers)
    val_ds = monai.data.PersistentDataset(data=val_files,
                                          transform=val_transforms,
                                          cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=batch_size_valid,
                            shuffle=do_shuffle,
                            collate_fn=collate_fn_to_use,
                            num_workers=num_workers)
    # check_valid_data = monai.utils.misc.first(val_loader)
    # print("Validation data tensor shapes")
    # print(check_valid_data['img'].shape, check_valid_data['seg'].shape)
    """
    Network preparation
    """
    # Create UNet, DiceLoss and Adam optimizer.
    net = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )

    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.cuda.current_device()
    if lr_decay is not None:
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt,
                                                              gamma=lr_decay,
                                                              last_epoch=-1)
    """
    Set ignite trainer
    """

    # function to manage batch at training
    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch['img'], batch['seg']), device,
                              non_blocking)

    trainer = create_supervised_trainer(model=net,
                                        optimizer=opt,
                                        loss_fn=loss_function,
                                        device=device,
                                        non_blocking=False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    if model_to_load is not None:
        checkpoint_handler = CheckpointLoader(load_path=model_to_load,
                                              load_dict={
                                                  'net': net,
                                                  'opt': opt,
                                              })
        checkpoint_handler.attach(trainer)
        state = trainer.state_dict()
    else:
        checkpoint_handler = ModelCheckpoint(out_model_dir,
                                             'net',
                                             n_saved=max_nr_models_saved,
                                             require_empty=False)
        # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=save_params)
        trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                  handler=checkpoint_handler,
                                  to_save={
                                      'net': net,
                                      'opt': opt
                                  })

    # StatsHandler prints loss at every iteration and print metrics at every epoch
    train_stats_handler = StatsHandler(name='trainer')
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_tensorboard_stats_handler = TensorBoardStatsHandler(
        summary_writer=writer_train)
    train_tensorboard_stats_handler.attach(trainer)

    if lr_decay is not None:
        print("Using Exponential LR decay")
        lr_schedule_handler = LrScheduleHandler(lr_scheduler,
                                                print_lr=True,
                                                name="lr_scheduler",
                                                writer=writer_train)
        lr_schedule_handler.attach(trainer)
    """
    Set ignite evaluator to perform validation at training
    """
    # set parameters for validation
    metric_name = 'Mean_Dice'
    # add evaluation metric to the evaluator engine
    val_metrics = {
        "Loss": 1.0 - MeanDice(add_sigmoid=True, to_onehot_y=False),
        "Mean_Dice": MeanDice(add_sigmoid=True, to_onehot_y=False)
    }

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch['img'].to(device), batch['seg'].to(
                device)
            roi_size = (96, 96, 1)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 batch_size_valid, net)
            return seg_probs, val_labels

    if sliding_window_validation:
        # use sliding window inference at validation
        print("3D evaluator is used")
        net.to(device)
        evaluator = Engine(_sliding_window_processor)
        for name, metric in val_metrics.items():
            metric.attach(evaluator, name)
    else:
        # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
        # user can add output_transform to return other values
        print("2D evaluator is used")
        evaluator = create_supervised_evaluator(model=net,
                                                metrics=val_metrics,
                                                device=device,
                                                non_blocking=True,
                                                prepare_batch=prepare_batch)

    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters))
    def run_validation(engine):
        evaluator.run(val_loader)

    # add early stopping handler to evaluator
    # early_stopper = EarlyStopping(patience=4,
    #                               score_function=stopping_fn_from_metric(metric_name),
    #                               trainer=trainer)
    # evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid"))
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        summary_writer=writer_valid,
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.iteration
    )  # fetch global iteration number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations.
    if val_image_to_tensorboad:
        val_tensorboard_image_handler = TensorBoardImageHandler(
            summary_writer=writer_valid,
            batch_transform=lambda batch: (batch['img'], batch['seg']),
            output_transform=lambda output: predict_segmentation(output[0]),
            global_iter_transform=lambda x: trainer.state.epoch)
        evaluator.add_event_handler(
            event_name=Events.ITERATION_COMPLETED(every=1),
            handler=val_tensorboard_image_handler)
    """
    Run training
    """
    state = trainer.run(train_loader, nr_train_epochs)
    print("Done!")
Beispiel #18
0
 def attach(self, engine: Engine):
     engine.add_event_handler(Events.ITERATION_COMPLETED, self)
     engine.register_events(*PeriodEvents)
     for e in PeriodEvents:
         State.event_to_attr[e] = "iteration"
Beispiel #19
0
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = os.path.join(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))
        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            'run',
            n_saved=3,
            require_empty=False,
            create_dir=True,
            score_function=self._negative_loss,
            score_name='loss')
        logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        label_df = pd.read_csv(config_parameters['label'], sep='\s+')
        data_df = pd.read_csv(config_parameters['data'], sep='\s+')
        # In case that both are not matching
        merged = data_df.merge(label_df, on='filename')
        common_idxs = merged['filename']
        data_df = data_df[data_df['filename'].isin(common_idxs)]
        label_df = label_df[label_df['filename'].isin(common_idxs)]

        train_df, cv_df = utils.split_train_cv(
            label_df, **config_parameters['data_args'])
        train_label = utils.df_to_dict(train_df)
        cv_label = utils.df_to_dict(cv_df)
        data = utils.df_to_dict(data_df)

        transform = utils.parse_transforms(config_parameters['transforms'])
        torch.save(config_parameters, os.path.join(outputdir,
                                                   'run_config.pth'))
        logger.info("Transforms:")
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        assert len(cv_df) > 0, "Fraction a bit too large?"

        trainloader = dataset.gettraindataloader(
            h5files=data,
            h5labels=train_label,
            transform=transform,
            label_type=config_parameters['label_type'],
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'],
            shuffle=True,
        )

        cvdataloader = dataset.gettraindataloader(
            h5files=data,
            h5labels=cv_label,
            label_type=config_parameters['label_type'],
            transform=None,
            shuffle=False,
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'],
        )
        model = getattr(models, config_parameters['model'],
                        'CRNN')(inputdim=trainloader.dataset.datadim,
                                outputdim=2,
                                **config_parameters['model_args'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            model_dump = torch.load(config_parameters['pretrained'],
                                    map_location='cpu')
            model_state = model.state_dict()
            pretrained_state = {
                k: v
                for k, v in model_dump.items()
                if k in model_state and v.size() == model_state[k].size()
            }
            model_state.update(pretrained_state)
            model.load_state_dict(model_state)
            logger.info("Loading pretrained model {}".format(
                config_parameters['pretrained']))

        model = model.to(DEVICE)
        optimizer = getattr(
            torch.optim,
            config_parameters['optimizer'],
        )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = getattr(losses, config_parameters['loss'])().to(DEVICE)

        def _train_batch(_, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch)  # output is tuple (clip, frame, target)
                loss = criterion(*output)
                loss.backward()
                # Single loss
                optimizer.step()
                return loss.item()

        def _inference(_, batch):
            model.eval()
            with torch.no_grad():
                return self._forward(model, batch)

        def thresholded_output_transform(output):
            # Output is (clip, frame, target, lengths)
            _, y_pred, y, y_clip, length = output
            batchsize, timesteps, ndim = y.shape
            idxs = torch.arange(timesteps,
                                device='cpu').repeat(batchsize).view(
                                    batchsize, timesteps)
            mask = (idxs < length.view(-1, 1)).to(y.device)
            y = y * mask.unsqueeze(-1)
            y_pred = torch.round(y_pred)
            y = torch.round(y)
            return y_pred, y

        metrics = {
            'Loss': losses.Loss(
                criterion),  #reimplementation of Loss, supports 3 way loss 
            'Precision': Precision(thresholded_output_transform),
            'Recall': Recall(thresholded_output_transform),
            'Accuracy': Accuracy(thresholded_output_transform),
        }
        train_engine = Engine(_train_batch)
        inference_engine = Engine(_inference)
        for name, metric in metrics.items():
            metric.attach(inference_engine, name)

        def compute_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.2f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))
            pbar.n = pbar.last_print_n = 0

        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine)

        train_engine.add_event_handler(Events.ITERATION_COMPLETED(every=5000),
                                       compute_metrics)
        train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics)

        early_stop_handler = EarlyStopping(
            patience=config_parameters['early_stop'],
            score_function=self._negative_loss,
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                           })

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir
Beispiel #20
0
 def __call__(self, engine: Engine):
     for period, event in self.INTERVAL_TO_EVENT.items():
         if engine.state.iteration % period == 0:
             engine.fire_event(event)
Beispiel #21
0
 def attach(self, engine: Engine):
     """
     Args:
         engine: Ignite Engine, it can be a trainer, validator or evaluator.
     """
     return engine.add_event_handler(Events.ITERATION_COMPLETED, self)
Beispiel #22
0
 def attach(self, engine: Engine):
     engine.add_event_handler(Events.ITERATION_COMPLETED, self)
     engine.register_events(*EpisodeEvents)
     State.event_to_attr[EpisodeEvents.EPISODE_COMPLETED] = "episode"
     State.event_to_attr[EpisodeEvents.BOUND_REWARD_REACHED] = "episode"
     State.event_to_attr[EpisodeEvents.BEST_REWARD_REACHED] = "episode"
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = NiftiDataset(images,
                      segs,
                      transform=imtrans,
                      seg_transform=segtrans,
                      image_only=False)

    device = torch.device("cuda:0")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 sw_batch_size, net)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(add_sigmoid=True,
             to_onehot_y=False).attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda x: x[2],
        output_transform=lambda output: predict_segmentation(output[0]),
    )
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(load_path="./runs/net_checkpoint_100.pth",
                                  load_dict={"net": net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds,
                        batch_size=1,
                        num_workers=1,
                        pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    shutil.rmtree(tempdir)
Beispiel #24
0
 def attach(self, engine: Engine, manual_step: bool = False):
     self._timer.attach(
         engine, step=None if manual_step else Events.ITERATION_COMPLETED)
     engine.add_event_handler(EpisodeEvents.EPISODE_COMPLETED, self)
Beispiel #25
0
        with torch.no_grad():
            # Anneal learning rate
            mu = next(mu_scheme)
            i = engine.state.iteration
            for group in optimizer.param_groups:
                group["lr"] = mu * math.sqrt(1 - 0.999**i) / (1 - 0.9**i)

        return {
            "elbo": elbo.item(),
            "kl": kl_divergence.item(),
            "sigma": sigma,
            "mu": mu
        }

    # Trainer and metrics
    trainer = Engine(step)
    metric_names = ["elbo", "kl", "sigma", "mu"]
    RunningAverage(output_transform=lambda x: x["elbo"]).attach(
        trainer, "elbo")
    RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl")
    RunningAverage(output_transform=lambda x: x["sigma"]).attach(
        trainer, "sigma")
    RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu")
    ProgressBar().attach(trainer, metric_names=metric_names)

    # Model checkpointing
    checkpoint_handler = ModelCheckpoint("./",
                                         "checkpoint",
                                         save_interval=1,
                                         n_saved=3,
                                         require_empty=False)
Beispiel #26
0
def main():
    # Init state params
    params = init_parms()
    device = params.get('device')

    # Loading the model, optimizer & criterion
    model = ASRModel(input_features=config.num_mel_banks,
                     num_classes=config.vocab_size).to(device)
    model = torch.nn.DataParallel(model)
    logger.info(
        f'Model initialized with {get_model_size(model):.3f}M parameters')
    optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5)
    load_checkpoint(model, optimizer, params)
    start_epoch = params['start_epoch']
    sup_criterion = CustomCTCLoss()

    # Validation progress bars defined here.
    pbar = ProgressBar(persist=True, desc="Loss")
    pbar_valid = ProgressBar(persist=True, desc="Validate")

    # load timer and best meter to keep track of state params
    timer = Timer(average=True)

    # load all the train data
    logger.info('Begining to load Datasets')
    trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path,
                                           'train-labelled-en')

    # form data loaders
    train = lmdbMultiDatasetTester(roots=[trainAirtelPaymentsPath],
                                   transform=image_val_transform)

    logger.info(f'loaded train & test dataset = {len(train)}')

    def train_update_function(engine, _):
        optimizer.zero_grad()
        imgs_sup, labels_sup, label_lengths, input_lengths = next(
            engine.state.train_loader_labbeled)
        imgs_sup = imgs_sup.to(device)
        labels_sup = labels_sup
        probs_sup = model(imgs_sup)
        sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths,
                                 input_lengths)
        sup_loss.backward()
        optimizer.step()
        return sup_loss.item()

    @torch.no_grad()
    def validate_update_function(engine, batch):
        img, labels, label_lengths, image_lengths = batch
        y_pred = model(img.to(device))
        if np.random.rand() > 0.99:
            pred_sentences = get_most_probable(y_pred)
            labels_list = labels.tolist()
            idx = 0
            for i, length in enumerate(label_lengths.cpu().tolist()):
                pred_sentence = pred_sentences[i]
                gt_sentence = sequence_to_string(labels_list[idx:idx + length])
                idx += length
                print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}")
        return (y_pred, labels, label_lengths)

    train_loader = torch.utils.data.DataLoader(train,
                                               batch_size=train_batch_size,
                                               shuffle=True,
                                               num_workers=config.workers,
                                               pin_memory=True,
                                               collate_fn=allign_collate)
    trainer = Engine(train_update_function)
    evaluator = Engine(validate_update_function)
    metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()}
    iteration_log_step = int(0.33 * len(train_loader))
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config.lr_gamma,
        patience=int(config.epochs * 0.05),
        verbose=True,
        threshold_mode="abs",
        cooldown=int(config.epochs * 0.025),
        min_lr=1e-5)

    pbar.attach(trainer, output_transform=lambda x: {'loss': x})
    pbar_valid.attach(evaluator, ['wer', 'cer'],
                      event_name=Events.EPOCH_COMPLETED,
                      closing_event_name=Events.COMPLETED)
    timer.attach(trainer)

    @trainer.on(Events.STARTED)
    def set_init_epoch(engine):
        engine.state.epoch = params['start_epoch']
        logger.info(f'Initial epoch for trainer set to {engine.state.epoch}')

    @trainer.on(Events.EPOCH_STARTED)
    def set_model_train(engine):
        if hasattr(engine.state, 'train_loader_labbeled'):
            del engine.state.train_loader_labbeled
        engine.state.train_loader_labbeled = iter(train_loader)

    @trainer.on(Events.ITERATION_COMPLETED)
    def iteration_completed(engine):
        if (engine.state.iteration % iteration_log_step
                == 0) and (engine.state.iteration > 0):
            engine.state.epoch += 1
            train.set_epochs(engine.state.epoch)
            model.eval()
            logger.info('Model set to eval mode')
            evaluator.run(train_loader)
            model.train()
            logger.info('Model set back to train mode')

    @trainer.on(Events.EPOCH_COMPLETED)
    def after_complete(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    trainer.run(train_loader, max_epochs=epochs)
    tb_logger.close()
Beispiel #27
0
def test_pbar_fail_with_non_callable_transform():
    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(TypeError):
        pbar.attach(engine, output_transform=1)
def main(
    dataset,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
    classifier_weight
):

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"
    wandb.init(project=args.dataset)

    check_manual_seed(seed)

    image_shape = (64,64,3)
    # if args.dataset == "task1": num_classes = 24
    # else : num_classes = 40

    num_classes = 40

    # Note: unsupported for now
    multi_class = True #It's True but this variable doesn't be used now


    # if args.dataset == "task1":
    #     dataset_train = CLEVRDataset(root_folder=args.dataroot,img_folder=args.dataroot+'images/')
    #     train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True)
    # else :
    #     dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/'
    #     train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True)

    dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/'
    train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True)    


    model = Glow(
        image_shape,
        hidden_channels,
        K,
        L,
        actnorm_scale,
        flow_permutation,
        flow_coupling,
        LU_decomposed,
        num_classes,
        learn_top,
        y_condition,
    )

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    wandb.watch(model)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)
        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            ### x: torch.Size([batchsize, 3, 64, 64]); y: torch.Size([batchsize, 24]); z: torch.Size([batchsize, 48, 8, 8])
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()


        return losses


    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(
        output_dir, "glow", n_saved=None, require_empty=False
    )
    ### n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {"model": model, "optimizer": optimizer},
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss"
    )


    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)


    if saved_model:
        model.load_state_dict(torch.load(saved_model, map_location="cpu")['model'])
        model.set_actnorm_init()

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)




    # evaluator = evaluation_model(args.classifier_weight)
    # @trainer.on(Events.EPOCH_COMPLETED)
    # def evaluate(engine):
    #     if args.dataset == "task1":
    #         model.eval()
    #         with torch.no_grad():
    #             test_conditions = get_test_conditions(args.dataroot).cuda()
    #             predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float()
    #             score = evaluator.eval(predict_x, test_conditions)
    #             save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_score{score:.3f}.png", normalize=True)

    #             test_conditions = get_new_test_conditions(args.dataroot).cuda()
    #             predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float()
    #             newscore = evaluator.eval(predict_x.float(), test_conditions)
    #             save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_newscore{newscore:.3f}.png", normalize=True)

    #             print(f"Iter: {engine.state.iteration}  score:{score:.3f} newscore:{newscore:.3f} ")
    #             wandb.log({"score": score, "new_score": newscore})




    trainer.run(train_loader, epochs)
Beispiel #29
0
def test_add_event_handler_raises_with_invalid_event():
    engine = Engine(lambda e, b: 1)

    with pytest.raises(ValueError, match=r"is not a valid event for this Engine"):
        engine.add_event_handler("incorrect", lambda engine: None)
Beispiel #30
0
def train():
    ################################ Model Config ###################################
    if args.lbl_method == "BIO":
        num_labels_emo = 4  # O POS NEG NORM
        num_labels_ent = 3  # O B I
    else:
        num_labels_emo = 4  # O POS NEG NORM
        num_labels_ent = 5  # O B I E S

    if not args.freeze_step > 0:
        if args.net == "3":
            model = NetY3.from_pretrained(args.bert_model,
                                          cache_dir="",
                                          num_labels_ent=num_labels_ent,
                                          num_labels_emo=num_labels_emo,
                                          dp=args.dp)
        elif args.net == "2":
            model = NetY2.from_pretrained(args.bert_model,
                                          cache_dir="",
                                          num_labels_ent=num_labels_ent,
                                          num_labels_emo=num_labels_emo,
                                          dp=args.dp)
        elif args.net == "1":
            model = NetY1.from_pretrained(args.bert_model,
                                          cache_dir="",
                                          num_labels_ent=num_labels_ent,
                                          num_labels_emo=num_labels_emo,
                                          dp=args.dp)
        elif args.net == "4":
            model = NetY4.from_pretrained(args.bert_model,
                                          cache_dir="",
                                          num_labels_ent=num_labels_ent,
                                          num_labels_emo=num_labels_emo,
                                          dp=args.dp)

    else:
        print("in net 5.................")
        if args.net == "5":
            model = NetY5_fz.from_pretrained(args.bert_model,
                                             cache_dir="",
                                             num_labels_ent=num_labels_ent,
                                             num_labels_emo=num_labels_emo,
                                             dp=args.dp)

        else:
            model = NetY3_fz.from_pretrained(args.bert_model,
                                             cache_dir="",
                                             num_labels_ent=num_labels_ent,
                                             num_labels_emo=num_labels_emo,
                                             dp=args.dp)
            ########################### Freeze First ###################################
            model.freeze()
            print("freezed model")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model = torch.nn.DataParallel(model)

    ################################# hyper parameters ###########################
    # alpha = 0.5 # 0.44
    # alpha = 0.6 # 0.42
    # alpha = 0.7
    # alpha = 1.2
    # alpha = 0.8
    # alpha = 0.7
    # alphas = [1, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8]
    if not args.multi:
        alpha = args.alpha
    else:
        alphas = [0.1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        print("alphas: ", " ".join(map(str, alphas)))
    # alphas = [2,1,1,0.8,0.8,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]
    # ------------------------------ load model from file -------------------------
    model_file = os.path.join(args.checkpoint_model_dir, args.ckp)
    if os.path.exists(model_file):
        model.load_state_dict(torch.load(model_file))
        print("load checkpoint: {} successfully!".format(model_file))
    # -----------------------------------------------------------------------------

    trn_dataloader, val_dataloader, trn_size = get_data_loader()

    ############################## Optimizer ###################################
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in param_optimizer
            if (not any(nd in n for nd in no_decay)) and p.requires_grad
        ],
        'weight_decay':
        args.wd
    }, {
        'params': [
            p for n, p in param_optimizer
            if any(nd in n for nd in no_decay) and p.requires_grad
        ],
        'weight_decay':
        0.0
    }]
    # num_train_optimization_steps = int( trn_size / args.batch_size / args.gradient_accumulation_steps) * args.epochs
    num_train_optimization_steps = int(
        trn_size / args.batch_size) * args.epochs + 5
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.lr,
                         warmup=args.warmup_proportion,
                         t_total=num_train_optimization_steps)
    # optimizer = Adam(filter(lambda p:p.requires_grad, model.parameters()), args.lr, weight_decay=5e-3)
    ######################################################################
    if not args.focal:
        if not args.wc:
            criterion = torch.nn.CrossEntropyLoss()
        else:
            # O B I E S
            wc_ent = torch.tensor([0.5, 2, 1, 2, 2]).to(device)
            # O POS NEG NORM
            wc_emo = torch.tensor([0.5, 0.8, 1, 1]).to(device)
            criterion_ent = torch.nn.CrossEntropyLoss(weight=wc_ent)
            criterion_emo = torch.nn.CrossEntropyLoss(weight=wc_emo)
            print("weight classes over!")
    else:
        criterion = FocalLoss(args.gamma)

    if args.lf:
        print("lr finding.........")
        import math
        init_value = 1e-8
        final_value = 10
        beta = 0.98
        num = len(trn_dataloader) - 1
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        optimizer.param_groups[0]['lr'] = lr
        optimizer.param_groups[1]['lr'] = lr
        optimizer.param_groups[0]['lr'] = lr
        optimizer.param_groups[1]['lr'] = lr

        def lr_find(engine, batch):
            batch_num = engine.state.iteration
            if engine.state.metrics.get("avg_loss") is None:
                engine.state.metrics["avg_loss"] = 0.
                engine.state.metrics["best_loss"] = 0.
                engine.state.metrics["lr"] = lr
                engine.state.metrics["losses"] = []
                engine.state.metrics["log_lrs"] = []
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch

            optimizer.zero_grad()
            act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids = model(
                input_ids, myinput_ids, segment_ids, input_mask, label_ent_ids,
                label_emo_ids)
            # Only keep active parts of the loss
            if not args.wc:
                loss_ent = criterion(act_logits_ent, act_y_ent)
                loss_emo = criterion(act_logits_emo, act_y_emo)
            else:
                loss_ent = criterion_ent(act_logits_ent, act_y_ent)
                loss_emo = criterion_emo(act_logits_emo, act_y_emo)
            # loss = alphas[engine.state.epoch-1] * loss_ent + loss_emo
            if not args.multi:
                loss = alpha * loss_ent + loss_emo
            else:
                loss = loss_ent + alphas[engine.state.epoch - 1] * loss_emo
            # Compute the smoothed loss
            avg_loss = beta * engine.state.metrics.get("avg_loss") + (
                1 - beta) * loss.item()
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            engine.state.metrics["avg_loss"] = avg_loss
            # Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > 4 * engine.state.metrics.get(
                    "best_loss") or torch.isnan(torch.tensor(smoothed_loss)):
                # engine.terminate()
                engine.terminate_epoch()
            # Record the best loss
            if smoothed_loss < engine.state.metrics.get(
                    "best_loss") or batch_num == 1:
                best_loss = smoothed_loss
                engine.state.metrics["best_loss"] = best_loss
            # print("cur batch:{} smothed_loss: {} best_loss: {} ".format(batch_num, smoothed_loss, engine.state.metrics["best_loss"]))
            # Store the values
            engine.state.metrics["losses"].append(smoothed_loss)
            engine.state.metrics["log_lrs"].append(
                math.log10(engine.state.metrics["lr"]))
            # Do the SGD step
            loss.backward()
            optimizer.step()
            # Update the lr for the next step
            engine.state.metrics["lr"] *= mult
            optimizer.param_groups[0]['lr'] = engine.state.metrics["lr"]
            optimizer.param_groups[1]['lr'] = engine.state.metrics["lr"]

        trn_lr_fineder = Engine(lr_find)

        @trn_lr_fineder.on(Events.EPOCH_COMPLETED)
        def draw_lr(engine):
            import matplotlib.pyplot as plt
            plt.plot(
                engine.state.metrics.get("log_lrs")[10:-5],
                engine.state.metrics.get("losses")[10:-5])
            plt.savefig("lr_finder.jpg")
            print("lr find end!")

        pbar = ProgressBar(persist=True)
        pbar.attach(trn_lr_fineder)
        trn_lr_fineder.run(trn_dataloader, max_epochs=1)
    else:
        iterations = None

        def step(engine, batch):
            if args.freeze_step > 0:
                if args.net != "5":
                    if engine.state.epoch - 1 == args.freeze_step:
                        freeze_paras = model.module.unfreeze()
                        # opt unfreeze
                        optimizer.add_param_group({'params': freeze_paras})
                        # run only once
                        args.freeze_step = -1
                else:
                    # print(f'in net 5, freeze_step:{args.freeze_step}, epos:{engine.state.epoch}')
                    if engine.state.epoch == args.freeze_step:
                        # model.module.freeze()
                        print("freeze net 5...")
                        for param in model.module.bert.parameters():
                            param.requires_grad = False
                        for param in model.module.classifier_ent.parameters():
                            param.requires_grad = False
                        print("freeze net 5 over")
                        args.freeze_step = -1

            if args.eval_radio > 0:
                global iterations
                iterations = len(trn_dataloader) // len(batch)

            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch

            optimizer.zero_grad()
            act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids = model(
                input_ids, myinput_ids, segment_ids, input_mask, label_ent_ids,
                label_emo_ids)
            # Only keep active parts of the loss
            if args.wc:
                loss_ent = criterion_ent(act_logits_ent, act_y_ent)
                loss_emo = criterion_emo(act_logits_emo, act_y_emo)
            # loss = alphas[engine.state.epoch-1] * loss_ent + loss_emo
            if not args.smooth_beta > 0:
                if not args.multi:
                    if not (args.net == "5" and args.freeze_step < 0):
                        loss_ent = criterion(act_logits_ent, act_y_ent)
                        loss_emo = criterion(act_logits_emo, act_y_emo)
                        loss = alpha * loss_ent + loss_emo
                    else:
                        loss_ent = torch.tensor([0])
                        loss_emo = criterion(act_logits_emo, act_y_emo)
                        # print("freeze net5 over")
                        loss = loss_emo
                else:
                    loss = loss_ent + alphas[engine.state.epoch - 1] * loss_emo
            else:  # smooth
                if not args.multi:
                    loss = alpha * loss_ent + loss_emo
                else:
                    loss = loss_ent + alphas[engine.state.epoch - 1] * loss_emo
                if engine.state.metrics.get("smooth_loss"):
                    loss = loss * (
                        1 - args.smooth_beta
                    ) + args.smooth_beta * engine.state.metrics["smooth_loss"]

            if engine.state.metrics.get("total_loss") is None:
                engine.state.metrics["total_loss"] = loss.item()
                engine.state.metrics["ent_loss"] = loss_ent.item()
                engine.state.metrics["emo_loss"] = loss_emo.item()
            else:
                engine.state.metrics["total_loss"] += loss.item()
                engine.state.metrics["ent_loss"] += loss_ent.item()
                engine.state.metrics["emo_loss"] += loss_emo.item()
            engine.state.metrics["smooth_loss"] = loss
            engine.state.metrics["batchloss"] = loss.item()
            engine.state.metrics["batchloss_ent"] = loss_ent.item()
            engine.state.metrics["batchloss_emo"] = loss_emo.item()
            loss.backward()
            optimizer.step()
            return loss.item(
            ), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids  # [-1, 11]

        def infer(engine, batch):
            model.eval()
            batch = tuple(t.to(device) for t in batch)
            input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch

            with torch.no_grad():
                act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids = model(
                    input_ids, myinput_ids, segment_ids, input_mask,
                    label_ent_ids, label_emo_ids)
                # Only keep active parts of the loss
                if not args.wc:
                    loss_ent = criterion(act_logits_ent, act_y_ent)
                    loss_emo = criterion(act_logits_emo, act_y_emo)
                else:
                    loss_ent = criterion_ent(act_logits_ent, act_y_ent)
                    loss_emo = criterion_emo(act_logits_emo, act_y_emo)
                # loss = alphas[engine.state.epoch-1] * loss_ent + loss_emo
                if not args.smooth_beta > 0:
                    if not args.multi:
                        loss = alpha * loss_ent + loss_emo
                    else:
                        loss = loss_ent + alphas[engine.state.epoch -
                                                 1] * loss_emo
                else:
                    if not args.multi:
                        loss = alpha * loss_ent + loss_emo
                    else:
                        loss = loss_ent + alphas[engine.state.epoch -
                                                 1] * loss_emo
                    if engine.state.metrics.get("smooth_loss"):
                        loss = loss * (
                            1 - args.smooth_beta
                        ) + args.smooth_beta * engine.state.metrics[
                            "smooth_loss"]

                if engine.state.metrics.get("total_loss") is None:
                    engine.state.metrics["total_loss"] = loss.item()
                    engine.state.metrics["ent_loss"] = loss_ent.item()
                    engine.state.metrics["emo_loss"] = loss_emo.item()
                else:
                    engine.state.metrics["total_loss"] += loss.item()
                    engine.state.metrics["ent_loss"] += loss_ent.item()
                    engine.state.metrics["emo_loss"] += loss_emo.item()
                engine.state.metrics["smooth_loss"] = loss
                engine.state.metrics["batchloss"] = loss.item()
                engine.state.metrics["batchloss_ent"] = loss_ent.item()
                engine.state.metrics["batchloss_emo"] = loss_emo.item()
            # act_logits = torch.argmax(torch.softmax(act_logits, dim=-1), dim=-1)  # [-1, 1]
            # loss = loss.mean()
            return loss.item(
            ), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids  # [-1, 11]

        trainer = Engine(step)
        trn_evaluator = Engine(infer)
        val_evaluator = Engine(infer)
        val_evaluator_iteration = Engine(infer)

        ############################## Custom Period Event ###################################
        '''
            cpe1 = CustomPeriodicEvent(n_epochs=1)
            cpe1.attach(trainer)
            cpe2 = CustomPeriodicEvent(n_epochs=2)
            cpe2.attach(trainer)
            cpe3 = CustomPeriodicEvent(n_epochs=3)
            cpe3.attach(trainer)
            cpe5 = CustomPeriodicEvent(n_epochs=5)
            cpe5.attach(trainer)
        '''
        if args.eval_step > 0:
            cpe = CustomPeriodicEvent(n_iterations=args.eval_step)
            cpe.attach(trainer)
        if args.eval_radio > 0:
            cpe2 = CustomPeriodicEvent(
                n_iterations=iterations * args.eval_radio + 1)
            cpe2.attach(trainer)

        ############################## My F1 ###################################
        F1 = FScore(output_transform=lambda x: [x[1], x[2], x[3], x[4], x[-1]],
                    lbl_method=args.lbl_method)
        F1.attach(val_evaluator, "F1")
        F1.attach(val_evaluator_iteration, "F1")

        #####################################  progress bar #########################
        RunningAverage(output_transform=lambda x: x[0]).attach(
            trainer, 'batch_loss')
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["batch_loss"])

        #####################################  Evaluate #########################
        @trainer.on(Events.EPOCH_COMPLETED)
        def compute_val_metric(engine):
            # trainer engine
            engine.state.metrics["total_loss"] /= engine.state.iteration
            engine.state.metrics["ent_loss"] /= engine.state.iteration
            engine.state.metrics["emo_loss"] /= engine.state.iteration
            pbar.log_message(
                "Training - total_loss: {:.4f} ent_loss: {:.4f} emo_loss: {:.4f}"
                .format(engine.state.metrics["total_loss"],
                        engine.state.metrics["ent_loss"],
                        engine.state.metrics["emo_loss"]))
            val_evaluator.run(val_dataloader)

            metrics = val_evaluator.state.metrics
            ent_loss = metrics["ent_loss"]
            emo_loss = metrics["emo_loss"]
            f1 = metrics['F1']
            pbar.log_message(
                "Validation Results - Epoch: {}  Ent_loss: {:.4f}, Emo_loss: {:.4f}, F1: {:.4f}"
                .format(engine.state.epoch, ent_loss, emo_loss, f1))

            pbar.n = pbar.last_print_n = 0
            # trainer.add_event_handler(Events.EPOCH_COMPLETED, compute_val_metric)

        def compute_val_metric_iteration(engine):
            # trainer engine
            val_evaluator_iteration.run(val_dataloader)
            metrics = val_evaluator_iteration.state.metrics
            f1 = metrics['F1']
            pbar.log_message(
                "Validation Results - Iteration: {} F1: {:.4f}".format(
                    engine.state.iteration, f1))
            pbar.n = pbar.last_print_n = 0

        if args.eval_step > 0:
            trainer.add_event_handler(
                eval(f"cpe.Events.ITERATIONS_{args.eval_step}_COMPLETED"),
                compute_val_metric_iteration)
        if args.eval_radio > 0:
            trainer.add_event_handler(
                eval(
                    f"cpe2.Events.ITERATIONS_{iterations * args.eval_radio + 1}_COMPLETED"
                ), compute_val_metric_iteration)

        @val_evaluator.on(Events.EPOCH_COMPLETED)
        def reduct_step(engine):
            engine.state.metrics["total_loss"] /= engine.state.iteration
            engine.state.metrics["ent_loss"] /= engine.state.iteration
            engine.state.metrics["emo_loss"] /= engine.state.iteration
            pbar.log_message(
                "Validation - total_loss: {:.4f} ent_loss: {:.4f} emo_loss: {:.4f}"
                .format(engine.state.metrics["total_loss"],
                        engine.state.metrics["ent_loss"],
                        engine.state.metrics["emo_loss"]))

        ######################################################################

        ############################## checkpoint ###################################
        def best_f1(engine):
            f1 = engine.state.metrics["F1"]
            # loss = engine.state.metrics["loss"]
            return f1

        if not args.lite:
            ckp_dir = os.path.join(args.checkpoint_model_dir, "full",
                                   args.hyper_cfg)
        else:
            ckp_dir = os.path.join(args.checkpoint_model_dir, "lite",
                                   args.hyper_cfg)
        checkpoint_handler = ModelCheckpoint(
            ckp_dir,
            'ckp',
            # save_interval=args.checkpoint_interval,
            score_function=best_f1,
            score_name="F1",
            n_saved=5,
            require_empty=False,
            create_dir=True)

        # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
        #                           to_save={'model_3FC': model})
        val_evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                        handler=checkpoint_handler,
                                        to_save={'model_title': model})
        if args.eval_step > 0:
            val_evaluator_iteration.add_event_handler(
                event_name=Events.EPOCH_COMPLETED,
                handler=checkpoint_handler,
                to_save={"model_iter": model})

        ######################################################################

        ############################## earlystopping ###################################
        stopping_handler = EarlyStopping(patience=2,
                                         score_function=best_f1,
                                         trainer=trainer)
        val_evaluator.add_event_handler(Events.COMPLETED, stopping_handler)

        ######################################################################

        #################################### tb logger ##################################
        # 在已经在对应基础上计算了 metric 的值 (compute_metric) 后 取值 log
        if not args.lite:
            tb_logger = TensorboardLogger(
                log_dir=os.path.join(args.log_dir, "full", args.hyper_cfg))
        else:
            tb_logger = TensorboardLogger(
                log_dir=os.path.join(args.log_dir, "lite", args.hyper_cfg))
        '''
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training", output_transform=lambda x: {'batchloss': x[0]}),
                         event_name=Events.ITERATION_COMPLETED)
    
        tb_logger.attach(val_evaluator,
                         log_handler=OutputHandler(tag="validation", output_transform=lambda x: {'batchloss': x[0]}),
                         event_name=Events.ITERATION_COMPLETED)
        '''
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=[
                                                       "batchloss",
                                                       "batchloss_ent",
                                                       "batchloss_emo"
                                                   ]),
                         event_name=Events.ITERATION_COMPLETED)

        tb_logger.attach(val_evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=[
                                                       "batchloss",
                                                       "batchloss_ent",
                                                       "batchloss_emo"
                                                   ]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=[
                                                       "total_loss",
                                                       "ent_loss", "emo_loss"
                                                   ]),
                         event_name=Events.EPOCH_COMPLETED)
        # tb_logger.attach(trainer,
        #                  log_handler=OutputHandler(tag="training", output_transform=lambda x: {'loss': x[0]}),
        #                  event_name=Events.EPOCH_COMPLETED)
        '''
        tb_logger.attach(trn_evaluator,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["F1"],
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
    
        '''
        tb_logger.attach(val_evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=[
                                                       "total_loss",
                                                       "ent_loss", "emo_loss",
                                                       "F1"
                                                   ],
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        if args.eval_step > 0:
            tb_logger.attach(val_evaluator_iteration,
                             log_handler=OutputHandler(
                                 tag="validation_iteration",
                                 metric_names=["F1"],
                                 another_engine=trainer),
                             event_name=Events.EPOCH_COMPLETED)

        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer, "lr"),
                         event_name=Events.EPOCH_COMPLETED)
        '''
    
        tb_logger.attach(trainer,
                         log_handler=WeightsScalarHandler(model),
                         event_name=Events.ITERATION_COMPLETED)
    
        tb_logger.attach(trainer,
                         log_handler=WeightsHistHandler(model),
                         event_name=Events.EPOCH_COMPLETED)
    
        # tb_logger.attach(trainer,
        #                  log_handler=GradsScalarHandler(model),
        #                  event_name=Events.ITERATION_COMPLETED)
    
        tb_logger.attach(trainer,
                         log_handler=GradsHistHandler(model),
                         event_name=Events.EPOCH_COMPLETED)
        '''
        trainer.run(trn_dataloader, max_epochs=args.epochs)
        tb_logger.close()