示例#1
0
def main(dataset, dataroot,
         z_dim, g_filters, d_filters,
         batch_size, epochs,
         learning_rate, beta_1,
         saved_G, saved_D,
         seed,
         n_workers, device,
         alpha, output_dir):

    # seed
    check_manual_seed(seed)

    # Summarywriter
    writer = SummaryWriter(output_dir+ "board")
    
    # logger
    logger = set_logger("GAN_model", output_dir,0)
    logger.info("start training")

    # data
    dataset, num_channels = check_dataset(dataset, dataroot)
    loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True)

    # netowrks
    netG = Generator(z_dim, g_filters, num_channels).to(device)
    netD = Discriminator(num_channels, d_filters).to(device)

    # criterion
    bce = nn.BCELoss()

    # optimizers
    optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999))

    # load pre-trained models
    if saved_G:
        netG.load_state_dict(torch.load(saved_G))

    if saved_D:
        netD.load_state_dict(torch.load(saved_D))

    # misc
    real_labels = torch.ones(batch_size, device=device)
    fake_labels = torch.zeros(batch_size, device=device)
    fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)

    def get_noise():
        return torch.randn(batch_size, z_dim, 1, 1, device=device)

    # The main function, processing a batch of examples
    def step(engine, batch):

        # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
        real, _ = batch
        real = real.to(device)

        # -----------------------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        # train with real
        output = netD(real)
        errD_real = bce(output, real_labels)
        D_x = output.mean().item()

        errD_real.backward()

        # get fake image from generator
        noise = get_noise()
        fake = netG(noise)

        # train with fake
        output = netD(fake.detach())
        errD_fake = bce(output, fake_labels)
        D_G_z1 = output.mean().item()

        errD_fake.backward()

        # gradient update
        errD = errD_real + errD_fake
        optimizerD.step()

        # -----------------------------------------------------------
        # (2) Update G network: maximize log(D(G(z)))
        netG.zero_grad()

        # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
        output = netD(fake)
        errG = bce(output, real_labels)
        D_G_z2 = output.mean().item()

        errG.backward()

        # gradient update
        optimizerG.step()

        return {
            'errD': errD.item(),
            'errG': errG.item(),
            'D_x': D_x,
            'D_G_z1': D_G_z1,
            'D_G_z2': D_G_z2
        }

    # ignite objects
    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, save_interval=1, n_saved=10, require_empty=False)
    timer = Timer(average=True)

    # attach running average metrics
    monitoring_metrics = ['errD', 'errG', 'D_x', 'D_G_z1', 'D_G_z2']
    RunningAverage(alpha=alpha, output_transform=lambda x: x['errD']).attach(trainer, 'errD')
    RunningAverage(alpha=alpha, output_transform=lambda x: x['errG']).attach(trainer, 'errG')
    RunningAverage(alpha=alpha, output_transform=lambda x: x['D_x']).attach(trainer, 'D_x')
    RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z1']).attach(trainer, 'D_G_z1')
    RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z2']).attach(trainer, 'D_G_z2')

    # attach progress bar

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_results(engine):
        iter = (engine.state.iteration - 1) % len(loader) + 1

        if iter % PRINT_FREQ == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] errD:{:.3f}\
             errG:{:.3f} D_x:{:.3f} D_G_z1:{:.3f} D_G_z2:{:.3f}"
             .format(engine.state.epoch, iter,len(loader), engine.state.metrics["errD"]
             ,engine.state.metrics["errG"],engine.state.metrics["D_x"],engine.state.metrics["D_G_z1"],
             engine.state.metrics["D_G_z2"]))
            writer.add_scalars("train", {"errD": engine.state.metrics["errD"]}, engine.state.iteration)
            writer.add_scalars("train", {"errG": engine.state.metrics["errG"]}, engine.state.iteration)
            writer.add_scalars("train", {"D_x": engine.state.metrics["D_x"]}, engine.state.iteration)
            writer.add_scalars("train", {"D_G_z1": engine.state.metrics["D_G_z1"]}, engine.state.iteration)
            writer.add_scalars("train", {"D_G_z2": engine.state.metrics["D_G_z2"]}, engine.state.iteration)
        

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # adding handlers using `trainer.add_event_handler` method API
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
                              to_save={
                                  'netG': netG,
                                  'netD': netD
                              })

    # automatically adding handlers via a special `attach` method of `Timer` handler
    '''
                engine (Engine):
                Engine that this timer will be attached to.
            start (Events):
                Event which should start (reset) the timer.
            pause (Events):
                Event which should pause the timer.
            resume (Events, optional):
                Event which should resume the timer.
            step (Events, optional):
                Event which should call the `step` method of the counter.
    '''
    timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
  

    # adding handlers using `trainer.on` decorator API
    # timer.step_count 
    # timer.value 一个样本运行事件
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info("Epoch {} done. Time per batch:{:.3f}[s] Speed:{:.1f}[samples/s]"
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                    loader.batch_size / timer.value()))
        logger.info("timer.step_count:{:.3f}, timer.value:{:.3f}".format(timer.step_count, timer.value()))
        timer.reset()


    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')

            create_plots(engine)
            checkpoint_handler(engine, {
                'netG_exception': netG,
                'netD_exception': netD
            })

        else:
            raise e

    # Setup is done. Now let's run the training
    trainer.run(loader, epochs)
    writer.close()
示例#2
0
文件: trainer.py 项目: fromptj/DL
 def attach_running_average(engine, metric_name):
     RunningAverage(output_transform=lambda x: x[metric_name]).attach(
         engine,
         metric_name,
     )
示例#3
0
def do_train_with_center(cfg, model, center_criterion, train_loader,
                         val_loader, optimizer, optimizer_center, scheduler,
                         loss_fn, num_query, start_epoch):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer_with_center(
        model,
        center_criterion,
        optimizer,
        optimizer_center,
        loss_fn,
        cfg.SOLVER.CENTER_LOSS_WEIGHT,
        device=device)
    evaluator = create_supervised_evaluator(
        model,
        metrics={
            'r1_mAP': R1_mAP(num_query,
                             max_rank=50,
                             feat_norm=cfg.TEST.FEAT_NORM)
        },
        device=device)
    checkpointer = ModelCheckpoint(output_dir,
                                   cfg.MODEL.NAME,
                                   checkpoint_period,
                                   n_saved=10,
                                   require_empty=False)
    timer = Timer(average=True)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED, checkpointer, {
            'model': model,
            'optimizer': optimizer,
            'optimizer_center': optimizer_center
        })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        global ITER
        ITER += 1

        if ITER % log_period == 0:
            logger.info(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                .format(engine.state.epoch, ITER, len(train_loader),
                        engine.state.metrics['avg_loss'],
                        engine.state.metrics['avg_acc'],
                        scheduler.get_lr()[0]))
        if len(train_loader) == ITER:
            ITER = 0

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(
                engine.state.epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))

    trainer.run(train_loader, max_epochs=epochs)
示例#4
0
    # Evaluating 시 process_function
    def evaluate_process(engine, batch):
        model.float().to(device).eval()
        with torch.no_grad():
            _, font = batch

            font = font.float().to(device)

            font_hat, latent_vectors = model(font)

            return font, font_hat, latent_vectors

    trainer = Engine(train_process)
    evaluator = Engine(evaluate_process)

    RunningAverage(output_transform=lambda x: x).attach(trainer, 'mse')

    Loss(F.mse_loss,
         output_transform=lambda x: [x[1], x[0]]).attach(evaluator, 'mse')

    desc = "ITERATION - loss: {:.5f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    train_history = []
    # valid_history = []

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
    def __call__(self, model, train_dataset, val_dataset=None, **_):
        """Train a PyTorch model.

        Args:
            model (torch.nn.Module): PyTorch model to train.
            train_dataset (torch.utils.data.Dataset): Dataset used to train.
            val_dataset (torch.utils.data.Dataset, optional): Dataset used to validate.

        Returns:
            trained_model (torch.nn.Module): Trained PyTorch model.
        """
        assert train_dataset is not None
        train_params = self.train_params
        mlflow_logging = self.mlflow_logging

        if mlflow_logging:
            try:
                import mlflow  # NOQA
            except ImportError:
                log.warning(
                    "Failed to import mlflow. MLflow logging is disabled.")
                mlflow_logging = False

        loss_fn = train_params.get("loss_fn")
        assert loss_fn
        epochs = train_params.get("epochs")
        seed = train_params.get("seed")
        optimizer = train_params.get("optimizer")
        assert optimizer
        optimizer_params = train_params.get("optimizer_params", dict())
        train_dataset_size_limit = train_params.get("train_dataset_size_limit")
        if train_dataset_size_limit:
            train_dataset = PartialDataset(train_dataset,
                                           train_dataset_size_limit)
            log.info("train dataset size is set to {}".format(
                len(train_dataset)))

        val_dataset_size_limit = train_params.get("val_dataset_size_limit")
        if val_dataset_size_limit and (val_dataset is not None):
            val_dataset = PartialDataset(val_dataset, val_dataset_size_limit)
            log.info("val dataset size is set to {}".format(len(val_dataset)))

        train_data_loader_params = train_params.get("train_data_loader_params",
                                                    dict())
        val_data_loader_params = train_params.get("val_data_loader_params",
                                                  dict())
        evaluation_metrics = train_params.get("evaluation_metrics")
        evaluate_train_data = train_params.get("evaluate_train_data")
        evaluate_val_data = train_params.get("evaluate_val_data")
        progress_update = train_params.get("progress_update")

        scheduler = train_params.get("scheduler")
        scheduler_params = train_params.get("scheduler_params", dict())

        model_checkpoint = train_params.get("model_checkpoint")
        model_checkpoint_params = train_params.get("model_checkpoint_params")
        early_stopping_params = train_params.get("early_stopping_params")
        time_limit = train_params.get("time_limit")

        cudnn_deterministic = train_params.get("cudnn_deterministic")
        cudnn_benchmark = train_params.get("cudnn_benchmark")

        if seed:
            torch.manual_seed(seed)
            np.random.seed(seed)
        if cudnn_deterministic:
            torch.backends.cudnn.deterministic = cudnn_deterministic
        if cudnn_benchmark:
            torch.backends.cudnn.benchmark = cudnn_benchmark

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        optimizer_ = optimizer(model.parameters(), **optimizer_params)
        trainer = create_supervised_trainer(model,
                                            optimizer_,
                                            loss_fn=loss_fn,
                                            device=device)

        train_data_loader_params.setdefault("shuffle", True)
        train_data_loader_params.setdefault("drop_last", True)
        train_data_loader_params["batch_size"] = _clip_batch_size(
            train_data_loader_params.get("batch_size", 1), train_dataset,
            "train")
        train_loader = DataLoader(train_dataset, **train_data_loader_params)

        RunningAverage(output_transform=lambda x: x,
                       alpha=0.98).attach(trainer, "ema_loss")

        RunningAverage(output_transform=lambda x: x,
                       alpha=2**(-1022)).attach(trainer, "batch_loss")

        if scheduler:

            class ParamSchedulerSavingAsMetric(
                    ParamSchedulerSavingAsMetricMixIn, scheduler):
                pass

            cycle_epochs = scheduler_params.pop("cycle_epochs", 1)
            scheduler_params.setdefault("cycle_size",
                                        int(cycle_epochs * len(train_loader)))
            scheduler_params.setdefault("param_name", "lr")
            scheduler_ = ParamSchedulerSavingAsMetric(optimizer_,
                                                      **scheduler_params)
            trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_)

        if evaluate_train_data:
            evaluator_train = create_supervised_evaluator(
                model, metrics=evaluation_metrics, device=device)

        if evaluate_val_data:
            val_data_loader_params["batch_size"] = _clip_batch_size(
                val_data_loader_params.get("batch_size", 1), val_dataset,
                "val")
            val_loader = DataLoader(val_dataset, **val_data_loader_params)
            evaluator_val = create_supervised_evaluator(
                model, metrics=evaluation_metrics, device=device)

        if model_checkpoint_params:
            assert isinstance(model_checkpoint_params, dict)
            minimize = model_checkpoint_params.pop("minimize", True)
            save_interval = model_checkpoint_params.get("save_interval", None)
            if not save_interval:
                model_checkpoint_params.setdefault(
                    "score_function",
                    get_score_function("ema_loss", minimize=minimize))
            model_checkpoint_params.setdefault("score_name", "ema_loss")
            mc = model_checkpoint(**model_checkpoint_params)
            trainer.add_event_handler(Events.EPOCH_COMPLETED, mc,
                                      {"model": model})

        if early_stopping_params:
            assert isinstance(early_stopping_params, dict)
            metric = early_stopping_params.pop("metric", None)
            assert (metric is None) or (metric in evaluation_metrics)
            minimize = early_stopping_params.pop("minimize", False)
            if metric:
                assert (
                    "score_function" not in early_stopping_params
                ), "Remove either 'metric' or 'score_function' from early_stopping_params: {}".format(
                    early_stopping_params)
                early_stopping_params["score_function"] = get_score_function(
                    metric, minimize=minimize)

            es = EarlyStopping(trainer=trainer, **early_stopping_params)
            if evaluate_val_data:
                evaluator_val.add_event_handler(Events.COMPLETED, es)
            elif evaluate_train_data:
                evaluator_train.add_event_handler(Events.COMPLETED, es)
            elif early_stopping_params:
                log.warning(
                    "Early Stopping is disabled because neither "
                    "evaluate_val_data nor evaluate_train_data is set True.")

        if time_limit:
            assert isinstance(time_limit, (int, float))
            tl = TimeLimit(limit_sec=time_limit)
            trainer.add_event_handler(Events.ITERATION_COMPLETED, tl)

        pbar = None
        if progress_update:
            if not isinstance(progress_update, dict):
                progress_update = dict()
            progress_update.setdefault("persist", True)
            progress_update.setdefault("desc", "")
            pbar = ProgressBar(**progress_update)
            pbar.attach(trainer, ["ema_loss"])

        else:

            def log_train_metrics(engine):
                log.info("[Epoch: {} | {}]".format(engine.state.epoch,
                                                   engine.state.metrics))

            trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                      log_train_metrics)

        if evaluate_train_data:

            def log_evaluation_train_data(engine):
                evaluator_train.run(train_loader)
                train_report = _get_report_str(engine, evaluator_train,
                                               "Train Data")
                if pbar:
                    pbar.log_message(train_report)
                else:
                    log.info(train_report)

            eval_train_event = (Events[evaluate_train_data] if isinstance(
                evaluate_train_data, str) else Events.EPOCH_COMPLETED)
            trainer.add_event_handler(eval_train_event,
                                      log_evaluation_train_data)

        if evaluate_val_data:

            def log_evaluation_val_data(engine):
                evaluator_val.run(val_loader)
                val_report = _get_report_str(engine, evaluator_val, "Val Data")
                if pbar:
                    pbar.log_message(val_report)
                else:
                    log.info(val_report)

            eval_val_event = (Events[evaluate_val_data] if isinstance(
                evaluate_val_data, str) else Events.EPOCH_COMPLETED)
            trainer.add_event_handler(eval_val_event, log_evaluation_val_data)

        if mlflow_logging:
            mlflow_logger = MLflowLogger()

            logging_params = {
                "train_n_samples": len(train_dataset),
                "train_n_batches": len(train_loader),
                "optimizer": _name(optimizer),
                "loss_fn": _name(loss_fn),
                "pytorch_version": torch.__version__,
                "ignite_version": ignite.__version__,
            }
            logging_params.update(_loggable_dict(optimizer_params,
                                                 "optimizer"))
            logging_params.update(
                _loggable_dict(train_data_loader_params, "train"))
            if scheduler:
                logging_params.update({"scheduler": _name(scheduler)})
                logging_params.update(
                    _loggable_dict(scheduler_params, "scheduler"))

            if evaluate_val_data:
                logging_params.update({
                    "val_n_samples": len(val_dataset),
                    "val_n_batches": len(val_loader),
                })
                logging_params.update(
                    _loggable_dict(val_data_loader_params, "val"))

            mlflow_logger.log_params(logging_params)

            batch_metric_names = ["batch_loss", "ema_loss"]
            if scheduler:
                batch_metric_names.append(scheduler_params.get("param_name"))

            mlflow_logger.attach(
                trainer,
                log_handler=OutputHandler(
                    tag="step",
                    metric_names=batch_metric_names,
                    global_step_transform=global_step_from_engine(trainer),
                ),
                event_name=Events.ITERATION_COMPLETED,
            )

            if evaluate_train_data:
                mlflow_logger.attach(
                    evaluator_train,
                    log_handler=OutputHandler(
                        tag="train",
                        metric_names=list(evaluation_metrics.keys()),
                        global_step_transform=global_step_from_engine(trainer),
                    ),
                    event_name=Events.COMPLETED,
                )
            if evaluate_val_data:
                mlflow_logger.attach(
                    evaluator_val,
                    log_handler=OutputHandler(
                        tag="val",
                        metric_names=list(evaluation_metrics.keys()),
                        global_step_transform=global_step_from_engine(trainer),
                    ),
                    event_name=Events.COMPLETED,
                )

        trainer.run(train_loader, max_epochs=epochs)

        try:
            if pbar and pbar.pbar:
                pbar.pbar.close()
        except Exception as e:
            log.error(e, exc_info=True)

        model = load_latest_model(model_checkpoint_params)(model)

        return model
