Esempio n. 1
0
def train_network(model: nn.Module, training_loader: DataLoader,
                  validation_loader: DataLoader):
    """Trains the given neural network model.

    Parameters
    ----------
    model (nn.Module): The PyTorch model to be trained

    training_loader (DataLoader): Training data loader

    validation_loader (DataLoader): Validation data loader
    """
    device = "cuda:0" if cast(Any, torch).cuda.is_available() else "cpu"

    if device == "cuda:0":
        model.cuda()

    optimizer = cast(Any, torch).optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)

    save_handler = Checkpoint(
        {
            "model": model,
            "optimizer": optimizer,
            "trainer": trainer
        },
        DiskSaver("dist/models", create_dir=True),
        n_saved=2,
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=100), save_handler)

    # Create a logger
    tb_logger = TensorboardLogger(log_dir="logs/training" +
                                  datetime.now().strftime("-%Y%m%d-%H%M%S"),
                                  flush_secs=1)

    tb_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED,
        tag="training",
        output_transform=lambda loss: {"loss": loss},
    )

    # Training evaluator
    training_evaluator = create_supervised_evaluator(model,
                                                     metrics={
                                                         "r2": R2Score(),
                                                         "MSELoss":
                                                         Loss(criterion)
                                                     },
                                                     device=device)

    tb_logger.attach_output_handler(
        training_evaluator,
        event_name=Events.EPOCH_COMPLETED,
        tag="training",
        metric_names=["MSELoss", "r2"],
        global_step_transform=global_step_from_engine(trainer),
    )

    # Validation evaluator
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "r2": R2Score(),
                                                "MSELoss": Loss(criterion)
                                            },
                                            device=device)

    tb_logger.attach_output_handler(
        evaluator,
        event_name=Events.EPOCH_COMPLETED,
        tag="validation",
        metric_names=["MSELoss", "r2"],
        global_step_transform=global_step_from_engine(trainer),
    )

    @trainer.on(Events.EPOCH_COMPLETED(every=10))
    def log_training_results(trainer):
        training_evaluator.run(training_loader)

        metrics = training_evaluator.state.metrics
        print(
            f"Training Results - Epoch: {trainer.state.epoch}",
            f" Avg r2: {metrics['r2']:.2f} Avg loss: {metrics['MSELoss']:.2f}",
        )

    @trainer.on(Events.EPOCH_COMPLETED(every=10))
    def log_validation_results(trainer):
        evaluator.run(validation_loader)

        metrics = evaluator.state.metrics
        print(
            f"Validation Results - Epoch: {trainer.state.epoch}",
            f" Avg r2: {metrics['r2']:.2f} Avg loss: {metrics['MSELoss']:.2f}\n",
        )

    trainer.run(training_loader, max_epochs=int(1e6))
