Beispiel #1
0
def training(config,
             local_rank=None,
             with_mlflow_logging=False,
             with_plx_logging=False):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    train_loader = config.train_loader
    train_sampler = getattr(train_loader, "sampler", None)
    assert train_sampler is not None, "Train loader of type '{}' " \
                                      "should have attribute 'sampler'".format(type(train_loader))
    assert hasattr(train_sampler, 'set_epoch') and callable(train_sampler.set_epoch), \
        "Train sampler should have a callable method `set_epoch`"

    train_eval_loader = config.train_eval_loader
    val_loader = config.val_loader

    model = config.model.to(device)
    optimizer = config.optimizer
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=getattr(
                                          config, "fp16_opt_level", "O2"),
                                      num_losses=1)
    model = DDP(model, delay_allreduce=True)
    criterion = config.criterion.to(device)

    prepare_batch = getattr(config, "prepare_batch", _prepare_batch)
    non_blocking = getattr(config, "non_blocking", True)

    # Setup trainer
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    def train_update_function(engine, batch):

        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y)

        if isinstance(loss, Mapping):
            assert 'supervised batch loss' in loss
            loss_dict = loss
            output = {k: v.item() for k, v in loss_dict.items()}
            loss = loss_dict['supervised batch loss'] / accumulation_steps
        else:
            output = {'supervised batch loss': loss.item()}

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return output

    output_names = getattr(config, "output_names", [
        'supervised batch loss',
    ])

    trainer = Engine(train_update_function)
    common.setup_common_distrib_training_handlers(
        trainer,
        train_sampler,
        to_save={
            'model': model,
            'optimizer': optimizer
        },
        save_every_iters=1000,
        output_path=config.output_path.as_posix(),
        lr_scheduler=config.lr_scheduler,
        with_gpu_stats=True,
        output_names=output_names,
        with_pbars=True,
        with_pbar_on_iters=with_mlflow_logging,
        log_every_iters=1)

    # Setup evaluators
    num_classes = config.num_classes
    cm_metric = ConfusionMatrix(num_classes=num_classes)

    val_metrics = {
        "IoU": IoU(cm_metric),
        "mIoU_bg": mIoU(cm_metric),
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    evaluator_args = dict(model=model,
                          metrics=val_metrics,
                          device=device,
                          non_blocking=non_blocking,
                          prepare_batch=prepare_batch,
                          output_transform=lambda x, y, y_pred: (
                              model_output_transform(y_pred),
                              y,
                          ))
    train_evaluator = create_supervised_evaluator(**evaluator_args)
    evaluator = create_supervised_evaluator(**evaluator_args)

    if dist.get_rank() == 0 and with_mlflow_logging:
        ProgressBar(persist=False,
                    desc="Train Evaluation").attach(train_evaluator)
        ProgressBar(persist=False, desc="Val Evaluation").attach(evaluator)

    def run_validation(_):
        train_evaluator.run(train_eval_loader)
        evaluator.run(val_loader)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1)),
        run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    score_metric_name = "mIoU_bg"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    if dist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(config.output_path.as_posix(),
                                            trainer,
                                            optimizer,
                                            evaluators={
                                                "training": train_evaluator,
                                                "validation": evaluator
                                            })
        if with_mlflow_logging:
            common.setup_mlflow_logging(trainer,
                                        optimizer,
                                        evaluators={
                                            "training": train_evaluator,
                                            "validation": evaluator
                                        })

        if with_plx_logging:
            common.setup_plx_logging(trainer,
                                     optimizer,
                                     evaluators={
                                         "training": train_evaluator,
                                         "validation": evaluator
                                     })

        common.save_best_model_by_val_score(config.output_path.as_posix(),
                                            evaluator,
                                            model,
                                            metric_name=score_metric_name,
                                            trainer=trainer)

        # Log train/val predictions:
        tb_logger.attach(evaluator,
                         log_handler=predictions_gt_images_handler(
                             img_denormalize_fn=config.img_denormalize,
                             n_images=15,
                             another_engine=trainer,
                             prefix_tag="validation"),
                         event_name=Events.EPOCH_COMPLETED)

        log_train_predictions = getattr(config, "log_train_predictions", False)
        if log_train_predictions:
            tb_logger.attach(train_evaluator,
                             log_handler=predictions_gt_images_handler(
                                 img_denormalize_fn=config.img_denormalize,
                                 n_images=15,
                                 another_engine=trainer,
                                 prefix_tag="validation"),
                             event_name=Events.EPOCH_COMPLETED)

    trainer.run(train_loader, max_epochs=config.num_epochs)
Beispiel #2
0
def attach_decorators(trainer, SR, feature_extractor, domain_classifier,
                      resolution_classifier, sr_classif_critic, optim, loader):
    timer = Timer(average=True)

    checkpoint_handler = ModelCheckpoint(
        args.output_dir + '/checkpoints/domain_adaptation_training/',
        'training',
        save_interval=1,
        n_saved=300,
        require_empty=False,
        iteration=args.epoch_c)

    monitoring_metrics = [
        'tgt_loss', 'src_loss', 'sr_loss', 'loss', 'GP', 'res_down_loss',
        'res_up_loss', 'tv_loss', 'vgg_loss'
    ]
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['tgt_loss']).attach(
                       trainer, 'tgt_loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['src_loss']).attach(
                       trainer, 'src_loss')
    RunningAverage(alpha=0.98, output_transform=lambda x: x['sr_loss']).attach(
        trainer, 'sr_loss')
    RunningAverage(alpha=0.98, output_transform=lambda x: x['loss']).attach(
        trainer, 'loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['GP']).attach(trainer, 'GP')
    # RunningAverage(alpha=0.98, output_transform=lambda x: x['g_loss']).attach(trainer, 'g_loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['res_down_loss']).attach(
                       trainer, 'res_down_loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['res_up_loss']).attach(
                       trainer, 'res_up_loss')
    RunningAverage(alpha=0.98, output_transform=lambda x: x['tv_loss']).attach(
        trainer, 'tv_loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['vgg_loss']).attach(
                       trainer, 'vgg_loss')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED,
        handler=checkpoint_handler,
        to_save={
            'feature_extractor': feature_extractor,
            'SR': SR,
            # 'optim_feature': optim_feature,
            # 'optim_domain_classif': optim_domain_classif,
            # 'optim_res_classif': optim_res_classif,
            'optim': optim,
            # 'optim_sr_critic': optim_sr_critic,
            'domain_D': domain_classifier,
            'res_D': resolution_classifier,
            'sr_D': sr_classif_critic
        })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_logs(engine):
        if (engine.state.iteration - 1) % PRINT_FREQ == 0:
            fname = os.path.join(args.output_dir, LOGS_FNAME)
            columns = engine.state.metrics.keys()
            values = [
                str(round(value, 5))
                for value in engine.state.metrics.values()
            ]

            with open(fname, 'a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

            i = (engine.state.iteration % len(loader))
            message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(
                epoch=engine.state.epoch,
                max_epoch=args.epochs,
                i=i,
                max_i=len(loader))
            for name, value in zip(columns, values):
                message += ' | {name}: {value}'.format(name=name, value=value)

            pbar.log_message(message)

    @trainer.on(Events.ITERATION_COMPLETED)
    def save_real_example(engine):
        if (engine.state.iteration - 1) % PRINT_FREQ == 0:
            if (engine.state.iteration - 1) % PRINT_FREQ == 0:
                if not os.path.exists(args.output_dir +
                                      '/imgs/domain_adaptation_training/'):
                    os.makedirs(args.output_dir +
                                '/imgs/domain_adaptation_training/')
            px, py, px2, py2, px_up, _, px2_up, _ = engine.state.batch
            img = SR(feature_extractor(px2.cuda()))
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                predtgt_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(img, path)
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                targetY_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(py2, path)
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                targetX_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(px2, path)
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                sourceX_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(px, path)
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                sourceY_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(py, path)
            img = SR(feature_extractor(px.cuda()))
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                predsrc_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(img, path)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')

            checkpoint_handler(
                engine, {
                    'feature_extractor_{}'.format(engine.state.iteration):
                    feature_extractor,
                    'SR_{}'.format(engine.state.iteration): SR,
                    'DOMAIN_D_{}'.format(engine.state.iteration):
                    domain_classifier,
                    'RES_D_{}'.format(engine.state.iteration):
                    resolution_classifier,
                    'SR_D_{}'.format(engine.state.iteration):
                    sr_classif_critic,
                    'OPTIM_{}'.format(engine.state.iteration): optim
                })

        else:
            raise e

    @trainer.on(Events.STARTED)
    def loaded(engine):
        if args.epoch_c != 0:
            engine.state.epoch = args.epoch_c
            engine.state.iteration = args.epoch_c * len(loader)
Beispiel #3
0
    engine.add_event_handler(
        ignite.engine.Events.ITERATION_STARTED,
        handlers.calculate_gdl_lambda
    )
    engine.add_event_handler(
        ignite.engine.Events.ITERATION_COMPLETED(every=args.log_every),
        handlers.log_summaries,
        torch.utils.tensorboard.SummaryWriter(logs_directory),
    )
    engine.add_event_handler(
        ignite.engine.Events.ITERATION_COMPLETED(every=args.checkpoint_every),
        handlers.save_checkpoint,
        model,
        optimizer,
        lr_scheduler,
        amp,
        args.checkpoint_last,
        checkpoint_directory,
    )
else:
    engine.add_event_handler(
        ignite.engine.Events.ITERATION_COMPLETED,
        handlers.save_output,
        outputs_directory,
    )

pbar = ProgressBar()
pbar.attach(engine, output_transform=lambda output: {"loss": output[("loss")]})

e = engine.run(data=data_loader, max_epochs=args.epochs if args.mode == utils.TRAINING else 1)
device = torch.device(args.device)

tfms = albu.Compose([
    albu.Resize(256, 256),
    albu.CenterCrop(224, 224),
    albu.Normalize(),
    ToTensor(),
])

dataset_dir = os.path.join(os.environ.get('DATASET_DIR'), 'imagenet')
dataset = Imagenet(root_dir=dataset_dir,
                   split='val',
                   transforms=tfms)


train_loader = DataLoader(dataset, batch_size=args.batch_size,
                          shuffle=False, num_workers=8)

model = (getattr(vgg, args.model))(3, 1000)
state_dict = torch.load(args.state_dict)
model.load_state_dict(state_dict)
model = model.cuda()

evaluator = create_classification_evaluator(model, device=device)
ProgressBar(persist=True).attach(evaluator)


state = evaluator.run(train_loader)

print(state.metrics)
Beispiel #5
0
    def train(self, config, **kwargs):
        """Trains a model on the given configurations.
        :param config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE
        :param **kwargs: parameters to overwrite yaml config
        """

        from pycocoevalcap.cider.cider import Cider

        conf = train_util.parse_config_or_kwargs(config, **kwargs)
        conf["seed"] = self.seed
        outputdir = os.path.join(
            conf["outputpath"], conf["model"],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))

        # Early init because of creating dir
        checkpoint_handler = ModelCheckpoint(
            outputdir,
            "run",
            n_saved=1,
            require_empty=False,
            create_dir=True,
            score_function=lambda engine: engine.state.metrics["score"],
            score_name="loss")

        logger = train_util.genlogger(os.path.join(outputdir, "train.log"))
        # print passed config parameters
        logger.info("Storing files in: {}".format(outputdir))
        train_util.pprint_dict(conf, logger.info)

        zh = conf["zh"]
        vocabulary = torch.load(conf["vocab_file"])
        train_loader, val_loader, info = self._get_dataloaders(conf, vocabulary)
        conf["inputdim"] = info["inputdim"]
        val_key2refs = info["val_key2refs"]
        logger.info("<== Estimating Scaler ({}) ==>".format(info["scaler"].__class__.__name__))
        logger.info(
            "Feature: {} Input dimension: {} Vocab Size: {}".format(
                conf["feature_file"], info["inputdim"], len(vocabulary)))

        model = self._get_model(conf, len(vocabulary))
        model = model.to(self.device)
        train_util.pprint_dict(model, logger.info, formatter="pretty")
        optimizer = getattr(
            torch.optim, conf["optimizer"]
        )(model.parameters(), **conf["optimizer_args"])
        train_util.pprint_dict(optimizer, logger.info, formatter="pretty")


        XE_criterion = torch.nn.CrossEntropyLoss().to(self.device)
        seq_criterion = torch.nn.CosineEmbeddingLoss().to(self.device)
        crtrn_imprvd = train_util.criterion_improver(conf['improvecriterion'])

        def _train_batch(engine, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch, "train", ss_ratio=conf["ss_args"]["ss_ratio"])
                XE_loss = XE_criterion(output["packed_logits"], output["word_targets"])
                seq_loss = seq_criterion(output["seq_outputs"], output["sentence_targets"], torch.ones(batch[0].shape[0]).to(self.device))
                loss = XE_loss + seq_loss * conf["seq_loss_ratio"]
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()
                output["XE_loss"] = XE_loss.item()
                output["seq_loss"] = seq_loss.item()
                output["loss"] = loss.item()
                return output

        trainer = Engine(_train_batch)
        RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "running_loss")
        pbar = ProgressBar(persist=False, ascii=True, ncols=100)
        pbar.attach(trainer, ["running_loss"])

        key2pred = {}

        def _inference(engine, batch):
            model.eval()
            keys = batch[3]
            with torch.no_grad():
                output = self._forward(model, batch, "validation")
                output["seq_loss"] = seq_criterion(output["seq_outputs"], output["sentence_targets"], torch.ones(len(keys)).to(self.device))
                seqs = output["seqs"].cpu().numpy()
                for (idx, seq) in enumerate(seqs):
                    if keys[idx] in key2pred:
                        continue
                    candidate = self._convert_idx2sentence(seq, vocabulary, zh)
                    key2pred[keys[idx]] = [candidate,]
                return output

        metrics = {
            "loss": Average(output_transform=lambda x: x["loss"]),
            "XE_loss": Average(output_transform=lambda x: x["XE_loss"]),
            "seq_loss": Average(output_transform=lambda x: x["seq_loss"]),
        }

        evaluator = Engine(_inference)

        def eval_val(engine, key2pred, key2refs):
            scorer = Cider(zh=zh)
            score, scores = scorer.compute_score(key2refs, key2pred)
            engine.state.metrics["score"] = score
            key2pred.clear()

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, eval_val, key2pred, val_key2refs)

        for name, metric in metrics.items():
            metric.attach(trainer, name)

        metrics["seq_loss"].attach(evaluator, "seq_loss")
            
        trainer.add_event_handler(
              Events.EPOCH_COMPLETED, train_util.log_results, evaluator, val_loader,
              logger.info, metrics.keys(), ["seq_loss", "score"])

        if conf["ss"]:
            trainer.add_event_handler(
                Events.GET_BATCH_COMPLETED, train_util.update_ss_ratio, conf, len(train_loader))


        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, train_util.save_model_on_improved, crtrn_imprvd,
            "score", {
                "model": model.state_dict(),
                "config": conf,
                "scaler": info["scaler"]
        }, os.path.join(outputdir, "saved.pth"))

        scheduler = getattr(torch.optim.lr_scheduler, conf["scheduler"])(
            optimizer, **conf["scheduler_args"])
        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, train_util.update_lr,
            scheduler, "score")

        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler, {
                "model": model,
            }
        )

        trainer.run(train_loader, max_epochs=conf["epochs"])
        return outputdir
