Exemplo n.º 1
0
def test_disabled_n_saved(dirname):

    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=None)
    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}

    num_iters = 100
    for i in range(num_iters):
        engine.state.iteration = i
        h(engine, to_save)

    saved_files = sorted(os.listdir(dirname))
    assert len(saved_files) == num_iters, "{}".format(saved_files)

    expected = sorted(
        ["{}_{}_{}.pth".format(_PREFIX, "model", i) for i in range(num_iters)])
    assert saved_files == expected, "{} vs {}".format(saved_files, expected)
Exemplo n.º 2
0
def test_removes_each_score_at_most_once(dirname):
    scores = [0, 1, 1, 2, 3]
    scores_iter = iter(scores)

    def score_function(_):
        return next(scores_iter)

    h = ModelCheckpoint(dirname,
                        _PREFIX,
                        create_dir=False,
                        n_saved=2,
                        score_function=score_function)

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}
    for _ in range(len(scores)):
        h(engine, to_save)
Exemplo n.º 3
0
 def save_model(self, model, save_interval=None, n_saved=1):
     """Extension method for saving model.
     This method saves model as a PyTorch model filetype (.pth). Saved
     file will be saved on `self.res_dir / model / {model_class_name}.pth`.
     Args:
         trainer (ignite.Engine): trainer
         model (torch.nn.Module): model class.
         save_interval (int): Number of epoch interval in which model should
             be kept on disk.
         n_saved (int): Number of objects that should be kept on disk. Older
             files will be removed. If set to None, all objects are kept.
     """
     if isinstance(model, torch.nn.DataParallel):
         model = model.module
     save_handler = ModelCheckpoint(self.res_dir / 'model',
                                    model.__class__.__name__,
                                    save_interval=save_interval,
                                    n_saved=n_saved)
     self.trainer.add_event_handler(Events.EPOCH_COMPLETED, save_handler,
                                    {'epoch': model})
Exemplo n.º 4
0
def test_best_k(dirname):
    scores = iter([1.0, -2., 3.0, -4.0])

    def score_function(engine):
        return next(scores)

    h = ModelCheckpoint(dirname,
                        _PREFIX,
                        create_dir=False,
                        n_saved=2,
                        score_function=score_function,
                        save_as_state_dict=False)

    to_save = {'name': 42}
    for _ in range(4):
        h(None, to_save)

    expected = ['{}_{}_{}.pth'.format(_PREFIX, 'name', i) for i in [1, 3]]

    assert sorted(os.listdir(dirname)) == expected
Exemplo n.º 5
0
def _test_tpu_saves_to_cpu(device, dirname):
    torch.manual_seed(0)

    h = ModelCheckpoint(dirname, _PREFIX)
    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=1)

    model = DummyModel().to(device)
    to_save = {"model": model}

    h(engine, to_save)

    idist.barrier()

    fname = h.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, _PREFIX) in fname
    assert os.path.exists(fname)
    loaded_objects = torch.load(fname)
    assert loaded_objects == model.cpu().state_dict()
Exemplo n.º 6
0
def warp_common_handler(engine, option, networks_to_save, monitoring_metrics,
                        add_message, use_folder_pathes):
    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(engine, metric_names=monitoring_metrics)
    timer = Timer(average=True)
    timer.attach(engine,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)
    create_plots = make_handle_create_plots(option.output_dir, LOGS_FNAME,
                                            PLOT_FNAME)
    checkpoint_handler = ModelCheckpoint(option.output_dir,
                                         CKPT_PREFIX,
                                         save_interval=option.save_interval,
                                         n_saved=option.n_saved,
                                         require_empty=False,
                                         create_dir=True,
                                         save_as_state_dict=True)

    engine.add_event_handler(Events.ITERATION_COMPLETED,
                             checkpoint_handler,
                             to_save=networks_to_save)
    engine.add_event_handler(Events.ITERATION_COMPLETED, create_plots)
    engine.add_event_handler(
        Events.EXCEPTION_RAISED,
        make_handle_handle_exception(checkpoint_handler, networks_to_save,
                                     create_plots))
    engine.add_event_handler(
        Events.STARTED,
        make_handle_make_dirs(option.output_dir, use_folder_pathes))
    engine.add_event_handler(Events.STARTED, make_move_html(option.output_dir))
    engine.add_event_handler(Events.STARTED, make_create_option_data(option))
    engine.add_event_handler(Events.EPOCH_COMPLETED,
                             make_handle_print_times(timer, pbar))
    engine.add_event_handler(
        Events.ITERATION_COMPLETED,
        make_handle_print_logs(option.output_dir, option.epochs,
                               option.print_freq, pbar, add_message))
    return engine
Exemplo n.º 7
0
def save_best_model_by_val_score(output_path,
                                 evaluator,
                                 model,
                                 metric_name,
                                 n_saved=3,
                                 trainer=None,
                                 tag="val"):
    """Method adds a handler to `evaluator` to save best models based on the score (named by `metric_name`)
    provided by `evaluator`.

    Args:
        output_path (str): output path to indicate where to save best models
        evaluator (Engine): evaluation engine used to provide the score
        model (nn.Module): model to store
        metric_name (str): metric name to use for score evaluation. This metric should be present in
            `evaluator.state.metrics`.
        n_saved (int, optional): number of best models to store
        trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model.
        tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val".

    Returns:
        A :class:`~ignite.handlers.checkpoint.ModelCheckpoint` handler.
    """
    global_step_transform = None
    if trainer is not None:
        global_step_transform = global_step_from_engine(trainer)

    best_model_handler = ModelCheckpoint(
        dirname=output_path,
        filename_prefix="best",
        n_saved=n_saved,
        global_step_transform=global_step_transform,
        score_name="{}_{}".format(tag, metric_name.lower()),
        score_function=get_default_score_fn(metric_name),
    )
    evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {
        "model": model,
    })

    return best_model_handler
Exemplo n.º 8
0
def test_best_k_with_suffix(dirname):
    scores = [0.3456789, 0.1234, 0.4567, 0.134567]
    scores_iter = iter(scores)

    def score_function(engine):
        return next(scores_iter)

    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2,
                        score_function=score_function, score_name="val_loss")

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {'model': model}
    for _ in range(4):
        engine.state.epoch += 1
        h(engine, to_save)

    expected = ['{}_{}_val_loss={:.4}.pth'.format(_PREFIX, 'model', scores[e - 1]) for e in [1, 3]]

    assert sorted(os.listdir(dirname)) == expected
Exemplo n.º 9
0
    def _test(ext, require_empty, archived):
        previous_fname = os.path.join(dirname, '{}_{}_{}{}'.format(_PREFIX, 'obj', 1, ext))
        with open(previous_fname, 'w') as f:
            f.write("test")

        h = ModelCheckpoint(dirname, _PREFIX, create_dir=True, require_empty=require_empty, archived=archived)
        engine = Engine(lambda e, b: None)
        engine.state = State(epoch=0, iteration=1)

        model = DummyModel()
        to_save = {'model': model}
        h(engine, to_save)

        fname = h.last_checkpoint
        ext = ".pth.tar" if archived else ".pth"
        assert isinstance(fname, str)
        assert os.path.join(dirname, '{}_{}_{}{}'.format(_PREFIX, 'model', 1, ext)) == fname
        assert os.path.exists(fname)
        assert os.path.exists(previous_fname)
        loaded_objects = torch.load(fname)
        assert loaded_objects == model.state_dict()
        os.remove(fname)
Exemplo n.º 10
0
def test_best_k(dirname):
    scores = iter([1.2, -2.0, 3.1, -4.0])

    def score_function(_):
        return next(scores)

    h = ModelCheckpoint(dirname,
                        _PREFIX,
                        create_dir=False,
                        n_saved=2,
                        score_function=score_function)

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}
    for _ in range(4):
        h(engine, to_save)

    expected = ["{}_{}_{}.pth".format(_PREFIX, "model", i) for i in [1.2, 3.1]]

    assert sorted(os.listdir(dirname)) == expected
Exemplo n.º 11
0
def test_simple_recovery_from_existing_non_empty(dirname):
    previous_fname = os.path.join(dirname,
                                  '{}_{}_{}.pth'.format(_PREFIX, 'obj', 1))
    with open(previous_fname, 'w') as f:
        f.write("test")

    h = ModelCheckpoint(dirname, _PREFIX, create_dir=True, require_empty=False)
    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=1)

    model = DummyModel()
    to_save = {'model': model}
    h(engine, to_save)

    fname = h.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, '{}_{}_{}.pth'.format(_PREFIX, 'model',
                                                       1)) == fname
    assert os.path.exists(fname)
    assert os.path.exists(previous_fname)
    loaded_objects = torch.load(fname)
    assert "model" in loaded_objects
    assert loaded_objects['model'] == model.state_dict()