Esempio n. 2
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    if sys.version_info > (3, ):
        from ignite.contrib.metrics.gpu_info import GpuInfo

        try:
            GpuInfo().attach(trainer)
        except RuntimeError:
            print(
                "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
                "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
                "install it : `pip install pynvml`")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    tb_logger = TensorboardLogger(log_dir=log_dir)

    tb_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
        metric_names="all",
    )

    for tag, evaluator in [("training", train_evaluator),
                           ("validation", validation_evaluator)]:
        tb_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    tb_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    tb_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    def score_function(engine):
        return engine.state.metrics["accuracy"]

    model_checkpoint = ModelCheckpoint(
        log_dir,
        n_saved=2,
        filename_prefix="best",
        score_function=score_function,
        score_name="validation_accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    tb_logger.close()
Esempio n. 3
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="data/korean/",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="gpt2",
                        help="Path, url or short name of the model")
    parser.add_argument("--model_version",
                        type=str,
                        default='v4',
                        help="version of model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=30,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=1,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=1,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=5,
                        help="Number of training epochs")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    torch.manual_seed(42)

    def get_kogpt2_tokenizer(model_path=None):
        if not model_path:
            model_path = 'taeminlee/kogpt2'
        tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        return tokenizer

    tokenizer = get_kogpt2_tokenizer()
    optimizer_class = AdamW
    model = get_kogpt2_model()
    model.to(args.device)
    optimizer = optimizer_class(model.parameters(), lr=args.lr)

    # tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint, unk_token='<|unkwn|>')
    SPECIAL_TOKENS_DICT = {'additional_special_tokens': SPECIAL_TOKENS}

    # tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    print("SPECIAL TOKENS")
    print(SPECIAL_TOKENS)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)

    for value in SPECIAL_TOKENS:
        logger.info("Assigning %s to the %s key of the tokenizer", value,
                    value)
        setattr(tokenizer, value, value)
    model.resize_token_embeddings(len(tokenizer))

    s = ' '.join(act_name) + ' '.join(slot_name)
    print(tokenizer.decode(tokenizer.encode(s)))
    print(len(act_name) + len(slot_name), len(tokenizer.encode(s)))
    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer

    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        lm_loss, mc_loss, *_ = model(*batch)
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)
    trainer.logger.setLevel(logging.INFO)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)
    evaluator.logger.setLevel(logging.INFO)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.writer.log_dir = tb_logger.writer.file_writer.get_logdir()
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        """tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()),
                                                              global_step_transform=global_step_from_engine(trainer)),
                         event_name=Events.EPOCH_COMPLETED)"""
        tb_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag="validation",
            metric_names=list(metrics.keys()),
            global_step_transform=global_step_from_engine(trainer))

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Esempio n. 4
0
def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, params):
    # Metrics
    UnitConvergence(model[0], learning_rule.norm).attach(trainer.engine, 'unit_conv')

    # Tqdm logger
    pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT)
    pbar.attach(trainer.engine, metric_names='all')
    tqdm_logger = TqdmLogger(pbar=pbar)
    # noinspection PyTypeChecker
    tqdm_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        global_step_transform=global_step_from_engine(trainer.engine),
    )

    # Evaluator
    evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED(every=100), train_loader, val_loader)

    # Learning rate scheduling
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                                     lr_lambda=lambda epoch: 1 - epoch / params['epochs'])
    lr_scheduler = LRScheduler(lr_scheduler)
    trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler)

    # Early stopping
    mc_handler = ModelCheckpoint(config.MODELS_DIR, run.replace('/', '-'), n_saved=1, create_dir=True,
                                 require_empty=False,
                                 global_step_transform=global_step_from_engine(trainer.engine))
    trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler, {'m': model})

    # Create a TensorBoard logger
    tb_logger = TensorboardLogger(log_dir=os.path.join(config.TENSORBOARD_DIR, run))
    images, labels = next(iter(train_loader))
    tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images)
    tb_logger.writer.add_hparams(params, {})

    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        trainer.engine,
        event_name=Events.EPOCH_COMPLETED,
        tag="train",
        metric_names=["unit_conv"]
    )
    input_shape = tuple(next(iter(train_loader))[0].shape[1:])
    tb_logger.attach(trainer.engine,
                     log_handler=WeightsImageHandler(model, input_shape),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer.engine, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_STARTED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=WeightsScalarHandler(model, layer_names=['linear1', 'linear2']),
    #                  event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=WeightsHistHandler(model, layer_names=['linear1', 'linear2']),
    #                  event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsHistHandler(model, layer_names=['batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=NumActivationsScalarHandler(model, layer_names=['repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.mean,
    #                                                       layer_names=['batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.std,
    #                                                       layer_names=['batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)

    return tb_logger
Esempio n. 5
0
# Create evaluators
evaluator = create_evaluator(model, metrics=metrics)
train_evaluator = create_evaluator(model, metrics=metrics, tag='train')

# Add validation logging
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), evaluate_model)

# Add step length update at the end of each epoch
trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: lr_scheduler.step())