Beispiel #6
0
def main(
    dataset,
    dataroot,
    z_dim,
    g_filters,
    d_filters,
    batch_size,
    epochs,
    learning_rate,
    beta_1,
    saved_G,
    saved_D,
    seed,
    n_workers,
    device,
    alpha,
    output_dir,
):

    # seed
    check_manual_seed(seed)

    # data
    dataset, num_channels = check_dataset(dataset, dataroot)
    loader = data.DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=n_workers,
                             drop_last=True)

    # netowrks
    netG = Generator(z_dim, g_filters, num_channels).to(device)
    netD = Discriminator(num_channels, d_filters).to(device)

    # criterion
    bce = nn.BCELoss()

    # optimizers
    optimizerG = optim.Adam(netG.parameters(),
                            lr=learning_rate,
                            betas=(beta_1, 0.999))
    optimizerD = optim.Adam(netD.parameters(),
                            lr=learning_rate,
                            betas=(beta_1, 0.999))

    # load pre-trained models
    if saved_G:
        netG.load_state_dict(torch.load(saved_G))

    if saved_D:
        netD.load_state_dict(torch.load(saved_D))

    # misc
    real_labels = torch.ones(batch_size, device=device)
    fake_labels = torch.zeros(batch_size, device=device)
    fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)

    def get_noise():
        return torch.randn(batch_size, z_dim, 1, 1, device=device)

    # The main function, processing a batch of examples
    def step(engine, batch):

        # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
        real, _ = batch
        real = real.to(device)

        # -----------------------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        # train with real
        output = netD(real)
        errD_real = bce(output, real_labels)
        D_x = output.mean().item()

        errD_real.backward()

        # get fake image from generator
        noise = get_noise()
        fake = netG(noise)

        # train with fake
        output = netD(fake.detach())
        errD_fake = bce(output, fake_labels)
        D_G_z1 = output.mean().item()

        errD_fake.backward()

        # gradient update
        errD = errD_real + errD_fake
        optimizerD.step()

        # -----------------------------------------------------------
        # (2) Update G network: maximize log(D(G(z)))
        netG.zero_grad()

        # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
        output = netD(fake)
        errG = bce(output, real_labels)
        D_G_z2 = output.mean().item()

        errG.backward()

        # gradient update
        optimizerG.step()

        return {
            "errD": errD.item(),
            "errG": errG.item(),
            "D_x": D_x,
            "D_G_z1": D_G_z1,
            "D_G_z2": D_G_z2
        }

    # ignite objects
    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         CKPT_PREFIX,
                                         n_saved=10,
                                         require_empty=False)
    timer = Timer(average=True)

    # attach running average metrics
    monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"]
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errD"]).attach(
        trainer, "errD")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errG"]).attach(
        trainer, "errG")
    RunningAverage(alpha=alpha,
                   output_transform=lambda x: x["D_x"]).attach(trainer, "D_x")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z1"]).attach(
        trainer, "D_G_z1")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z2"]).attach(
        trainer, "D_G_z2")

    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.ITERATION_COMPLETED(every=PRINT_FREQ))
    def print_logs(engine):
        fname = os.path.join(output_dir, LOGS_FNAME)
        columns = [
            "iteration",
        ] + list(engine.state.metrics.keys())
        values = [
            str(engine.state.iteration),
        ] + [str(round(value, 5)) for value in engine.state.metrics.values()]

        with open(fname, "a") as f:
            if f.tell() == 0:
                print("\t".join(columns), file=f)
            print("\t".join(values), file=f)
        message = f"[{engine.state.epoch}/{epochs}][{engine.state.iteration % len(loader)}/{len(loader)}]"
        for name, value in zip(columns, values):
            message += f" | {name}: {value}"

        pbar.log_message(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = os.path.join(output_dir,
                            FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = os.path.join(output_dir,
                            REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # adding handlers using `trainer.add_event_handler` method API
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "netG": netG,
                                  "netD": netD
                              })

    # automatically adding handlers via a special `attach` method of `Timer` handler
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl

            mpl.use("agg")

            import matplotlib.pyplot as plt
            import numpy as np
            import pandas as pd

        except ImportError:
            warnings.warn(
                "Loss plots will not be generated -- pandas or matplotlib not found"
            )

        else:
            df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME),
                             delimiter="\t",
                             index_col="iteration")
            _ = df.plot(subplots=True, figsize=(20, 20))
            _ = plt.xlabel("Iteration number")
            fig = plt.gcf()
            path = os.path.join(output_dir, PLOT_FNAME)

            fig.savefig(path)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            create_plots(engine)
            checkpoint_handler(engine, {
                "netG_exception": netG,
                "netD_exception": netD
            })

        else:
            raise e

    # Setup is done. Now let's run the training
    trainer.run(loader, epochs)
def run(conf: DictConfig):
    epochs = conf.train.epochs
    epoch_length = conf.train.epoch_length
    torch.manual_seed(conf.general.seed)

    dist_conf = conf.distributed
    local_rank = dist_conf.local_rank
    backend = dist_conf.backend
    distributed = backend is not None
    use_tpu = conf.tpu.enabled

    if use_tpu:
        rank = xm.get_ordinal()
        num_replicas = xm.xrt_world_size()
        device = xm.xla_device()
    else:
        if distributed:
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
            torch.cuda.set_device(local_rank)
        else:
            rank = 0
            num_replicas = 1
            torch.cuda.set_device(conf.general.gpu)
        device = torch.device('cuda')

    if rank == 0:
        print(conf.pretty())

    if num_replicas > 1:
        epoch_length = epoch_length // num_replicas
        loader_args = dict(rank=rank, num_replicas=num_replicas)
    else:
        loader_args = dict()

    train_dl = create_train_loader(conf.data.train,
                                   epoch_length=epoch_length,
                                   **loader_args)
    valid_dl = create_val_loader(conf.data.val, **loader_args)
    train_sampler = train_dl.sampler

    if epoch_length < 1:
        epoch_length = len(train_dl)

    if use_tpu:
        train_dl = pl.ParallelLoader(train_dl, [device])
        valid_dl = pl.ParallelLoader(valid_dl, [device])

    model = instantiate(conf.model).to(device)
    if distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[
                                            local_rank,
                                        ],
                                        output_device=local_rank)
        model.to_y = model.module.to_y
    if rank == 0 and conf.logging.model:
        print(model)

    loss = instantiate(conf.loss)
    optim = instantiate(conf.optimizer,
                        filter(lambda x: x.requires_grad, model.parameters()))
    metrics = create_metrics(loss.keys(), device if distributed else None)
    build_trainer_fn = create_tpu_trainer if use_tpu else create_trainer
    trainer = build_trainer_fn(model, loss, optim, device, conf, metrics)
    evaluator = create_evaluator(model, loss, device, metrics)

    every_iteration = Events.ITERATION_COMPLETED

    if 'lr_scheduler' in conf.keys():
        # TODO: total_steps is wrong, it works only for one-cycle
        lr_scheduler = instantiate(conf.lr_scheduler,
                                   optim,
                                   total_steps=epoch_length)
        trainer.add_event_handler(every_iteration,
                                  lambda _: lr_scheduler.step())

        if isinstance(lr_scheduler, torch.optim.lr_scheduler.OneCycleLR):
            initial_state = lr_scheduler.state_dict()
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED(every=epoch_length),
                lambda _: lr_scheduler.load_state_dict(initial_state))
    else:
        lr_scheduler = None

    trainer.add_event_handler(every_iteration, TerminateOnNan())

    cp = conf.train.checkpoints
    to_save = {
        'trainer': trainer,
        'model': model.module if distributed else model,
        'optimizer': optim,
        'lr_scheduler': lr_scheduler
    }
    save_path = cp.get('base_dir', os.getcwd())

    if rank == 0:
        log_freq = conf.logging.iter_freq
        log_event = Events.ITERATION_COMPLETED(every=log_freq)
        pbar = ProgressBar(persist=False)

        for engine, name in zip([trainer, evaluator], ['train', 'val']):
            engine.add_event_handler(Events.EPOCH_STARTED, on_epoch_start)
            engine.add_event_handler(log_event, log_iter, trainer, pbar, name,
                                     log_freq)
            engine.add_event_handler(Events.EPOCH_COMPLETED, log_epoch,
                                     trainer, name)
            pbar.attach(engine, metric_names=loss.keys())

        if 'load' in cp.keys() and cp.load:
            logging.info("Resume from a checkpoint: {}".format(cp.load))
            trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp,
                                      pbar)

        logging.info("Saving checkpoints to {}".format(save_path))

    if rank == 0 or use_tpu:
        max_cp = max(int(cp.get('max_checkpoints', 1)), 1)
        Saver = TpuDiskSaver if use_tpu else DiskSaver
        save = Saver(save_path, create_dir=True, require_empty=True)
        make_checkpoint = Checkpoint(to_save, save, n_saved=max_cp)
        cp_iter = cp.interval_iteration
        cp_epoch = cp.interval_epoch
        if cp_iter > 0:
            save_event = Events.ITERATION_COMPLETED(every=cp_iter)
            trainer.add_event_handler(save_event, make_checkpoint)
        if cp_epoch > 0:
            if cp_iter < 1 or epoch_length % cp_iter:
                save_event = Events.EPOCH_COMPLETED(every=cp_epoch)
                trainer.add_event_handler(save_event, make_checkpoint)

    if 'load' in cp.keys() and cp.load:
        Checkpoint.load_objects(to_load=to_save,
                                checkpoint=torch.load(cp.load,
                                                      map_location=device))

    assert train_sampler is not None
    trainer.add_event_handler(
        Events.EPOCH_STARTED,
        lambda e: train_sampler.set_epoch(e.state.epoch - 1))

    def run_validation(e: Engine):
        if distributed:
            torch.cuda.synchronize(device)
        if use_tpu:
            xm.rendezvous('validate_{}'.format(e.state.iteration))
            valid_it = valid_dl.per_device_loader(device)
            evaluator.run(valid_it, epoch_length=len(valid_dl))
        else:
            evaluator.run(valid_dl)

    eval_event = Events.EPOCH_COMPLETED(every=conf.validate.interval)
    trainer.add_event_handler(eval_event, run_validation)

    try:
        if conf.train.skip:
            evaluator.run(valid_dl)
        else:
            loader = train_dl
            if use_tpu:
                # need to catch StopIteration before ignite, otherwise it will crash
                loader = iter(_regenerate(train_dl, device))
            trainer.run(loader, max_epochs=epochs, epoch_length=epoch_length)
    except Exception as e:
        import traceback
        print(traceback.format_exc())
    if rank == 0:
        pbar.close()