Exemplo n.º 12
0
def train():
    set_seed(train_param.seed)
    model = Model(model_param)
    optimizer = AdamW(model.parameters(), lr=train_param.lr, eps=1e-8)
    update_steps = train_param.epoch * len(train_loader)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=update_steps)
    loss_fn = [translate, MSELoss()]
    device = torch.device(f'cuda:{train_param.device}')
    trainer = create_trainer(model, optimizer, scheduler, loss_fn,
                             train_param.grad_norm, device)
    train_evaluator = create_evaluator(model, metric, device)
    dev_evaluator = create_evaluator(model, metric, device)
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(every=train_param.interval),
        log_training_loss)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results,
                              *(train_evaluator, train_loader, 'Train'))
    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results,
                              *(dev_evaluator, dev_loader, 'Dev'))
    es_handler = EarlyStopping(patience=train_param.patience,
                               score_function=score_fn,
                               trainer=trainer)
    dev_evaluator.add_event_handler(Events.COMPLETED, es_handler)
    ckpt_handler = ModelCheckpoint(train_param.save_path,
                                   '',
                                   score_function=score_fn,
                                   score_name='score',
                                   require_empty=False)
    dev_evaluator.add_event_handler(Events.COMPLETED, ckpt_handler, {
        'model': model,
        'param': model_param
    })
    print(
        f'Start running {train_param.save_path.split("/")[-1]} at device: {train_param.device}\t'
        f'lr: {train_param.lr}')
    trainer.run(train_loader, max_epochs=train_param.epoch)
Exemplo n.º 13
0
def test_best_k_with_suffix(dirname):
    scores = [0.3456789, 0.1234, 0.4567, 0.134567]
    scores_iter = iter(scores)

    def score_function(engine):
        return next(scores_iter)

    h = ModelCheckpoint(dirname,
                        _PREFIX,
                        create_dir=False,
                        n_saved=2,
                        score_function=score_function,
                        score_name="val_loss")

    to_save = {'name': 42}
    for _ in range(4):
        h(None, to_save)

    expected = [
        '{}_{}_{}_val_loss={:.7}.pth'.format(_PREFIX, 'name', i, scores[i - 1])
        for i in [1, 3]
    ]

    assert sorted(os.listdir(dirname)) == expected
Exemplo n.º 14
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    if sys.version_info > (3, ):
        from ignite.contrib.metrics.gpu_info import GpuInfo

        try:
            GpuInfo().attach(trainer)
        except RuntimeError:
            print(
                "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
                "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
                "install it : `pip install pynvml`")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    tb_logger = TensorboardLogger(log_dir=log_dir)

    tb_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
        metric_names="all",
    )

    for tag, evaluator in [("training", train_evaluator),
                           ("validation", validation_evaluator)]:
        tb_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    tb_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    tb_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    def score_function(engine):
        return engine.state.metrics["accuracy"]

    model_checkpoint = ModelCheckpoint(
        log_dir,
        n_saved=2,
        filename_prefix="best",
        score_function=score_function,
        score_name="validation_accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    tb_logger.close()
Exemplo n.º 15
0
def train():
    config_file = "configs/train_daily_dialog_emotion_action_config.json"
    config = Config.from_json_file(config_file)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", config.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(config.device)
    optimizer = OpenAIAdam(model.parameters(), lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.fp16)
    if config.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[config.local_rank],
                                        output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        config, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple(
            input_tensor.to(config.device) for input_tensor in batch)
        lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels,
                                 token_type_ids, token_emotion_ids,
                                 token_action_ids)
        loss = (lm_loss * config.lm_coef +
                mc_loss * config.mc_coef) / config.gradient_accumulation_steps
        if config.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           config.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
        if engine.state.iteration % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch
            #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids,
                                  token_emotion_ids=token_emotion_ids,
                                  token_action_ids=token_action_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if config.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if config.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if config.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, config.lr),
                                 (config.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], config),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], config)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if config.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=config.log_dir)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(config,
                   tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=config.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if config.local_rank in [-1, 0] and config.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Exemplo n.º 16
0
def train(model,
          train_dl,
          val_dl,
          test_dl,
          loss_str,
          optimizer_name,
          lr,
          max_epochs,
          metrics,
          val_metric_to_monitor,
          print_freq,
          epoch_per_metric,
          plateau_patience,
          plateau_terminate,
          gpu_if_available,
          gpu_idx,
          custom_metrics=None,
          save_dir=None):
    """Simple model training framework setup with ignite.

    This builds and runs a standard training process using the ignite framework. Given train/val/test dataloaders,
    attaches specified metrics and runs over the training set with LR scheduling, early stopping, and model
    check-pointing all built in.

    Args:
        model (nn.Module): A network built in standard PyTorch.
        train_dl (DataLoader): Train data.
        val_dl (DataLoader): Val data.
        test_dl (DataLoader): Test data.
        optimizer_name (str): Name of the optimizer to use.
        lr (float): The initial value of the learning rate.
        loss_str (function): The loss function.
        max_epochs (int): Max epochs to run the algorithm for.
        metrics (list): A list of metric strings to be monitored.
        val_metric_to_monitor (str): The metric to monitor for LR scheduling and early stopping.
        print_freq (int): Frequency of printing train/val results to console.
        epoch_per_metric (int): Number of epochs before next computation of val metrics.
        plateau_patience (int): Number of epochs with no improvement before LR reduction.
        plateau_terminate (int): Number of epochs with no improvement before stopping.
        gpu_if_available (bool): Run on the gpu if one exists.
        gpu_idx (int): The index of the gpu to run on.
        custom_metrics (dict): Dictionary of custom metrics.
        save_dir (str): Location to save the model checkpoints.

    Returns:
        (results:dict, validation_history:dict): The results of the best model and the full training history.
    """
    device = set_device(gpu_if_available, gpu_idx=gpu_idx)
    loss_fn = set_loss(loss_str)
    lr = set_lr(train_dl) if lr is None else lr
    optimizer = setup_optimizer(model, optimizer_name, lr)

    # Choose metrics given the string list
    binary = True if isinstance(loss_fn, torch.nn.BCEWithLogitsLoss) else False
    metrics, train_metrics, val_metrics = setup_metrics(
        metrics, loss_fn, binary=binary, custom_metrics=custom_metrics)

    # Build engines
    trainer_output_tfm = lambda x, y, y_pred, loss: (loss.item(), y, y_pred)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device,
                                        output_transform=trainer_output_tfm)
    evaluator = create_supervised_evaluator(model,
                                            device=device,
                                            metrics=val_metrics)

    # Attach running average metrics to trainer
    for name, metric in train_metrics.items():
        metric.attach(trainer, name)

    # Progress bar
    pbar = tqdm(range(max_epochs))

    # Validation loop
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_metrics(engine):
        epoch = engine.state.epoch
        pbar.update(1)

        if (epoch % epoch_per_metric == 0) or (epoch == 0):
            evaluator.run(val_dl, max_epochs=1)

            add_metrics_to_dict(trainer.state.metrics, validation_history,
                                '.train')
            add_metrics_to_dict(evaluator.state.metrics, validation_history,
                                '.val')

            if (epoch % print_freq == 0) or (epoch == 0):
                print_val_results(epoch, validation_history, pbar=pbar)

    # Score to monitor for early stopping and check-pointing
    sign = -1 if val_metric_to_monitor is 'loss' else 1
    score_function = lambda engine: engine.state.metrics[val_metric_to_monitor
                                                         ] * sign

    # LR scheduling (monitors validation loss), early stopping and check-pointing
    scheduler = ReduceLROnPlateau(optimizer,
                                  patience=plateau_patience,
                                  threshold=1e-6,
                                  min_lr=1e-7)
    evaluator.add_event_handler(
        Events.EPOCH_COMPLETED,
        lambda engine: scheduler.step(engine.state.metrics['loss']))

    # Early stopping
    stopping = EarlyStopping(patience=plateau_terminate,
                             score_function=score_function,
                             trainer=trainer)
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, stopping)

    # Checkpoint
    save_best_model = ModelCheckpoint(save_dir,
                                      '',
                                      score_function=score_function)
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, save_best_model,
                                {'best_model': model})

    # History
    validation_history = OrderedDict()
    for type in ('train', 'val'):
        for name in metrics:
            validation_history[name + '.' + type] = []

    # Train the model
    start, start_memory = time.time(), get_memory(device, reset=True)
    trainer.run(train_dl, max_epochs=max_epochs)
    elapsed = time.time() - start
    memory_usage = get_memory(device) - start_memory

    # Score on test
    model.load_state_dict(torch.load(save_best_model.last_checkpoint))
    evaluator.run(test_dl, max_epochs=1)

    # Final model results
    results = OrderedDict(**{
        'elapsed_time': elapsed,
        'memory_usage': memory_usage
    })

    # Best metric/value
    func = np.argmax if sign == 1 else np.argmin
    best_idx = func(validation_history[val_metric_to_monitor + '.val'])
    for key, value in validation_history.items():
        results[key] = value[best_idx]

    for metric, value in evaluator.state.metrics.items():
        results[metric + '.test'] = value

    print_final_results(results)

    return model, results, validation_history