# Add TensorBoard logging
tb_logger = TensorboardLogger(log_dir=os.path.join(working_dir,'tb_logs'))
# Logging iteration loss
tb_logger.attach_output_handler(
    engine=trainer, 
    event_name=Events.ITERATION_COMPLETED, 
    tag='training', 
    output_transform=lambda loss: {"batch loss": loss}
    )
# Logging epoch training metrics
tb_logger.attach_output_handler(
    engine=train_evaluator,
    event_name=Events.EPOCH_COMPLETED,
    tag="training",
    metric_names=["loss", "accuracy", "precision", "recall", "f1", "topKCatAcc"],
    global_step_transform=global_step_from_engine(trainer),
)
# Logging epoch validation metrics
tb_logger.attach_output_handler(
    engine=evaluator,
    event_name=Events.EPOCH_COMPLETED,
    tag="validation",
Esempio n. 6
0
def run_training(
    model,
    optimizer,
    scheduler,
    output_path,
    train_loader,
    val_loader,
    epochs,
    patience,
    epochs_pretrain,
    mixed_precision,
    classes_weights,
):

    # trainer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if classes_weights is not None:
        classes_weights = classes_weights.to(device)
    crit = nn.CrossEntropyLoss(weight=classes_weights)
    metrics = {"accuracy": Accuracy(), "loss": Loss(crit)}
    model.to(device)
    trainer = create_supervised_trainer_with_pretraining(
        model,
        optimizer,
        crit,
        device=device,
        epochs_pretrain=epochs_pretrain,
        mixed_precision=mixed_precision,
    )
    train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
    val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

    # Out paths
    path_ckpt = os.path.join(output_path, "model_ckpt")
    log_dir = os.path.join(output_path, "log_dir")
    os.makedirs(log_dir, exist_ok=True)

    # tensorboard
    tb_logger = TensorboardLogger(log_dir=log_dir)
    tb_logger.attach_output_handler(
        train_evaluator,
        event_name=Events.EPOCH_COMPLETED,
        tag="training",
        metric_names=["accuracy", "loss"],
    )
    tb_logger.attach_output_handler(
        val_evaluator,
        event_name=Events.EPOCH_COMPLETED,
        tag="validation",
        metric_names=["accuracy", "loss"],
        global_step_transform=global_step_from_engine(trainer),
    )

    # training progress
    pbar = ProgressBar(persist=True, position=0)
    pbar.attach(trainer, metric_names="all")

    def log_training_results(engine):
        train_evaluator.run(train_loader)
        val_evaluator.run(val_loader)
        train_loss = train_evaluator.state.metrics["loss"]
        val_loss = val_evaluator.state.metrics["loss"]
        train_acc = train_evaluator.state.metrics["accuracy"]
        val_acc = val_evaluator.state.metrics["accuracy"]
        pbar.log_message(
            "Training Results - Epoch: {}  Loss: {:.6f}  Accuracy: {:.6f}".format(
                engine.state.epoch, train_loss, train_acc
            )
        )
        pbar.log_message(
            "Validation Results - Epoch: {}  Loss: {:.6f}  Accuracy: {:.6f}".format(
                engine.state.epoch, val_loss, val_acc
            )
        )
        pbar.n = pbar.last_print_n = 0

    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results)

    # def get_val_loss(engine):
    # 	return -engine.state.metrics['loss']
    def get_val_acc(engine):
        return engine.state.metrics["accuracy"]

    # checkpoint and early stopping
    checkpointer = ModelCheckpoint(
        path_ckpt,
        "model",
        score_function=get_val_acc,
        score_name="accuracy",
        require_empty=False,
    )
    early_stopper = EarlyStopping(patience, get_val_acc, trainer)

    to_save = {"optimizer": optimizer, "model": model}
    if scheduler is not None:
        to_save["scheduler"] = scheduler
    val_evaluator.add_event_handler(Events.COMPLETED, checkpointer, to_save)
    val_evaluator.add_event_handler(Events.COMPLETED, early_stopper)
    if scheduler is not None:
        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # free resources
    trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda _: _empty_cache())
    train_evaluator.add_event_handler(
        Events.ITERATION_COMPLETED, lambda _: _empty_cache()
    )
    val_evaluator.add_event_handler(
        Events.ITERATION_COMPLETED, lambda _: _empty_cache()
    )

    trainer.run(train_loader, max_epochs=epochs)
    tb_logger.close()

    # Evaluation with best model
    model.load_state_dict(
        torch.load(glob.glob(os.path.join(path_ckpt, "*.pt*"))[0])["model"]
    )
    train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
    val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

    train_evaluator.run(train_loader)
    val_evaluator.run(val_loader)

    _pretty_print("Evaluating best model")
    pbar.log_message(
        "Best model on training set - Loss: {:.6f}  Accuracy: {:.6f}".format(
            train_evaluator.state.metrics["loss"],
            train_evaluator.state.metrics["accuracy"],
        )
    )
    pbar.log_message(
        "Best model on validation set - Loss: {:.6f}  Accuracy: {:.6f}".format(
            val_evaluator.state.metrics["loss"], val_evaluator.state.metrics["accuracy"]
        )
    )

    return model, train_evaluator.state.metrics, val_evaluator.state.metrics