Beispiel #8
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="t5-small",
                        help="Path, url or short name of the model")
    parser.add_argument("--max_history",
                        type=int,
                        default=7,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=10,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=10,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=12,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr", type=float, default=6e-4, help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--save_name", type=str, default="")
    parser.add_argument("--mask_ratio", type=float, default=0.15)
    parser.add_argument("--objective",
                        type=str,
                        default="span_denosing",
                        help="response_generation, span_denosing, both")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer = T5Tokenizer.from_pretrained(args.model_checkpoint)
    model = T5ForConditionalGeneration.from_pretrained(args.model_checkpoint)
    model.to(args.device)
    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    def collate_fn(data):
        batch = {
            "corrupted_context": [],
            "context": [],
            "target": [],
            "response": []
        }
        padded_dataset = {}
        batch_size = len(data)
        resp_sos, context_sos = tokenizer.convert_tokens_to_ids([
            "<go_r>",
            "<go_b>",
        ])
        for x in data:
            corrupted_context = ["fill : "]
            target = []
            length = len(x["context_words"])
            mask_bool = random_spans_noise_mask(length=length,
                                                noise_density=args.mask_ratio,
                                                mean_noise_span_length=3.0)
            mask_id = 0
            #print(mask_bool)
            for i in range(length):
                if mask_bool[i]:
                    if i > 0 and mask_bool[i - 1]:
                        target.append(x["context_words"][i])
                    else:
                        target.append(f"<extra_id_{mask_id}>")
                        target.append(x["context_words"][i])
                        corrupted_context.append(f"<extra_id_{mask_id}>")
                        mask_id += 1
                else:
                    corrupted_context.append(x["context_words"][i])
            target.append("<eos_b>")
            batch["context"].append(
                tokenizer.encode("response : " + " ".join(x["context_words"])))
            batch["corrupted_context"].append(
                tokenizer.encode(" ".join(corrupted_context)))
            batch["target"].append(tokenizer.encode(" ".join(target)))
            batch["response"].append(tokenizer.encode(x["response"]))
            # print(" ".join(x["context_words"]))
            # print(" ".join(corrupted_context))
            # print(" ".join(target))
            # print("")

            # print(tokenizer.decode(batch["corrupted_context"][-1]))
            # print(tokenizer.decode(batch["target"][-1]))
            # print(tokenizer.decode(batch["response"][-1]))
            # print("")
        context_ids, context_masks = padInput(batch["context"])
        input_ids, masks = padInput(batch["corrupted_context"])
        target_ids, target_inputs = padOutput(batch["target"])
        response_ids, response_inputs = padOutput(batch["response"])
        #inputs
        padded_dataset["input_ids"] = torch.tensor(input_ids, dtype=torch.long)
        padded_dataset["masks"] = torch.tensor(masks, dtype=torch.long)
        padded_dataset["context_ids"] = torch.tensor(context_ids,
                                                     dtype=torch.long)
        padded_dataset["context_masks"] = torch.tensor(context_masks,
                                                       dtype=torch.long)
        padded_dataset["target_ids"] = torch.tensor(target_ids,
                                                    dtype=torch.long)
        padded_dataset["response_ids"] = torch.tensor(response_ids,
                                                      dtype=torch.long)
        padded_dataset["target_inputs"] = torch.tensor(np.concatenate((np.ones(
            (batch_size, 1)) * context_sos, target_inputs[:, :-1]),
                                                                      axis=1),
                                                       dtype=torch.long)
        padded_dataset["response_inputs"] = torch.tensor(np.concatenate(
            (np.ones((batch_size, 1)) * resp_sos, response_inputs[:, :-1]),
            axis=1),
                                                         dtype=torch.long)

        return padded_dataset

    logger.info("Prepare datasets")
    train_dataset, valid_dataset, train_sampler, valid_sampler = get_data(
        args, tokenizer)

    train_loader = DataLoader(train_dataset,
                              sampler=train_sampler,
                              batch_size=args.train_batch_size,
                              shuffle=(not args.distributed),
                              collate_fn=collate_fn,
                              num_workers=4)
    val_loader = DataLoader(valid_dataset,
                            sampler=valid_sampler,
                            batch_size=args.valid_batch_size,
                            shuffle=False,
                            collate_fn=collate_fn,
                            num_workers=4)

    logger.info("Train dataset length: {}".format(len(train_dataset)))
    logger.info("Valid dataset length: {}".format(len(valid_dataset)))

    # for batch in train_loader:
    #     #print(batch)
    #     exit(0)
    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(batch[input_name].to(args.device)
                      for input_name in MODEL_INPUTS)
        input_ids, masks, context_ids, context_masks, target_ids, target_inputs, response_ids, response_inputs = batch
        # print("input")
        # print(tokenizer.decode(input_ids[0, :].tolist()))
        # print("context_ids")
        # print(tokenizer.decode(context_ids[0, :].tolist()))
        # print("target")
        # print(tokenizer.decode(target_ids[0, :].tolist()))
        # print("target In")
        # print(tokenizer.decode(target_inputs[0, :].tolist()))
        # print("response_ids")
        # print(tokenizer.decode(response_ids[0, :].tolist()))
        # print("response_inputs")
        # print(tokenizer.decode(response_inputs[0, :].tolist()))
        #exit(0)
        outputs = model(input_ids,
                        attention_mask=masks,
                        decoder_input_ids=target_inputs,
                        lm_labels=target_ids)
        context_loss = outputs[0]

        outputs = model(context_ids,
                        attention_mask=context_masks,
                        decoder_input_ids=response_inputs,
                        lm_labels=response_ids)

        resp_loss = outputs[0]

        loss = (context_loss + resp_loss) / args.gradient_accumulation_steps

        loss = (context_loss) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(batch[input_name].to(args.device)
                          for input_name in MODEL_INPUTS)
            input_ids, masks, context_ids, context_masks, target_ids, target_inputs, response_ids, response_inputs = batch

            outputs = model(
                input_ids,
                attention_mask=masks,
                decoder_input_ids=target_inputs  #, lm_labels=target_ids
            )

            context_logits = outputs[0]
            outputs = model(
                context_ids,
                attention_mask=context_masks,
                decoder_input_ids=response_inputs,
                #lm_labels=response_ids
            )
            resp_logits = outputs[0]

            context_logits_flat_shifted = context_logits.view(
                -1, context_logits.size(-1))
            context_labels_flat_shifted = target_ids.view(-1)

            resp_logits_flat_shifted = resp_logits.view(
                -1, resp_logits.size(-1))
            resp_labels_flat_shifted = response_ids.view(-1)

            return (context_logits_flat_shifted,
                    resp_logits_flat_shifted), (context_labels_flat_shifted,
                                                resp_labels_flat_shifted)
            #return (context_logits_flat_shifted, context_logits_flat_shifted), (context_labels_flat_shifted, context_labels_flat_shifted)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    # if args.eval_before_start:
    #     trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "span":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "response":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_span":
        MetricsLambda(average_distributed_scalar, metrics["span"], args),
        "average_response":
        MetricsLambda(average_distributed_scalar, metrics["response"], args)
    })
    metrics["average_response"] = MetricsLambda(math.exp,
                                                metrics["average_response"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        if not os.path.exists(f"pretrained_model/{args.save_name}"):
            os.makedirs(f"pretrained_model/{args.save_name}")
        log_dir = f"pretrained_model/{args.save_name}"
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
    def run_once(self):

        log_dir = self.log_dir

        misc.check_manual_seed(self.seed)
        train_pairs, valid_pairs = dataset.prepare_data_CANCER()
        print(len(train_pairs))
        # --------------------------- Dataloader

        train_augmentors = self.train_augmentors()
        train_dataset = dataset.DatasetSerial(
            train_pairs[:],
            shape_augs=iaa.Sequential(train_augmentors[0]),
            input_augs=iaa.Sequential(train_augmentors[1]))

        infer_augmentors = self.infer_augmentors()
        infer_dataset = dataset.DatasetSerial(
            valid_pairs[:], shape_augs=iaa.Sequential(infer_augmentors))

        train_loader = data.DataLoader(train_dataset,
                                       num_workers=self.nr_procs_train,
                                       batch_size=self.train_batch_size,
                                       shuffle=True,
                                       drop_last=True)

        valid_loader = data.DataLoader(infer_dataset,
                                       num_workers=self.nr_procs_valid,
                                       batch_size=self.infer_batch_size,
                                       shuffle=True,
                                       drop_last=False)

        # --------------------------- Training Sequence

        if self.logging:
            misc.check_log_dir(log_dir)

        device = 'cuda'

        # networks
        input_chs = 3
        net = DenseNet(input_chs, self.nr_classes)
        net = torch.nn.DataParallel(net).to(device)
        # print(net)

        # optimizers
        optimizer = optim.Adam(net.parameters(), lr=self.init_lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, self.lr_steps)

        # load pre-trained models
        if self.load_network:
            saved_state = torch.load(self.save_net_path)
            net.load_state_dict(saved_state)
        #
        trainer = Engine(lambda engine, batch: self.train_step(
            net, batch, optimizer, 'cuda'))
        inferer = Engine(
            lambda engine, batch: self.infer_step(net, batch, 'cuda'))

        train_output = ['loss', 'acc']
        infer_output = ['prob', 'true']
        ##

        if self.logging:
            checkpoint_handler = ModelCheckpoint(log_dir,
                                                 self.chkpts_prefix,
                                                 save_interval=1,
                                                 n_saved=120,
                                                 require_empty=False)
            # adding handlers using `trainer.add_event_handler` method API
            trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                      handler=checkpoint_handler,
                                      to_save={'net': net})

        timer = Timer(average=True)
        timer.attach(trainer,
                     start=Events.EPOCH_STARTED,
                     resume=Events.ITERATION_STARTED,
                     pause=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_COMPLETED)
        timer.attach(inferer,
                     start=Events.EPOCH_STARTED,
                     resume=Events.ITERATION_STARTED,
                     pause=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_COMPLETED)

        # attach running average metrics computation
        # decay of EMA to 0.95 to match tensorpack default
        RunningAverage(alpha=0.95,
                       output_transform=lambda x: x['loss']).attach(
                           trainer, 'loss')
        RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(
            trainer, 'acc')

        # attach progress bar
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=['loss'])
        pbar.attach(inferer)

        # adding handlers using `trainer.on` decorator API
        @trainer.on(Events.EXCEPTION_RAISED)
        def handle_exception(engine, e):
            if isinstance(e,
                          KeyboardInterrupt) and (engine.state.iteration > 1):
                engine.terminate()
                warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
                checkpoint_handler(engine, {'net_exception': net})
            else:
                raise e

        # writer for tensorboard logging
        if self.logging:
            writer = SummaryWriter(log_dir=log_dir)
            json_log_file = log_dir + '/stats.json'
            with open(json_log_file, 'w') as json_file:
                json.dump({}, json_file)  # create empty file

        @trainer.on(Events.EPOCH_STARTED)
        def log_lrs(engine):
            if self.logging:
                lr = float(optimizer.param_groups[0]['lr'])
                writer.add_scalar("lr", lr, engine.state.epoch)
            # advance scheduler clock
            scheduler.step()

        ####
        def update_logs(output, epoch, prefix, color):
            # print values and convert
            max_length = len(max(output.keys(), key=len))
            for metric in output:
                key = colored(prefix + '-' + metric.ljust(max_length), color)
                print('------%s : ' % key, end='')
                print('%0.7f' % output[metric])
            if 'train' in prefix:
                lr = float(optimizer.param_groups[0]['lr'])
                key = colored(prefix + '-' + 'lr'.ljust(max_length), color)
                print('------%s : %0.7f' % (key, lr))

            if not self.logging:
                return

            # create stat dicts
            stat_dict = {}
            for metric in output:
                metric_value = output[metric]
                stat_dict['%s-%s' % (prefix, metric)] = metric_value

            # json stat log file, update and overwrite
            with open(json_log_file) as json_file:
                json_data = json.load(json_file)

            current_epoch = str(epoch)
            if current_epoch in json_data:
                old_stat_dict = json_data[current_epoch]
                stat_dict.update(old_stat_dict)
            current_epoch_dict = {current_epoch: stat_dict}
            json_data.update(current_epoch_dict)

            with open(json_log_file, 'w') as json_file:
                json.dump(json_data, json_file)

            # log values to tensorboard
            for metric in output:
                writer.add_scalar(prefix + '-' + metric, output[metric],
                                  current_epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_train_running_results(engine):
            """
            running training measurement
            """
            training_ema_output = engine.state.metrics  #
            update_logs(training_ema_output,
                        engine.state.epoch,
                        prefix='train-ema',
                        color='green')

        ####
        def get_init_accumulator(output_names):
            return {metric: [] for metric in output_names}

        import cv2

        def process_accumulated_output(output):
            def uneven_seq_to_np(seq, batch_size=self.infer_batch_size):
                if self.infer_batch_size == 1:
                    return np.squeeze(seq)

                item_count = batch_size * (len(seq) - 1) + len(seq[-1])
                cat_array = np.zeros((item_count, ) + seq[0][0].shape,
                                     seq[0].dtype)
                for idx in range(0, len(seq) - 1):
                    cat_array[idx * batch_size:(idx + 1) *
                              batch_size] = seq[idx]
                cat_array[(idx + 1) * batch_size:] = seq[-1]
                return cat_array

            #
            prob = uneven_seq_to_np(output['prob'])
            true = uneven_seq_to_np(output['true'])

            # cmap = plt.get_cmap('jet')
            # epi = prob[...,1]
            # epi = (cmap(epi) * 255.0).astype('uint8')
            # cv2.imwrite('sample.png', cv2.cvtColor(epi, cv2.COLOR_RGB2BGR))

            pred = np.argmax(prob, axis=-1)
            true = np.squeeze(true)

            # deal with ignore index
            pred = pred.flatten()
            true = true.flatten()
            pred = pred[true != 0] - 1
            true = true[true != 0] - 1

            acc = np.mean(pred == true)
            inter = (pred * true).sum()
            total = (pred + true).sum()
            dice = 2 * inter / total
            #
            proc_output = dict(acc=acc, dice=dice)
            return proc_output

        # @trainer.on(Events.EPOCH_COMPLETED)
        # def infer_valid(engine):
        #     """
        #     inference measurement
        #     """
        #     inferer.accumulator = get_init_accumulator(infer_output)
        #     inferer.run(valid_loader)
        #     output_stat = process_accumulated_output(inferer.accumulator)
        #     update_logs(output_stat, engine.state.epoch, prefix='valid', color='red')

        @inferer.on(Events.ITERATION_COMPLETED)
        def accumulate_outputs(engine):
            batch_output = engine.state.output
            for key, item in batch_output.items():
                engine.accumulator[key].extend([item])

        ###
        #Setup is done. Now let's run the training
        trainer.run(train_loader, self.nr_epochs)
        return
def finetune_model(args, model, loader):
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    def update(engine, batch):
        model.train()
        batch = tuple(batch[input_name].to(args.device)
                      for input_name in MODEL_INPUTS)
        input_ids, lm_labels, token_type_ids, nodes_ids, attention_mask = batch
        if (not args.graph and not args.edge_list):
            nodes_ids = None
        if (not args.unilm): attention_mask = None
        (lm_loss), *_ = model(input_ids=input_ids,
                              token_type_ids=token_type_ids,
                              labels=lm_labels,
                              nodes=nodes_ids,
                              attention_mask=attention_mask)
        loss = lm_loss / args.gradient_accumulation_steps

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(batch[input_name].to(args.device)
                          for input_name in MODEL_INPUTS)
            input_ids, lm_labels, token_type_ids, nodes_ids, attention_mask = batch
            if (not args.graph and not args.edge_list):
                nodes_ids = None
            if (not args.unilm): attention_mask = None
            # if we dont send labels to model, it doesnt return losses
            lm_logits, *_ = model(input_ids=input_ids,
                                  token_type_ids=token_type_ids,
                                  nodes=nodes_ids,
                                  attention_mask=attention_mask)
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, ), (lm_labels_flat_shifted, )

    trainer = Engine(update)
    evaluator = Engine(inference)

    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(loader))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(),
             output_transform=lambda x: (x[0][0], x[1][0]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=["loss"])
    evaluator.add_event_handler(
        Events.COMPLETED, lambda _: pbar.log_message(
            "Validation: %s" % pformat(evaluator.state.metrics)))
    trainer.run(loader, max_epochs=args.n_epochs)
    return model