Exemplo n.º 17
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every,
         ld_on_samples, weight_gan, weight_prior, weight_logdet,
         jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every,
         eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale,
         no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split,
         disc_arch, weight_entropy_reg, db):

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)
    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, logittransform, sn, affine_eps,
                 no_actnorm, affine_scale_eps, actnorm_max_scale,
                 no_conv_actnorm, affine_max_scale, actnorm_eps, no_split)

    model = model.to(device)

    if disc_arch == 'mine':
        discriminator = mine.Discriminator(image_shape[-1])
    elif disc_arch == 'biggan':
        discriminator = cgan_models.Discriminator(
            image_channels=image_shape[-1], conditional_D=False)
    elif disc_arch == 'dcgan':
        discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1])
    elif disc_arch == 'inv':
        discriminator = InvDiscriminator(
            image_shape, hidden_channels, K, L, actnorm_scale,
            flow_permutation, flow_coupling, LU_decomposed, num_classes,
            learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm,
            affine_scale_eps, actnorm_max_scale, no_conv_actnorm,
            affine_max_scale, actnorm_eps, no_split)

    discriminator = discriminator.to(device)
    D_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                    discriminator.parameters()),
                             lr=disc_lr,
                             betas=(.5, .99),
                             weight_decay=0)
    if optim_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               betas=(.5, .99),
                               weight_decay=0)
    elif optim_name == 'adamax':
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    if not no_warm_up:
        lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lr_lambda)

    iteration_fieldnames = [
        'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd',
        'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc'
    ]
    iteration_logger = CSVLogger(fieldnames=iteration_fieldnames,
                                 filename=os.path.join(output_dir,
                                                       'iteration_log.csv'))
    iteration_fieldnames = [
        'global_iteration', 'condition_num', 'max_sv', 'min_sv',
        'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv'
    ]
    svd_logger = CSVLogger(fieldnames=iteration_fieldnames,
                           filename=os.path.join(output_dir, 'svd_log.csv'))

    #
    test_iter = test_loader.__iter__()
    N_inception = 1000
    x_real_inception = torch.cat([
        test_iter.__next__()[0].to(device)
        for _ in range(N_inception // args.batch_size + 1)
    ], 0)[:N_inception]
    x_real_inception = x_real_inception + .5
    x_for_recon = test_iter.__next__()[0].to(device)

    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        real_acc = fake_acc = acc = 0
        if weight_gan > 0:
            fake = generate_from_noise(model, x.size(0), clamp=clamp)

            D_real_scores = run_noised_disc(discriminator, x.detach())
            D_fake_scores = run_noised_disc(discriminator, fake.detach())

            ones_target = torch.ones((x.size(0), 1), device=x.device)
            zeros_target = torch.zeros((x.size(0), 1), device=x.device)

            D_real_accuracy = torch.sum(
                torch.round(F.sigmoid(D_real_scores)) ==
                ones_target).float() / ones_target.size(0)
            D_fake_accuracy = torch.sum(
                torch.round(F.sigmoid(D_fake_scores)) ==
                zeros_target).float() / zeros_target.size(0)

            D_real_loss = F.binary_cross_entropy_with_logits(
                D_real_scores, ones_target)
            D_fake_loss = F.binary_cross_entropy_with_logits(
                D_fake_scores, zeros_target)

            D_loss = (D_real_loss + D_fake_loss) / 2
            gp = gradient_penalty(
                x.detach(), fake.detach(),
                lambda _x: run_noised_disc(discriminator, _x))
            D_loss_plus_gp = D_loss + 10 * gp
            D_optimizer.zero_grad()
            D_loss_plus_gp.backward()
            D_optimizer.step()

            # Train generator
            fake = generate_from_noise(model,
                                       x.size(0),
                                       clamp=clamp,
                                       guard_nans=False)
            G_loss = F.binary_cross_entropy_with_logits(
                run_noised_disc(discriminator, fake),
                torch.ones((x.size(0), 1), device=x.device))

            # Trace
            real_acc = D_real_accuracy.item()
            fake_acc = D_fake_accuracy.item()
            acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item())

        z, nll, y_logits, (prior, logdet) = model.forward(x,
                                                          None,
                                                          return_details=True)
        train_bpd = nll.mean().item()

        loss = 0
        if weight_gan > 0:
            loss = loss + weight_gan * G_loss
        if weight_prior > 0:
            loss = loss + weight_prior * -prior.mean()
        if weight_logdet > 0:
            loss = loss + weight_logdet * -logdet.mean()

        if weight_entropy_reg > 0:
            _, _, _, (sample_prior,
                      sample_logdet) = model.forward(fake,
                                                     None,
                                                     return_details=True)
            # notice this is actually "decreasing" sample likelihood.
            loss = loss + weight_entropy_reg * (sample_prior.mean() +
                                                sample_logdet.mean())
        # Jac Reg
        if jac_reg_lambda > 0:
            # Sample
            x_samples = generate_from_noise(model,
                                            args.batch_size,
                                            clamp=clamp).detach()
            x_samples.requires_grad_()
            z = model.forward(x_samples, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            sample_foward_jac = compute_jacobian_regularizer(x_samples,
                                                             all_z,
                                                             n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            randz = torch.randn(zshape).to(device)
            randz = torch.autograd.Variable(randz, requires_grad=True)
            images = model(z=randz,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [randz] + other_zs
            sample_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # Data
            x.requires_grad_()
            z = model.forward(x, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            z.requires_grad_()
            images = model(z=z,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [z] + other_zs
            data_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac )
            loss = loss + jac_reg_lambda * (sample_foward_jac +
                                            sample_inverse_jac +
                                            data_foward_jac + data_inverse_jac)

        if not eval_only:
            optimizer.zero_grad()
            loss.backward()
            if not db:
                assert max_grad_clip == max_grad_norm == 0
            if max_grad_clip > 0:
                torch.nn.utils.clip_grad_value_(model.parameters(),
                                                max_grad_clip)
            if max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)

            # Replace NaN gradient with 0
            for p in model.parameters():
                if p.requires_grad and p.grad is not None:
                    g = p.grad.data
                    g[g != g] = 0

            optimizer.step()

        if engine.iter_ind % 100 == 0:
            with torch.no_grad():
                fake = generate_from_noise(model, x.size(0), clamp=clamp)
                z = model.forward(fake, None, return_details=True)[0]
            print("Z max min")
            print(z.max().item(), z.min().item())
            if (fake != fake).float().sum() > 0:
                title = 'NaNs'
            else:
                title = "Good"
            grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.title(title)
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

        if engine.iter_ind % eval_every == 0:

            def check_all_zero_except_leading(x):
                return x % 10**np.floor(np.log10(x)) == 0

            if engine.iter_ind == 0 or check_all_zero_except_leading(
                    engine.iter_ind):
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt'))

            model.eval()

            with torch.no_grad():
                # Plot recon
                fpath = os.path.join(output_dir, '_recon',
                                     f'recon_{engine.iter_ind}.png')
                sample_pad = run_recon_evolution(
                    model,
                    generate_from_noise(model, args.batch_size,
                                        clamp=clamp).detach(), fpath)
                print(
                    f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}")

                pad = run_recon_evolution(model, x_for_recon, fpath)
                print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}")
                pad = pad.item()
                sample_pad = sample_pad.item()

                # Inception score
                sample = torch.cat([
                    generate_from_noise(model, args.batch_size, clamp=clamp)
                    for _ in range(N_inception // args.batch_size + 1)
                ], 0)[:N_inception]
                sample = sample + .5

                if (sample != sample).float().sum() > 0:
                    print("Sample NaNs")
                    raise
                else:
                    fid = run_fid(x_real_inception.clamp_(0, 1),
                                  sample.clamp_(0, 1))
                    print(f'fid: {fid}, global_iter: {engine.iter_ind}')

                # Eval BPD
                eval_bpd = np.mean([
                    model.forward(x.to(device), None,
                                  return_details=True)[1].mean().item()
                    for x, _ in test_loader
                ])

                stats_dict = {
                    'global_iteration': engine.iter_ind,
                    'fid': fid,
                    'train_bpd': train_bpd,
                    'pad': pad,
                    'eval_bpd': eval_bpd,
                    'sample_pad': sample_pad,
                    'batch_real_acc': real_acc,
                    'batch_fake_acc': fake_acc,
                    'batch_acc': acc
                }
                iteration_logger.writerow(stats_dict)
                plot_csv(iteration_logger.filename)
            model.train()

        if engine.iter_ind + 2 % svd_every == 0:
            model.eval()
            svd_dict = {}
            ret = utils.computeSVDjacobian(x_for_recon, model)
            D_for, D_inv = ret['D_for'], ret['D_inv']
            cn = float(D_for.max() / D_for.min())
            cn_inv = float(D_inv.max() / D_inv.min())
            svd_dict['global_iteration'] = engine.iter_ind
            svd_dict['condition_num'] = cn
            svd_dict['max_sv'] = float(D_for.max())
            svd_dict['min_sv'] = float(D_for.min())
            svd_dict['inverse_condition_num'] = cn_inv
            svd_dict['inverse_max_sv'] = float(D_inv.max())
            svd_dict['inverse_min_sv'] = float(D_inv.min())
            svd_logger.writerow(svd_dict)
            # plot_utils.plot_stability_stats(output_dir)
            # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv')
            model.train()
            if eval_only:
                sys.exit()

        # Dummy
        losses['total_loss'] = torch.mean(nll).item()
        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(gan_step)
    # else:
    #     trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=5,
                                         n_saved=1,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        print("Loading...")
        print(saved_model)
        loaded = torch.load(saved_model)
        # if 'Glow' in str(type(loaded)):
        #     model  = loaded
        # else:
        #     raise
        # # if 'Glow' in str(type(loaded)):
        # #     loaded  = loaded.state_dict()
        model.load_state_dict(loaded)
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        if saved_model:
            return
        model.train()
        print("Initializing Actnorm...")
        init_batches = []
        init_targets = []

        if n_init_batches == 0:
            model.set_actnorm_init()
            return
        with torch.no_grad():
            if init_sample:
                generate_from_noise(model,
                                    args.batch_size * args.n_init_batches)
            else:
                for batch, target in islice(train_loader, None,
                                            n_init_batches):
                    init_batches.append(batch)
                    init_targets.append(target)

                init_batches = torch.cat(init_batches).to(device)

                assert init_batches.shape[0] == n_init_batches * batch_size

                if y_condition:
                    init_targets = torch.cat(init_targets).to(device)
                else:
                    init_targets = None

                model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        if not no_warm_up:
            scheduler.step()
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Exemplo n.º 18
0
def test_simple_recovery(dirname):
    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, save_interval=1)
    h(None, {'obj': 42})

    fname = os.path.join(dirname, '{}_{}_{}.pth'.format(_PREFIX, 'obj', 1))
    assert torch.load(fname) == 42
Exemplo n.º 19
0
# Evaluation
metrics = {
    'loss': Loss(loss_fn),
    'acc': Accuracy()
}

def score_fn(engine):
    acc = engine.state.metrics['acc']

    return acc

evaluator = create_evaluator(model, metrics, device=device)

def log_metrics(engine):

    metrics = evaluator.run(valoader).metrics
    print('[INFO] Compute metrics...')
    print(' Validation Results - Average Loss: {:.4f} | Accuracy: {:.4f}'.format(metrics['loss'], metrics['acc']))
    print('[INFO] Complete metrics...')
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_metrics)

# save the model checkpoints
saver = ModelCheckpoint(snapshots, 'r101', n_saved=10, score_name='acc', score_function=score_fn)
evaluator.add_event_handler(Events.COMPLETED, saver, {'model': model.module})

# start training
print('[INFO] Start training...')
trainer.run(trainloader, epochs)
print('[INFO] Complete training...')
Exemplo n.º 20
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model")
    parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training")
    parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps")
    parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate")
    parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient")
    parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences")
    parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", args.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTLMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(args.device)
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        lm_loss, mc_loss = model(*batch)
        loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics 
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
               "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Exemplo n.º 21
0
def train_gan(logger: Logger,
              experiment_dir: Path,
              data_dir: Path,
              batch_size: int,
              z_dim: int,
              g_filters: int,
              d_filters: int,
              learning_rate: float,
              beta_1: float,
              epochs: int,
              saved_g: bool = False,
              saved_d: bool = False,
              seed: Optional[int] = None,
              g_extra_layers: int = 0,
              d_extra_layers: int = 0,
              scheduler: bool = False) -> None:
    seed = fix_random_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Train started with seed: {seed}")
    dataset = HDF5ImageDataset(image_dir=data_dir)
    desired_minkowski = pickle.load(
        (data_dir / 'minkowski.pkl').open(mode='rb'))

    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        drop_last=True,
                        pin_memory=True)
    iterations = epochs * len(loader)
    img_size = dataset.shape[-1]
    num_channels = dataset.shape[0]

    # networks
    net_g = Generator(img_size=img_size,
                      z_dim=z_dim,
                      num_channels=num_channels,
                      num_filters=g_filters,
                      num_extra_layers=g_extra_layers).to(device)
    net_d = Discriminator(img_size=img_size,
                          num_channels=num_channels,
                          num_filters=d_filters,
                          num_extra_layers=d_extra_layers).to(device)
    summary(net_g, (z_dim, 1, 1, 1))
    summary(net_d, (num_channels, img_size, img_size, img_size))

    if saved_g:
        net_g.load_state_dict(torch.load(experiment_dir / G_CHECKPOINT_NAME))
        logger.info("Loaded generator checkpoint")
    if saved_d:
        net_d.load_state_dict(torch.load(experiment_dir / D_CHECKPOINT_NAME))
        logger.info("Loaded discriminator checkpoint")

    # criterion
    criterion = nn.BCELoss()

    optimizer_g = optim.Adam(net_g.parameters(),
                             lr=learning_rate,
                             betas=(beta_1, 0.999))
    optimizer_d = optim.Adam(net_d.parameters(),
                             lr=learning_rate,
                             betas=(beta_1, 0.999))

    patience = int(3000 / len(loader))
    scheduler_g = optim.lr_scheduler.ReduceLROnPlateau(optimizer_g,
                                                       min_lr=1e-6,
                                                       verbose=True,
                                                       patience=patience)
    scheduler_d = optim.lr_scheduler.ReduceLROnPlateau(optimizer_d,
                                                       min_lr=1e-6,
                                                       verbose=True,
                                                       patience=patience)

    # labels smoothing
    real_labels = torch.full((batch_size, ), fill_value=0.9, device=device)
    fake_labels = torch.zeros((batch_size, ), device=device)
    fixed_noise = torch.randn(1, z_dim, 1, 1, 1, device=device)

    def step(engine: Engine, batch: torch.Tensor) -> Dict[str, float]:
        """
        Train step function

        :param engine: pytorch ignite train engine
        :param batch: batch to process
        :return batch metrics
        """
        # get batch of fake images from generator
        fake_batch = net_g(
            torch.randn(batch_size, z_dim, 1, 1, 1, device=device))
        # 1. Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        batch = batch.to(device)
        optimizer_d.zero_grad()
        # train D with real and fake batches
        d_out_real = net_d(batch)
        d_out_fake = net_d(fake_batch.detach())
        loss_d_real = criterion(d_out_real, real_labels)
        loss_d_fake = criterion(d_out_fake, fake_labels)
        # mean probabilities
        p_real = d_out_real.mean().item()
        p_fake = d_out_fake.mean().item()

        loss_d = (loss_d_real + loss_d_fake) / 2
        loss_d.backward()
        optimizer_d.step()

        # 2. Update G network: maximize log(D(G(z)))
        loss_g = None
        p_gen = None
        for _ in range(1):
            fake_batch = net_g(
                torch.randn(batch_size, z_dim, 1, 1, 1, device=device))
            optimizer_g.zero_grad()
            d_out_fake = net_d(fake_batch)
            loss_g = criterion(d_out_fake, real_labels)
            # mean fake generator probability
            p_gen = d_out_fake.mean().item()
            loss_g.backward()
            optimizer_g.step()

        # minkowski functional measures
        cube = net_g(fixed_noise).detach().squeeze().cpu()
        cube = cube.mul(0.5).add(0.5).numpy()
        cube = postprocess_cube(cube)
        cube = np.pad(cube, ((1, 1), (1, 1), (1, 1)),
                      mode='constant',
                      constant_values=0)
        v, s, b, xi = compute_minkowski(cube)
        return {
            'loss_d': loss_d.item(),
            'loss_g': loss_g.item(),
            'p_real': p_real,
            'p_fake': p_fake,
            'p_gen': p_gen,
            'V': v,
            'S': s,
            'B': b,
            'Xi': xi
        }

    # ignite objects
    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(dirname=str(experiment_dir),
                                         filename_prefix=CKPT_PREFIX,
                                         save_interval=5,
                                         n_saved=50,
                                         require_empty=False)

    # attach running average metrics
    monitoring_metrics = [
        'loss_d', 'loss_g', 'p_real', 'p_fake', 'p_gen', 'V', 'S', 'B', 'Xi'
    ]
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['loss_d']).attach(
        trainer, 'loss_d')
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['loss_g']).attach(
        trainer, 'loss_g')
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_real']).attach(
        trainer, 'p_real')
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_fake']).attach(
        trainer, 'p_fake')
    RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_gen']).attach(
        trainer, 'p_gen')
    RunningAverage(alpha=ALPHA,
                   output_transform=lambda x: x['V']).attach(trainer, 'V')
    RunningAverage(alpha=ALPHA,
                   output_transform=lambda x: x['S']).attach(trainer, 'S')
    RunningAverage(alpha=ALPHA,
                   output_transform=lambda x: x['B']).attach(trainer, 'B')
    RunningAverage(alpha=ALPHA,
                   output_transform=lambda x: x['Xi']).attach(trainer, 'Xi')

    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_logs(engine):
        if (engine.state.iteration - 1) % PRINT_FREQ == 0:
            fname = experiment_dir / LOGS_FNAME
            columns = ['iter'] + list(engine.state.metrics.keys())
            values = [str(engine.state.iteration)] + [
                str(round(value, 7))
                for value in engine.state.metrics.values()
            ]

            with fname.open(mode='a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

            message = f"[{engine.state.epoch}/{epochs}][{engine.state.iteration:04d}/{iterations}]"
            for name, value in zip(engine.state.metrics.keys(),
                                   engine.state.metrics.values()):
                message += f" | {name}: {value:0.5f}"

            pbar.log_message(message)

    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'net_g': net_g,
                                  'net_d': net_d
                              })

    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        df = pd.read_csv(experiment_dir / LOGS_FNAME, delimiter='\t')

        fig_1 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['loss_d'], label='loss_d', linestyle='dashed')
        plt.plot(df['iter'], df['loss_g'], label='loss_g')
        plt.xlabel('Iteration number')
        plt.legend()
        fig_1.savefig(experiment_dir / ('loss_' + PLOT_FNAME))
        plt.close(fig_1)

        fig_2 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['p_real'], label='p_real', linestyle='dashed')
        plt.plot(df['iter'], df['p_fake'], label='p_fake', linestyle='dashdot')
        plt.plot(df['iter'], df['p_gen'], label='p_gen')
        plt.xlabel('Iteration number')
        plt.legend()
        fig_2.savefig(experiment_dir / PLOT_FNAME)
        plt.close(fig_2)

        desired_v = [desired_minkowski[0]] * len(df['iter'])
        desired_s = [desired_minkowski[1]] * len(df['iter'])
        desired_b = [desired_minkowski[2]] * len(df['iter'])
        desired_xi = [desired_minkowski[3]] * len(df['iter'])

        fig_3 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['V'], label='V', color='b')
        plt.plot(df['iter'], desired_v, color='b', linestyle='dashed')
        plt.xlabel('Iteration number')
        plt.ylabel('Minkowski functional V')
        plt.legend()
        fig_3.savefig(experiment_dir / ('minkowski_V_' + PLOT_FNAME))
        plt.close(fig_3)

        fig_4 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['S'], label='S', color='r')
        plt.plot(df['iter'], desired_s, color='r', linestyle='dashed')
        plt.xlabel('Iteration number')
        plt.ylabel('Minkowski functional S')
        plt.legend()
        fig_4.savefig(experiment_dir / ('minkowski_S_' + PLOT_FNAME))
        plt.close(fig_4)

        fig_5 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['B'], label='B', color='g')
        plt.plot(df['iter'], desired_b, color='g', linestyle='dashed')
        plt.xlabel('Iteration number')
        plt.ylabel('Minkowski functional B')
        plt.legend()
        fig_5.savefig(experiment_dir / ('minkowski_B_' + PLOT_FNAME))
        plt.close(fig_5)

        fig_6 = plt.figure(figsize=(18, 12))
        plt.plot(df['iter'], df['Xi'], label='Xi', color='y')
        plt.plot(df['iter'], desired_xi, color='y', linestyle='dashed')
        plt.xlabel('Iteration number')
        plt.ylabel('Minkowski functional Xi')
        plt.legend()
        fig_6.savefig(experiment_dir / ('minkowski_Xi_' + PLOT_FNAME))
        plt.close(fig_6)

    if scheduler:

        @trainer.on(Events.EPOCH_COMPLETED)
        def lr_scheduler(engine):
            desired_b = desired_minkowski[2]
            desired_xi = desired_minkowski[3]

            current_b = engine.state.metrics['B']
            current_xi = engine.state.metrics['Xi']

            delta = abs(desired_b - current_b) + abs(desired_xi - current_xi)

            scheduler_d.step(delta)
            scheduler_g.step(delta)

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')

            create_plots(engine)
            checkpoint_handler(engine, {
                'net_g_exception': net_g,
                'net_d_exception': net_d
            })
        else:
            raise e

    trainer.run(loader, epochs)