Esempio n. 7
0
def attach_handlers(run, model, optimizer, trainer, train_evaluator, evaluator,
                    train_loader, val_loader, params):
    # Tqdm logger
    pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT)
    pbar.attach(trainer.engine, metric_names='all')
    tqdm_logger = TqdmLogger(pbar=pbar)
    # noinspection PyTypeChecker
    tqdm_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    # noinspection PyTypeChecker
    tqdm_logger.attach_output_handler(
        train_evaluator.engine,
        event_name=Events.COMPLETED,
        tag="train",
        global_step_transform=global_step_from_engine(trainer.engine),
    )

    # Evaluators
    train_evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED,
                           train_loader)
    evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED, data=val_loader)

    # Learning rate scheduling
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              'max',
                                                              verbose=True,
                                                              patience=5,
                                                              factor=0.5)
    evaluator.engine.add_event_handler(
        Events.COMPLETED,
        lambda engine: lr_scheduler.step(engine.state.metrics['accuracy']))

    # Early stopping
    es_handler = EarlyStopping(
        patience=15,
        score_function=lambda engine: engine.state.metrics['accuracy'],
        trainer=trainer.engine,
        cumulative_delta=True,
        min_delta=0.0001)
    if 'train_all' in params and params['train_all']:
        train_evaluator.engine.add_event_handler(Events.COMPLETED, es_handler)
    else:
        evaluator.engine.add_event_handler(Events.COMPLETED, es_handler)

    es_handler.logger.setLevel(logging.DEBUG)

    # Model checkpoints
    name = run.replace('/', '-')
    mc_handler = ModelCheckpoint(
        config.MODELS_DIR,
        name,
        n_saved=1,
        create_dir=True,
        require_empty=False,
        score_name='acc',
        score_function=lambda engine: engine.state.metrics['accuracy'],
        global_step_transform=global_step_from_engine(trainer.engine))
    evaluator.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler,
                                       {'m': model})

    # TensorBoard logger
    tb_logger = TensorboardLogger(
        log_dir=os.path.join(config.TENSORBOARD_DIR, run))
    images, labels = next(iter(train_loader))
    tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images)
    tb_logger.writer.add_hparams(params, {'hparam/dummy': 0})

    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        train_evaluator.engine,
        event_name=Events.COMPLETED,
        tag="train",
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    input_shape = tuple(next(iter(train_loader))[0].shape[1:])
    tb_logger.attach(trainer.engine,
                     log_handler=WeightsImageHandler(model, input_shape),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer.engine,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.EPOCH_STARTED)
    # tb_logger.attach(trainer.engine, log_handler=WeightsScalarHandler(model), event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsHistHandler(model, layer_names=['linear1', 'batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=NumActivationsScalarHandler(model, layer_names=['linear1', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.mean,
    #                                                       layer_names=['linear1', 'batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.std,
    #                                                       layer_names=['linear1', 'batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)

    return es_handler, tb_logger