Beispiel #11
0
#evaluator.add_event_handler(Events.COMPLETED, es_handler)
#setup_logger(es_handler._logger)


# Clear cuda cache between training/testing
def empty_cuda_cache(engine):
    torch.cuda.empty_cache()
    import gc
    gc.collect()


trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
evaluator.add_event_handler(Events.COMPLETED, empty_cuda_cache)
#train_evaluator.add_event_handler(Events.COMPLETED, empty_cuda_cache)

num_epochs = 80

ProgressBar(persist=True).attach(trainer)

trainer.run(train_loader, max_epochs=num_epochs)

print('The results are')
print(train_evaluator.state.metrics)
print(evaluator.state.metrics)

# Dill routine

model_copy = dill.dumps(model)
torch.save(model_copy, 'complete_model_final.pt')
torch.save(train_evaluator.state.metrics, 'metrics_final.pt')
Beispiel #12
0
    def setup_training(self, base_model, classifier, setops_model):

        #
        # Create the train and test dataset.
        #
        train_loader, train_subset_loader, val_loader = self.setup_datasets()

        logging.info("Setup logging and controls.")

        #
        # Setup metrics plotters.
        #
        mlflow_logger = MlflowLogger()

        #
        # Setup the optimizer.
        #
        logging.info("Setup optimizers and losses.")

        parameters = list(base_model.parameters())
        parameters += list(setops_model.parameters())
        if self.train_classifier:
            parameters += list(classifier.parameters())

        if self.optimizer_cls == "SGD":
            optimizer = torch.optim.SGD(parameters,
                                        lr=self.lr1,
                                        momentum=0.9,
                                        weight_decay=self.weight_decay)
        else:
            optimizer = torch.optim.Adam(parameters,
                                         lr=self.lr1,
                                         weight_decay=self.weight_decay)

        if self.focal_loss:
            attr_loss = FocalLoss().cuda()
        else:
            attr_loss = torch.nn.MultiLabelSoftMarginLoss().cuda()

        recon_loss = torch.nn.MSELoss(
        ) if self.recon_loss == "mse" else torch.nn.L1Loss()

        #
        # Setup the trainer object and its logging.
        #
        logging.info("Setup trainer")
        trainer = create_setops_trainer(base_model,
                                        classifier,
                                        setops_model,
                                        optimizer,
                                        criterion1=attr_loss,
                                        criterion2=recon_loss.cuda(),
                                        params_object=self,
                                        device=self.device)
        ProgressBar(bar_format=None).attach(trainer)

        mlflow_logger.attach(engine=trainer,
                             prefix="Train ",
                             plot_event=Events.ITERATION_COMPLETED,
                             update_period=LOG_INTERVAL,
                             output_transform=lambda x: x)

        #
        # Define the evaluation metrics.
        #
        logging.info("Setup evaluator")
        evaluation_losses = {
            'real class loss':
                Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["real class a"], o["targets"]["class a"])) + \
                Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["real class b"], o["targets"]["class b"])),
            'fake class loss':
                Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["fake class a"], o["targets"]["class a"])) + \
                Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["fake class b"], o["targets"]["class b"])),
            '{} fake loss'.format(self.recon_loss):
                (Loss(recon_loss.cuda(), lambda o: (o["outputs"]["fake embed a"], o["targets"]["embed a"])) +
                Loss(recon_loss.cuda(), lambda o: (o["outputs"]["fake embed b"], o["targets"]["embed b"]))) / 2,
        }
        labels_list = train_loader.dataset.labels_list
        mask = labels_list_to_1hot(labels_list, labels_list).astype(np.bool)
        evaluation_accuracies = {
            'real class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "real class a"], o["targets"]["class a"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "real class b"], o["targets"]["class b"]))) / 2,
            'fake class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "fake class a"], o["targets"]["class a"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "fake class b"], o["targets"]["class b"]))) / 2,
            'S class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "a_S_b class"], o["targets"]["a_S_b class"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "b_S_a class"], o["targets"]["b_S_a class"]))) / 2,
            'I class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "a_I_b class"], o["targets"]["a_I_b class"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "b_I_a class"], o["targets"]["a_I_b class"]))) / 2,
            'U class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "a_U_b class"], o["targets"]["a_U_b class"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "b_U_a class"], o["targets"]["a_U_b class"]))) / 2,
            'MSE fake acc':
            (EWMeanSquaredError(lambda o: (o["outputs"]["fake embed a"], o[
                "targets"]["embed a"])) + EWMeanSquaredError(lambda o: (o[
                    "outputs"]["fake embed b"], o["targets"]["embed b"]))) / 2,
            'real mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["real class a"], o["targets"]["class a"])),
            'fake mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["fake class a"], o["targets"]["class a"])),
            'S mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["a_S_b class"], o["targets"]["a_S_b class"])),
            'I mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["a_I_b class"], o["targets"]["a_I_b class"])),
            'U mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["a_U_b class"], o["targets"]["a_U_b class"])),
        }

        #
        # Setup the training evaluator object and its logging.
        #
        train_evaluator = create_setops_evaluator(
            base_model,
            classifier,
            setops_model,
            metrics=evaluation_accuracies.copy(),
            device=self.device)

        mlflow_logger.attach(engine=train_evaluator,
                             prefix="Train Eval ",
                             plot_event=Events.EPOCH_COMPLETED,
                             metric_names=list(evaluation_accuracies.keys()))
        ProgressBar(bar_format=None).attach(train_evaluator)

        #
        # Setup the evaluator object and its logging.
        #
        evaluator = create_setops_evaluator(base_model,
                                            classifier,
                                            setops_model,
                                            metrics={
                                                **evaluation_losses,
                                                **evaluation_accuracies
                                            },
                                            device=self.device)

        mlflow_logger.attach(engine=evaluator,
                             prefix="Eval ",
                             plot_event=Events.EPOCH_COMPLETED,
                             metric_names=list({
                                 **evaluation_losses,
                                 **evaluation_accuracies
                             }.keys()))
        ProgressBar(bar_format=None).attach(evaluator)

        #
        # Checkpoint of the model
        #
        self.setup_checkpoint(base_model, classifier, setops_model, evaluator)

        logging.info("Setup schedulers.")

        #
        # Update learning rate manually using the Visdom interface.
        #
        one_cycle_size = len(train_loader) * self.warmup_epochs * 2

        scheduler_1 = LinearCyclicalScheduler(optimizer,
                                              "lr",
                                              start_value=self.lr1,
                                              end_value=self.lr2,
                                              cycle_size=one_cycle_size)
        scheduler_2 = ReduceLROnPlateau(optimizer,
                                        factor=0.5,
                                        patience=4 * len(train_loader),
                                        cooldown=len(train_loader),
                                        output_transform=lambda x: x["main"])
        lr_scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2],
                                       durations=[one_cycle_size // 2],
                                       save_history=True)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)

        #
        # Evaluation
        #
        @trainer.on(Events.EPOCH_COMPLETED)
        def epoch_completed(engine):
            #
            # Re-randomize the indices of the training dataset.
            #
            train_loader.dataset.calc_indices()

            #
            # Run the evaluator on a subset of the training dataset.
            #
            logging.info("Evaluation on a subset of the training data.")
            train_evaluator.run(train_subset_loader)

            #
            # Run the evaluator on the validation set.
            #
            logging.info("Evaluation on the eval data.")
            evaluator.run(val_loader)

        return trainer, train_loader