Exemplo n.º 22
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz",
    ]
    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0])
    train_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[:10], labels[:10])]
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[-10:], labels[-10:])]

    # define transforms for image
    train_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["label"])

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    net = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    )
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # Ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):

        return _prepare_batch((batch["img"], batch["label"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "net": net,
                                  "opt": opt
                              })

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {
        metric_name: Accuracy(),
        "AUC": ROCAUC(to_onehot_y=True, add_softmax=True)
    }
    # Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
Exemplo n.º 23
0
def add_handlers(trainer, evaluator, data, models_dict, cfg):
    """

    :param trainer: ignite trainer object
    :param evaluator: ignite evaluator object
    :param data: tuple containing train and test dataloader
    :param models_dict: dict containing all models & optimizers to save
    :param cfg: configuration dict
    """
    train_loader, test_loader = data

    # add progressbar
    progbar = ProgressBar(trainer, train_loader)
    trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                              handler=progbar)

    # initialize checkpoint savings funtion
    checkpoint = ModelCheckpoint(cfg.DIRS.CHKP_DIR,
                                 cfg.DIRS.CHKP_PREFIX,
                                 require_empty=False,
                                 save_interval=1,
                                 n_saved=100000,
                                 save_as_state_dict=True)

    writer = Tensorboard(create_dir(cfg.DIRS.CHKP_DIR, 'summaries'))
    trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                              handler=writer)
    evaluator.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                                handler=writer)

    # if models were loaded, resume training from left off epoch
    # otherwise start at epoch 0.
    @trainer.on(Events.STARTED)
    def epoch_start(engine):

        if (cfg.SOLVER.RESUME_EPOCH != '0') and \
                (cfg.SOLVER.COMPLETED_EPOCHS != 0):
            mess = 'TRAINING COMPLETE' if \
                ((cfg.SOLVER.COMPLETED_EPOCHS - cfg.SOLVER.EPOCHS) >= 0) \
                else 'RESUME TRAINING'

            engine.state.iteration = cfg.SOLVER.TRAINER_ITERATION
            engine.state.epoch = checkpoint._iteration = cfg.SOLVER.COMPLETED_EPOCHS

            print(' --- LOADED MODEL FOR EPOCH: {comp_epochs} / {conf} ---\
                 \n --------- {mess} ---------'.format(
                comp_epochs=cfg.SOLVER.COMPLETED_EPOCHS,
                conf=cfg.SOLVER.EPOCHS,
                mess=mess))

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_models(engine):

        cfg.SOLVER.COMPLETED_EPOCHS = engine.state.epoch
        cfg.SOLVER.TRAINER_ITERATION = engine.state.iteration

        if cfg.SOLVER.COMPLETED_EPOCHS % cfg.MODEL.SAVE_INTERVAL == 0:
            # checkpoint only counts nr of checkpoint calls, not epochs
            checkpoint._iteration = cfg.SOLVER.COMPLETED_EPOCHS - 1
            checkpoint(engine, models_dict)
            save_ignite_params(engine, engine_name='trainer', cfg=cfg)

    @trainer.on(Events.EPOCH_COMPLETED)
    def classification_validation(engine):
        cfg.RESULTS.LATENTS, cfg.RESULTS.CLF_ACC, cfg.RESULTS.MEAN_DISTANCE, \
        cfg.RESULTS.SMOOTHNESS, cfg.RESULTS.CLUSTER_ACC = [], [], [], [], []

        print('--- Evaluating model on validation set ---')
        evaluator.run(test_loader)

        clf_acc_str = calc_mean_non_empty('clf_acc', cfg.RESULTS.CLF_ACC)
        cluster_acc_str = calc_mean_non_empty('cluster_acc',
                                              cfg.RESULTS.CLUSTER_ACC)
        mean_dist_str = calc_mean_non_empty('mean_distance',
                                            cfg.RESULTS.MEAN_DISTANCE)
        smoothness_str = calc_mean_non_empty('smoothness',
                                             cfg.RESULTS.SMOOTHNESS)

        print('{clf}{cluster}{mean_dist}{smooth}'.format(
            clf=clf_acc_str,
            cluster=cluster_acc_str,
            mean_dist=mean_dist_str,
            smooth=smoothness_str))

    @evaluator.on(Events.STARTED)
    def continue_validation(engine):
        # create dict keys at first epoch
        if 'EVAL_ITERATION' not in cfg.SOLVER.keys():
            cfg.SOLVER.EVAL_ITERATION = 0
            cfg.SOLVER.EVAL_EPOCH = 0
        else:
            # after each iteration, dict gets updated
            # needed so that it continues counter at start of
            # every eval run (bc for every validation run
            # the evaluator is newly initialized)
            engine.state.iteration = cfg.SOLVER.EVAL_ITERATION

            # always put to 0 to run evaluation once
            # can't specify max epochs bc it would run
            # validation set several times after each other
            engine.state.epoch = 0

    @evaluator.on(Events.EPOCH_COMPLETED)
    def save_eval_state(engine):
        # save iteration after each validation run
        # continue at this counter for next run
        # (as number otherwise resets),
        # save number of eval epochs for saving/loading params
        cfg.SOLVER.EVAL_ITERATION = engine.state.iteration
        cfg.SOLVER.EVAL_EPOCH += 1

        if cfg.SOLVER.COMPLETED_EPOCHS % cfg.MODEL.SAVE_INTERVAL == 0:
            save_ignite_params(engine, engine_name='eval', cfg=cfg)

    return trainer, evaluator