示例#6
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, gpuid):

    device = 'cpu' if (not torch.cuda.is_available()
                       or not cuda) else 'cuda:' + str(gpuid)

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: lr * min(1., epoch / warmup)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):

        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
            "Episode %d: reward=%s, steps=%s, speed=%.3f frames/s, elapsed=%s"
            % (trainer.state.episode, trainer.state.episode_reward,
               trainer.state.episode_steps,
               trainer.state.metrics.get('avg_fps', 0),
               timedelta(seconds=trainer.state.metrics.get('time_passed', 0))))

    @engine.on(ptan_ignite.EpisodeEvents.BOUND_REWARD_REACHED)
    def game_solved(trainer: Engine):
        print("Game solved in %s, after %d episodes and %d iterations!" %
              (timedelta(seconds=trainer.state.metrics['time_passed']),
               trainer.state.episode, trainer.state.iteration))
        trainer.should_terminate = True

    logdir = f"runs/{datetime.now().isoformat(timespec='minutes')}-{params.run_name}-{NAME}"
    tb = tb_logger.TensorboardLogger(log_dir=logdir)
    RunningAverage(output_transform=lambda v: v['loss']).attach(
        engine, "avg_loss")

    episode_handler = tb_logger.OutputHandler(
        tag="episodes", metric_names=['reward', 'steps', 'avg_reward'])
    tb.attach(engine,
              log_handler=episode_handler,
              event_name=ptan_ignite.EpisodeEvents.EPISODE_COMPLETED)

    # write to tensorboard every 100 iterations
    ptan_ignite.PeriodicEvents().attach(engine)
    handler = tb_logger.OutputHandler(tag="train",
                                      metric_names=['avg_loss', 'avg_fps'],
                                      output_transform=lambda a: a)
    tb.attach(engine,
              log_handler=handler,
              event_name=ptan_ignite.PeriodEvents.ITERS_100_COMPLETED)
示例#8
0
def test_integration():

    n_iters = 100
    batch_size = 10
    n_classes = 10
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    loss_values = iter(range(n_iters))

    def update_fn(engine, batch):
        loss_value = next(loss_values)
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return (
            loss_value,
            torch.from_numpy(y_pred_batch),
            torch.from_numpy(y_true_batch),
        )

    trainer = Engine(update_fn)
    alpha = 0.98

    acc_metric = RunningAverage(
        Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha)
    acc_metric.attach(trainer, "running_avg_accuracy")

    avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha)
    avg_output.attach(trainer, "running_avg_output")

    running_avg_acc = [
        None,
    ]

    @trainer.on(Events.ITERATION_COMPLETED)
    def manual_running_avg_acc(engine):
        _, y_pred, y = engine.state.output
        indices = torch.max(y_pred, 1)[1]
        correct = torch.eq(indices, y).view(-1)
        num_correct = torch.sum(correct).item()
        num_examples = correct.shape[0]
        batch_acc = num_correct * 1.0 / num_examples
        if running_avg_acc[0] is None:
            running_avg_acc[0] = batch_acc
        else:
            running_avg_acc[0] = running_avg_acc[0] * alpha + (
                1.0 - alpha) * batch_acc
        engine.state.running_avg_acc = running_avg_acc[0]

    @trainer.on(Events.EPOCH_STARTED)
    def running_avg_output_init(engine):
        engine.state.running_avg_output = None

    @trainer.on(Events.ITERATION_COMPLETED)
    def running_avg_output_update(engine):
        if engine.state.running_avg_output is None:
            engine.state.running_avg_output = engine.state.output[0]
        else:
            engine.state.running_avg_output = (
                engine.state.running_avg_output * alpha +
                (1.0 - alpha) * engine.state.output[0])

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_acc_values(engine):
        assert (engine.state.running_avg_acc == engine.state.
                metrics["running_avg_accuracy"]), "{} vs {}".format(
                    engine.state.running_avg_acc,
                    engine.state.metrics["running_avg_accuracy"])

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_output_values(engine):
        assert (engine.state.running_avg_output ==
                engine.state.metrics["running_avg_output"]), "{} vs {}".format(
                    engine.state.running_avg_output,
                    engine.state.metrics["running_avg_output"])

    np.random.seed(10)
    running_avg_acc = [
        None,
    ]
    n_iters = 10
    batch_size = 10
    n_classes = 10
    data = list(range(n_iters))
    loss_values = iter(range(n_iters))
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    trainer.run(data, max_epochs=1)

    running_avg_acc = [
        None,
    ]
    n_iters = 10
    batch_size = 10
    n_classes = 10
    data = list(range(n_iters))
    loss_values = iter(range(n_iters))
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    trainer.run(data, max_epochs=1)
示例#9
0
def _setup_common_training_handlers(
    trainer: Engine,
    to_save: Optional[Mapping] = None,
    save_every_iters: int = 1000,
    output_path: Optional[str] = None,
    lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
    with_gpu_stats: bool = False,
    output_names: Optional[Iterable[str]] = None,
    with_pbars: bool = True,
    with_pbar_on_iters: bool = True,
    log_every_iters: int = 100,
    stop_on_nan: bool = True,
    clear_cuda_cache: bool = True,
    save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
    **kwargs: Any,
) -> None:
    if output_path is not None and save_handler is not None:
        raise ValueError(
            "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
        )

    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()
            )
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:

        if output_path is None and save_handler is None:
            raise ValueError(
                "If to_save argument is provided then output_path or save_handler arguments should be also defined"
            )
        if output_path is not None:
            save_handler = DiskSaver(dirname=output_path, require_empty=False)

        checkpoint_handler = Checkpoint(
            to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs
        )
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(
            trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)  # type: ignore[arg-type]
        )

    if output_names is not None:

        def output_transform(x: Any, index: int, name: str) -> Any:
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise TypeError(
                    "Unhandled type of update_function's output. "
                    f"It should either mapping or sequence, but given {type(x)}"
                )

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(
                trainer, n
            )

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)
            )

        ProgressBar(persist=True, bar_format="").attach(
            trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED
        )