Beispiel #13
0
def main(
    dataset,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
):

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    model = Glow(
        image_shape,
        hidden_channels,
        K,
        L,
        actnorm_scale,
        flow_permutation,
        flow_coupling,
        LU_decomposed,
        num_classes,
        learn_top,
        y_condition,
    )

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction="none")
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         "glow",
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {
            "model": model,
            "optimizer": optimizer
        },
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss")

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(
            trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x:
            (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    trainer.run(train_loader, epochs)
def train(dataset_path,
          dataset_cache='./dataset_cache',
          model_checkpoint='gpt2',
          num_candidates=2,
          max_history=2,
          train_batch_size=4,
          valid_batch_size=4,
          gradient_accumulation_steps=8,
          lr=6.25e-5,
          lm_coef=1.0,
          mc_coef=1.0,
          max_norm=1.0,
          n_epochs=3,
          personality_permutations=1,
          eval_before_start=False,
          device="cuda" if torch.cuda.is_available() else "cpu",
          fp16='',
          path_prefix='',
          log_dir='',
          local_rank=-1):
    args = {**locals()}

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if local_rank in [-1, 0] else logging.WARN)
    # This is a logger.warning: it will be printed by all distributed processes
    logger.warning("Running process %d", local_rank)
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    distributed = (local_rank != -1)
    args['distributed'] = distributed

    if distributed:
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    # cant use Autotokenizer because checkpoint could be a Path
    tokenizer_class = GPT2Tokenizer if "gpt2" in model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(model_checkpoint)

    model_class = GPT2DoubleHeadsModel if "gpt2" in model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(model_checkpoint)
    model.to(device)
    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)
    optimizer = AdamW(model.parameters(), lr=lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=fp16)
    if distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[local_rank],
                                        output_device=local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        dataset_path, dataset_cache, num_candidates, personality_permutations,
        max_history, train_batch_size, valid_batch_size, distributed,
        tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(device) for input_tensor in batch)
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
        (lm_loss), (mc_loss), *_ = model(input_ids,
                                         token_type_ids=token_type_ids,
                                         mc_token_ids=mc_token_ids,
                                         mc_labels=mc_labels,
                                         lm_labels=lm_labels)
        loss = (lm_loss * lm_coef + mc_loss * mc_coef) / \
            gradient_accumulation_steps
        if fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        if engine.state.iteration % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            lm_logits, mc_logits, *_ = model(
                input_ids,
                token_type_ids=token_type_ids,
                mc_token_ids=mc_token_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, lr), (n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], local_rank,
                      device),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"],
                      local_rank, device)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = log_dir if log_dir else make_logdir(model_checkpoint,
                                                      path=path_prefix)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if local_rank in [-1, 0] and n_epochs > 0:
        # TODO: PR in ignite to have better access to saved file paths (cleaner)
        os.rename(checkpoint_handler._saved[-1][1][-1],
                  os.path.join(log_dir, WEIGHTS_NAME))
        tb_logger.close()
Beispiel #15
0
def train():
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTLMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(args.device)
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        lm_loss = model(*batch)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, lm_labels, token_type_ids = batch

            # logger.info(tokenizer.decode(input_ids[0, :].tolist()))
            model_outputs = model(input_ids, token_type_ids=token_type_ids)
            lm_logits = model_outputs[0]

            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)

            return lm_logits_flat_shifted, lm_labels_flat_shifted

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))}
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=args.output_dir)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #16
0
def train(name, load, lrate, weight_decay, workers, smooth, device, validation,
          ground_truth):

    if not name:
        name = '{}_{}'.format(lrate, weight_decay)
    click.echo('model output name: {}'.format(name))

    torch.set_num_threads(1)

    train_set = BaselineSet(glob.glob('{}/**/*.seeds.png'.format(ground_truth),
                                      recursive=True),
                            smooth=smooth)
    train_data_loader = DataLoader(dataset=train_set,
                                   num_workers=workers,
                                   batch_size=1,
                                   shuffle=True,
                                   pin_memory=True)
    val_set = BaselineSet(glob.glob('{}/**/*.seeds.png'.format(validation),
                                    recursive=True),
                          smooth=smooth)
    val_data_loader = DataLoader(dataset=val_set,
                                 num_workers=workers,
                                 batch_size=1,
                                 pin_memory=True)

    click.echo('loading network')
    model = ResUNet(refine_encoder=False).to(device)

    if load:
        click.echo('loading weights')
        model = torch.load(load, map_location=device)

    criterion = nn.BCEWithLogitsLoss()
    opti = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                      lr=lrate,
                      weight_decay=weight_decay)

    def score_function(engine):
        val_loss = engine.state.metrics['loss']
        return -val_loss

    def output_preprocess(output):
        o, target = output
        o = torch.sigmoid(o)
        o = denoising_hysteresis_thresh(o.detach().squeeze().cpu().numpy(),
                                        0.8, 0.9, 2.5)
        return torch.from_numpy(o.astype('f')).unsqueeze(0).unsqueeze(0).to(
            device), target.double().to(device)

    trainer = create_supervised_trainer(model,
                                        opti,
                                        criterion,
                                        device=device,
                                        non_blocking=True)
    accuracy = Accuracy(output_transform=output_preprocess)
    precision = Precision(output_transform=output_preprocess)
    recall = Recall(output_transform=output_preprocess)
    loss = Loss(criterion)
    precision = Precision(average=False)
    recall = Recall(average=False)
    f1 = (precision * recall * 2 / (precision + recall)).mean()

    evaluator = create_supervised_evaluator(model,
                                            device=device,
                                            non_blocking=True)

    accuracy.attach(evaluator, 'accuracy')
    precision.attach(evaluator, 'precision')
    recall.attach(evaluator, 'recall')
    loss.attach(evaluator, 'loss')
    f1.attach(evaluator, 'f1')

    ckpt_handler = ModelCheckpoint('.',
                                   name,
                                   save_interval=1,
                                   n_saved=10,
                                   require_empty=False)
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    progress_bar = ProgressBar(persist=True)
    progress_bar.attach(trainer, ['loss'])

    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=ckpt_handler,
                              to_save={'net': model})
    trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                              handler=TerminateOnNan())

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_data_loader)
        metrics = evaluator.state.metrics
        progress_bar.log_message(
            'eval results - epoch {} loss: {:.4f} f1: {:.4f}, accuracy: {:.4f} recall: {:.4f} precision {:.4f}'
            .format(engine.state.epoch, metrics['loss'], metrics['f1'],
                    metrics['accuracy'], metrics['recall'],
                    metrics['precision']))

    trainer.run(train_data_loader, max_epochs=1000)
Beispiel #17
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument("--personality_permutations",
                        type=int,
                        default=1,
                        help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(args.device)
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        lm_loss, mc_loss = model(*batch)
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        # tb_logger = TensorboardLogger(log_dir=None)
        # tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        # tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        # tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        # checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3)
        # trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        # torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        # getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        # tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)
Beispiel #18
0
def attach_decorators(trainer, SR, D, vgg, loader, schedulerD, schedulerG,
                      optimizerD, optimizerG, resume_epoch, resume_iter):

    timer = Timer(average=True)

    checkpoint_handler = ModelCheckpoint(args.output_dir,
                                         'training',
                                         save_interval=1,
                                         n_saved=10,
                                         require_empty=False)

    monitoring_metrics = [
        'dloss_real', 'dloss_fake', 'd_loss', 'GP', 'WD', 'VGG', 'gloss'
    ]
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['dloss_real']).attach(
                       trainer, 'dloss_real')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['dloss_fake']).attach(
                       trainer, 'dloss_fake')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['GP']).attach(trainer, 'GP')
    RunningAverage(alpha=0.98, output_transform=lambda x: x['d_loss']).attach(
        trainer, 'd_loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['WD']).attach(trainer, 'WD')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['VGG']).attach(trainer, 'VGG')
    RunningAverage(alpha=0.98, output_transform=lambda x: x['gloss']).attach(
        trainer, 'gloss')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'SR': SR,
                                  'D': D,
                                  'VGG': vgg,
                                  'optim_D': optimizerD,
                                  'optim_G': optimizerG,
                                  'sched_D': schedulerD,
                                  'sched_G': schedulerG
                              })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_logs(engine):
        if (engine.state.iteration - 1) % PRINT_FREQ == 0:
            fname = os.path.join(args.output_dir, LOGS_FNAME)
            columns = engine.state.metrics.keys()
            values = [
                str(round(value, 5))
                for value in engine.state.metrics.values()
            ]

            with open(fname, 'a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

            i = (engine.state.iteration % len(loader))
            message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(
                epoch=engine.state.epoch,
                max_epoch=args.epochs,
                i=i,
                max_i=len(loader))
            for name, value in zip(columns, values):
                message += ' | {name}: {value}'.format(name=name, value=value)

            pbar.log_message(message)

    @trainer.on(Events.ITERATION_COMPLETED)
    def save_real_example(engine):
        if (engine.state.iteration - 1) % PRINT_FREQ == 0:
            px, y = engine.state.batch
            img = SR(px.cuda())
            path = os.path.join(
                args.output_dir,
                FAKE_IMG_FNAME.format(engine.state.epoch,
                                      engine.state.iteration))
            vutils.save_image(img, path)
            path = os.path.join(
                args.output_dir,
                REAL_IMG_FNAME.format(engine.state.epoch,
                                      engine.state.iteration))
            vutils.save_image(y, path)
            path = os.path.join(
                args.output_dir,
                TRAIN_IMG_FNAME.format(engine.state.epoch,
                                       engine.state.iteration))
            vutils.save_image(px, path)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def LRstep(engine):
        schedulerD.step()
        schedulerG.step()

    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl
            mpl.use('agg')

            import numpy as np
            import pandas as pd
            import matplotlib.pyplot as plt
            df = pd.read_csv(os.path.join(args.output_dir, LOGS_FNAME),
                             delimiter='\t')
            x = np.arange(1, engine.state.iteration + 1, PRINT_FREQ)
            _ = df.plot(subplots=True, figsize=(20, 20), grid=True, xticks=x)
            _ = plt.xlabel('Iteration number')
            fig = plt.gcf()
            path = os.path.join(args.output_dir, PLOT_FNAME)

            fig.savefig(path)

        except ImportError:
            warnings.warn(
                'Loss plots will not be generated -- pandas or matplotlib not found'
            )

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')

            create_plots(engine)
            checkpoint_handler(
                engine, {
                    'netG_{}'.format(engine.state.iteration): SR,
                    'netD_{}'.format(engine.state.iteration): D,
                    'optim_D_{}'.format(engine.state.iteration): optimizerD,
                    'optim_G_{}'.format(engine.state.iteration): optimizerG,
                    'sched_D_{}'.format(engine.state.iteration): schedulerD,
                    'sched_G_{}'.format(engine.state.iteration): schedulerG
                })

        else:
            raise e

    @trainer.on(Events.STARTED)
    def resume_training(engine):
        engine.state.iteration = resume_iter
        engine.state.epoch = resume_epoch

    def OAR_shutdown():
        raise KeyboardInterrupt

    signal.signal(signal.SIGUSR2, OAR_shutdown)
def _upd_pbar_iter_from_cp(engine: Engine, pbar: ProgressBar) -> None:
    pbar.n = engine.state.iteration
Beispiel #20
0
def run(args):
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)

    num_classes = CityscapesDataset.num_classes()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = GoogLeNetFCN(num_classes)
    model.init_from_googlenet()

    device_count = torch.cuda.device_count()
    if device_count > 1:
        print("Using %d GPU(s)" % device_count)
        model = nn.DataParallel(model)
        args.batch_size = device_count * args.batch_size
        args.val_batch_size = device_count * args.val_batch_size

    model = model.to(device)

    train_loader, val_loader = get_data_loaders(args.dataset_dir,
                                                args.batch_size,
                                                args.val_batch_size,
                                                args.num_workers,
                                                args.include_coarse)

    criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='sum')

    optimizer = optim.SGD([{
        'params': [
            param for name, param in model.named_parameters()
            if name.endswith('weight')
        ]
    }, {
        'params': [
            param for name, param in model.named_parameters()
            if name.endswith('bias')
        ],
        'lr':
        args.lr * 2,
        'weight_decay':
        0
    }],
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint.get('bestIoU', 0.0)
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("Loaded checkpoint '{}' (Epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            sys.exit()

    if args.freeze_bn:
        print("Freezing batch norm")
        model = freeze_batchnorm(model)

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device,
                                        non_blocking=True)

    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    # attach progress bar
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=['loss'])

    cm = ConfusionMatrix(num_classes)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'loss': Loss(criterion),
                                                'IoU': IoU(cm),
                                                'accuracy': cmAccuracy(cm)
                                            },
                                            device=device,
                                            non_blocking=True)

    pbar2 = ProgressBar(persist=True, desc='Eval Epoch')
    pbar2.attach(evaluator)

    def _global_step_transform(engine, event_name):
        return trainer.state.iteration

    tb_logger = TensorboardLogger(args.log_dir)
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag='training',
                                               metric_names=['loss']),
                     event_name=Events.ITERATION_COMPLETED)

    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.ITERATION_STARTED)

    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=OutputHandler(
                         tag='validation',
                         metric_names=['loss', 'IoU', 'accuracy'],
                         global_step_transform=_global_step_transform),
                     event_name=Events.EPOCH_COMPLETED)

    @evaluator.on(Events.EPOCH_COMPLETED)
    def save_checkpoint(engine):
        iou = engine.state.metrics['IoU'] * 100.0
        mean_iou = iou.mean()

        is_best = mean_iou.item() > trainer.state.best_iou
        trainer.state.best_iou = max(mean_iou.item(), trainer.state.best_iou)

        name = 'epoch{}_mIoU={:.1f}.pth'.format(trainer.state.epoch, mean_iou)
        file = {
            'model': model.state_dict(),
            'epoch': trainer.state.epoch,
            'iteration': engine.state.iteration,
            'optimizer': optimizer.state_dict(),
            'args': args,
            'bestIoU': trainer.state.best_iou
        }

        save(file, args.output_dir, 'checkpoint_{}'.format(name))
        if is_best:
            save(model.state_dict(), args.output_dir, 'model_{}'.format(name))

    @trainer.on(Events.STARTED)
    def initialize(engine):
        if args.resume:
            engine.state.epoch = args.start_epoch
            engine.state.iteration = args.start_epoch * len(
                engine.state.dataloader)
            engine.state.best_iou = best_iou
        else:
            engine.state.best_iou = 0.0

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        pbar.log_message("Start Validation - Epoch: [{}/{}]".format(
            engine.state.epoch, engine.state.max_epochs))
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        iou = metrics['IoU']
        acc = metrics['accuracy']
        mean_iou = iou.mean()

        pbar.log_message(
            "Validation results - Epoch: [{}/{}]: Loss: {:.2e}, Accuracy: {:.1f}, mIoU: {:.1f}"
            .format(engine.state.epoch, engine.state.max_epochs, loss,
                    acc * 100.0, mean_iou * 100.0))

    print("Start training")
    trainer.run(train_loader, max_epochs=args.epochs)
    tb_logger.close()