Exemplo n.º 24
0
def run(args):
    train_loader, val_loader = get_data_loaders(args.dir, args.batch_size,
                                                args.num_workers)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    num_classes = CityscapesDataset.num_instance_classes() + 1
    model = models.box2pix(num_classes=num_classes)
    model.init_from_googlenet()

    writer = create_summary_writer(model, train_loader, args.log_dir)

    if torch.cuda.device_count() > 1:
        print("Using %d GPU(s)" % torch.cuda.device_count())
        model = nn.DataParallel(model)

    model = model.to(device)

    semantics_criterion = nn.CrossEntropyLoss(ignore_index=255)
    offsets_criterion = nn.MSELoss()
    box_criterion = BoxLoss(num_classes, gamma=2)
    multitask_criterion = MultiTaskLoss().to(device)

    box_coder = BoxCoder()
    optimizer = optim.Adam([{
        'params': model.parameters(),
        'weight_decay': 5e-4
    }, {
        'params': multitask_criterion.parameters()
    }],
                           lr=args.lr)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            multitask_criterion.load_state_dict(checkpoint['multitask'])
            print("Loaded checkpoint '{}' (Epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    def _prepare_batch(batch, non_blocking=True):
        x, instance, boxes, labels = batch

        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(instance,
                               device=device,
                               non_blocking=non_blocking),
                convert_tensor(boxes, device=device,
                               non_blocking=non_blocking),
                convert_tensor(labels,
                               device=device,
                               non_blocking=non_blocking))

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, instance, boxes, labels = _prepare_batch(batch)
        boxes, labels = box_coder.encode(boxes, labels)

        loc_preds, conf_preds, semantics_pred, offsets_pred = model(x)

        semantics_loss = semantics_criterion(semantics_pred, instance)
        offsets_loss = offsets_criterion(offsets_pred, instance)
        box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds,
                                            labels)

        loss = multitask_criterion(semantics_loss, offsets_loss, box_loss,
                                   conf_loss)

        loss.backward()
        optimizer.step()

        return {
            'loss': loss.item(),
            'loss_semantics': semantics_loss.item(),
            'loss_offsets': offsets_loss.item(),
            'loss_ssdbox': box_loss.item(),
            'loss_ssdclass': conf_loss.item()
        }

    trainer = Engine(_update)

    checkpoint_handler = ModelCheckpoint(args.output_dir,
                                         'checkpoint',
                                         save_interval=1,
                                         n_saved=10,
                                         require_empty=False,
                                         create_dir=True,
                                         save_as_state_dict=False)
    timer = Timer(average=True)

    # attach running average metrics
    train_metrics = [
        'loss', 'loss_semantics', 'loss_offsets', 'loss_ssdbox',
        'loss_ssdclass'
    ]
    for m in train_metrics:
        transform = partial(lambda x, metric: x[metric], metric=m)
        RunningAverage(output_transform=transform).attach(trainer, m)

    # attach progress bar
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=train_metrics)

    checkpoint = {
        'model': model.state_dict(),
        'epoch': trainer.state.epoch,
        'optimizer': optimizer.state_dict(),
        'multitask': multitask_criterion.state_dict()
    }
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={'checkpoint': checkpoint})

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            x, instance, boxes, labels = _prepare_batch(batch)
            loc_preds, conf_preds, semantics, offsets_pred = model(x)
            boxes_preds, labels_preds, scores_preds = box_coder.decode(
                loc_preds, F.softmax(conf_preds, dim=1), score_thresh=0.01)

            semantics_loss = semantics_criterion(semantics, instance)
            offsets_loss = offsets_criterion(offsets_pred, instance)
            box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds,
                                                labels)

            semantics_pred = semantics.argmax(dim=1)
            instances = helper.assign_pix2box(semantics_pred, offsets_pred,
                                              boxes_preds, labels_preds)

        return {
            'loss': (semantics_loss, offsets_loss, {
                'box_loss': box_loss,
                'conf_loss': conf_loss
            }),
            'objects':
            (boxes_preds, labels_preds, scores_preds, boxes, labels),
            'semantics':
            semantics_pred,
            'instances':
            instances
        }

    train_evaluator = Engine(_inference)
    Loss(multitask_criterion,
         output_transform=lambda x: x['loss']).attach(train_evaluator, 'loss')
    MeanAveragePrecision(num_classes,
                         output_transform=lambda x: x['objects']).attach(
                             train_evaluator, 'objects')
    IntersectionOverUnion(num_classes,
                          output_transform=lambda x: x['semantics']).attach(
                              train_evaluator, 'semantics')

    evaluator = Engine(_inference)
    Loss(multitask_criterion,
         output_transform=lambda x: x['loss']).attach(evaluator, 'loss')
    MeanAveragePrecision(num_classes,
                         output_transform=lambda x: x['objects']).attach(
                             evaluator, 'objects')
    IntersectionOverUnion(num_classes,
                          output_transform=lambda x: x['semantics']).attach(
                              evaluator, 'semantics')

    @trainer.on(Events.STARTED)
    def initialize(engine):
        if args.resume:
            engine.state.epoch = args.start_epoch

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            "Epoch [{}/{}] done. Time per batch: {:.3f}[s]".format(
                engine.state.epoch, engine.state.max_epochs, timer.value()))
        timer.reset()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iteration = (engine.state.iteration - 1) % len(train_loader) + 1
        if iteration % args.log_interval == 0:
            writer.add_scalar("training/loss", engine.state.output['loss'],
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        loss = metrics['loss']
        mean_ap = metrics['objects']
        iou = metrics['semantics']

        pbar.log_message(
            'Training results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
            .format(loss, evaluator.state.epochs, evaluator.state.max_epochs,
                    mean_ap, iou * 100.0))

        writer.add_scalar("train-val/loss", loss, engine.state.epoch)
        writer.add_scalar("train-val/mAP", mean_ap, engine.state.epoch)
        writer.add_scalar("train-val/IoU", iou, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        mean_ap = metrics['objects']
        iou = metrics['semantics']

        pbar.log_message(
            'Validation results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
            .format(loss, evaluator.state.epochs, evaluator.state.max_epochs,
                    mean_ap, iou * 100.0))

        writer.add_scalar("validation/loss", loss, engine.state.epoch)
        writer.add_scalar("validation/mAP", mean_ap, engine.state.epoch)
        writer.add_scalar("validation/IoU", iou, engine.state.epoch)

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            checkpoint_handler(engine, {'model_exception': model})
        else:
            raise e

    @trainer.on(Events.COMPLETED)
    def save_final_model(engine):
        checkpoint_handler(engine, {'final': model})

    trainer.run(train_loader, max_epochs=args.epochs)
    writer.close()
Exemplo n.º 25
0
def train():
    logger.info('*' * 64)
    logger.info('token:%s' % current_time)
    logger.info('*' * 64)

    parser = ArgumentParser()
    parser.add_argument(
        "--train_file",
        type=str,
        default="./my_test/data/student/part1.txt",
        help="Path or url of the dataset. If empty download from S3.")

    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./cache/',
                        help="Path or url of the dataset cache")
    parser.add_argument("--batch_size",
                        type=int,
                        default=2,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=1,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-4,
                        help="Learning rate")
    # parser.add_argument("--train_precent", type=float, default=0.7, help="Batch size for validation")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=1,
                        help="Number of training epochs")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    # parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--log_step",
                        type=int,
                        default=1,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--base_model", type=str, default="bert-base-uncased")
    parser.add_argument(
        "--on_memory",
        action='store_true',
        help="Whether to load train samples into memory or use disk")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )

    args = parser.parse_args()
    logger.info(args)
    device = torch.device(args.device)
    tokenizer = BertTokenizer.from_pretrained(args.base_model)

    train_dataset = BERTDataset(args.train_file,
                                tokenizer,
                                seq_len=args.max_seq_length,
                                corpus_lines=None,
                                on_memory=args.on_memory)
    train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    model = BertForPreTraining.from_pretrained(args.base_model)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    steps = len(train_data_loader.dataset) // train_data_loader.batch_size
    steps = steps if steps > 0 else 1
    logger.info('steps:%d' % steps)

    lr_warmup = get_cosine_schedule_with_warmup(optimizer=optimizer,
                                                num_warmup_steps=1500,
                                                num_training_steps=steps *
                                                args.n_epochs)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        gpu_num = torch.cuda.device_count()
        gpu_list = [int(i) for i in range(gpu_num)]
        model = DataParallel(model, device_ids=gpu_list)
        multi_gpu = True

    if torch.cuda.is_available():
        model.cuda()

    # model.to(device)
    # criterion.to(device)

    def update(engine, batch):
        model.train()
        # input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
        """
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        masked_lm_labels=None,
        next_sentence_label=None,
        """
        # loss = model(input_ids=batch[0],input_mask=batch[1],segment_ids=batch[2],lm_label_ids=batch[3],is_next=batch[4])

        loss = model(input_ids=batch[0],
                     attention_mask=batch[1],
                     position_ids=batch[2],
                     masked_lm_labels=batch[3],
                     next_sentence_label=batch[4])

        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        lr_warmup.step()
        if multi_gpu:
            loss = loss.mean()
        loss.backward()

        return loss.cpu().item()

    trainer = Engine(update)

    # def inference(engine, batch):
    #     model.eval()
    #     with torch.no_grad():
    #         input_ids = batch[0].to(device)
    #         attention_mask = batch[1].to(device)
    #         labels = batch[2].to(device)
    #         output = model(input_ids=input_ids, attention_mask=attention_mask)
    #
    #         predict = output.permute(1, 2, 0)
    #         trg = labels.permute(1, 0)
    #         loss = criterion(predict.to(device), trg.to(device))
    # return predict, trg
    #
    # evaluator = Engine(inference)
    # metrics = {"nll": Loss(criterion, output_transform=lambda x: (x[0], x[1])),
    #            "accuracy": Accuracy(output_transform=lambda x: (x[0], x[1]))}
    # for name, metric in metrics.items():
    #     metric.attach(evaluator, name)
    #
    # @trainer.on(Events.EPOCH_COMPLETED)
    # def log_validation_results(trainer):
    #     evaluator.run(valid_data_loader)
    #     ms = evaluator.state.metrics
    #     logger.info("Validation Results - Epoch: [{}/{}]  Avg accuracy: {:.6f} Avg loss: {:.6f}"
    #           .format(trainer.state.epoch, trainer.state.max_epochs, ms['accuracy'], ms['nll']))

    #
    '''======================early stopping =========================='''
    # def score_function(engine):
    #     val_loss = engine.state.metrics['nll']
    #     return -val_loss
    # handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)
    # evaluator.add_event_handler(Events.COMPLETED, handler)
    '''==================print information by iterator========================='''

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(trainer):
        if trainer.state.iteration % args.log_step == 0:
            logger.info("Epoch[{}/{}] Step[{}/{}] Loss: {:.6f}".format(
                trainer.state.epoch, trainer.state.max_epochs,
                trainer.state.iteration % steps, steps,
                trainer.state.output * args.gradient_accumulation_steps))

    '''================add check point========================'''
    checkpoint_handler = ModelCheckpoint(checkpoint_dir,
                                         'checkpoint',
                                         n_saved=3)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED, checkpoint_handler,
        {'BertClassificationModel': getattr(model, 'module', model)
         })  # "getattr" take care of distributed encapsulation
    '''==============run trainer============================='''
    trainer.run(train_data_loader, max_epochs=args.n_epochs)