示例#10
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_checkpoint", type=str, default=PRETRAINED_MODEL_URL, help="Path to the pretrained model checkpoint")
    parser.add_argument("--dataset_path", type=str, default='../data/sst', help="Directory to dataset.")
    parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path to dataset cache")
    parser.add_argument("--logdir", type=str, default='./transformer_results', help="Path to logs")
    parser.add_argument("--num_classes", type=int, default=5, help="Number of classes for the target classification task")
    parser.add_argument("--adapters_dim", type=int, default=-1, help="If >0 add adapters to the model with adapters_dim dimension")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout for transformer module")
    parser.add_argument("--clf_loss_coef", type=float, default=1, help="If >0 add a classification loss")
    parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=32, help="Batch size for validation")
    parser.add_argument("--valid_pct", type=float, default=0.1, help="Percentage of training data to use for validation")
    parser.add_argument("--lr", type=float, default=6.5e-5, help="Learning rate")
    parser.add_argument("--n_warmup", type=int, default=10, help="Number of warmup iterations")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay")
    parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--gradient_acc_steps", type=int, default=2, help="Number of update steps to accumulate before a backward pass.")
    parser.add_argument("--init_range", type=float, default=0.02, help="Normal initialization standard deviation")

    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    args = parser.parse_args()

    # Define pretrained model and optimizer
    model, state_dict, config = load_pretrained_model(args)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=False)
    num_model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {num_model_params:,} parameters")
    # Define datasets
    datasets = read_sst5(args.dataset_path)

    # Define labels
    labels = list(set(datasets["train"][LABEL_COL].tolist()))
    assert len(labels) == args.num_classes  # Specified number of classes should be equal to that in the given dataset!
    label2int = {label: i for i, label in enumerate(labels)}
    int2label = {i: label for label, i in label2int.items()}

    # Get BertTokenizer for this pretrained model
    tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
    clf_token = tokenizer.vocab['[CLS]']  # classifier token
    pad_token = tokenizer.vocab['[PAD]']  # pad token
    processor = TextProcessor(tokenizer, label2int, clf_token, pad_token, max_length=config.num_max_positions)

    train_dl = processor.create_dataloader(datasets["train"],
                                           shuffle=True,
                                           batch_size=args.train_batch_size,
                                           valid_pct=None)

    valid_dl = processor.create_dataloader(datasets["dev"],
                                           batch_size=args.train_batch_size,
                                           valid_pct=None)

    test_dl = processor.create_dataloader(datasets["test"],
                                          batch_size=args.valid_batch_size,
                                          valid_pct=None)

    # Training function and trainer
    def update(engine, batch):
        "update function for training"
        model.train()
        inputs, labels = (t.to(args.device) for t in batch)
        inputs = inputs.transpose(0, 1).contiguous()  # to shape [seq length, batch]
        _, loss = model(inputs,
                        clf_tokens_mask=(inputs == clf_token),
                        clf_labels=labels)
        loss = loss / args.gradient_acc_steps
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch, labels = (t.to(args.device) for t in batch)
            inputs = batch.transpose(0, 1).contiguous()  # to shape [seq length, batch]
            clf_logits = model(inputs,
                               clf_tokens_mask=(inputs == clf_token),
                               padding_mask=(batch == pad_token))
        return clf_logits, labels
    evaluator = Engine(inference)

    # add metric to evaluator
    Accuracy().attach(evaluator, "accuracy")

    # add evaluator to trainer: eval on valid set after each epoch
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(valid_dl)
        print(f"validation epoch: {engine.state.epoch} acc: {100*evaluator.state.metrics['accuracy']:.3f}%")

    # Learning rate schedule: linearly warm-up to lr and then to zero
    scheduler = PiecewiseLinear(optimizer, 'lr', [(0, 0.0), (args.n_warmup, args.lr),
                                (len(train_dl) * args.n_epochs, 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Add progressbar with loss
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    ProgressBar(persist=True).attach(trainer, metric_names=['loss'])

    # Save checkpoints and finetuning config
    checkpoint_handler = ModelCheckpoint(args.logdir, 'checkpoint',
                                         save_interval=1, require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'sst_model': model})

    # Save metadata
    torch.save({
        "config": config,
        "config_ft": args,
        "int2label": int2label
    }, os.path.join(args.logdir, "model_training_args.bin"))

    # Run trainer
    trainer.run(train_dl, max_epochs=args.n_epochs)
    # Evaluate
    evaluator.run(test_dl)
    print(f"test results - acc: {100*evaluator.state.metrics['accuracy']:.3f}")
    # Save fine-tuned model weights
    torch.save(model.state_dict(), os.path.join(args.logdir, "model_weights.pth"))
示例#11
0
def _test_distrib_on_metric(device):
    import torch.distributed as dist

    rank = dist.get_rank()
    n_iters = 10
    n_epochs = 3
    batch_size = 10
    n_classes = 10

    data = list(range(n_iters))
    np.random.seed(12)
    all_y_true_batch_values = np.random.randint(0,
                                                n_classes,
                                                size=(dist.get_world_size(),
                                                      n_epochs * n_iters,
                                                      batch_size))
    all_y_pred_batch_values = np.random.rand(dist.get_world_size(),
                                             n_epochs * n_iters, batch_size,
                                             n_classes)

    y_true_batch_values = iter(all_y_true_batch_values[rank, ...])
    y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...])

    def update_fn(engine, batch):
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

    trainer = Engine(update_fn)
    alpha = 0.98

    acc_metric = RunningAverage(
        Accuracy(output_transform=lambda x: [x[0], x[1]], device=device),
        alpha=alpha,
        epoch_bound=False,
    )
    acc_metric.attach(trainer, "running_avg_accuracy")

    running_avg_acc = [
        None,
    ]
    true_acc_metric = Accuracy(device=device)

    @trainer.on(Events.ITERATION_COMPLETED)
    def manual_running_avg_acc(engine):
        i = engine.state.iteration - 1

        true_acc_metric.reset()
        for j in range(dist.get_world_size()):
            output = (
                torch.from_numpy(all_y_pred_batch_values[j, i, :, :]),
                torch.from_numpy(all_y_true_batch_values[j, i, :]),
            )
            true_acc_metric.update(output)

        batch_acc = true_acc_metric._num_correct * 1.0 / true_acc_metric._num_examples

        if running_avg_acc[0] is None:
            running_avg_acc[0] = batch_acc
        else:
            running_avg_acc[0] = running_avg_acc[0] * alpha + (
                1.0 - alpha) * batch_acc
        engine.state.running_avg_acc = running_avg_acc[0]

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_acc_values(engine):
        assert (engine.state.running_avg_acc == engine.state.
                metrics["running_avg_accuracy"]), "{} vs {}".format(
                    engine.state.running_avg_acc,
                    engine.state.metrics["running_avg_accuracy"])

    trainer.run(data, max_epochs=3)
示例#12
0
def run(args):
    train_loader, val_loader = get_data_loaders(args.dataset_dir,
                                                args.batch_size,
                                                args.val_batch_size,
                                                args.num_workers)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    num_classes = KITTI.num_classes()
    model = LiLaNet(num_classes)

    device_count = torch.cuda.device_count()
    if device_count > 1:
        print("Using %d GPU(s)" % device_count)
        model = nn.DataParallel(model)
        args.batch_size = device_count * args.batch_size
        args.val_batch_size = device_count * args.val_batch_size

    model = model.to(device)

    criterion = nn.CrossEntropyLoss(weight=KITTI.class_weights()).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("Loaded checkpoint '{}' (Epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    def _prepare_batch(batch, non_blocking=True):
        distance, reflectivity, target = batch

        return (convert_tensor(distance,
                               device=device,
                               non_blocking=non_blocking),
                convert_tensor(reflectivity,
                               device=device,
                               non_blocking=non_blocking),
                convert_tensor(target,
                               device=device,
                               non_blocking=non_blocking))

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        distance, reflectivity, target = _prepare_batch(batch)
        pred = model(distance, reflectivity)
        loss = criterion(pred, target)
        loss.backward()
        optimizer.step()

        return loss.item()

    trainer = Engine(_update)

    # attach running average metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    # attach progress bar
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=['loss'])

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            distance, reflectivity, target = _prepare_batch(batch)
            pred = model(distance, reflectivity)

            return pred, target

    evaluator = Engine(_inference)
    cm = ConfusionMatrix(num_classes)
    IoU(cm, ignore_index=0).attach(evaluator, 'IoU')
    Loss(criterion).attach(evaluator, 'loss')

    pbar2 = ProgressBar(persist=True, desc='Eval Epoch')
    pbar2.attach(evaluator)

    def _global_step_transform(engine, event_name):
        if trainer.state is not None:
            return trainer.state.iteration
        else:
            return 1

    tb_logger = TensorboardLogger(args.log_dir)
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag='training',
                                               metric_names=['loss']),
                     event_name=Events.ITERATION_COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=OutputHandler(
                         tag='validation',
                         metric_names=['loss', 'IoU'],
                         global_step_transform=_global_step_transform),
                     event_name=Events.EPOCH_COMPLETED)

    @trainer.on(Events.STARTED)
    def initialize(engine):
        engine.state.exception_raised = False
        if args.resume:
            engine.state.epoch = args.start_epoch

    @evaluator.on(Events.EPOCH_COMPLETED)
    def save_checkpoint(engine):
        epoch = trainer.state.epoch if trainer.state is not None else 1
        iou = engine.state.metrics['IoU'] * 100.0
        mean_iou = iou.mean()

        name = 'epoch{}_mIoU={:.1f}.pth'.format(epoch, mean_iou)
        file = {
            'model': model.state_dict(),
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'args': args
        }

        save(file, args.output_dir, 'checkpoint_{}'.format(name))
        save(model.state_dict(), args.output_dir, 'model_{}'.format(name))

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(engine):
        pbar.log_message("Start Validation - Epoch: [{}/{}]".format(
            engine.state.epoch, engine.state.max_epochs))
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        iou = metrics['IoU'] * 100.0
        mean_iou = iou.mean()

        iou_text = ', '.join([
            '{}: {:.1f}'.format(KITTI.classes[i + 1].name, v)
            for i, v in enumerate(iou.tolist())
        ])
        pbar.log_message(
            "Validation results - Epoch: [{}/{}]: Loss: {:.2e}\n IoU: {}\n mIoU: {:.1f}"
            .format(engine.state.epoch, engine.state.max_epochs, loss,
                    iou_text, mean_iou))

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        engine.state.exception_raised = True
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            name = 'epoch{}_exception.pth'.format(trainer.state.epoch)
            file = {
                'model': model.state_dict(),
                'epoch': trainer.state.epoch,
                'optimizer': optimizer.state_dict()
            }

            save(file, args.output_dir, 'checkpoint_{}'.format(name))
            save(model.state_dict(), args.output_dir, 'model_{}'.format(name))
        else:
            raise e

    if args.eval_on_start:
        print("Start validation")
        evaluator.run(val_loader, max_epochs=1)

    print("Start training")
    trainer.run(train_loader, max_epochs=args.epochs)
    tb_logger.close()