Beispiel #21
0
def run(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = IQAModel(arch=args.arch,
                     pool=args.pool,
                     use_bn_end=args.use_bn_end,
                     P6=args.P6,
                     P7=args.P7).to(device)  #
    print(model)
    if args.ft_lr_ratio == .0:
        for param in model.features.parameters():
            param.requires_grad = False
    train_loader, val_loader, test_loader = get_data_loaders(args)

    optimizer = Adam(
        [
            {
                'params': model.regression.parameters()
            },  # The most important parameters. Maybe we need three levels of lrs
            {
                'params': model.dr6.parameters()
            },
            {
                'params': model.dr7.parameters()
            },
            {
                'params': model.regr6.parameters()
            },
            {
                'params': model.regr7.parameters()
            },
            {
                'params': model.features.parameters(),
                'lr': args.lr * args.ft_lr_ratio
            }
        ],
        lr=args.lr,
        weight_decay=args.weight_decay
    )  # Adam can be changed to other optimizers, such as SGD, Adadelta.

    # Initialization
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=args.opt_level)

    mapping = True  #  args.loss_type != 'l1' and args.loss_type != 'mse'

    if args.evaluate:
        checkpoint = torch.load(args.trained_model_file)
        model.load_state_dict(checkpoint['model'])
        k = checkpoint['k']
        b = checkpoint['b']

        evaluator = create_supervised_evaluator(model,
                                                metrics={
                                                    'IQA_performance':
                                                    IQAPerformance(
                                                        status='test',
                                                        k=k,
                                                        b=b,
                                                        mapping=mapping)
                                                },
                                                device=device)
        evaluator.run(test_loader)
        performance = evaluator.state.metrics
        for metric_print in metrics_printed:
            print('{}, {}: {:.3f}'.format(args.dataset, metric_print,
                                          performance[metric_print].item()))
        for metric_print in metrics_printed:
            print('{:.3f}'.format(performance[metric_print].item()))
        np.save(args.save_result_file, performance)
        return

    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.lr_decay_step,
                                    gamma=args.lr_decay)
    loss_func = IQALoss(
        loss_type=args.loss_type,
        alpha=args.alpha,
        beta=args.beta,
        p=args.p,
        q=args.q,
        monotonicity_regularization=args.monotonicity_regularization,
        gamma=args.gamma,
        detach=args.detach)
    trainer = create_supervised_trainer(
        model,
        optimizer,
        loss_func,
        device=device,
        accumulation_steps=args.accumulation_steps)

    if args.pbar:
        from ignite.contrib.handlers import ProgressBar

        ProgressBar().attach(trainer)

    evaluator_for_train = create_supervised_evaluator(model,
                                                      metrics={
                                                          'IQA_performance':
                                                          IQAPerformance(
                                                              status='train',
                                                              mapping=mapping)
                                                      },
                                                      device=device)

    current_time = datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")
    writer = SummaryWriter(
        log_dir='{}/{}-{}'.format(args.log_dir, args.format_str, current_time))
    global best_val_criterion, best_epoch
    best_val_criterion, best_epoch = -100, -1  # larger, better, e.g., SROCC or PLCC. If RMSE is used, best_val_criterion <- 10000

    @trainer.on(Events.ITERATION_COMPLETED)
    def iter_event_function(engine):
        writer.add_scalar("train/loss", engine.state.output,
                          engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def epoch_event_function(engine):
        if args.test_during_training:
            evaluator_for_train.run(
                train_loader
            )  # It is better to re-make a train_loader_for_evaluation so as not to disturb the random number generator.
            performance = evaluator_for_train.state.metrics
            writer_add_scalar(writer, 'train', args.dataset, performance,
                              engine.state.epoch)
            k = performance['k']
            b = performance['b']
        else:
            k = [1, 1, 1]
            b = [0, 0, 0]

        evaluator = create_supervised_evaluator(model,
                                                metrics={
                                                    'IQA_performance':
                                                    IQAPerformance(
                                                        status='test',
                                                        k=k,
                                                        b=b,
                                                        mapping=mapping)
                                                },
                                                device=device)
        evaluator.run(val_loader)
        performance = evaluator.state.metrics
        writer_add_scalar(writer, 'val', args.dataset, performance,
                          engine.state.epoch)
        val_criterion = abs(
            performance[args.val_criterion]
        )  # when alpha=[0,1],loss_type='linearity', test_during_training=False, SROCC/PLCC can be negative during training.
        if args.test_during_training:
            evaluator.run(test_loader)
            performance = evaluator.state.metrics
            writer_add_scalar(writer, 'test', args.dataset, performance,
                              engine.state.epoch)

        global best_val_criterion, best_epoch
        if val_criterion > best_val_criterion:  # If RMSE is used, then change ">" to "<".
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'amp': amp.state_dict(),
                'k': k,
                'b': b
            }
            torch.save(checkpoint, args.trained_model_file)
            best_val_criterion = val_criterion
            best_epoch = engine.state.epoch
            print(
                'Save current best model @best_val_criterion ({}): {:.3f} @epoch: {}'
                .format(args.val_criterion, best_val_criterion, best_epoch))
        else:
            print(
                'Model is not updated @val_criterion ({}): {:.3f} @epoch: {}'.
                format(args.val_criterion, val_criterion, engine.state.epoch))

        scheduler.step(engine.state.epoch)

    @trainer.on(Events.COMPLETED)
    def final_testing_results(engine):
        writer.close()  # close the Tensorboard writer
        print('best epoch: {}'.format(best_epoch))
        checkpoint = torch.load(args.trained_model_file)
        model.load_state_dict(checkpoint['model'])
        if args.test_during_training:
            k = checkpoint['k']
            b = checkpoint['b']
        else:
            evaluator_for_train.run(train_loader)
            performance = evaluator_for_train.state.metrics
            k = performance['k']
            b = performance['b']
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'amp': amp.state_dict(),
                'k': k,
                'b': b
            }
            torch.save(checkpoint, args.trained_model_file)

        evaluator = create_supervised_evaluator(model,
                                                metrics={
                                                    'IQA_performance':
                                                    IQAPerformance(
                                                        status='test',
                                                        k=k,
                                                        b=b,
                                                        mapping=mapping)
                                                },
                                                device=device)
        evaluator.run(test_loader)
        performance = evaluator.state.metrics
        for metric_print in metrics_printed:
            print('{}, {}: {:.3f}'.format(args.dataset, metric_print,
                                          performance[metric_print].item()))
        for metric_print in metrics_printed:
            print('{:.3f}'.format(performance[metric_print].item()))
        np.save(args.save_result_file, performance)

    trainer.run(train_loader, max_epochs=args.epochs)
Beispiel #22
0
def _setup_common_training_handlers(
    trainer: Engine,
    to_save: Optional[Mapping] = None,
    save_every_iters: int = 1000,
    output_path: Optional[str] = None,
    lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
    with_gpu_stats: bool = False,
    output_names: Optional[Iterable[str]] = None,
    with_pbars: bool = True,
    with_pbar_on_iters: bool = True,
    log_every_iters: int = 100,
    stop_on_nan: bool = True,
    clear_cuda_cache: bool = True,
    save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
    **kwargs: Any,
) -> None:
    if output_path is not None and save_handler is not None:
        raise ValueError(
            "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
        )

    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED,
                lambda engine: cast(_LRScheduler, lr_scheduler).step())
        elif isinstance(lr_scheduler, LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:

        if output_path is None and save_handler is None:
            raise ValueError(
                "If to_save argument is provided then output_path or save_handler arguments should be also defined"
            )
        if output_path is not None:
            save_handler = DiskSaver(dirname=output_path, require_empty=False)

        checkpoint_handler = Checkpoint(to_save,
                                        cast(Union[Callable, BaseSaveHandler],
                                             save_handler),
                                        filename_prefix="training",
                                        **kwargs)
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=save_every_iters),
            checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(
            trainer,
            name="gpu",
            event_name=Events.ITERATION_COMPLETED(
                every=log_every_iters)  # type: ignore[arg-type]
        )

    if output_names is not None:

        def output_transform(x: Any, index: int, name: str) -> Any:
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise TypeError(
                    "Unhandled type of update_function's output. "
                    f"It should either mapping or sequence, but given {type(x)}"
                )

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform,
                                                    index=i,
                                                    name=n),
                           epoch_bound=False).attach(trainer, n)

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer,
                metric_names="all",
                event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
Beispiel #23
0
                       **kwargs)

# Load model
if args.state_dict is not None:
    state_dict = torch.load(args.state_dict, map_location='cpu')
    model.load_state_dict(state_dict, strict=True)

model = model.to(device)

evaluator = create_segmentation_evaluator(
    model,
    device=device,
    num_classes=19,
)

ProgressBar().attach(evaluator)

state = evaluator.run(val_loader)

classes = CLASSES[TRAIN_MAPPING != 255]

metrics = {
    'accuracy': state.metrics['accuracy'],
    'miou': state.metrics['miou'],
    'iou':
    {name: state.metrics['iou'][id].item()
     for id, name in enumerate(classes)},
}