Esempio n. 8
0
class Ignite_Trainer(Trainer):
    def __init__(self,
                 config=None,
                 cmd_args=None,
                 framework='ignite',
                 model=None,
                 device=None,
                 optimizer=None,
                 scheduler=None,
                 criterion=None,
                 train_loader=None,
                 val_loader=None,
                 data_transforms=None):
        super().__init__(config=config,
                         cmd_args=cmd_args,
                         framework=framework,
                         model=model,
                         device=device,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         criterion=criterion,
                         train_loader=train_loader,
                         val_loader=val_loader,
                         data_transforms=data_transforms)

        self.train_engine = None
        self.evaluator = None
        self.train_evaluator = None
        self.tb_logger = None

    def create_trainer(self):

        # Define any training logic for iteration update
        def train_step(engine, batch):

            # Get the images and labels for this batch
            x, y = batch[0].to(self.device), batch[1].to(self.device)

            # Set the model into training mode
            self.model.train()

            # Zero paramter gradients
            self.optimizer.zero_grad()

            # Update the model
            if self.config.MODEL.WITH_GRAD_SCALE:
                with autocast(enabled=self.config.MODEL.WITH_AMP):
                    y_pred = self.model(x)
                    loss = self.criterion(y_pred, y)
                scaler = GradScaler(enabled=self.config.MODEL.WITH_AMP)
                scaler.scale(loss).backward()
                scaler.step(self.optimizer)
                scaler.update()
            else:
                with torch.set_grad_enabled(True):
                    y_pred = self.model(x)
                    loss = self.criterion(y_pred, y)
                    loss.backward()
                    # With ReduceLROnPlateau, the step() call needs validation loss at the end epoch, so this is handled through an evaluator event handler rather than here.
                    if not self.config.TRAIN.SCHEDULER.TYPE == 'ReduceLROnPlateau':
                        self.optimizer.step()

            return loss.item()

        # Define trainer engine
        trainer = Engine(train_step)

        return trainer

    def create_evaluator(self, metrics, tag='val'):

        # Evaluation step function
        @torch.no_grad()
        def evaluate_step(engine: Engine, batch):
            self.model.eval()
            x, y = batch[0].to(self.device), batch[1].to(self.device)
            if self.config.MODEL.WITH_GRAD_SCALE:
                with autocast(enabled=self.config.MODEL.WITH_AMP):
                    y_pred = self.model(x)
            else:
                y_pred = self.model(x)
            return y_pred, y

        # Create the evaluator object
        evaluator = Engine(evaluate_step)

        # Attach the metrics
        for name, metric in metrics.items():
            metric.attach(evaluator, name)

        return evaluator

    def evaluate_model(self):
        epoch = self.train_engine.state.epoch
        # Training Metrics
        train_state = self.train_evaluator.run(self.train_loader)
        tr_accuracy = train_state.metrics['accuracy']
        tr_precision = train_state.metrics['precision']
        tr_recall = train_state.metrics['recall']
        tr_f1 = train_state.metrics['f1']
        tr_topKCatAcc = train_state.metrics['topKCatAcc']
        tr_loss = train_state.metrics['loss']
        # Validation Metrics
        val_state = self.evaluator.run(self.val_loader)
        val_accuracy = val_state.metrics['accuracy']
        val_precision = val_state.metrics['precision']
        val_recall = val_state.metrics['recall']
        val_f1 = val_state.metrics['f1']
        val_topKCatAcc = val_state.metrics['topKCatAcc']
        val_loss = val_state.metrics['loss']
        print(
            "Epoch: {:0>4}  TrAcc: {:.3f} ValAcc: {:.3f} TrPrec: {:.3f} ValPrec: {:.3f} TrRec: {:.3f} ValRec: {:.3f} TrF1: {:.3f} ValF1: {:.3f} TrTopK: {:.3f} ValTopK: {:.3f} TrLoss: {:.3f} ValLoss: {:.3f}"
            .format(epoch, tr_accuracy, val_accuracy, tr_precision,
                    val_precision, tr_recall, val_recall, tr_f1, val_f1,
                    tr_topKCatAcc, val_topKCatAcc, tr_loss, val_loss))

    def add_logging(self):

        # Add validation logging
        self.train_engine.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                            self.evaluate_model)

        # Add step length update at the end of each epoch
        self.train_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                            lambda _: self.scheduler.step())

    def add_tensorboard_logging(self, logging_dir=None):

        # Add TensorBoard logging
        if logging_dir is None:
            os.path.join(self.config.DIRS.WORKING_DIR, 'tb_logs')
        else:
            os.path.join(logging_dir, 'tb_logs')
        print('Tensorboard logging saving to:: {} ...'.format(logging_dir),
              end='')

        self.tb_logger = TensorboardLogger(log_dir=logging_dir)
        # Logging iteration loss
        self.tb_logger.attach_output_handler(
            engine=self.train_engine,
            event_name=Events.ITERATION_COMPLETED,
            tag='training',
            output_transform=lambda loss: {"batch loss": loss})
        # Logging epoch training metrics
        self.tb_logger.attach_output_handler(
            engine=self.train_evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag="training",
            metric_names=[
                "loss", "accuracy", "precision", "recall", "f1", "topKCatAcc"
            ],
            global_step_transform=global_step_from_engine(self.train_engine),
        )
        # Logging epoch validation metrics
        self.tb_logger.attach_output_handler(
            engine=self.evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag="validation",
            metric_names=[
                "loss", "accuracy", "precision", "recall", "f1", "topKCatAcc"
            ],
            global_step_transform=global_step_from_engine(self.train_engine),
        )
        # Attach the logger to the trainer to log model's weights as a histogram after each epoch
        self.tb_logger.attach(self.train_engine,
                              event_name=Events.EPOCH_COMPLETED,
                              log_handler=WeightsHistHandler(self.model))
        # Attach the logger to the trainer to log model's gradients as a histogram after each epoch
        self.tb_logger.attach(self.train_engine,
                              event_name=Events.EPOCH_COMPLETED,
                              log_handler=GradsHistHandler(self.model))
        print('Tensorboard Logging...', end='')
        print('done')

    def create_callbacks(self, best_model_only=True):

        ## SETUP CALLBACKS
        print('[INFO] Creating callback functions for training loop...',
              end='')

        # If using ReduceLROnPlateau then need to add event to handle the step() call with loss:
        if self.config.TRAIN.SCHEDULER.TYPE == 'ReduceLROnPlateau':
            self.evaluator.add_event_handler(Events.COMPLETED, self.scheduler)
        else:
            print('No checkpointing required for LR Scheduler....', end='')

        # Early Stopping - stops training if the validation loss does not decrease after 5 epochs
        handler = EarlyStopping(patience=self.config.EARLY_STOPPING_PATIENCE,
                                score_function=score_function_loss,
                                trainer=self.train_engine)
        self.evaluator.add_event_handler(Events.COMPLETED, handler)
        print('Early Stopping ({} epochs)...'.format(
            self.config.EARLY_STOPPING_PATIENCE),
              end='')

        # Model checkpointing
        self._create_ingite_model_checkpointer(best_model_only=best_model_only)

    def run(self, logging_dir=None, best_model_only=True):

        #assert self.model is not None, '[ERROR] No model object loaded. Please load a PyTorch model torch.nn object into the class object.'
        #assert (self.train_loader is not None) or (self.val_loader is not None), '[ERROR] You must specify data loaders.'

        for key in self.trainer_status.keys():
            assert self.trainer_status[
                key], '[ERROR] The {} has not been generated and you cannot proceed.'.format(
                    key)
        print('[INFO] Trainer pass OK for training.')

        # TRAIN ENGINE
        # Create the objects for training
        self.train_engine = self.create_trainer()

        # METRICS AND EVALUATION
        # Metrics - running average
        RunningAverage(output_transform=lambda x: x).attach(
            self.train_engine, 'loss')

        # Metrics - epochs
        metrics = {
            'accuracy': Accuracy(),
            'recall': Recall(average=True),
            'precision': Precision(average=True),
            'f1': Fbeta(beta=1),
            'topKCatAcc': TopKCategoricalAccuracy(k=5),
            'loss': Loss(self.criterion)
        }

        # Create evaluators
        self.evaluator = self.create_evaluator(metrics=metrics)
        self.train_evaluator = self.create_evaluator(metrics=metrics,
                                                     tag='train')

        # LOGGING
        # Create logging to terminal
        self.add_logging()

        # Create Tensorboard logging
        self.add_tensorboard_logging(logging_dir=logging_dir)

        ## CALLBACKS
        self.create_callbacks(best_model_only=best_model_only)

        ## TRAIN
        # Train the model
        print('[INFO] Executing model training...')
        self.train_engine.run(self.train_loader,
                              max_epochs=self.config.TRAIN.NUM_EPOCHS)
        print('[INFO] Model training is complete.')

    def update_model_from_checkpoint(self,
                                     checkpoint_file=None,
                                     load_to_device=True):
        '''
        Function to take a saved checkpoint of the models weights, and load it into the model.
        '''
        assert self.trainer_status[
            'model'], '[ERROR] You must create the model to load the weights. Use Trainer.create_model() method to first create your model, then load weights.'
        assert checkpoint_file is not None, '[ERROR] You must provide the full path and name of the .pt file containing the saved weights of the model you want to update.'

        try:
            # Load the weights of the checkpointed model from the PT file
            self.model.load_state_dict(torch.load(f=checkpoint_file))
        except:
            raise Exception(
                '[ERROR] Something went wrong with loading the weights into the model.'
            )
        else:
            print(
                '[INFO] Successfully loaded weights into the model from weights file:: {}'
                .format(checkpoint_file))

        if load_to_device:
            self.model.to(self.device)
            print(
                '[INFO] Successfully updated model and pushed it to the device {}'
                .format(self.device))
            # Print summary of model
            summary(
                self.model,
                batch_size=self.config.TRAIN.BATCH_SIZE,
                input_size=(
                    3,
                    self.config.DATA.TRANSFORMS.PARAMS.DEFAULT.img_crop_size,
                    self.config.DATA.TRANSFORMS.PARAMS.DEFAULT.img_crop_size))
        else:
            print(
                '[INFO] Successfully updated model but NOT pushed it to the device {}'
                .format(self.device))
            # Print summary of model
            summary(
                self.model,
                device='cpu',
                batch_size=self.config.TRAIN.BATCH_SIZE,
                input_size=(
                    3,
                    self.config.DATA.TRANSFORMS.PARAMS.DEFAULT.img_crop_size,
                    self.config.DATA.TRANSFORMS.PARAMS.DEFAULT.img_crop_size))

    def convert_to_torchscript(self,
                               checkpoint_file=None,
                               torchscript_model_path=None,
                               method='trace',
                               return_jit_model=False):

        assert self.trainer_status[
            'model'], '[ERROR] You must create the model to load the weights. Use Trainer.create_model() method to first create your model, then load weights.'
        assert checkpoint_file is not None, '[ERROR] You must provide the path and name of a PyTorch Ignite checkpoint file of model weights [checkpoint_file].'

        # Update the Trainer class attribute model with model weights file
        self.update_model_from_checkpoint(checkpoint_file=checkpoint_file)

        if torchscript_model_path is None:
            torchscript_model_path = os.path.join(os.getcwd(),
                                                  'torchscript_model.pt')

        if method == 'trace':
            assert self.trainer_status[
                'val_loader'], '[ERROR] You must create the validation loader in order to load images. Use Trainer.create_dataloaders() method to create access to image batches.'

            # Create an image batch
            X, _ = next(iter(self.val_loader))
            # Push the input images to the device
            X = X.to(self.device)
            # Trace the model
            jit_model = torch.jit.trace(self.model, (X))
            # Write the trace module of the model to disk
            print(
                '[INFO] Torchscript file being saved to temporary location:: {}'
                .format(torchscript_model_path))
            jit_model.save(torchscript_model_path)

        elif method == 'script':
            # Trace the model
            jit_model = torch.jit.script(self.model)
            # Write the trace module of the model to disk
            print(
                '[INFO] Torchscript file being saved to temporary location:: {}'
                .format(torchscript_model_path))
            jit_model.save(torchscript_model_path)

        if return_jit_model:
            return jit_model

    def _create_ingite_model_checkpointer(self, best_model_only=True):
        '''
        Function to create an ingite model checkpointer based on validation accuracy (best model == True), or at every epoch (best model == False)
        '''

        print('Model Checkpointing...', end='')
        if best_model_only:
            print('best model checkpointing...', end='')
            # best model checkpointer, based on validation accuracy.
            self.model_checkpointer = ModelCheckpoint(
                dirname=self.config.DIRS.WORKING_DIR,
                filename_prefix='caltech_birds_ignite_best',
                score_function=score_function_acc,
                score_name='val_acc',
                n_saved=2,
                create_dir=True,
                save_as_state_dict=True,
                require_empty=False,
                global_step_transform=global_step_from_engine(
                    self.train_engine))
            self.evaluator.add_event_handler(
                Events.COMPLETED, self.model_checkpointer,
                {self.config.MODEL.MODEL_NAME: self.model})
        else:
            # Checkpoint the model
            # iteration checkpointer
            print('every iteration model checkpointing...', end='')
            self.model_checkpointer = ModelCheckpoint(
                dirname=self.config.DIRS.WORKING_DIR,
                filename_prefix='caltech_birds_ignite',
                n_saved=2,
                create_dir=True,
                save_as_state_dict=True,
                require_empty=False)
            self.train_engine.add_event_handler(
                Events.EPOCH_COMPLETED, self.model_checkpointer,
                {self.config.MODEL.MODEL_NAME: self.model})

        print('Done')