示例#13
0
            mu = next(mu_scheme)
            i = engine.state.iteration
            for group in optimizer.param_groups:
                group["lr"] = mu * math.sqrt(1 - 0.999**i) / (1 - 0.9**i)

        return {
            "elbo": elbo.item(),
            "kl": kl_divergence.item(),
            "sigma": sigma,
            "mu": mu
        }

    # Trainer and metrics
    trainer = Engine(step)
    metric_names = ["elbo", "kl", "sigma", "mu"]
    RunningAverage(output_transform=lambda x: x["elbo"]).attach(
        trainer, "elbo")
    RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl")
    RunningAverage(output_transform=lambda x: x["sigma"]).attach(
        trainer, "sigma")
    RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu")
    ProgressBar().attach(trainer, metric_names=metric_names)

    # Model checkpointing
    checkpoint_handler = ModelCheckpoint("./",
                                         "checkpoint",
                                         save_interval=1,
                                         n_saved=3,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
示例#14
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
    model = Net()
    writer = SummaryWriter(log_dir=log_dir)

    # Use TPU device
    device = xm.xla_device()

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.NLLLoss()

    # Create trainer and evaluator
    trainer = create_supervised_trainer(
        model, optimizer, criterion, device=device, output_transform=lambda x, y, y_pred, loss: [loss.item(),]
    )

    val_metrics = {"accuracy": Accuracy(), "nll": Loss(criterion)}
    evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

    tracker = xm.RateTracker()

    # Add RateTracker as an output of the training step
    @trainer.on(Events.ITERATION_COMPLETED)
    def add_rate_tracker(engine):
        tracker.add(len(engine.state.batch))
        engine.state.output.append(tracker.global_rate())

    # Setup output values of the training step as EMA metrics
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "batch_loss")
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, "global_rate")

    # Let's log the EMA metrics every `log_interval` iterations
    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        writer.add_scalar("training/batch_loss", engine.state.metrics["batch_loss"], engine.state.iteration)
        writer.add_scalar("training/global_rate", engine.state.metrics["global_rate"], engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            f"Validation Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    writer.close()
def setup_ignite(
        engine: Engine,
        params: SimpleNamespace,
        exp_source,
        run_name: str,
        model,
        optimizer,
        extra_metrics: Iterable[str] = (),
):
    warnings.simplefilter("ignore", category=UserWarning)
    handler = ptan_ignite.EndOfEpisodeHandler(
        exp_source, bound_avg_reward=params.stop_reward)
    handler.attach(engine)
    ptan_ignite.EpisodeFPSHandler().attach(engine)

    objects_to_checkpoint = {
        'model': model,
        'optimizer': optimizer,
        'trainer': engine
    }
    checkpoint_dir = Path("models")
    saver = DiskSaver(str(checkpoint_dir),
                      create_dir=True,
                      require_empty=False)
    handler = Checkpoint(objects_to_checkpoint, saver, n_saved=2)
    engine.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

    checkpoints_paths = list(checkpoint_dir.iterdir())
    if checkpoints_paths:
        checkpoint = torch.load(checkpoints_paths[-1])
        print(f"Loading checkpoint {checkpoints_paths[-1].name}")
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    @engine.on(ptan_ignite.EpisodeEvents.EPISODE_COMPLETED)
    def episode_completed(trainer: Engine):
        passed = trainer.state.metrics.get('time_passed', 0)
        print("Episode %d: reward=%.2f, steps=%s, "
              "speed=%.1f f/s, elapsed=%s" %
              (trainer.state.episode, trainer.state.episode_reward,
               trainer.state.episode_steps,
               trainer.state.metrics.get('avg_fps',
                                         0), timedelta(seconds=int(passed))))

    @engine.on(ptan_ignite.EpisodeEvents.BOUND_REWARD_REACHED)
    def game_solved(trainer: Engine):
        passed = trainer.state.metrics['time_passed']
        print("Game solved in %s, after %d episodes "
              "and %d iterations!" %
              (timedelta(seconds=int(passed)), trainer.state.episode,
               trainer.state.iteration))
        trainer.should_terminate = True

    now = datetime.now().isoformat(timespec='minutes').replace(":", "-")
    logdir = f"runs/{now}-{params.run_name}-{run_name}"
    tb = tb_logger.TensorboardLogger(log_dir=logdir)
    run_avg = RunningAverage(output_transform=lambda v: v['loss'])
    run_avg.attach(engine, "avg_loss")

    metrics = ['reward', 'steps', 'avg_reward']
    handler = tb_logger.OutputHandler(tag="episodes", metric_names=metrics)
    event = ptan_ignite.EpisodeEvents.EPISODE_COMPLETED
    tb.attach(engine, log_handler=handler, event_name=event)

    # write to tensorboard every 100 iterations
    ptan_ignite.PeriodicEvents().attach(engine)
    metrics = ['avg_loss', 'avg_fps']
    metrics.extend(extra_metrics)
    handler = tb_logger.OutputHandler(tag="train",
                                      metric_names=metrics,
                                      output_transform=lambda a: a)
    event = ptan_ignite.PeriodEvents.ITERS_100_COMPLETED
    tb.attach(engine, log_handler=handler, event_name=event)
示例#16
0
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             loss_fn, num_query, start_epoch, image_map_label2, num_classes2):

    # ---------------------- LOSS start-----------------------------
    print('----------Initialize Loss Start...')
    criterion = torch.nn.CrossEntropyLoss()
    criterion_lsr = LSR()
    criterion_mse = torch.nn.MSELoss()  #(size_average=True)
    criterion_lsr_direction = LSR_direction()
    criterion_adaptive_lsr = AdaptiveLSR(0.25)

    criterion_lsr.set_epsilon(0.1)

    criterion_lsr_direction.set_alpha(0.6)
    criterion_lsr_direction.set_beta(0.15)
    print('******\nalpha:', criterion_lsr_direction.alpha, ' beta:',
          criterion_lsr_direction.beta)
    same_id_list = get_same_id_list(image_map_label2)
    criterion_lsr_direction.set_mask(same_id_list, num_classes2)

    mask_tensor_matrix = torch.zeros(num_classes2, num_classes2)
    eplsion = [1, 1, 1]
    for ids_item in same_id_list:
        if len(ids_item) == 2:
            mask_tensor_matrix[ids_item[0], ids_item[1]] = eplsion[1]
        if len(ids_item) == 3:
            mask_tensor_matrix[ids_item[0], ids_item[1]] = eplsion[2] / 3
            mask_tensor_matrix[ids_item[0], ids_item[2]] = eplsion[2] / 3
            mask_tensor_matrix[ids_item[1], ids_item[2]] = eplsion[2] / 3
    mask_tensor_matrix = mask_tensor_matrix.float()
    #mask_tensor_matrix = Variable(mask_tensor_matrix.cuda())
    print('mask_tensor_matrix.shape:', mask_tensor_matrix.shape,
          type(mask_tensor_matrix), '\n\n\n')
    print('----------Initialize Loss End!!!')
    # ---------------------------------------------------------

    global mAP_path, model_dir
    mAP_path = osp.join(cfg.OUTPUT_DIR, 'map_cmc.txt')
    model_dir = cfg.OUTPUT_DIR

    map_cmc_txt = open(mAP_path, 'a+')
    map_cmc_txt.close()

    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(
        model, optimizer, loss_fn, criterion, criterion_mse, criterion_lsr,
        criterion_adaptive_lsr, criterion_lsr_direction, mask_tensor_matrix,
        device, cfg.SOLVER.MIXUP, cfg.SOLVER.RICAP, cfg.MODEL.FREEZE_BASE,
        cfg.MODEL.FREEZE_BASE_EPOCH)
    #evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
    evaluator = create_supervised_evaluator(
        model,
        metrics={
            'r1_mAP': R1_mAP(num_query,
                             max_rank=50,
                             feat_norm=cfg.TEST.FEAT_NORM)
        },
        device=device)
    checkpointer = ModelCheckpoint(output_dir,
                                   cfg.MODEL.NAME,
                                   checkpoint_period,
                                   n_saved=3,
                                   require_empty=False)
    timer = Timer(average=True)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpointer,
        {
            'model': model,  #.state_dict(),
            'optimizer': optimizer
        })  #.state_dict()})
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        if cfg.SOLVER.MY_WARMUP == 'yes':
            if engine.state.epoch <= cfg.SOLVER.MY_WARMUP_EPOCH:
                print('--- warmup')
            else:
                scheduler.step()
        else:
            scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        global ITER
        ITER += 1

        if ITER % log_period == 0:
            if cfg.SOLVER.MY_SCHEDULER == 'yes':
                logger.info(
                    "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}".
                    format(engine.state.epoch, ITER, len(train_loader),
                           engine.state.metrics['avg_loss'],
                           engine.state.metrics['avg_acc']))
            else:
                logger.info(
                    "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                    .format(engine.state.epoch, ITER, len(train_loader),
                            engine.state.metrics['avg_loss'],
                            engine.state.metrics['avg_acc'],
                            scheduler.get_lr()[0]))
        if len(train_loader) == ITER:
            ITER = 0

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.2f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        global best_mAP, best_epoch, mAP_path, save_flag
        if engine.state.epoch % eval_period == 0:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(
                engine.state.epoch))
            logger.info("[Epoch {}]  mAP: {:.2%}".format(
                engine.state.epoch, mAP))
            for r in [1, 5, 10, 20]:
                logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(
                    r, cmc[r - 1]))
            if float(mAP) > float(best_mAP):
                print('+++ get best_mAP: ', best_mAP, '-->', mAP)
                best_mAP = mAP
                best_epoch = int(engine.state.epoch)
                save_flag = True
                print(' set save_flag: True')

            map_cmc_txt = open(mAP_path, 'a+')
            map_cmc_txt.write(
                "Epoch[{}]    best_mAP: {:.2f}  best_epoch: {} \n".format(
                    engine.state.epoch, best_mAP * 100, best_epoch))
            map_cmc_txt.write(
                "       mAP: {:.2f}  Rank-1: {:.2f}  Rank-5: {:.2f}  Rank-10: {:.2f}  Rank-20: {:.2f}\n"
                .format(
                    float(mAP) * 100, cmc[0] * 100, cmc[4] * 100, cmc[9] * 100,
                    cmc[19] * 100))
            map_cmc_txt.flush()
            os.fsync(map_cmc_txt)
            map_cmc_txt.close()

    trainer.run(train_loader, max_epochs=epochs)
示例#17
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path",
                        type=str,
                        default="",
                        help="Path or url of the dataset.")
    parser.add_argument("--use_adapter",
                        default=False,
                        action='store_true',
                        help="Use adapter or not")
    parser.add_argument("--keyword_module",
                        type=str,
                        default="",
                        help="add, attention, ")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=20,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=20,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=5,
                        help="Number of training epochs")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--gpt2_model_name",
                        type=str,
                        default="gpt2",
                        help="Path, url or short name of the model")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_model = BertModel.from_pretrained('bert-base-uncased')
    bert_model.to(args.device)
    bert_model.eval()

    tokenizer_class = GPT2Tokenizer if "gpt2" in args.gpt2_model_name else OpenAIGPTTokenizer  # cant use Autotokenizer because checkpoint could be a Path
    tokenizer = tokenizer_class.from_pretrained(args.gpt2_model_name)

    config_class = GPT2Config if "gpt2" in args.gpt2_model_name else OpenAIGPTConfig
    gpt_config = config_class.from_pretrained(args.gpt2_model_name)
    gpt_config.adapter = args.use_adapter
    gpt_config.keyword_module = args.keyword_module

    model_class = GPT2LMHeadModel if "gpt2" in args.gpt2_model_name else OpenAIGPTLMHeadModel
    model = model_class.from_pretrained(args.gpt2_model_name,
                                        config=gpt_config)
    model.to(args.device)

    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)

    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, bert_tokenizer, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        source_ids, target_ids, lm_labels = batch

        encoded_layers, _ = bert_model(source_ids)
        (lm_loss), *_ = model(target_ids, encoded_layers, labels=lm_labels)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            source_ids, target_ids, lm_labels = batch
            logger.info(tokenizer.decode(target_ids[0].tolist()))
            encoded_layers, _ = bert_model(source_ids)
            lm_logits, *_ = model(target_ids, encoded_layers)
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, ), (lm_labels_flat_shifted, )

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0][0], x[1][0]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.gpt2_model_name, args.dataset_path,
                              args.use_adapter, args.keyword_module)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=4)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