pprint(metrics)
Beispiel #24
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, logittransform, gan, disc_lr):

    device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0'

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, logittransform)

    model = model.to(device)

    if gan:
        # Debug
        model = mine.Generator(32, 1).to(device)

        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               betas=(.5, .99),
                               weight_decay=0)
        discriminator = mine.Discriminator(image_shape[-1])
        discriminator = discriminator.to(device)
        D_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                        discriminator.parameters()),
                                 lr=disc_lr,
                                 betas=(.5, .99),
                                 weight_decay=0)
    else:
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    # lr_lambda = lambda epoch: lr * min(1., epoch+1 / warmup)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    i = 0

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        # def generate_from_noise(batch_size):
        #     _, c2, h, w  = model.prior_h.shape
        #     c = c2 // 2
        #     zshape = (batch_size, c, h, w)
        #     randz  = torch.autograd.Variable(torch.randn(zshape), requires_grad=True).to(device)
        #     images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size)
        #     return images

        def generate_from_noise(batch_size):

            zshape = (batch_size, 32, 1, 1)
            randz = torch.randn(zshape).to(device)
            images = model(randz)
            return images / 2

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        # Train Disc
        fake = generate_from_noise(x.size(0))

        D_real_scores = run_noised_disc(discriminator, x.detach())
        D_fake_scores = run_noised_disc(discriminator, fake.detach())

        ones_target = torch.ones((x.size(0), 1), device=x.device)
        zeros_target = torch.zeros((x.size(0), 1), device=x.device)

        # D_real_accuracy = torch.sum(torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0)
        # D_fake_accuracy = torch.sum(torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0)

        D_real_loss = F.binary_cross_entropy_with_logits(
            D_real_scores, ones_target)
        D_fake_loss = F.binary_cross_entropy_with_logits(
            D_fake_scores, zeros_target)

        D_loss = (D_real_loss + D_fake_loss) / 2
        gp = gradient_penalty(x.detach(), fake.detach(),
                              lambda _x: run_noised_disc(discriminator, _x))
        D_loss_plus_gp = D_loss + 10 * gp
        D_optimizer.zero_grad()
        D_loss_plus_gp.backward()
        D_optimizer.step()

        # Train generator
        fake = generate_from_noise(x.size(0))
        G_loss = F.binary_cross_entropy_with_logits(
            run_noised_disc(discriminator, fake),
            torch.ones((x.size(0), 1), device=x.device))
        losses['total_loss'] = G_loss

        # G-step
        optimizer.zero_grad()
        losses['total_loss'].backward()
        params = list(model.parameters())
        gnorm = [p.grad.norm() for p in params]
        optimizer.step()
        # if max_grad_clip > 0:
        #     torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        # if max_grad_norm > 0:
        #     torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        if engine.iter_ind % 50 == 0:
            grid = make_grid((postprocess(fake.detach().cpu())[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

            grid = make_grid(
                (postprocess(uniform_binning_correction(x)[0].cpu())[:30]),
                nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.savefig(os.path.join(output_dir,
                                     f'data_{engine.iter_ind}.png'))

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    if gan:
        trainer = Engine(gan_step)
    else:
        trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    # @trainer.on(Events.STARTED)
    # def init(engine):
    #     model.train()

    #     init_batches = []
    #     init_targets = []

    #     with torch.no_grad():
    #         for batch, target in islice(train_loader, None,
    #                                     n_init_batches):
    #             init_batches.append(batch)
    #             init_targets.append(target)

    #         init_batches = torch.cat(init_batches).to(device)

    #         assert init_batches.shape[0] == n_init_batches * batch_size

    #         if y_condition:
    #             init_targets = torch.cat(init_targets).to(device)
    #         else:
    #             init_targets = None

    #         model(init_batches, init_targets)

    # @trainer.on(Events.EPOCH_COMPLETED)
    # def evaluate(engine):
    #     evaluator.run(test_loader)

    #     # scheduler.step()
    #     metrics = evaluator.state.metrics

    #     losses = ', '.join([f"{key}: {value:.2f}" for key, value in metrics.items()])

    #     myprint(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Beispiel #25
0
def attach_handlers(run, model, optimizer, trainer, train_evaluator, evaluator,
                    train_loader, val_loader, params):
    # Tqdm logger
    pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT)
    pbar.attach(trainer.engine, metric_names='all')
    tqdm_logger = TqdmLogger(pbar=pbar)
    # noinspection PyTypeChecker
    tqdm_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    # noinspection PyTypeChecker
    tqdm_logger.attach_output_handler(
        train_evaluator.engine,
        event_name=Events.COMPLETED,
        tag="train",
        global_step_transform=global_step_from_engine(trainer.engine),
    )

    # Evaluators
    train_evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED,
                           train_loader)
    evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED, data=val_loader)

    # Learning rate scheduling
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              'max',
                                                              verbose=True,
                                                              patience=5,
                                                              factor=0.5)
    evaluator.engine.add_event_handler(
        Events.COMPLETED,
        lambda engine: lr_scheduler.step(engine.state.metrics['accuracy']))

    # Early stopping
    es_handler = EarlyStopping(
        patience=15,
        score_function=lambda engine: engine.state.metrics['accuracy'],
        trainer=trainer.engine,
        cumulative_delta=True,
        min_delta=0.0001)
    if 'train_all' in params and params['train_all']:
        train_evaluator.engine.add_event_handler(Events.COMPLETED, es_handler)
    else:
        evaluator.engine.add_event_handler(Events.COMPLETED, es_handler)

    es_handler.logger.setLevel(logging.DEBUG)

    # Model checkpoints
    name = run.replace('/', '-')
    mc_handler = ModelCheckpoint(
        config.MODELS_DIR,
        name,
        n_saved=1,
        create_dir=True,
        require_empty=False,
        score_name='acc',
        score_function=lambda engine: engine.state.metrics['accuracy'],
        global_step_transform=global_step_from_engine(trainer.engine))
    evaluator.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler,
                                       {'m': model})

    # TensorBoard logger
    tb_logger = TensorboardLogger(
        log_dir=os.path.join(config.TENSORBOARD_DIR, run))
    images, labels = next(iter(train_loader))
    tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images)
    tb_logger.writer.add_hparams(params, {'hparam/dummy': 0})

    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        train_evaluator.engine,
        event_name=Events.COMPLETED,
        tag="train",
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    input_shape = tuple(next(iter(train_loader))[0].shape[1:])
    tb_logger.attach(trainer.engine,
                     log_handler=WeightsImageHandler(model, input_shape),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer.engine,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.EPOCH_STARTED)
    # tb_logger.attach(trainer.engine, log_handler=WeightsScalarHandler(model), event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsHistHandler(model, layer_names=['linear1', 'batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=NumActivationsScalarHandler(model, layer_names=['linear1', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.mean,
    #                                                       layer_names=['linear1', 'batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.std,
    #                                                       layer_names=['linear1', 'batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)

    return es_handler, tb_logger
def train(device, net, dataloader, val_loader, args, logger, experiment):
    def update(engine, data):
        input_left, input_right, label = data['left_image'], data['right_image'], data['winner']
        input_left, input_right, label = input_left.to(device), input_right.to(device), label.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        label = label.float()

        start = timer()
        output_rank_left, output_rank_right = net(input_left,input_right)
        end = timer()
        logger.info(f'FORWARD,{end-start:.4f}')

        #compute ranking loss
        start = timer()
        loss = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit)
        end = timer()

        logger.info(f'LOSS,{end-start:.4f}')

        #compute ranking accuracy
        start = timer()
        rank_acc = compute_ranking_accuracy(output_rank_left, output_rank_right, label)
        end = timer()
        logger.info(f'RANK-ACC,{end-start:.4f}')

        # backward step
        start = timer()
        loss.backward()
        optimizer.step()
        end = timer()
        logger.info(f'BACKWARD,{end-start:.4f}')
        scheduler.step()
        return  { 'loss':loss.item(),
                'rank_acc': rank_acc
                }

    def inference(engine,data):
        with torch.no_grad():
            start = timer()
            input_left, input_right, label = data['left_image'], data['right_image'], data['winner']
            input_left, input_right, label = input_left.to(device), input_right.to(device), label.to(device)
            label = label.float()
            output_rank_left, output_rank_right = net(input_left,input_right)
            loss = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit)
            rank_acc = compute_ranking_accuracy(output_rank_left, output_rank_right, label)
            end = timer()
            logger.info(f'INFERENCE,{end-start:.4f}')
            return  { 'loss':loss.item(),
                'rank_acc': rank_acc
                }
    net = net.to(device)
    if args.equal:
        rank_crit = RankingLoss(margin=1, tie_margin=0)
        print("using new loss")
    else:
        rank_crit = nn.MarginRankingLoss(reduction='mean', margin=1)
    #optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
    optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.wd, betas=(0.9, 0.98), eps=1e-09)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.995, last_epoch=-1)
    trainer = Engine(update)
    evaluator = Engine(inference)
    RunningAverage(output_transform=lambda x: x['loss']).attach(trainer, 'loss')
    RunningAverage(output_transform=lambda x: x['rank_acc']).attach(trainer, 'rank_acc')

    RunningAverage(output_transform=lambda x: x['loss']).attach(evaluator, 'loss')
    RunningAverage(output_transform=lambda x: x['rank_acc']).attach(evaluator, 'rank_acc')

    if args.pbar:
        pbar = ProgressBar(persist=False)
        pbar.attach(trainer,['loss', 'rank_acc'])

        pbar = ProgressBar(persist=False)
        pbar.attach(evaluator,['loss','rank_acc'])

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        net.eval()
        evaluator.run(val_loader)
        trainer.state.metrics['val_acc'] = evaluator.state.metrics['rank_acc']
        net.train()
        if hasattr(net,'partial_eval'): net.partial_eval()
        metrics = {
                'train_rank_accuracy':trainer.state.metrics['rank_acc'],
                'train_loss':trainer.state.metrics['loss'],
                'val_rank_accuracy': evaluator.state.metrics['rank_acc'],
                'val_loss':evaluator.state.metrics['loss']
            }
        comet_log(
            metrics,
            experiment,
            epoch=trainer.state.epoch,
            step=trainer.state.epoch,
        )
        console_log(metrics,{},trainer.state.epoch)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_results(trainer):
        if trainer.state.iteration %100 == 0:
            metrics = {
                    'train_rank_accuracy':trainer.state.metrics['rank_acc'],
                    'train_loss':trainer.state.metrics['loss'],
                    'lr': scheduler.get_lr()
                }
            comet_log(
                metrics,
                experiment,
                step=trainer.state.iteration,
                epoch=trainer.state.epoch
            )
            console_log(
                metrics,
                {},
                trainer.state.epoch,
                step=trainer.state.iteration,
            )
    model_name = '{}_{}_{}'.format(args.model, args.premodel, args.attribute)
    if args.tag: model_name += f'_{args.tag}'
    handler = ModelCheckpoint(args.model_dir, model_name,
                                n_saved=1,
                                create_dir=True,
                                save_as_state_dict=True,
                                require_empty=False,
                                score_function=lambda engine: engine.state.metrics['val_acc'])
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {
                'model': net
                })

    if (args.resume):
        def start_epoch(engine):
            engine.state.epoch = args.epoch
        trainer.add_event_handler(Events.STARTED, start_epoch)
        evaluator.add_event_handler(Events.STARTED, start_epoch)

    trainer.run(dataloader,max_epochs=args.max_epochs)
Beispiel #27
0
def _setup_common_training_handlers(
    trainer,
    to_save=None,
    save_every_iters=1000,
    output_path=None,
    lr_scheduler=None,
    with_gpu_stats=False,
    output_names=None,
    with_pbars=True,
    with_pbar_on_iters=True,
    log_every_iters=100,
    stop_on_nan=True,
    clear_cuda_cache=True,
):
    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      lambda engine: lr_scheduler.step())
        elif isinstance(lr_scheduler, LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:
        if output_path is None:
            raise ValueError(
                "If to_save argument is provided then output_path argument should be also defined"
            )
        checkpoint_handler = Checkpoint(
            to_save,
            DiskSaver(dirname=output_path, require_empty=False),
            filename_prefix="training",
        )
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=save_every_iters),
            checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(
            trainer,
            name="gpu",
            event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

    if output_names is not None:

        def output_transform(x, index, name):
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise ValueError(
                    "Unhandled type of update_function's output. "
                    "It should either mapping or sequence, but given {}".
                    format(type(x)))

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform,
                                                    index=i,
                                                    name=n),
                           epoch_bound=False).attach(trainer, n)

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer,
                metric_names="all",
                event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