Exemplo n.º 26
0
    def train(self, config, **kwargs):
        """Trains a model on the given configurations.
        :param config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE
        :param **kwargs: parameters to overwrite yaml config
        """
        from pycocoevalcap.cider.cider import Cider

        config_parameters = train_util.parse_config_or_kwargs(config, **kwargs)
        config_parameters["seed"] = self.seed
        outputdir = os.path.join(
            config_parameters["outputpath"], config_parameters["model"],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))

        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            "run",
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: engine.state.metrics["score"],
            score_name="score")

        logger = train_util.genlogger(os.path.join(outputdir, "train.log"))
        # print passed config parameters
        logger.info("Storing files in: {}".format(outputdir))
        train_util.pprint_dict(config_parameters, logger.info)

        zh = config_parameters["zh"]
        vocabulary = torch.load(config_parameters["vocab_file"])
        train_loader, cv_loader, info = self._get_dataloaders(
            config_parameters, vocabulary)
        config_parameters["inputdim"] = info["inputdim"]
        cv_key2refs = info["cv_key2refs"]
        logger.info("<== Estimating Scaler ({}) ==>".format(
            info["scaler"].__class__.__name__))
        logger.info("Feature: {} Input dimension: {} Vocab Size: {}".format(
            config_parameters["feature_file"], info["inputdim"],
            len(vocabulary)))

        model = self._get_model(config_parameters, len(vocabulary))
        if "pretrained_word_embedding" in config_parameters:
            embeddings = np.load(
                config_parameters["pretrained_word_embedding"])
            model.load_word_embeddings(
                embeddings,
                tune=config_parameters["tune_word_embedding"],
                projection=True)
        model = model.to(self.device)
        train_util.pprint_dict(model, logger.info, formatter="pretty")
        optimizer = getattr(torch.optim, config_parameters["optimizer"])(
            model.parameters(), **config_parameters["optimizer_args"])
        train_util.pprint_dict(optimizer, logger.info, formatter="pretty")

        criterion = torch.nn.CrossEntropyLoss().to(self.device)
        crtrn_imprvd = train_util.criterion_improver(
            config_parameters['improvecriterion'])

        def _train_batch(engine, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(model, batch, "train")
                loss = criterion(output["packed_logits"],
                                 output["targets"]).to(self.device)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()
                output["loss"] = loss.item()
                return output

        trainer = Engine(_train_batch)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(
            trainer, "running_loss")
        pbar = ProgressBar(persist=False, ascii=True, ncols=100)
        pbar.attach(trainer, ["running_loss"])

        key2pred = {}

        def _inference(engine, batch):
            model.eval()
            keys = batch[2]
            with torch.no_grad():
                output = self._forward(model, batch, "validation")
                seqs = output["seqs"].cpu().numpy()
                for (idx, seq) in enumerate(seqs):
                    if keys[idx] in key2pred:
                        continue
                    candidate = self._convert_idx2sentence(seq, vocabulary, zh)
                    key2pred[keys[idx]] = [
                        candidate,
                    ]
                return output

        metrics = {
            "loss":
            Loss(criterion,
                 output_transform=lambda x: (x["packed_logits"], x["targets"]))
        }

        evaluator = Engine(_inference)

        def eval_cv(engine, key2pred, key2refs):
            scorer = Cider(zh=zh)
            score, scores = scorer.compute_score(key2refs, key2pred)
            engine.state.metrics["score"] = score
            key2pred.clear()

        evaluator.add_event_handler(Events.EPOCH_COMPLETED, eval_cv, key2pred,
                                    cv_key2refs)

        for name, metric in metrics.items():
            metric.attach(evaluator, name)

        trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                  train_util.log_results, evaluator, cv_loader,
                                  logger.info, ["loss", "score"])

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, train_util.save_model_on_improved,
            crtrn_imprvd, "score", {
                "model": model.state_dict(),
                "config": config_parameters,
                "scaler": info["scaler"]
            }, os.path.join(outputdir, "saved.pth"))

        scheduler = getattr(torch.optim.lr_scheduler,
                            config_parameters["scheduler"])(
                                optimizer,
                                **config_parameters["scheduler_args"])
        evaluator.add_event_handler(Events.EPOCH_COMPLETED,
                                    train_util.update_lr, scheduler, "score")

        evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                    {
                                        "model": model,
                                    })

        trainer.run(train_loader, max_epochs=config_parameters["epochs"])
        return outputdir
    def fit(self, dataset, fold=0, train_split='train', valid_split='val'):
        """Fit the predictor model.

    Args:
      - dataset: temporal, static, label, time, treatment information
      - fold: Cross validation fold
      - train_split: training set splitting parameter
      - valid_split: validation set splitting parameter

    Returns:
      - self.predictor_model: trained predictor model
    """
        train_x, train_y = self._data_preprocess(dataset, fold, train_split)
        valid_x, valid_y = self._data_preprocess(dataset, fold, valid_split)

        train_dataset = torch.utils.data.dataset.TensorDataset(
            self._make_tensor(train_x), self._make_tensor(train_y))
        valid_dataset = torch.utils.data.dataset.TensorDataset(
            self._make_tensor(valid_x), self._make_tensor(valid_y))

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=self.batch_size,
                                                   shuffle=True)
        val_loader = torch.utils.data.DataLoader(valid_dataset,
                                                 batch_size=self.batch_size,
                                                 shuffle=True)

        if self.predictor_model is None:
            self.predictor_model = TransformerModule(
                self.task, dataset.problem, train_x.shape[-1], self.h_dim,
                train_y.shape[-1], self.n_head, self.n_layer).to(self.device)
            self.optimizer = torch.optim.Adam(
                self.predictor_model.parameters(), lr=self.learning_rate)

        self.predictor_model.train()

        # classification vs regression
        # static vs dynamic
        trainer = create_supervised_trainer(self.predictor_model,
                                            self.optimizer,
                                            self.predictor_model.loss_fn)
        evaluator = create_supervised_evaluator(
            self.predictor_model,
            metrics={'loss': Loss(self.predictor_model.loss_fn)})
        # model check point
        checkpoint_handler = ModelCheckpoint(self.model_path,
                                             self.model_id,
                                             n_saved=1,
                                             create_dir=True,
                                             require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  checkpoint_handler,
                                  {'model': self.predictor_model})

        # early stopping
        def score_function(engine):
            val_loss = engine.state.metrics['loss']
            return -val_loss

        early_stopping_handler = EarlyStopping(patience=10,
                                               score_function=score_function,
                                               trainer=trainer)
        evaluator.add_event_handler(Events.COMPLETED, early_stopping_handler)

        # evaluation loss
        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(trainer):
            evaluator.run(val_loader)
            metrics = evaluator.state.metrics
            print("Validation Results - Epoch[{}] Avg loss: {:.2f}".format(
                trainer.state.epoch, metrics['loss']))

        trainer.run(train_loader, max_epochs=self.epoch)

        return self.predictor_model