示例#18
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--train_path",
                        type=str,
                        default="data/train_set4DSTC7-AVSD.json",
                        help="Path of the trainset")
    parser.add_argument("--fea_path",
                        type=str,
                        default="data/",
                        help="Path of the trainset")
    parser.add_argument("--valid_path",
                        type=str,
                        default="data/valid_set4DSTC7-AVSD.json",
                        help="Path of the validset")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="gpt2",
                        help="Path, url or short name of the model")
    parser.add_argument("--max_history",
                        type=int,
                        default=3,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--drop_rate",
                        type=float,
                        default=0.5,
                        help="drop rate for caption")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=8,
                        help="Number of training epochs")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--log_path",
                        type=str,
                        default="log/",
                        help="Log path")
    args = parser.parse_args()

    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)
    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = VideoGPT2LMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders_new(args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, token_type_ids, labels, input_mask, i3d, video_mask, reply_mask = batch
        input_embs = model.transformer.wte(input_ids)
        video_embs = model.video_ff(i3d)
        input_embs = torch.cat([video_embs, input_embs], dim=1)
        token_type_ids = torch.cat([
            torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
            tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids
        ],
                                   dim=1)
        video_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[video_mask, input_mask],
                           mode="video")[0]
        reply_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[reply_mask, input_mask],
                           mode="reply")[0]
        loss = (video_loss + reply_loss) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, token_type_ids, lm_labels, input_mask, i3d, video_mask, reply_mask = batch
            input_embs = model.transformer.wte(input_ids)
            video_embs = model.video_ff(i3d)
            input_embs = torch.cat([video_embs, input_embs], dim=1)
            token_type_ids = torch.cat([
                torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
                tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]),
                token_type_ids
            ],
                                       dim=1)
            model_outputs = model(input_embs,
                                  token_type_ids=token_type_ids,
                                  attention_mask=[reply_mask, input_mask])[0]

            lm_logits = model_outputs  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return lm_logits_flat_shifted, lm_labels_flat_shifted

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0], x[1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir="./tb_logs")
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(args.log_path,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=8,
                                             require_empty=False)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, args.log_path + 'model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(args.log_path, CONFIG_NAME))
        tokenizer.save_vocabulary(args.log_path)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(args.log_path, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
示例#19
0
    def run_once(self):
        # self.log_path = 'log/%s/' % self.dataset
        # self.model_name = 'efficientnet-b0_MSI_{0}fold_random_tile_patch'.format(self.fold_idx)
        # self.log_dir = self.log_path + self.model_name

        log_dir = self.log_dir
        check_manual_seed(self.seed)
        train_pairs, valid_pairs = dataset.prepare_PAIP2020_PANDA(
            self.fold_idx)
        print(len(train_pairs))
        print(len(valid_pairs))

        train_augmentors = self.train_augmentors()
        train_dataset = dataset.DatasetSerial(train_pairs[:],
                                              self.tile_size,
                                              self.num_tile,
                                              train_mode=True)

        infer_augmentors = self.infer_augmentors()  # HACK at has_aux
        infer_dataset = dataset.DatasetSerial(valid_pairs[:],
                                              self.tile_size,
                                              self.num_tile,
                                              train_mode=False)

        train_loader = data.DataLoader(train_dataset,
                                       num_workers=self.nr_procs_train,
                                       batch_size=self.train_batch_size,
                                       shuffle=True,
                                       drop_last=True)

        valid_loader = data.DataLoader(infer_dataset,
                                       num_workers=self.nr_procs_valid,
                                       batch_size=self.infer_batch_size,
                                       shuffle=True,
                                       drop_last=False)

        # --------------------------- Training Sequence

        if self.logging:
            check_log_dir(log_dir)
        #
        device = 'cuda'

        # networksv
        input_chs = 3  # TODO: dynamic config
        # ### VGGNet

        net = EfficientNet.from_pretrained('efficientnet-b0', num_classes=2)

        #net =DenseNet(3,2)
        # load pre-trained models
        net = torch.nn.DataParallel(net).to(device)

        if self.load_network:
            saved_state = torch.load(self.save_net_path)
            net.load_state_dict(saved_state)

        # optimizers
        optimizer = optim.Adam(net.parameters(), lr=self.init_lr)
        scheduler = StepLR(optimizer, self.lr_steps, gamma=0.1)
        scheduler = LRScheduler(scheduler)
        #
        trainer = Engine(lambda engine, batch: self.train_step(
            net, batch, optimizer, device))
        valider = Engine(
            lambda engine, batch: self.infer_step(net, batch, device))

        infer_output = ['prob', 'true']
        ##

        if self.logging:
            checkpoint_handler = ModelCheckpoint(log_dir,
                                                 self.chkpts_prefix,
                                                 save_interval=1,
                                                 n_saved=100,
                                                 require_empty=False)
            # adding handlers using `trainer.add_event_handler` method API
            trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                      handler=checkpoint_handler,
                                      to_save={'net': net})

        timer = Timer(average=True)
        timer.attach(trainer,
                     start=Events.EPOCH_STARTED,
                     resume=Events.ITERATION_STARTED,
                     pause=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_COMPLETED)
        timer.attach(valider,
                     start=Events.EPOCH_STARTED,
                     resume=Events.ITERATION_STARTED,
                     pause=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_COMPLETED)

        # attach running average metrics computation
        # decay of EMA to 0.95 to match tensorpack default
        # TODO: refactor this
        RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(
            trainer, 'acc')
        RunningAverage(alpha=0.95,
                       output_transform=lambda x: x['loss']).attach(
                           trainer, 'loss')

        # attach progress bar
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=['loss'])
        pbar.attach(valider)

        # #early Stopping
        # def score_function(engine):
        #     val_acc=engine.state.metrics["valid-acc"]
        #     return val_acc
        # early_stopping_handler=EarlyStopping(patience=10,score_function=score_function,trainer=trainer)

        # adding handlers using `trainer.on` decorator API
        @trainer.on(Events.EXCEPTION_RAISED)
        def handle_exception(engine, e):
            if isinstance(e,
                          KeyboardInterrupt) and (engine.state.iteration > 1):
                engine.terminate()
                warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
                checkpoint_handler(engine, {'net_exception': net})
            else:
                raise e

        # writer for tensorboard logging
        tfwriter = None  # HACK temporary
        if self.logging:
            tfwriter = SummaryWriter(log_dir)
            json_log_file = log_dir + '/stats.json'
            with open(json_log_file, 'w') as json_file:
                json.dump({}, json_file)  # create empty file

        ### TODO refactor again
        log_info_dict = {
            'logging': self.logging,
            'optimizer': optimizer,
            'tfwriter': tfwriter,
            'json_file': json_log_file,
            'nr_classes': self.nr_classes,
            'metric_names': infer_output,
            'infer_batch_size': self.infer_batch_size  # too cumbersome
        }

        trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                  log_train_ema_results, log_info_dict)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, valider,
                                  valid_loader, log_info_dict)
        valider.add_event_handler(Events.ITERATION_COMPLETED,
                                  accumulate_outputs)

        # Setup is done. Now let's run the training
        trainer.run(train_loader, self.nr_epochs)
        return
示例#20
0
def train(run_name, forward_func, sample_func, model, train_set, val_set,
          n_epochs, batch_size, lr_i, lr_f, lr_n, sig_i, sig_f, sig_n):

    # Make the run directory
    save_dir = os.path.join('training/saved_runs', run_name)
    if run_name == 'debug':
        shutil.rmtree(save_dir, ignore_errors=True)
    os.mkdir(save_dir)

    model = model.to(device)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_i)
    lr_scheduler = utils.AnnealingStepLR(optimizer,
                                         mu_i=lr_i,
                                         mu_f=lr_f,
                                         n=lr_n)
    sigma_scheduler = utils.AnnealingStepSigma(sig_i, sig_f, sig_n)

    # Training step
    def step(engine, batch):
        model.train()

        if isinstance(batch, list):
            batch = [tensor.to(device) for tensor in batch]
        else:
            batch = batch.to(device)
        x_mu, x_q, kl = forward_func(model, batch)

        # Log likelihood
        sigma = sigma_scheduler.sigma
        lr = lr_scheduler.get_lr()[0]
        ll = Normal(x_mu, sigma).log_prob(x_q)

        likelihood = torch.mean(torch.sum(ll, dim=[1, 2, 3]))
        kl_divergence = torch.mean(torch.sum(kl, dim=[1, 2, 3]))

        # Evidence lower bound
        elbo = likelihood - kl_divergence
        loss = -elbo
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        lr_scheduler.step()
        sigma_scheduler.step()

        return {
            'elbo': elbo.item(),
            'likelihood': likelihood.item(),
            'kl': kl_divergence.item(),
            'lr': lr,
            'sigma': sigma
        }

    # Trainer and metrics
    trainer = Engine(step)
    metric_names = ['elbo', 'likelihood', 'kl', 'lr', 'sigma']
    RunningAverage(output_transform=lambda x: x['elbo']).attach(
        trainer, 'elbo')
    RunningAverage(output_transform=lambda x: x['likelihood']).attach(
        trainer, 'likelihood')
    RunningAverage(output_transform=lambda x: x['kl']).attach(trainer, 'kl')
    RunningAverage(output_transform=lambda x: x['lr']).attach(trainer, 'lr')
    RunningAverage(output_transform=lambda x: x['sigma']).attach(
        trainer, 'sigma')
    ProgressBar().attach(trainer, metric_names=metric_names)
    Timer(average=True).attach(trainer,
                               start=Events.EPOCH_STARTED,
                               resume=Events.ITERATION_STARTED,
                               pause=Events.ITERATION_COMPLETED,
                               step=Events.ITERATION_COMPLETED)

    # Model checkpointing
    checkpoint_handler = ModelCheckpoint(os.path.join(save_dir, 'checkpoints'),
                                         type(model).__name__,
                                         save_interval=1,
                                         n_saved=3,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'model': model,
                                  'optimizer': optimizer,
                                  'lr_scheduler': lr_scheduler,
                                  'sigma_scheduler': sigma_scheduler
                              })

    # Tensorbard writer
    writer = SummaryWriter(log_dir=os.path.join(save_dir, 'logs'))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_metrics(engine):
        if engine.state.iteration % 100 == 0:
            for metric, value in engine.state.metrics.items():
                writer.add_scalar('training/{}'.format(metric), value,
                                  engine.state.iteration)

    def save_images(engine, batch):
        x_mu, x_q, r = sample_func(model, batch)
        r_dim = r.shape[1]
        if isinstance(model, VVGQN):
            r = (r + 1) / 2
        r = r.view(-1, 1, int(math.sqrt(r_dim)), int(math.sqrt(r_dim)))

        x_mu = x_mu.detach().cpu().float()
        r = r.detach().cpu().float()

        writer.add_image('representation', make_grid(r), engine.state.epoch)
        writer.add_image('generation', make_grid(x_mu), engine.state.epoch)
        writer.add_image('query', make_grid(x_q), engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        model.eval()
        with torch.no_grad():
            batch = next(iter(val_loader))
            if isinstance(batch, list):
                batch = [tensor.to(device) for tensor in batch]
            else:
                batch = batch.to(device)
            x_mu, x_q, kl = forward_func(model, batch)

            # Validate at last sigma
            ll = Normal(x_mu, sigma_scheduler.sigma).log_prob(x_q)

            likelihood = torch.mean(torch.sum(ll, dim=[1, 2, 3]))
            kl_divergence = torch.mean(torch.sum(kl, dim=[1, 2, 3]))

            # Evidence lower bound
            elbo = likelihood - kl_divergence

            writer.add_scalar('validation/elbo', elbo.item(),
                              engine.state.epoch)
            writer.add_scalar('validation/likelihood', likelihood.item(),
                              engine.state.epoch)
            writer.add_scalar('validation/kl', kl_divergence.item(),
                              engine.state.epoch)

            save_images(engine, batch)

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        writer.close()
        engine.terminate()
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            import warnings
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
            checkpoint_handler(engine, {'model_exception': model})
        else:
            raise e

    start_time = time.time()
    trainer.run(train_loader, n_epochs)
    writer.close()
    end_time = time.time()
    print('Total training time: {}'.format(
        timedelta(seconds=end_time - start_time)))
示例#21
0
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--train", default="./dataset/icdar2015/train")
    parser.add_argument("--test")
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--epochs", default=100, type=int)
    parser.add_argument("--scale", default=4, type=int)
    parser.add_argument("--logdir")
    parser.add_argument("--checkpoint")
    parser.add_argument("--restore")
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--excitation",
                        choices=["cse", "sse", "scse", "none"],
                        default=None)
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    image_size = (512, 512)
    dataset = ICDAR15Dataset(os.path.join(args.train, "images"),
                             os.path.join(args.train, "labels"),
                             image_size=image_size,
                             scale=args.scale,
                             training=True)
    dataloader = data.DataLoader(dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=8)
    if args.test is not None:
        test_dataset = ICDAR15Dataset(os.path.join(args.test, "images"),
                                      os.path.join(args.test, "labels"),
                                      image_size=image_size,
                                      scale=args.scale,
                                      training=False)
    else:
        n_test = min(1000, (len(dataset) * 0.05))
        dataset, test_dataset = torch.utils.data.random_split(
            dataset, [len(dataset) - n_test, n_test])
        # indices = np.arange(len(dataset))
        # test_dataset = torch.utils.data.Subset(dataset, indices[:n_test])
        # dataset = torch.utils.data.Subset(dataset, indices[n_test:])
        # print(len(dataset), len(test_dataset))
    test_dataloader = data.DataLoader(test_dataset,
                                      batch_size=8,
                                      shuffle=False,
                                      num_workers=8)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # pixellink = net.PixelLink(args.scale, pretrained=False).to(device)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    if args.restore:
        if torch.cuda.is_available():
            map_location = None
        else:

            def map_location(storage, loc):
                return storage

        pixellink = torch.load(args.restore,
                               map_location=map_location).to(device)
    else:
        excitation_cls = {
            "cse": net.CSE,
            "sse": net.SSE,
            "scse": net.SCSE
        }.get(args.excitation, None)
        print(excitation_cls)
        pixellink = net.MobileNetV2PixelLink(
            args.scale, excitation_cls=excitation_cls).to(device)
    optimizer = torch.optim.Adam(pixellink.parameters(), lr=1e-3)

    def step_fn(training):
        def fn(engine, batch):
            if training:
                pixellink.train()
            else:
                pixellink.eval()
            with torch.set_grad_enabled(training):
                images, pos_pixel_masks, neg_pixel_masks, pixel_weights, link_masks = batch
                if training:
                    optimizer.zero_grad()
                images = images.to(device)
                pos_pixel_masks = pos_pixel_masks.to(device)
                neg_pixel_masks = neg_pixel_masks.to(device)
                pixel_weights = pixel_weights.to(device)
                link_masks = link_masks.to(device)
                pixel_input, link_input = pixellink(images)
                loss_object = net.PixelLinkLoss(pixel_input, pos_pixel_masks,
                                                neg_pixel_masks, pixel_weights,
                                                link_input, link_masks)
                # loss_object = net.PixelLinkFocalLoss(pixel_input, pos_pixel_masks, neg_pixel_masks, pixel_weights, link_input, link_masks)
                if training:
                    loss_object.loss.backward()
                    optimizer.step()
                return {
                    "loss": loss_object.loss.item(),
                    "loss/pixel": loss_object.pixel_loss.item(),
                    "loss/link": loss_object.link_loss.item(),
                    "accuracy/pixel": loss_object.pixel_accuracy,
                    "accuracy/link": np.mean(loss_object.link_accuracy),
                    "accuracy/positive_pixel":
                    loss_object.positive_pixel_accuracy,
                }

        return fn

    dummy = torch.randn(1, 3, image_size[0], image_size[1],
                        dtype=torch.float).to(device)
    writer = create_summary_writer(pixellink, dummy,
                                   os.path.join(args.logdir, "train"))
    test_writer = create_summary_writer(pixellink, dummy,
                                        os.path.join(args.logdir, "test"))

    trainer = Engine(step_fn(training=True))
    evaluator = Engine(step_fn(training=False))

    checkpoint_handler = ModelCheckpoint(
        args.checkpoint,
        "networks",
        n_saved=5,
        require_empty=False,
        score_function=lambda engine: -engine.state.metrics["loss"],
        score_name="loss")
    biggest_checkpoint_handler = ModelCheckpoint(
        args.checkpoint,
        "biggest",
        n_saved=5,
        score_function=lambda engine: engine.state.metrics["loss"],
        score_name="loss",
        require_empty=False)
    evaluator.add_event_handler(Events.COMPLETED,
                                handler=checkpoint_handler,
                                to_save={"net": pixellink})
    evaluator.add_event_handler(Events.COMPLETED,
                                handler=biggest_checkpoint_handler,
                                to_save={"net": pixellink})
    timer = Timer(average=True)

    monitoring_metrics = [
        "loss", "loss/pixel", "loss/link", "accuracy/pixel", "accuracy/link",
        "accuracy/positive_pixel"
    ]
    for metric in monitoring_metrics:

        def output_transform(m):
            def fn(x):
                return x[m]

            return fn

        RunningAverage(output_transform=output_transform(metric)).attach(
            trainer, metric)
        RunningAverage(output_transform=output_transform(metric)).attach(
            evaluator, metric)

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_logs(engine):
        if (engine.state.iteration - 1) % LOG_FREQ != 0:
            return
        for key, value in engine.state.metrics.items():
            writer.add_scalar(key, value, engine.state.iteration)

        message = "[{epoch}/{max_epoch}][{i}/{max_i}] (train)\t".format(
            epoch=engine.state.epoch,
            max_epoch=args.epochs,
            i=(engine.state.iteration % len(dataloader)),
            max_i=len(dataloader),
        )
        for key, value in engine.state.metrics.items():
            message += ' | {key}: {value}'.format(key=key,
                                                  value=str(round(value, 5)))
        pbar.log_message(message)

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_validation_results(engine):
        if (engine.state.iteration - 1) % LOG_FREQ != 0:
            return
        evaluator.run(test_dataloader)
        for key, value in evaluator.state.metrics.items():
            test_writer.add_scalar(key, value, engine.state.iteration)

        message = "[{epoch}/{max_epoch}][{i}/{max_i}] (test) \t".format(
            epoch=engine.state.epoch,
            max_epoch=args.epochs,
            i=(engine.state.iteration % len(dataloader)),
            max_i=len(dataloader),
        )
        for key, value in evaluator.state.metrics.items():
            message += ' | {key}: {value}'.format(key=key,
                                                  value=str(round(value, 5)))
        pbar.log_message(message)

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_time(engine):
        pbar.log_message("Epoch {} done. Time per batch: {:.3f}[s]".format(
            engine.state.epoch,
            timer.value(),
        ))
        timer.reset()

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")
            checkpoint_handler(engine, {"net": pixellink})
        else:
            raise e

    trainer.run(dataloader, args.epochs)
    writer.close()
    test_writer.close()