Beispiel #28
0
    if args.mixed_precision:
        (model, loss_fn), optimizer = amp.initialize([model, loss_fn],
                                                     optimizer)

    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[local_rank])

    trainer = create_sr_trainer(
        model,
        loss_fn,
        optimizer,
        device=device,
        mixed_precision=args.mixed_precision,
    )
    ProgressBar(persist=False).attach(trainer, ['loss'])
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _engine: scheduler.step())

    evaluator = create_sr_evaluator(
        model,
        device=device,
        mean=MEAN,
    )

    if local_rank == 0:
        checkpointer = ModelCheckpoint(
            dirname='checkpoints',
            filename_prefix='model',
            score_name='pnsr',
Beispiel #29
0
    def train(self, config, **kwargs):
        """Trains a given model specified in the config file or passed as the --model parameter.
        All options in the config file can be overwritten as needed by passing --PARAM
        Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'

        :param config: yaml config file
        :param **kwargs: parameters to overwrite yaml config
        """

        config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
        outputdir = os.path.join(
            config_parameters['outputpath'], config_parameters['model'],
            "{}_{}".format(
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
                uuid.uuid1().hex))
        # Create base dir
        Path(outputdir).mkdir(exist_ok=True, parents=True)

        logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log'))
        logger.info("Storing files in {}".format(outputdir))
        # utils.pprint_dict
        utils.pprint_dict(config_parameters, logger.info)
        logger.info("Running on device {}".format(DEVICE))
        labels_df = pd.read_csv(config_parameters['label'],
                                sep='\s+').convert_dtypes()
        # In case of ave dataset where index is int, we change the
        # absolute name to relname
        if not np.all(labels_df['filename'].str.isnumeric()):
            labels_df.loc[:, 'filename'] = labels_df['filename'].apply(
                os.path.basename)
        encoder = utils.train_labelencoder(labels=labels_df['event_labels'])
        # These labels are useless, only for mode == stratified
        label_array, _ = utils.encode_labels(labels_df['event_labels'],
                                             encoder)
        if 'cv_label' in config_parameters:
            cv_df = pd.read_csv(config_parameters['cv_label'],
                                sep='\s+').convert_dtypes()
            if not np.all(cv_df['filename'].str.isnumeric()):
                cv_df.loc[:, 'filename'] = cv_df['filename'].apply(
                    os.path.basename)
            train_df = labels_df
            logger.info(
                f"Using CV labels from {config_parameters['cv_label']}")
        else:
            train_df, cv_df = utils.split_train_cv(
                labels_df, y=label_array, **config_parameters['data_args'])

        if 'cv_data' in config_parameters:
            cv_data = config_parameters['cv_data']
            logger.info(f"Using CV data {config_parameters['cv_data']}")
        else:
            cv_data = config_parameters['data']

        train_label_array, _ = utils.encode_labels(train_df['event_labels'],
                                                   encoder)
        cv_label_array, _ = utils.encode_labels(cv_df['event_labels'], encoder)

        transform = utils.parse_transforms(config_parameters['transforms'])
        utils.pprint_dict({'Classes': encoder.classes_},
                          logger.info,
                          formatter='pretty')
        torch.save(encoder, os.path.join(outputdir, 'run_encoder.pth'))
        torch.save(config_parameters, os.path.join(outputdir,
                                                   'run_config.pth'))
        logger.info("Transforms:")
        utils.pprint_dict(transform, logger.info, formatter='pretty')
        # For Unbalanced Audioset, this is true
        if 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MultiBalancedSampler':
            # Training sampler that oversamples the dataset to be roughly equally sized
            # Calcualtes mean over multiple instances, rather useful when number of classes
            # is large
            train_sampler = dataset.MultiBalancedSampler(
                train_label_array,
                num_samples=1 * train_label_array.shape[0],
                replacement=True)
            sampling_kwargs = {"shuffle": False, "sampler": train_sampler}
        elif 'sampler' in config_parameters and config_parameters[
                'sampler'] == 'MinimumOccupancySampler':
            # Asserts that each "batch" contains at least one instance
            train_sampler = dataset.MinimumOccupancySampler(
                train_label_array, sampling_mode='same')
            sampling_kwargs = {"shuffle": False, "sampler": train_sampler}
        else:
            sampling_kwargs = {"shuffle": True}

        logger.info("Using Sampler {}".format(sampling_kwargs))

        trainloader = dataset.getdataloader(
            {
                'filename': train_df['filename'].values,
                'encoded': train_label_array
            },
            config_parameters['data'],
            transform=transform,
            batch_size=config_parameters['batch_size'],
            colname=config_parameters['colname'],
            num_workers=config_parameters['num_workers'],
            **sampling_kwargs)

        cvdataloader = dataset.getdataloader(
            {
                'filename': cv_df['filename'].values,
                'encoded': cv_label_array
            },
            cv_data,
            transform=None,
            shuffle=False,
            colname=config_parameters['colname'],
            batch_size=config_parameters['batch_size'],
            num_workers=config_parameters['num_workers'])
        model = getattr(models, config_parameters['model'],
                        'CRNN')(inputdim=trainloader.dataset.datadim,
                                outputdim=len(encoder.classes_),
                                **config_parameters['model_args'])
        if 'pretrained' in config_parameters and config_parameters[
                'pretrained'] is not None:
            models.load_pretrained(model,
                                   config_parameters['pretrained'],
                                   outputdim=len(encoder.classes_))
            logger.info("Loading pretrained model {}".format(
                config_parameters['pretrained']))

        model = model.to(DEVICE)
        if config_parameters['optimizer'] == 'AdaBound':
            try:
                import adabound
                optimizer = adabound.AdaBound(
                    model.parameters(), **config_parameters['optimizer_args'])
            except ImportError:
                config_parameters['optimizer'] = 'Adam'
                config_parameters['optimizer_args'] = {}
        else:
            optimizer = getattr(
                torch.optim,
                config_parameters['optimizer'],
            )(model.parameters(), **config_parameters['optimizer_args'])

        utils.pprint_dict(optimizer, logger.info, formatter='pretty')
        utils.pprint_dict(model, logger.info, formatter='pretty')
        if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
            model = torch.nn.DataParallel(model)
        criterion = getattr(losses, config_parameters['loss'])().to(DEVICE)

        def _train_batch(_, batch):
            model.train()
            with torch.enable_grad():
                optimizer.zero_grad()
                output = self._forward(
                    model, batch)  # output is tuple (clip, frame, target)
                loss = criterion(*output)
                loss.backward()
                # Single loss
                optimizer.step()
                return loss.item()

        def _inference(_, batch):
            model.eval()
            with torch.no_grad():
                return self._forward(model, batch)

        def thresholded_output_transform(output):
            # Output is (clip, frame, target)
            y_pred, _, y = output
            y_pred = torch.round(y_pred)
            return y_pred, y

        precision = Precision(thresholded_output_transform, average=False)
        recall = Recall(thresholded_output_transform, average=False)
        f1_score = (precision * recall * 2 / (precision + recall)).mean()
        metrics = {
            'Loss': losses.Loss(
                criterion),  #reimplementation of Loss, supports 3 way loss 
            'Precision': Precision(thresholded_output_transform),
            'Recall': Recall(thresholded_output_transform),
            'Accuracy': Accuracy(thresholded_output_transform),
            'F1': f1_score,
        }
        train_engine = Engine(_train_batch)
        inference_engine = Engine(_inference)
        for name, metric in metrics.items():
            metric.attach(inference_engine, name)

        def compute_metrics(engine):
            inference_engine.run(cvdataloader)
            results = inference_engine.state.metrics
            output_str_list = [
                "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
            ]
            for metric in metrics:
                output_str_list.append("{} {:<5.2f}".format(
                    metric, results[metric]))
            logger.info(" ".join(output_str_list))

        pbar = ProgressBar(persist=False)
        pbar.attach(train_engine)

        if 'itercv' in config_parameters and config_parameters[
                'itercv'] is not None:
            train_engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=config_parameters['itercv']),
                compute_metrics)
        train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics)

        # Default scheduler is using patience=3, factor=0.1
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, **config_parameters['scheduler_args'])

        @inference_engine.on(Events.EPOCH_COMPLETED)
        def update_reduce_on_plateau(engine):
            logger.info(f"Scheduling epoch {engine.state.epoch}")
            val_loss = engine.state.metrics['Loss']
            if 'ReduceLROnPlateau' == scheduler.__class__.__name__:
                scheduler.step(val_loss)
            else:
                scheduler.step()

        early_stop_handler = EarlyStopping(
            patience=config_parameters['early_stop'],
            score_function=self._negative_loss,
            trainer=train_engine)
        inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           early_stop_handler)
        if config_parameters['save'] == 'everyepoch':
            checkpoint_handler = ModelCheckpoint(outputdir,
                                                 'run',
                                                 n_saved=5,
                                                 require_empty=False)
            train_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                           checkpoint_handler, {
                                               'model': model,
                                           })
            train_engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=config_parameters['itercv']),
                checkpoint_handler, {
                    'model': model,
                })
        else:
            checkpoint_handler = ModelCheckpoint(
                outputdir,
                'run',
                n_saved=1,
                require_empty=False,
                score_function=self._negative_loss,
                global_step_transform=global_step_from_engine(
                    train_engine),  # Just so that model is saved with epoch...
                score_name='loss')
            inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
                                               checkpoint_handler, {
                                                   'model': model,
                                               })

        train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
        return outputdir
Beispiel #30
0
def main(args):
    # Load a pre-defined tokenizer (GPT-2), create config and model
    logger.info("Prepare tokenizer, pretrained model and optimizer - \
                add special tokens for fine-tuning")

    gpt_tokenizer = GPT2Tokenizer.from_pretrained(args.qgen_model_path,
                                                  cache_dir=args.dataset_cache)
    gpt_tokenizer.sep_token = '<sep>'

    gpt_tokenizer.add_tokens(SPECIAL_TOKENS)
    gpt_tokenizer.add_tokens(AMR_SPECIAL_TOKENS)
    if 'amr' in args.dataset_type:
        qgen = GPT2LMHeadModel.from_pretrained(args.qgen_model_path,
                                               cache_dir=args.dataset_cache)
    else:
        qgen = GPT2ConditionalLMHeadModel.from_pretrained(
            args.qgen_model_path, cache_dir=args.dataset_cache)

    logger.info("Adjust model size to new tokens")
    qgen.resize_token_embeddings(len(gpt_tokenizer))
    logger.info("Set model to GPU usage")
    qgen.to(args.device)
    logger.info("Set up optimizer")
    qgen_optimizer = AdamW(qgen.parameters(),
                           lr=args.learning_rate,
                           eps=args.adam_epsilon)

    bos, eos, ctx, ans, que, pad, gen = \
        gpt_tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)

    # if args.n_gpu > 1:
    if False:
        logger.info("More then 1 GPU for training")
        qgen = torch.nn.DataParallel(qgen)

    logger.info("Prepare datasets")
    if args.use_silver_data:
        data_type = 'Silver'
    else:
        data_type = 'Train'

    dataloader = get_data_loaders(args,
                                  gpt_tokenizer,
                                  qgen,
                                  dataset_name=data_type)

    # Define training function
    def update(engine, batch):

        # remove extra pad from batches
        batch = trim_batch(batch, pad)

        qgen.train()

        loss = torch.tensor([0.0])
        ###################################
        # MLE training with teacher forcing
        ###################################
        if 'sl' in args.learning:
            input_ids, lm_labels, token_type_ids, attention_mask, _, _, _, _ =\
                tuple(input_tensor.to(args.device) for input_tensor in batch)
            loss_ce = qgen(input_ids=input_ids,
                           labels=lm_labels,
                           token_type_ids=token_type_ids)[0]
            loss = apply_loss(engine.state.iteration, qgen_optimizer, loss_ce,
                              args)
        return loss.item()

    trainer = Engine(update)

    # Add progressbar with loss
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    ProgressBar(persist=True).attach(trainer, metric_names=['loss'])

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(qgen_optimizer, "lr",
                                [(0, args.learning_rate),
                                 (args.n_epochs * len(dataloader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Save checkpoints
    checkpoint_handler = ModelCheckpoint(args.checkpoint,
                                         'checkpoint',
                                         save_interval=1,
                                         n_saved=20,
                                         require_empty=False)

    # "getattr" take care of distributed encapsulation
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                              {'mymodel': getattr(qgen, 'module', qgen)})

    # save training config
    torch.save(dict(args), os.path.join(args.checkpoint, 'training_args.bin'))
    getattr(qgen, 'module', qgen).config.to_json_file(
        os.path.join(args.checkpoint, CONFIG_NAME))
    gpt_tokenizer.save_vocabulary(args.checkpoint)

    trainer.run(dataloader, max_epochs=args.n_epochs)