Exemplo n.º 28
0
def run(output_path, config):
    device = "cuda"

    local_rank = config['local_rank']
    distributed = backend is not None
    if distributed:
        torch.cuda.set_device(local_rank)
        device = "cuda"
    rank = dist.get_rank() if distributed else 0

    # Rescale batch_size and num_workers
    ngpus_per_node = torch.cuda.device_count()
    ngpus = dist.get_world_size() if distributed else 1
    batch_size = config['batch_size'] // ngpus
    num_workers = int(
        (config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)

    train_labelled_loader, test_loader = \
        get_train_test_loaders(path=config['data_path'],
                               batch_size=batch_size,
                               distributed=distributed,
                               num_workers=num_workers)

    model = get_model(config['model'])
    model = model.to(device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[
                local_rank,
            ], output_device=local_rank)

    optimizer = optim.SGD(model.parameters(),
                          lr=config['learning_rate'],
                          momentum=config['momentum'],
                          weight_decay=config['weight_decay'],
                          nesterov=True)

    criterion = nn.CrossEntropyLoss().to(device)

    le = len(train_labelled_loader)
    milestones_values = [(0, 0.0),
                         (le * config['num_warmup_epochs'],
                          config['learning_rate']),
                         (le * config['num_epochs'], 0.0)]
    lr_scheduler = PiecewiseLinear(optimizer,
                                   param_name="lr",
                                   milestones_values=milestones_values)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    def process_function(engine, labelled_batch):

        x, y = _prepare_batch(labelled_batch, device=device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            'batch loss': loss.item(),
        }

    trainer = Engine(process_function)

    if not hasattr(lr_scheduler, "step"):
        trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
    else:
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  lambda engine: lr_scheduler.step())

    metric_names = [
        'batch loss',
    ]

    def output_transform(x, name):
        return x[name]

    for n in metric_names:
        # We compute running average values on the output (batch loss) across all devices
        RunningAverage(output_transform=partial(output_transform, name=n),
                       epoch_bound=False,
                       device=device).attach(trainer, n)

    if rank == 0:
        checkpoint_handler = ModelCheckpoint(dirname=output_path,
                                             filename_prefix="checkpoint")
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000),
                                  checkpoint_handler, {
                                      'model': model,
                                      'optimizer': optimizer
                                  })

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
        if config['display_iters']:
            ProgressBar(persist=False,
                        bar_format="").attach(trainer,
                                              metric_names=metric_names)

        tb_logger = TensorboardLogger(log_dir=output_path)
        tb_logger.attach(trainer,
                         log_handler=tbOutputHandler(
                             tag="train", metric_names=metric_names),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=tbOptimizerParamsHandler(optimizer,
                                                              param_name="lr"),
                         event_name=Events.ITERATION_STARTED)

    metrics = {
        "accuracy": Accuracy(device=device if distributed else None),
        "loss": Loss(criterion, device=device if distributed else None)
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        torch.cuda.synchronize()
        train_evaluator.run(train_labelled_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(Events.EPOCH_STARTED(every=3), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        if config['display_iters']:
            ProgressBar(persist=False,
                        desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False,
                        desc="Test evaluation").attach(evaluator)

        tb_logger.attach(train_evaluator,
                         log_handler=tbOutputHandler(tag="train",
                                                     metric_names=list(
                                                         metrics.keys()),
                                                     another_engine=trainer),
                         event_name=Events.COMPLETED)

        tb_logger.attach(evaluator,
                         log_handler=tbOutputHandler(tag="test",
                                                     metric_names=list(
                                                         metrics.keys()),
                                                     another_engine=trainer),
                         event_name=Events.COMPLETED)

        # Store the best model
        def default_score_fn(engine):
            score = engine.state.metrics['accuracy']
            return score

        score_function = default_score_fn if not hasattr(
            config, "score_function") else config.score_function

        best_model_handler = ModelCheckpoint(
            dirname=output_path,
            filename_prefix="best",
            n_saved=3,
            global_step_transform=global_step_from_engine(trainer),
            score_name="val_accuracy",
            score_function=score_function)
        evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {
            'model': model,
        })

    trainer.run(train_labelled_loader, max_epochs=config['num_epochs'])

    if rank == 0:
        tb_logger.close()
Exemplo n.º 29
0
def _loss_fn(i, j):
    return loss(i[0], j)


# Create trainer
device = torch.device("cuda:0")
trainer = create_supervised_trainer(net,
                                    opt,
                                    _loss_fn,
                                    device,
                                    False,
                                    output_transform=lambda x, y, y_pred, loss:
                                    [y_pred, loss.item(), y])

checkpoint_handler = ModelCheckpoint('./',
                                     'net',
                                     n_saved=10,
                                     require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                          handler=checkpoint_handler,
                          to_save={
                              'net': net,
                              'opt': opt
                          })

dice_metric = MeanDice(add_sigmoid=True,
                       output_transform=lambda output:
                       (output[0][0], output[2]))
dice_metric.attach(trainer, "Training Dice")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
stats_logger = StatsHandler()
Exemplo n.º 30
0
        print("TEST EVAL")
        evaluator.run(test_ld)
        test_wra_vle = round(evaluator.state.metrics["WRA"], 3)
        report = f"{RUN_NAME};{test_wra_vle}\n"
        with EVALUATION_RESULTS_FILE_PATH.open(mode='a') as f:
            f.writelines(report)
        print(f"TRAINING IS DONE FOR {RUN_NAME} RUN.")

    pbar = ProgressBar()

    checkpointer = ModelCheckpoint(
        CHECKPOINTS_RUN_DIR_PATH,
        filename_prefix=RUN_NAME.lower(),
        n_saved=None,
        score_function=lambda engine: round(engine.state.metrics['WRA'], 3),
        score_name='WRA',
        atomic=True,
        require_empty=True,
        create_dir=True,
        archived=False,
        global_step_transform=global_step_from_engine(trainer))
    nan_handler = TerminateOnNan()
    coslr = CosineAnnealingScheduler(opt,
                                     "lr",
                                     start_value=LR,
                                     end_value=LR / 4,
                                     cycle_size=TOTAL_UPDATE_STEPS // 1)

    evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer,
                                {'_': mude})