示例#22
0
    def setup(self, training_metrics):
        def metric_name(n) -> str:
            if n.endswith('Accuracy'):
                n = 'acc'
            else:
                n = n[:-6] if n.endswith('Metric') else n
            return n

        def print_metrics(metrics) -> str:
            rv = ''
            metric_keys = sorted(k for k in metrics)
            for k in metric_keys:
                if k == 'Accuracy':
                    rv += f'{metric_name(k)}: {metrics[k]:.3}'
                else:
                    rv += f'{metric_name(k)}: {metrics[k]:.6}'
            return rv

        if self.seed:
            set_seed_everywhere(self.seed, self.cuda)

        pbar = ProgressBar()

        names = []
        for k, v in training_metrics.items():
            name = f'r{k}'
            names.append(name)
            RunningAverage(v).attach(self.trainer, name)
        RunningAverage(None,
                       output_transform=lambda x: x[-1] * self.
                       loss_accumulation_steps).attach(self.trainer, 'rloss')
        names.append('rloss')
        pbar.attach(self.trainer, names)

        pbar = ProgressBar()
        pbar.attach(self.evaluator)

        # A few events handler. To add / modify the events handler, you need to extend the __init__ method of RunnerABC
        # Ignite provides the necessary abstractions and a furnished repository of useful tools

        @self.trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(trainer):
            self.evaluator.run(self.dataset_splits.val_data_loader())
            metrics = self.evaluator.state.metrics
            logger.info(
                f"Validation Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}"
            )

            if self.scheduler:
                self.scheduler.step(
                    metrics[self.loss_metric.__class__.__name__])

        @self.trainer.on(Events.COMPLETED)
        def log_test_results(trainer):
            self.evaluator.run(self.dataset_splits.test_data_loader())
            metrics = self.evaluator.state.metrics
            logger.info(
                f"Test Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}"
            )

        if self.tensorboard_logs:
            tb_logger = TensorboardLogger(log_dir=self.tensorboard_logs)
            tb_logger.attach(self.trainer,
                             log_handler=OutputHandler(
                                 tag="training",
                                 output_transform=lambda loss: {'loss': loss}),
                             event_name=Events.ITERATION_COMPLETED)
            tb_logger.attach(self.evaluator,
                             log_handler=OutputHandler(
                                 tag="validation",
                                 metric_names=["LossMetric"],
                                 another_engine=self.trainer),
                             event_name=Events.EPOCH_COMPLETED)
            tb_logger.attach(self.trainer,
                             log_handler=OptimizerParamsHandler(
                                 self.optimizer),
                             event_name=Events.ITERATION_STARTED)
            tb_logger.attach(self.trainer,
                             log_handler=WeightsScalarHandler(self.model),
                             event_name=Events.ITERATION_COMPLETED)
            tb_logger.attach(self.trainer,
                             log_handler=WeightsHistHandler(self.model),
                             event_name=Events.EPOCH_COMPLETED)
            tb_logger.attach(self.trainer,
                             log_handler=GradsScalarHandler(self.model),
                             event_name=Events.ITERATION_COMPLETED)

            # This is important to close the tensorboard file logger
            @self.trainer.on(Events.COMPLETED)
            def end_tensorboard(trainer):
                logger.info("Training completed")
                tb_logger.close()

        if self.embeddings_name:

            @self.trainer.on(Events.COMPLETED)
            def log_embeddings(trainer):
                if hasattr(self.model, self.embeddings_name) and hasattr(
                        self.dataset_splits, "vectorizer"):
                    logger.info(
                        f"Logging embeddings ({self.embeddings_name}) to Tensorboard!"
                    )
                    embeddings = getattr(self.model,
                                         self.embeddings_name).weight.data
                    metadata = [
                        str(self.dataset_splits.vectorizer.data_vocab.
                            _id2token[token_index]).encode('utf-8')
                        for token_index in range(embeddings.shape[0])
                    ]
                    self.writer.add_embedding(
                        mat=embeddings,
                        metadata=metadata,
                        global_step=self.trainer.state.epoch)
示例#23
0
def train():
    config_file = "configs/train_full_config.json"
    config = Config.from_json_file(config_file)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", config.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(config.device)
    optimizer = OpenAIAdam(model.parameters(), lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.fp16)
    if config.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[config.local_rank],
                                        output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        config, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids = tuple(
            input_tensor.to(config.device) for input_tensor in batch)
        lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels,
                                 token_type_ids, token_emotion_ids)
        loss = (lm_loss * config.lm_coef +
                mc_loss * config.mc_coef) / config.gradient_accumulation_steps
        if config.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           config.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
        if engine.state.iteration % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids = batch
            #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids,
                                  token_emotion_ids=token_emotion_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if config.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if config.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if config.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, config.lr),
                                 (config.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], config),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], config)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if config.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=config.log_dir)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(config,
                   tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=config.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if config.local_rank in [-1, 0] and config.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
        'D_x': D_x,
        'D_G_z1': D_G_z1,
        'D_G_z2': D_G_z2
    }


trainer = Engine(step)
checkpoint_handler = ModelCheckpoint(output_dir,
                                     CKPT_PREFIX,
                                     n_saved=10,
                                     require_empty=False)
timer = Timer(average=True)

# attach running average metrics
monitoring_metrics = ['errD', 'errG', 'D_x', 'D_G_z1', 'D_G_z2']
RunningAverage(alpha=alpha,
               output_transform=lambda x: x['errD']).attach(trainer, 'errD')
RunningAverage(alpha=alpha,
               output_transform=lambda x: x['errG']).attach(trainer, 'errG')
RunningAverage(alpha=alpha,
               output_transform=lambda x: x['D_x']).attach(trainer, 'D_x')
RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z1']).attach(
    trainer, 'D_G_z1')
RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z2']).attach(
    trainer, 'D_G_z2')

# attach progress bar
pbar = ProgressBar()
pbar.attach(trainer, metric_names=monitoring_metrics)


@trainer.on(Events.ITERATION_COMPLETED)
        if trainer.state.iteration % SAVE_IMAGE_EVERY_ITER == 0:
            fake_img = vutils.make_grid(
                gen_output_v.data[:64], normalize=True)
            trainer.tb.writer.add_image(
                "fake", fake_img, trainer.state.iteration)
            real_img = vutils.make_grid(
                batch_v.data[:64], normalize=True)
            trainer.tb.writer.add_image(
                "real", real_img, trainer.state.iteration)
            trainer.tb.writer.flush()
        return dis_loss.item(), gen_loss.item()

    engine = Engine(process_batch)
    tb = tb_logger.TensorboardLogger(log_dir=None)
    engine.tb = tb
    RunningAverage(output_transform=lambda out: out[0]).\
        attach(engine, "avg_loss_gen")
    RunningAverage(output_transform=lambda out: out[1]).\
        attach(engine, "avg_loss_dis")

    handler = tb_logger.OutputHandler(tag="train",
        metric_names=['avg_loss_gen', 'avg_loss_dis'])
    tb.attach(engine, log_handler=handler,
              event_name=Events.ITERATION_COMPLETED)

    @engine.on(Events.ITERATION_COMPLETED)
    def log_losses(trainer):
        if trainer.state.iteration % REPORT_EVERY_ITER == 0:
            log.info("%d: gen_loss=%f, dis_loss=%f",
                     trainer.state.iteration,
                     trainer.state.metrics['avg_loss_gen'],
                     trainer.state.metrics['avg_loss_dis'])