Esempio n. 9
0
def training(rank, config):
    rank = idist.get_rank()
    manual_seed(config["seed"] + rank)
    device = idist.device()

    # Define output folder:
    config.output = "/tmp/output"

    model = idist.auto_model(config.model)
    optimizer = idist.auto_optim(config.optimizer)
    criterion = config.criterion

    train_set, val_set = config.train_set, config.val_set
    train_loader = idist.auto_dataloader(train_set,
                                         batch_size=config.train_batch_size)
    val_loader = idist.auto_dataloader(val_set,
                                       batch_size=config.val_batch_size)

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED(every=config.val_interval))
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    if rank == 0:
        tb_logger = TensorboardLogger(log_dir=config.output)

        tb_logger.attach_output_handler(
            trainer,
            event_name=Events.ITERATION_COMPLETED(every=100),
            tag="training",
            output_transform=lambda loss: {"batchloss": loss},
            metric_names="all",
        )

        for tag, evaluator in [("training", train_evaluator),
                               ("validation", validation_evaluator)]:
            tb_logger.attach_output_handler(
                evaluator,
                event_name=Events.EPOCH_COMPLETED,
                tag=tag,
                metric_names=["loss", "accuracy"],
                global_step_transform=global_step_from_engine(trainer),
            )

        tb_logger.attach_opt_params_handler(
            trainer,
            event_name=Events.ITERATION_COMPLETED(every=100),
            optimizer=optimizer)

    model_checkpoint = ModelCheckpoint(
        config.output,
        n_saved=2,
        filename_prefix="best",
        score_name="accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    trainer.run(train_loader, max_epochs=config.num_epochs)

    if rank == 0:
        tb_logger.close()