示例#26
0
def train(model, train_loader, eval_loaders, optimizer, loss_fn,
          n_it_max, patience, split_names, select_metric='Val accuracy_0',
          select_mode='max', viz=None, device='cpu', lr_scheduler=None, name=None, log_steps=None,
          log_epoch=False, _run=None, prepare_batch=_prepare_batch,
          single_pass=False, n_ep_max=None):

    # print(model)

    if not log_steps and not log_epoch:
        logger.warning('/!\\ No logging during training /!\\')

    if log_steps is None:
        log_steps = []

    epoch_steps = len(train_loader)
    if log_epoch:
        log_steps.append(epoch_steps)

    if single_pass:
        max_epoch = 1
    elif n_ep_max is None:
        assert n_it_max is not None
        max_epoch = int(n_it_max / epoch_steps) + 1
    else:
        assert n_it_max is None
        max_epoch = n_ep_max

    all_metrics = defaultdict(dict)
    trainer = create_supervised_trainer(model, optimizer, loss_fn,
                                        device=device,
                                        prepare_batch=prepare_batch)

    if hasattr(model, 'new_epoch_hook'):
        trainer.add_event_handler(Events.EPOCH_STARTED, model.new_epoch_hook)
    if hasattr(model, 'new_iter_hook'):
        trainer.add_event_handler(Events.ITERATION_STARTED,
                                     model.new_iter_hook)

    trainer.logger.setLevel(logging.WARNING)

    # trainer output is in the format (x, y, y_pred, loss, optionals)
    train_loss = RunningAverage(output_transform=lambda out: out[3].item(),
                                epoch_bound=True)
    train_loss.attach(trainer, 'Trainer loss')
    if hasattr(model, 's'):
        met = Average(output_transform=lambda _: float('nan') if model.s is None else model.s)
        met.attach(trainer, 'cur_s')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed, 'cur_s')

    if hasattr(model, 'arch_sampler') and model.arch_sampler.distrib_dim > 0:
        met = Average(output_transform=lambda _: float('nan') if model.cur_split is None else model.cur_split)
        met.attach(trainer, 'Trainer split')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed, 'Trainer split')
        # trainer.add_event_handler(Events.EPOCH_STARTED, met.started)
        all_ent = Average(
            output_transform=lambda out: out[-1]['arch_entropy_avg'].item())
        all_ent.attach(trainer, 'Trainer all entropy')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, all_ent.completed, 'Trainer all entropy')
        train_ent = Average(
            output_transform=lambda out: out[-1]['arch_entropy_sample'].item())
        train_ent.attach(trainer, 'Trainer sampling entropy')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, train_ent.completed, 'Trainer sampling entropy')
        trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                  lambda engine: model.check_arch_freezing(
                                      ent=train_ent.compute(),
                                      epoch=engine.state.iteration/(epoch_steps*max_epoch))
                                  )
        def log_always(engine, name):
            val = engine.state.output[-1][name]
            all_metrics[name][engine.state.iteration/epoch_steps] = val.mean().item()

        def log_always_dict(engine, name):
            for node, val in engine.state.output[-1][name].items():
                all_metrics['node {} {}'.format(node, name)][engine.state.iteration/epoch_steps] = val.mean().item()
        trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always_dict, name='arch_grads')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always_dict, name='arch_probas')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always_dict, name='node_grads')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always, name='task all_loss')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always, name='arch all_loss')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always, name='entropy all_loss')

    if n_it_max is not None:
        StopAfterIterations([n_it_max]).attach(trainer)
    # epoch_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name,
    #                          persist=True, disable=not (_run or viz))
    # epoch_pbar.attach(trainer, metric_names=['Train loss'])
    #
    # training_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name,
    #                             persist=True, disable=not (_run or viz))
    # training_pbar.attach(trainer, event_name=Events.EPOCH_COMPLETED,
    #                      closing_event_name=Events.COMPLETED)
    total_time = Timer(average=False)
    eval_time = Timer(average=False)
    eval_time.pause()
    data_time = Timer(average=False)
    forward_time = Timer(average=False)
    forward_time.attach(trainer, start=Events.EPOCH_STARTED,
                        pause=Events.ITERATION_COMPLETED,
                        resume=Events.ITERATION_STARTED,
                        step=Events.ITERATION_COMPLETED)
    epoch_time = Timer(average=False)
    epoch_time.attach(trainer, start=Events.EPOCH_STARTED,
                      pause=Events.EPOCH_COMPLETED,
                      resume=Events.EPOCH_STARTED,
                      step=Events.EPOCH_COMPLETED)

    def get_loss(y_pred, y):
        l = loss_fn(y_pred, y)
        if not torch.is_tensor(l):
            l, *l_details = l
        return l.mean()

    def get_member(x, n=0):
        if isinstance(x, (list, tuple)):
            return x[n]
        return x

    eval_metrics = {'loss': Loss(get_loss)}

    for i in range(model.n_out):
        out_trans = get_attr_transform(i)
        def extract_ys(out):
            x, y, y_pred, loss, _ = out
            return out_trans((y_pred, y))

        train_acc = Accuracy(extract_ys)
        train_acc.attach(trainer, 'Trainer accuracy_{}'.format(i))
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  train_acc.completed, 'Trainer accuracy_{}'.format(i))
        eval_metrics['accuracy_{}'.format(i)] = \
            Accuracy(output_transform=out_trans)
        # if isinstance(model, SSNWrapper):
        #     model.arch_sampler.entropy().mean()

    evaluator = create_supervised_evaluator(model, metrics=eval_metrics,
                                            device=device,
                                            prepare_batch=prepare_batch)
    last_iteration = 0
    patience_counter = 0

    best = {'value': float('inf') * 1 if select_mode == 'min' else -1,
            'iter': -1,
            'state_dict': None
            }

    def is_better(new, old):
        if select_mode == 'min':
            return new < old
        else:
            return new > old

    def log_results(evaluator, data_loader, iteration, split_name):
        evaluator.run(data_loader)
        metrics = evaluator.state.metrics

        log_metrics = {}

        for metric_name, metric_val in metrics.items():
            log_name = '{} {}'.format(split_name, metric_name)
            if viz:
                first = iteration == 0 and split_name == split_names[0]
                viz.line([metric_val], X=[iteration], win=metric_name,
                         name=log_name,
                         update=None if first else 'append',
                         opts={'title': metric_name,
                               'showlegend': True,
                               'width': 500, 'xlabel': 'iterations'})
                viz.line([metric_val], X=[iteration/epoch_steps],
                         win='{}epoch'.format(metric_name),
                         name=log_name,
                         update=None if first else 'append',
                         opts={'title': metric_name,
                               'showlegend': True,
                               'width': 500, 'xlabel': 'epoch'})
            if _run:
                _run.log_scalar(log_name, metric_val, iteration)
            log_metrics[log_name] = metric_val
            all_metrics[log_name][iteration] = metric_val

        return log_metrics

    if lr_scheduler is not None:
        @trainer.on(Events.EPOCH_COMPLETED)
        def step(_):
            lr_scheduler.step()
            # logger.warning('current lr {:.5e}'.format(
            #     optimizer.param_groups[0]['lr']))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_event(trainer):
        iteration = trainer.state.iteration if trainer.state else 0
        nonlocal last_iteration, patience_counter, best

        if not log_steps or not \
                (iteration in log_steps or iteration % log_steps[-1] == 0):
            return
        epoch_time.pause()
        eval_time.resume()
        all_metrics['training_epoch'][iteration] = iteration / epoch_steps
        all_metrics['training_iteration'][iteration] = iteration
        if hasattr(model, 'arch_sampler'):
            all_metrics['training_archs'][iteration] = \
                model.arch_sampler().squeeze().detach()
        # if hasattr(model, 'distrib_gen'):
        #     entropy = model.distrib_gen.entropy()
        #     all_metrics['entropy'][iteration] = entropy.mean().item()
        # if trainer.state and len(trainer.state.metrics) > 1:
        #     raise ValueError(trainer.state.metrics)
        all_metrics['data time'][iteration] = data_time.value()
        all_metrics['data time_ps'][iteration] = data_time.value() / max(data_time.step_count, 1.)
        all_metrics['forward time'][iteration] = forward_time.value()
        all_metrics['forward time_ps'][iteration] = forward_time.value() / max(forward_time.step_count, 1.)
        all_metrics['epoch time'][iteration] = epoch_time.value()
        all_metrics['epoch time_ps'][iteration] = epoch_time.value() / max(epoch_time.step_count, 1.)

        if trainer.state:
            # logger.warning(trainer.state.metrics)
            for metric, value in trainer.state.metrics.items():
                all_metrics[metric][iteration] = value
                if viz:
                    viz.line([value], X=[iteration], win=metric.split()[-1],
                         name=metric,
                         update=None if iteration==0 else 'append',
                         opts={'title': metric,
                               'showlegend': True,
                               'width': 500, 'xlabel': 'iterations'})

        iter_this_step = iteration - last_iteration
        for d_loader, name in zip(eval_loaders, split_names):
            if name == 'Train':
                if iteration == 0:
                    all_metrics['Trainer loss'][iteration] = float('nan')
                    all_metrics['Trainer accuracy_0'][iteration] = float('nan')
                    if hasattr(model, 'arch_sampler'):
                        all_metrics['Trainer all entropy'][iteration] = float('nan')
                        all_metrics['Trainer sampling entropy'][iteration] = float('nan')
                    # if hasattr(model, 'cur_split'):
                        all_metrics['Trainer split'][iteration] = float('nan')
                continue
            split_metrics = log_results(evaluator, d_loader, iteration, name)
            if select_metric not in split_metrics:
                continue
            if is_better(split_metrics[select_metric], best['value']):
                best['value'] = split_metrics[select_metric]
                best['iter'] = iteration
                best['state_dict'] = copy.deepcopy(model.state_dict())
                if patience > 0:
                    patience_counter = 0
            elif patience > 0:
                patience_counter += iter_this_step
                if patience_counter >= patience:
                    logger.info('#####')
                    logger.info('# Early stopping Run')
                    logger.info('#####')
                    trainer.terminate()
        last_iteration = iteration
        eval_time.pause()
        eval_time.step()
        all_metrics['eval time'][iteration] = eval_time.value()
        all_metrics['eval time_ps'][iteration] = eval_time.value() / eval_time.step_count
        all_metrics['total time'][iteration] = total_time.value()
        epoch_time.resume()

    log_event(trainer)

    #
    # @trainer.on(Events.EPOCH_COMPLETED)
    # def log_epoch(trainer):
    #     iteration = trainer.state.iteration if trainer.state else 0
    #     epoch = iteration/epoch_steps
    #     fw_t = forward_time.value()
    #     fw_t_ps = fw_t / forward_time.step_count
    #     d_t = data_time.value()
    #     d_t_ps = d_t / data_time.step_count
    #     e_t = epoch_time.value()
    #     e_t_ps = e_t / epoch_time.step_count
    #     ev_t = eval_time.value()
    #     ev_t_ps = ev_t / eval_time.step_count
    #     logger.warning('<{}> Epoch {}/{} finished (Forward: {:.3f}s({:.3f}), '
    #                    'data: {:.3f}s({:.3f}), epoch: {:.3f}s({:.3f}),'
    #                    ' Eval: {:.3f}s({:.3f}), Total: '
    #                    '{:.3f}s)'.format(type(model).__name__, epoch,
    #                                      max_epoch, fw_t, fw_t_ps, d_t, d_t_ps,
    #                                      e_t, e_t_ps, ev_t, ev_t_ps,
    #                                      total_time.value()))

    data_time.attach(trainer, start=Events.STARTED,
                     pause=Events.ITERATION_STARTED,
                     resume=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_STARTED)

    if hasattr(model, 'iter_per_epoch'):
        model.iter_per_epoch = len(train_loader)
    trainer.run(train_loader, max_epochs=max_epoch)
    return trainer.state.iteration, all_metrics, best
示例#27
0
文件: train_vae.py 项目: jihoonl/VAEs
def main():
    args = parse_args()

    logger.info('Num GPU: {}'.format(num_gpus))
    logger.info('Load Dataset')
    data = get_dataset(args.dataset, args.data_root, args.batch_size)
    data1, _ = data['train'][0]

    dims = list(data1.shape)
    param = dict(zdim=args.zdim, hdim=args.hdim, quant=args.quantization)
    model, optimizer = get_model(args.model, args.learning_rate, param, *dims)

    model = torch.nn.DataParallel(model) if num_gpus > 1 else model
    model.to(device)
    logger.info(model)

    kwargs = {
        'pin_memory': True if use_gpu else False,
        'shuffle': True,
        'num_workers': num_gpus * 4
    }

    logdir = get_logdir_name(args, param)
    logger.info('Log Dir: {}'.format(logdir))
    writer = SummaryWriter(logdir)

    os.makedirs(logdir, exist_ok=True)

    train_loader = DataLoader(data['train'], args.batch_size, **kwargs)
    kwargs['shuffle'] = False
    test_loader = DataLoader(data['test'], args.batch_size, **kwargs)

    if args.quantization:
        q = Quantization(device=device)
    else:
        q = Dummy()

    def get_recon_error(recon, x, sigma):
        if x.shape[1] == 1:  # Binary image
            ll = Bernoulli(recon).log_prob(x)
        elif x.shape[1] == 3:  # RGB image
            ll = Normal(recon, sigma).log_prob(x)
        else:
            NotImplementedError('X must be either 1 or 3')
        return -ll.sum()

    def step(engine, batch):
        model.train()
        x, _ = batch
        x = x.to(device)
        x_quant = q.preprocess(x)

        recon, kl = model(x_quant)

        nll = get_recon_error(recon, x,
                              sigma(engine.state.epoch, args.sigma_switch))
        loss = nll + kl
        elbo = -loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr = optimizer.param_groups[0]['lr']
        ret = {
            'elbo': elbo.item() / len(x),
            'nll': nll.item() / len(x),
            'kl': kl.item() / len(x),
            'lr': lr,
            'sigma': sigma(engine.state.epoch, args.sigma_switch)
        }
        return ret

    trainer = Engine(step)
    metric_names = ['elbo', 'nll', 'kl', 'lr', 'sigma']

    RunningAverage(output_transform=lambda x: x['elbo']).attach(trainer, 'elbo')
    RunningAverage(output_transform=lambda x: x['nll']).attach(trainer, 'nll')
    RunningAverage(output_transform=lambda x: x['kl']).attach(trainer, 'kl')
    RunningAverage(output_transform=lambda x: x['lr']).attach(trainer, 'lr')
    RunningAverage(output_transform=lambda x: x['sigma']).attach(
        trainer, 'sigma')

    ProgressBar().attach(trainer, metric_names=metric_names)
    Timer(average=True).attach(trainer)

    add_events(trainer, model, writer, logdir, args.log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        model.eval()

        val_elbo = 0
        val_kl = 0
        val_nll = 0

        with torch.no_grad():
            for i, (x, _) in enumerate(test_loader):
                x = x.to(device)
                x_quant = q.preprocess(x)
                recon, kl = model(x_quant)
                nll = get_recon_error(
                    recon, x, sigma(engine.state.epoch, args.sigma_switch))
                loss = nll + kl
                elbo = -loss

                val_elbo += elbo
                val_kl += kl
                val_nll += nll
                if i == 0:
                    batch, *xdims = x.shape
                    row = 8
                    n = min(x.shape[0], row)
                    comparison = torch.cat([x[:n], recon[:n]])
                    grid = make_grid(comparison.detach().cpu().float(),
                                     nrow=row)
                    writer.add_image('val/reconstruction', grid,
                                     engine.state.iteration)
            val_elbo /= len(test_loader.dataset)
            val_kl /= len(test_loader.dataset)
            val_nll /= len(test_loader.dataset)
            writer.add_scalar('val/elbo', val_elbo.item(),
                              engine.state.iteration)
            writer.add_scalar('val/kl', val_kl.item(), engine.state.iteration)
            writer.add_scalar('val/nll', val_nll.item(), engine.state.iteration)
            print('{:3d} /{:3d} : ELBO: {:.4f}, KL: {:.4f}, NLL: {:.4f}'.format(
                engine.state.epoch, engine.state.max_epochs, val_elbo, val_kl,
                val_nll))

    @trainer.on(Events.EXCEPTION_RAISED)
    def handler_exception(engine, e):
        writer.close()
        engine.terminate()
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            logger.warn('KeyboardInterrupt caught. Exiting gracefully.')
        else:
            raise e

    logger.info(
        'Start training. Max epoch = {}, Batch = {}, # Trainset = {}'.format(
            args.epoch, args.batch_size, len(data['train'])))
    trainer.run(train_loader, args.epoch)
    logger.info('Done training')
    writer.close()
示例#28
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.device_count() > 1 else "cpu")
    model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    DISTRIBUTED = args.local_rank != -1

    if DISTRIBUTED and torch.distributed.is_available():
        print("Distributed")
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        #BATCH_SIZE *= 2

    def average_distributed_scalar(scalar):
        if (not DISTRIBUTED):
            return scalar
        scalar_t = torch.tensor(
            scalar, dtype=torch.float,
            device=device) / torch.distributed.get_world_size()
        torch.distributed.all_reduce(scalar_t,
                                     op=torch.distributed.ReduceOp.SUM)
        return scalar_t.item()

    optimizer = AdamW(model.parameters(), lr=6.25e-5)

    ds = dataloader.Conv_GPT2_DataClass(tokenizer)
    v_ds = dataloader.Conv_GPT2_DataClass(tokenizer, dev=True)
    orig_added_tokens = len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens(
        dataloader.ATTR_SPECIAL_TOKENS)
    if (num_added_tokens > 0):
        model.resize_token_embeddings(new_num_tokens=orig_added_tokens +
                                      num_added_tokens)
    model = model.to(device)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        ds) if DISTRIBUTED else None
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        v_ds) if DISTRIBUTED else None

    dl = DataLoader(ds,
                    sampler=train_sampler,
                    batch_size=BATCH_SIZE,
                    shuffle=not DISTRIBUTED)
    v_dl = DataLoader(v_ds, sampler=valid_sampler, shuffle=False)

    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"]),
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])

    def update(engine, batch):
        model.train()
        batch = tuple(t.to(device) for t in batch)
        lm_loss, *_ = model(batch[0],
                            token_type_ids=batch[1],
                            lm_labels=batch[2])
        loss = lm_loss / ITERATION_STEP
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        if engine.state.iteration % ITERATION_STEP == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(t.to(device) for t in batch)
            input_ids, token_type_ids, lm_labels = batch
            model_outputs = model(input_ids, token_type_ids=token_type_ids)
            lm_logits = model_outputs[0]
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return lm_logits_flat_shifted, lm_labels_flat_shifted

    trainer = Engine(update)
    evaluator = Engine(inference)

    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, 6.25e-5),
                                 (EPOCHS * len(ds) // BATCH_SIZE, 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(v_dl))

    if DISTRIBUTED:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        #evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")

    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    if (args.local_rank in [0, -1]):
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        #evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir='./logs')
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        #tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint('./checkpoint',
                                             '_checkpoint',
                                             n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                  {'gpt2_qg': getattr(model, 'module', model)})

        getattr(model, 'module', model).config.to_json_file(
            os.path.join('./checkpoint', 'config'))
        tokenizer.save_pretrained('./checkpoint')

    trainer.run(dl, max_epochs=EPOCHS)

    if (args.local_rank in [0, -1]):
        tb_logger.close()
示例#29
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    #parser.add_argument("--model_checkpoint", type=str, default="/home/rohola/codes/transfer-learning-conv-ai/runs/Jun18_10-40-49_rohola-pc", help="Path, url or short name of the model")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=2,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=1,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=20,
                        help="Number of training epochs")
    parser.add_argument("--personality_permutations",
                        type=int,
                        default=1,
                        help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument(
        "--device",
        type=str,
        default="cuda:0" if torch.cuda.is_available() else "cpu",
        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument(
        "--log_dir",
        type=str,
        default="",
        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(args.device)
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer

    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        #persona_ids, history_ids, reply_ids, mc_token_ids, lm_labels, mc_labels, history_token_type = batch
        lm_loss, mc_loss = model(*batch)
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            #input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            persona_ids, history_ids, reply_ids, mc_token_ids, lm_labels, mc_labels, history_token_type = batch
            #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            #model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids, past=engine.state.past)
            model_outputs = model(persona_ids,
                                  history_ids,
                                  reply_ids,
                                  mc_token_ids,
                                  history_token_type=history_token_type)
            lm_logits, mc_logits = model_outputs[0], model_outputs[1]
            #engine.state.presents = presents
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=args.log_dir)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
示例#30
0
def main(
    dataset,
    dataroot,
    z_dim,
    g_filters,
    d_filters,
    batch_size,
    epochs,
    learning_rate,
    beta_1,
    saved_G,
    saved_D,
    seed,
    n_workers,
    device,
    alpha,
    output_dir,
):

    # seed
    check_manual_seed(seed)

    # data
    dataset, num_channels = check_dataset(dataset, dataroot)
    loader = data.DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=n_workers,
                             drop_last=True)

    # netowrks
    netG = Generator(z_dim, g_filters, num_channels).to(device)
    netD = Discriminator(num_channels, d_filters).to(device)

    # criterion
    bce = nn.BCELoss()

    # optimizers
    optimizerG = optim.Adam(netG.parameters(),
                            lr=learning_rate,
                            betas=(beta_1, 0.999))
    optimizerD = optim.Adam(netD.parameters(),
                            lr=learning_rate,
                            betas=(beta_1, 0.999))

    # load pre-trained models
    if saved_G:
        netG.load_state_dict(torch.load(saved_G))

    if saved_D:
        netD.load_state_dict(torch.load(saved_D))

    # misc
    real_labels = torch.ones(batch_size, device=device)
    fake_labels = torch.zeros(batch_size, device=device)
    fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)

    def get_noise():
        return torch.randn(batch_size, z_dim, 1, 1, device=device)

    # The main function, processing a batch of examples
    def step(engine, batch):

        # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
        real, _ = batch
        real = real.to(device)

        # -----------------------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        # train with real
        output = netD(real)
        errD_real = bce(output, real_labels)
        D_x = output.mean().item()

        errD_real.backward()

        # get fake image from generator
        noise = get_noise()
        fake = netG(noise)

        # train with fake
        output = netD(fake.detach())
        errD_fake = bce(output, fake_labels)
        D_G_z1 = output.mean().item()

        errD_fake.backward()

        # gradient update
        errD = errD_real + errD_fake
        optimizerD.step()

        # -----------------------------------------------------------
        # (2) Update G network: maximize log(D(G(z)))
        netG.zero_grad()

        # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
        output = netD(fake)
        errG = bce(output, real_labels)
        D_G_z2 = output.mean().item()

        errG.backward()

        # gradient update
        optimizerG.step()

        return {
            "errD": errD.item(),
            "errG": errG.item(),
            "D_x": D_x,
            "D_G_z1": D_G_z1,
            "D_G_z2": D_G_z2
        }

    # ignite objects
    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         CKPT_PREFIX,
                                         n_saved=10,
                                         require_empty=False)
    timer = Timer(average=True)

    # attach running average metrics
    monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"]
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errD"]).attach(
        trainer, "errD")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errG"]).attach(
        trainer, "errG")
    RunningAverage(alpha=alpha,
                   output_transform=lambda x: x["D_x"]).attach(trainer, "D_x")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z1"]).attach(
        trainer, "D_G_z1")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z2"]).attach(
        trainer, "D_G_z2")

    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.ITERATION_COMPLETED(every=PRINT_FREQ))
    def print_logs(engine):
        fname = os.path.join(output_dir, LOGS_FNAME)
        columns = [
            "iteration",
        ] + list(engine.state.metrics.keys())
        values = [
            str(engine.state.iteration),
        ] + [str(round(value, 5)) for value in engine.state.metrics.values()]

        with open(fname, "a") as f:
            if f.tell() == 0:
                print("\t".join(columns), file=f)
            print("\t".join(values), file=f)

        message = "[{epoch}/{max_epoch}][{i}/{max_i}]".format(
            epoch=engine.state.epoch,
            max_epoch=epochs,
            i=(engine.state.iteration % len(loader)),
            max_i=len(loader))
        for name, value in zip(columns, values):
            message += " | {name}: {value}".format(name=name, value=value)

        pbar.log_message(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = os.path.join(output_dir,
                            FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = os.path.join(output_dir,
                            REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # adding handlers using `trainer.add_event_handler` method API
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "netG": netG,
                                  "netD": netD
                              })

    # automatically adding handlers via a special `attach` method of `Timer` handler
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message("Epoch {} done. Time per batch: {:.3f}[s]".format(
            engine.state.epoch, timer.value()))
        timer.reset()

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl

            mpl.use("agg")

            import numpy as np
            import pandas as pd
            import matplotlib.pyplot as plt

        except ImportError:
            warnings.warn(
                "Loss plots will not be generated -- pandas or matplotlib not found"
            )

        else:
            df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME),
                             delimiter="\t",
                             index_col="iteration")
            _ = df.plot(subplots=True, figsize=(20, 20))
            _ = plt.xlabel("Iteration number")
            fig = plt.gcf()
            path = os.path.join(output_dir, PLOT_FNAME)

            fig.savefig(path)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            create_plots(engine)
            checkpoint_handler(engine, {
                "netG_exception": netG,
                "netD_exception": netD
            })

        else:
            raise e

    # Setup is done. Now let's run the training
    trainer.run(loader, epochs)