Beispiel #1
0
        with torch.no_grad():
            # Anneal learning rate
            mu = next(mu_scheme)
            i = engine.state.iteration
            for group in optimizer.param_groups:
                group["lr"] = mu * math.sqrt(1 - 0.999**i) / (1 - 0.9**i)

        return {
            "elbo": elbo.item(),
            "kl": kl_divergence.item(),
            "sigma": sigma,
            "mu": mu
        }

    # Trainer and metrics
    trainer = Engine(step)
    metric_names = ["elbo", "kl", "sigma", "mu"]
    RunningAverage(output_transform=lambda x: x["elbo"]).attach(
        trainer, "elbo")
    RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl")
    RunningAverage(output_transform=lambda x: x["sigma"]).attach(
        trainer, "sigma")
    RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu")
    ProgressBar().attach(trainer, metric_names=metric_names)

    # Model checkpointing
    checkpoint_handler = ModelCheckpoint("./",
                                         "checkpoint",
                                         save_interval=1,
                                         n_saved=3,
                                         require_empty=False)
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = [
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100)
        ]
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            ToTensord(KEYS),
            CastToTyped(KEYS, dtype=torch.uint8),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

        # set up engine
        def _train_func(engine, batch):
            self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
            engine.state.output = batch
            engine.fire_event(IterationEvents.MODEL_COMPLETED)
            return engine.state.output

        engine = Engine(_train_func)
        engine.register_events(*IterationEvents)

        # set up testing handler
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="label",
            nearest_interp=True,
            postfix="inverted1",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        # test different nearest interpolation values
        TransformInverter(
            transform=transform,
            loader=loader,
            output_keys=["image", "label"],
            batch_keys="image",
            nearest_interp=[True, False],
            postfix="inverted2",
            num_workers=0
            if sys.platform == "darwin" or torch.cuda.is_available() else 2,
        ).attach(engine)

        engine.run(loader, max_epochs=1)
        set_determinism(seed=None)
        self.assertTupleEqual(engine.state.output["image"].shape,
                              (2, 1, 100, 100, 100))
        self.assertTupleEqual(engine.state.output["label"].shape,
                              (2, 1, 100, 100, 100))
        # check the nearest inerpolation mode
        for i in engine.state.output["image_inverted1"] + engine.state.output[
                "label_inverted1"]:
            torch.testing.assert_allclose(
                i.to(torch.uint8).to(torch.float), i.to(torch.float))
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        # check labels match
        reverted = engine.state.output["label_inverted1"][-1].detach().cpu(
        ).numpy()[0].astype(np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = engine.state.output["label_meta_dict"][
            "filename_or_obj"][-1]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1824: torch 1.5.1
        self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824),
                        "diff. in 3 possible values")

        # check the case that different items use different interpolation mode to invert transforms
        for i in engine.state.output["image_inverted2"]:
            # if the interpolation mode is nearest, accumulated diff should be smaller than 1
            self.assertLess(
                torch.sum(
                    i.to(torch.float) -
                    i.to(torch.uint8).to(torch.float)).item(), 1.0)
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))

        for i in engine.state.output["label_inverted2"]:
            # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
            self.assertGreater(
                torch.sum(
                    i.to(torch.float) -
                    i.to(torch.uint8).to(torch.float)).item(), 10000.0)
            self.assertTupleEqual(i.shape, (1, 100, 101, 107))
        ]).to(device)

        y_pred = descriminator(x_gan)

        loss = loss_fn(y_pred, y_gan)

        if args.mixed_precision:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()
        return loss

    trainer = Engine(_update_model)
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
    ProgressBar(persist=False).attach(trainer, ['loss'])

    if local_rank == 0:
        checkpointer = ModelCheckpoint(
            dirname='checkpoints',
            filename_prefix='model',
            score_name='loss',
            score_function=lambda engine: engine.state.metrics['loss'],
            n_saved=5,
            global_step_transform=global_step_from_engine(trainer),
        )
        trainer.add_event_handler(
            Events.COMPLETED, checkpointer,
            to_save={'descriminator': descriminator if not args.distributed else descriminator.module})
Beispiel #4
0
def create_setops_evaluator(base_model,
                            classifier,
                            setops_model,
                            metrics={},
                            device=None):
    """
    Factory function for creating an evaluator for supervised models

    Args:
        model (`torch.nn.Module`): the model to train
        optimizer (`torch.optim.Optimizer`): the optimizer to use
        loss_fn (torch.nn loss function): the loss function to use
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.

    Returns:
        Engine: a trainer engine with supervised update function
    """
    if device:
        base_model.to(device)
        classifier.to(device)
        setops_model.to(device)

    def _inference(engine, batch):

        base_model.eval()
        classifier.eval()
        setops_model.eval()

        with torch.no_grad():
            input_a, input_b, target_a, target_b = _prepare_batch(
                batch, device=device)

            #
            # Apply the classification model
            #
            embed_a = base_model(input_a)
            output_a = classifier(embed_a)
            embed_b = base_model(input_b)
            output_b = classifier(embed_b)

            #
            # Apply the setops model.
            #
            outputs_setopt = setops_model(embed_a, embed_b)
            fake_a, fake_b, a_S_b, b_S_a, a_U_b, b_U_a, a_I_b, b_I_a, \
            a_S_b_b, b_S_a_a, a_I_b_b, b_I_a_a, a_U_b_b, b_U_a_a, \
            a_S_b_I_a, b_S_a_I_b, a_S_a_I_b, b_S_b_I_a = \
                [classifier(o) for o in outputs_setopt]
            fake_a_em, fake_b_em = outputs_setopt[:2]

            #
            # Calculate the target setops operations
            #
            target_a_bt = target_a.type(torch.cuda.ByteTensor)
            target_b_bt = target_b.type(torch.cuda.ByteTensor)

            target_a_I_b = target_a_bt & target_b_bt
            target_a_U_b = target_a_bt | target_b_bt
            target_a_S_b = target_a_bt & ~target_a_I_b
            target_b_S_a = target_b_bt & ~target_a_I_b

            target_a_I_b = target_a_I_b.type(torch.cuda.FloatTensor)
            target_a_U_b = target_a_U_b.type(torch.cuda.FloatTensor)
            target_a_S_b = target_a_S_b.type(torch.cuda.FloatTensor)
            target_b_S_a = target_b_S_a.type(torch.cuda.FloatTensor)

            return dict(outputs={
                "real class a": output_a,
                "real class b": output_b,
                "fake class a": fake_a,
                "fake class b": fake_b,
                "a_S_b class": a_S_b,
                "b_S_a class": b_S_a,
                "a_U_b class": a_U_b,
                "b_U_a class": b_U_a,
                "a_I_b class": a_I_b,
                "b_I_a class": b_I_a,
                "fake embed a": fake_a_em,
                "fake embed b": fake_b_em,
            },
                        targets={
                            "class a": target_a,
                            "class b": target_b,
                            "a_S_b class": target_a_S_b,
                            "b_S_a class": target_b_S_a,
                            "a_U_b class": target_a_U_b,
                            "a_I_b class": target_a_I_b,
                            "embed a": embed_a,
                            "embed b": embed_b,
                        })

    engine = Engine(_inference)

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine
Beispiel #5
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration
        if config["log_every_iters"] > 0 and (
                engine.state.iteration - 1) % config["log_every_iters"] == 0:
            batch_loss = loss.item()
            engine.state.saved_batch_loss = batch_loss
        else:
            batch_loss = engine.state.saved_batch_loss

        return {
            "batch loss": batch_loss,
        }

    trainer = Engine(train_step)
    trainer.state.saved_batch_loss = -1.0
    trainer.state_dict_user_keys.append("saved_batch_loss")
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Beispiel #6
0
def test_pbar_fail_with_non_callable_transform():
    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(TypeError):
        pbar.attach(engine, output_transform=1)
Beispiel #7
0
def _test_distrib_on_metric(device):
    import torch.distributed as dist

    rank = dist.get_rank()
    n_iters = 10
    n_epochs = 3
    batch_size = 10
    n_classes = 10

    data = list(range(n_iters))
    np.random.seed(12)
    all_y_true_batch_values = np.random.randint(0,
                                                n_classes,
                                                size=(dist.get_world_size(),
                                                      n_epochs * n_iters,
                                                      batch_size))
    all_y_pred_batch_values = np.random.rand(dist.get_world_size(),
                                             n_epochs * n_iters, batch_size,
                                             n_classes)

    y_true_batch_values = iter(all_y_true_batch_values[rank, ...])
    y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...])

    def update_fn(engine, batch):
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

    trainer = Engine(update_fn)
    alpha = 0.98

    acc_metric = RunningAverage(
        Accuracy(output_transform=lambda x: [x[0], x[1]], device=device),
        alpha=alpha,
        epoch_bound=False,
    )
    acc_metric.attach(trainer, "running_avg_accuracy")

    running_avg_acc = [
        None,
    ]
    true_acc_metric = Accuracy(device=device)

    @trainer.on(Events.ITERATION_COMPLETED)
    def manual_running_avg_acc(engine):
        i = engine.state.iteration - 1

        true_acc_metric.reset()
        for j in range(dist.get_world_size()):
            output = (
                torch.from_numpy(all_y_pred_batch_values[j, i, :, :]),
                torch.from_numpy(all_y_true_batch_values[j, i, :]),
            )
            true_acc_metric.update(output)

        batch_acc = true_acc_metric._num_correct * 1.0 / true_acc_metric._num_examples

        if running_avg_acc[0] is None:
            running_avg_acc[0] = batch_acc
        else:
            running_avg_acc[0] = running_avg_acc[0] * alpha + (
                1.0 - alpha) * batch_acc
        engine.state.running_avg_acc = running_avg_acc[0]

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_acc_values(engine):
        assert (engine.state.running_avg_acc == engine.state.
                metrics["running_avg_accuracy"]), "{} vs {}".format(
                    engine.state.running_avg_acc,
                    engine.state.metrics["running_avg_accuracy"])

    trainer.run(data, max_epochs=3)
Beispiel #8
0
def run(args):
    train_loader, val_loader = get_data_loaders(args.dataset_dir,
                                                args.batch_size,
                                                args.val_batch_size,
                                                args.num_workers)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    num_classes = KITTI.num_classes()
    model = LiLaNet(num_classes)

    device_count = torch.cuda.device_count()
    if device_count > 1:
        print("Using %d GPU(s)" % device_count)
        model = nn.DataParallel(model)
        args.batch_size = device_count * args.batch_size
        args.val_batch_size = device_count * args.val_batch_size

    model = model.to(device)

    criterion = nn.CrossEntropyLoss(weight=KITTI.class_weights()).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("Loaded checkpoint '{}' (Epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    def _prepare_batch(batch, non_blocking=True):
        distance, reflectivity, target = batch

        return (convert_tensor(distance,
                               device=device,
                               non_blocking=non_blocking),
                convert_tensor(reflectivity,
                               device=device,
                               non_blocking=non_blocking),
                convert_tensor(target,
                               device=device,
                               non_blocking=non_blocking))

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        distance, reflectivity, target = _prepare_batch(batch)
        pred = model(distance, reflectivity)
        loss = criterion(pred, target)
        loss.backward()
        optimizer.step()

        return loss.item()

    trainer = Engine(_update)

    # attach running average metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    # attach progress bar
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=['loss'])

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            distance, reflectivity, target = _prepare_batch(batch)
            pred = model(distance, reflectivity)

            return pred, target

    evaluator = Engine(_inference)
    cm = ConfusionMatrix(num_classes)
    IoU(cm, ignore_index=0).attach(evaluator, 'IoU')
    Loss(criterion).attach(evaluator, 'loss')

    pbar2 = ProgressBar(persist=True, desc='Eval Epoch')
    pbar2.attach(evaluator)

    def _global_step_transform(engine, event_name):
        if trainer.state is not None:
            return trainer.state.iteration
        else:
            return 1

    tb_logger = TensorboardLogger(args.log_dir)
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag='training',
                                               metric_names=['loss']),
                     event_name=Events.ITERATION_COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=OutputHandler(
                         tag='validation',
                         metric_names=['loss', 'IoU'],
                         global_step_transform=_global_step_transform),
                     event_name=Events.EPOCH_COMPLETED)

    @trainer.on(Events.STARTED)
    def initialize(engine):
        engine.state.exception_raised = False
        if args.resume:
            engine.state.epoch = args.start_epoch

    @evaluator.on(Events.EPOCH_COMPLETED)
    def save_checkpoint(engine):
        epoch = trainer.state.epoch if trainer.state is not None else 1
        iou = engine.state.metrics['IoU'] * 100.0
        mean_iou = iou.mean()

        name = 'epoch{}_mIoU={:.1f}.pth'.format(epoch, mean_iou)
        file = {
            'model': model.state_dict(),
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'args': args
        }

        save(file, args.output_dir, 'checkpoint_{}'.format(name))
        save(model.state_dict(), args.output_dir, 'model_{}'.format(name))

    @trainer.on(Events.EPOCH_COMPLETED)
    def run_validation(engine):
        pbar.log_message("Start Validation - Epoch: [{}/{}]".format(
            engine.state.epoch, engine.state.max_epochs))
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        iou = metrics['IoU'] * 100.0
        mean_iou = iou.mean()

        iou_text = ', '.join([
            '{}: {:.1f}'.format(KITTI.classes[i + 1].name, v)
            for i, v in enumerate(iou.tolist())
        ])
        pbar.log_message(
            "Validation results - Epoch: [{}/{}]: Loss: {:.2e}\n IoU: {}\n mIoU: {:.1f}"
            .format(engine.state.epoch, engine.state.max_epochs, loss,
                    iou_text, mean_iou))

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        engine.state.exception_raised = True
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            name = 'epoch{}_exception.pth'.format(trainer.state.epoch)
            file = {
                'model': model.state_dict(),
                'epoch': trainer.state.epoch,
                'optimizer': optimizer.state_dict()
            }

            save(file, args.output_dir, 'checkpoint_{}'.format(name))
            save(model.state_dict(), args.output_dir, 'model_{}'.format(name))
        else:
            raise e

    if args.eval_on_start:
        print("Start validation")
        evaluator.run(val_loader, max_epochs=1)

    print("Start training")
    trainer.run(train_loader, max_epochs=args.epochs)
    tb_logger.close()
Beispiel #9
0
def train_model(learning_rate, scale, bins, la):
    model = Model()
    model2 = Model2()
    sf = torch.nn.Softmax(dim=1)
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model2.to(device)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=1e-4
    )

    optimizer2 = torch.optim.Adam(
        model2.parameters(), lr=learning_rate, weight_decay=1e-4
    )

    def get_dataloader():
        dl_train = torch.utils.data.DataLoader(
            train_dataset, batch_size=128, shuffle=True, num_workers=0, drop_last=True
        )
        dl_val = torch.utils.data.DataLoader(
            val_dataset, batch_size=400, shuffle=False, num_workers=0
        )
        dl_test = torch.utils.data.DataLoader(
            test_dataset, batch_size=400, shuffle=False, num_workers=0
        )

        return dl_train, dl_test, dl_val

    # get the pred from multi-views
    def get_pred_max(y_pred, y_pred2):
        pred_max = torch.max(y_pred, y_pred2)
        return pred_max

    def get_acc(y_pred, y):
        acc_1 = 0
        a_count =0
        for i in range(len(y)):
            if torch.argmax(y[i]) == torch.argmax(y_pred[i]):
                acc_1 += 1

        return acc_1/len(y)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, x2, y = batch
        y = F.one_hot(y, num_classes=10).float()

        device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
        x, x2, y = x.to(device), x2.to(device), y.to(device)

        x.requires_grad_(True)

        y_pred = sf(model(x))

        loss = F.binary_cross_entropy(y_pred, y)
        x.requires_grad_(False)

        loss.backward()
        optimizer.step()

        return loss.item()

    def step2(engine, batch):
        model2.train()
        optimizer2.zero_grad()

        x, x2, y = batch
        y = F.one_hot(y, num_classes=10).float()

        device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
        x, x2, y = x.to(device), x2.to(device), y.to(device)

        x2.requires_grad_(True)

        y_pred = sf(model2(x2))

        loss2 = F.binary_cross_entropy(y_pred, y)
        x2.requires_grad_(False)

        loss2.backward()
        optimizer2.step()

        return loss2.item()

    def val_step():
        with torch.no_grad():
            for batch in dl_val:
                x, x2, y = batch
                y = F.one_hot(y, num_classes=10).float()

                device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
                x, x2, y = x.to(device), x2.to(device), y.to(device)

                x.requires_grad_(True)
                x2.requires_grad_(True)
                y_pred = sf(model(x))
                y_pred2 = sf(model2(x2))

        return y_pred, y_pred2, y

    def eval_step():
        with torch.no_grad():
            for batch in dl_test:
                x, x2, y = batch
                y = F.one_hot(y, num_classes=10).float()

                device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
                x, x2, y = x.to(device), x2.to(device), y.to(device)
                global v1count
                x.requires_grad_(True)
                x2.requires_grad_(True)
                y_pred = sf(model(x))
                y_pred2 = sf(model2(x2))
                acc1 = get_acc(y_pred, y)
                acc2 = get_acc(y_pred2, y)
                val_pred1, val_pred2, y_val = val_step()
                y_pred_max = get_pred_max(y_pred, y_pred2)
                y_pred_note = y_pred*get_acc(val_pred1,y_val)+y_pred2*get_acc(val_pred2,y_val)
                acc_m = get_acc(y_pred_max, y)
                acc_m_note = get_acc(y_pred_note, y)
                y_pred_c = calibrate(y_pred, y_pred2, y, val_pred1, val_pred2, y_val, scale, bins, la)
                y_pred_max_c = get_pred_max(y_pred_c, y_pred2)
                y_pred_note2 = y_pred_c + y_pred2
                acc_m_note2 =get_acc(y_pred_note2, y)
                acc_m_c = get_acc(y_pred_max_c, y)
                print(acc_m_note, acc_m_note2)
        return acc1, acc2, acc_m, acc_m_c

    trainer = Engine(step)
    trainer2 = Engine(step2)

    dl_train, dl_test, dl_val = get_dataloader()

    trainer.run(dl_train, max_epochs=epoch)
    trainer2.run(dl_train, max_epochs=epoch)
    acc1, acc2, acc_m, acc_m_c = eval_step()

    return model, acc1, acc2, acc_m, acc_m_c
Beispiel #10
0
    def process_batch(engine, batch):
        optimizer.zero_grad()
        loss_v = common.calc_loss_dqn(
            batch, net, tgt_net.target_model, gamma=params.gamma, device=device
        )
        loss_v.backward()
        optimizer.step()
        if engine.state.iteration % params.target_net_sync == 0:
            tgt_net.sync()
        return {
            "loss": loss_v.item(),
            "epsilon": batch_generator.epsilon,
        }

    engine = Engine(process_batch)
    ptan_ignite.EndOfEpisodeHandler(batch_generator, bound_avg_reward=17.0).attach(engine)
    fps_handler.attach(engine, manual_step=True)

    @engine.on(ptan_ignite.EpisodeEvents.EPISODE_COMPLETED)
    def episode_completed(trainer: Engine):
        print(
            "Episode %d: reward=%s, steps=%s, speed=%.3f frames/s, elapsed=%s"
            % (
                trainer.state.episode,
                trainer.state.episode_reward,
                trainer.state.episode_steps,
                trainer.state.metrics.get("avg_fps", 0),
                timedelta(seconds=trainer.state.metrics.get("time_passed", 0)),
            )
        )
Beispiel #11
0
def inference(config, local_rank, with_pbar_on_iters=True):

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = 'cuda'

    torch.backends.cudnn.benchmark = True

    # Load model and weights
    model_weights_filepath = Path(
        get_artifact_path(config.run_uuid, config.weights_filename))
    assert model_weights_filepath.exists(), \
        "Model weights file '{}' is not found".format(model_weights_filepath.as_posix())

    model = config.model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[local_rank],
                                                      output_device=local_rank)

    if hasattr(config, "custom_weights_loading"):
        config.custom_weights_loading(model, model_weights_filepath)
    else:
        state_dict = torch.load(model_weights_filepath)
        if not all([k.startswith("module.") for k in state_dict]):
            state_dict = {f"module.{k}": v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)

    model.eval()

    prepare_batch = config.prepare_batch
    non_blocking = getattr(config, "non_blocking", True)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    tta_transforms = getattr(config, "tta_transforms", None)

    def eval_update_function(engine, batch):
        with torch.no_grad():
            x, y, meta = prepare_batch(batch,
                                       device=device,
                                       non_blocking=non_blocking)

            if tta_transforms is not None:
                y_preds = []
                for t in tta_transforms:
                    t_x = t.augment_image(x)
                    t_y_pred = model(t_x)
                    t_y_pred = model_output_transform(t_y_pred)
                    y_pred = t.deaugment_mask(t_y_pred)
                    y_preds.append(y_pred)

                y_preds = torch.stack(y_preds, dim=0)
                y_pred = torch.mean(y_preds, dim=0)
            else:
                y_pred = model(x)
                y_pred = model_output_transform(y_pred)
            return {"y_pred": y_pred, "y": y, "meta": meta}

    evaluator = Engine(eval_update_function)

    has_targets = getattr(config, "has_targets", False)

    if has_targets:

        def output_transform(output):
            return output['y_pred'], output['y']

        num_classes = config.num_classes
        cm_metric = ConfusionMatrix(num_classes=num_classes,
                                    output_transform=output_transform)
        pr = cmPrecision(cm_metric, average=False)
        re = cmRecall(cm_metric, average=False)

        val_metrics = {
            "IoU": IoU(cm_metric),
            "mIoU_bg": mIoU(cm_metric),
            "Accuracy": cmAccuracy(cm_metric),
            "Precision": pr,
            "Recall": re,
            "F1": Fbeta(beta=1.0, output_transform=output_transform)
        }

        if hasattr(config, "metrics") and isinstance(config.metrics, dict):
            val_metrics.update(config.metrics)

        for name, metric in val_metrics.items():
            metric.attach(evaluator, name)

        if dist.get_rank() == 0:
            # Log val metrics:
            mlflow_logger = MLflowLogger()
            mlflow_logger.attach(evaluator,
                                 log_handler=OutputHandler(
                                     tag="validation",
                                     metric_names=list(val_metrics.keys())),
                                 event_name=Events.EPOCH_COMPLETED)

    if dist.get_rank() == 0 and with_pbar_on_iters:
        ProgressBar(persist=True, desc="Inference").attach(evaluator)

    if dist.get_rank() == 0:
        do_save_raw_predictions = getattr(config, "do_save_raw_predictions",
                                          True)
        do_save_overlayed_predictions = getattr(
            config, "do_save_overlayed_predictions", True)

        if not has_targets:
            assert do_save_raw_predictions or do_save_overlayed_predictions, \
                "If no targets, either do_save_overlayed_predictions or do_save_raw_predictions should be " \
                "defined in the config and has value equal True"

        # Save predictions
        if do_save_raw_predictions:
            raw_preds_path = config.output_path / "raw"
            raw_preds_path.mkdir(parents=True)

            evaluator.add_event_handler(Events.ITERATION_COMPLETED,
                                        save_raw_predictions_with_geoinfo,
                                        raw_preds_path)

        if do_save_overlayed_predictions:
            overlayed_preds_path = config.output_path / "overlay"
            overlayed_preds_path.mkdir(parents=True)

            evaluator.add_event_handler(
                Events.ITERATION_COMPLETED,
                save_overlayed_predictions,
                overlayed_preds_path,
                img_denormalize_fn=config.img_denormalize,
                palette=default_palette)

    evaluator.add_event_handler(Events.EXCEPTION_RAISED, report_exception)

    # Run evaluation
    evaluator.run(config.data_loader)
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint", type=str, default="gpt2", help="Path, url or short name of the model")
    parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training")
    parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size", type=int, default=16, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation")
    # parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Accumulate gradients on several steps")
    # parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient")
    parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)")

    parser.add_argument(
        "--init_model",
        default="model/pytorch_kogpt2_676e9bcfa7.params",
        type=str,
        help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
    )


    args = parser.parse_args()
    
    

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", args.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    
    config = GPT2Config(vocab_size=50000)
    model = GPT2DoubleHeadsModel(config)
    if args.init_model:
        print("Load model from ", args.init_model)
        model.load_state_dict(torch.load(args.init_model), strict=False)

    model.to(args.device)
    add_special_tokens_(model, tokenizer)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
        (lm_loss), (mc_loss), *_ = model(
            input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
            mc_labels=mc_labels, lm_labels=lm_labels
        )
        loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            lm_logits, mc_logits, *_ = model(
                input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])),
               "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.init_model)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        # tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(os.path.join(log_dir, checkpoint_handler._saved[-1][1]), os.path.join(log_dir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #13
0
def train():
    config_file = "configs/train_daily_dialog_emotion_action_topic_config.json"
    config = Config.from_json_file(config_file)


    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", config.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(config.device)
    optimizer = OpenAIAdam(model.parameters(), lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16)
    if config.distributed:
        model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(config, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple(input_tensor.to(config.device) for input_tensor in batch)
        lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids)
        loss = (lm_loss * config.lm_coef + mc_loss * config.mc_coef) / config.gradient_accumulation_steps
        if config.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
        if engine.state.iteration % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch
            #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids,
                                  token_emotion_ids=token_emotion_ids,
                                  token_action_ids=token_action_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if config.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if config.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if config.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr", [(0, config.lr), (config.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
               "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], config),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], config)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if config.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=config.log_dir)
        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        torch.save(config, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=config.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if config.local_rank in [-1, 0] and config.n_epochs > 0:
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #14
0
def main(local_rank):
    params = init_parms(local_rank)
    device = params.get('device')
    model = ASRModel(input_features=config.num_mel_banks,
                     num_classes=config.vocab_size).to(device)
    logger.info(
        f'Model initialized with {get_model_size(model):.3f}M parameters')
    optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5)
    model = DistributedDataParallel(model,
                                    device_ids=[local_rank],
                                    output_device=local_rank,
                                    check_reduction=True)
    load_checkpoint(model, optimizer, params)
    print(f"Loaded model on {local_rank}")
    start_epoch = params['start_epoch']
    sup_criterion = CustomCTCLoss()
    unsup_criterion = UDALoss()
    if args.local_rank == 0:
        tb_logger = TensorboardLogger(log_dir=log_path)
        pbar = ProgressBar(persist=True, desc="Training")
        pbar_valid = ProgressBar(persist=True, desc="Validation Clean")
        pbar_valid_other = ProgressBar(persist=True, desc="Validation Other")
        pbar_valid_airtel = ProgressBar(persist=True, desc="Validation Airtel")
        pbar_valid_airtel_payments = ProgressBar(
            persist=True, desc="Validation Airtel Payments")
        timer = Timer(average=True)
        best_meter = params.get('best_stats', BestMeter())

    trainCleanPath = os.path.join(lmdb_root_path, 'train-labelled')
    trainOtherPath = os.path.join(lmdb_root_path, 'train-unlabelled')
    trainCommonVoicePath = os.path.join(lmdb_commonvoice_root_path,
                                        'train-labelled-en')
    trainAirtelPath = os.path.join(lmdb_airtel_root_path, 'train-labelled-en')
    trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path,
                                           'train-labelled-en')
    testCleanPath = os.path.join(lmdb_root_path, 'test-clean')
    testOtherPath = os.path.join(lmdb_root_path, 'test-other')
    testAirtelPath = os.path.join(lmdb_airtel_root_path, 'test-labelled-en')
    testAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path,
                                          'test-labelled-en')
    devOtherPath = os.path.join(lmdb_root_path, 'dev-other')

    train_clean = lmdbMultiDataset(roots=[
        trainCleanPath, trainOtherPath, trainCommonVoicePath, trainAirtelPath,
        trainAirtelPaymentsPath
    ],
                                   transform=image_train_transform)
    train_other = lmdbMultiDataset(roots=[devOtherPath],
                                   transform=image_train_transform)

    test_clean = lmdbMultiDataset(roots=[testCleanPath],
                                  transform=image_val_transform)
    test_other = lmdbMultiDataset(roots=[testOtherPath],
                                  transform=image_val_transform)
    test_airtel = lmdbMultiDataset(roots=[testAirtelPath],
                                   transform=image_val_transform)
    test_payments_airtel = lmdbMultiDataset(roots=[testAirtelPaymentsPath],
                                            transform=image_val_transform)

    logger.info(
        f'Loaded Train & Test Datasets, train_labbeled={len(train_clean)}, train_unlabbeled={len(train_other)}, test_clean={len(test_clean)}, test_other={len(test_other)}, test_airtel={len(test_airtel)}, test_payments_airtel={len(test_payments_airtel)} examples'
    )

    def train_update_function(engine, _):
        optimizer.zero_grad()
        # Supervised gt, pred
        imgs_sup, labels_sup, label_lengths = next(
            engine.state.train_loader_labbeled)
        imgs_sup = imgs_sup.cuda(local_rank, non_blocking=True)
        labels_sup = labels_sup
        probs_sup = model(imgs_sup)

        # Unsupervised gt, pred
        # imgs_unsup, augmented_imgs_unsup = next(engine.state.train_loader_unlabbeled)
        # with torch.no_grad():
        #     probs_unsup = model(imgs_unsup.to(device))
        # probs_aug_unsup = model(augmented_imgs_unsup.to(device))

        sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths)
        # unsup_loss = unsup_criterion(probs_unsup, probs_aug_unsup)

        # Blend supervised and unsupervised losses till unsupervision_warmup_epoch
        # alpha = get_alpha(engine.state.epoch)
        # final_loss = ((1 - alpha) * sup_loss) + (alpha * unsup_loss)

        # final_loss = sup_loss
        sup_loss.backward()
        optimizer.step()

        return sup_loss.item()

    @torch.no_grad()
    def validate_update_function(engine, batch):
        img, labels, label_lengths = batch
        y_pred = model(img.cuda(local_rank, non_blocking=True))
        if np.random.rand() > 0.99:
            pred_sentences = get_most_probable(y_pred)
            labels_list = labels.tolist()
            idx = 0
            for i, length in enumerate(label_lengths.cpu().tolist()):
                pred_sentence = pred_sentences[i]
                gt_sentence = get_sentence(labels_list[idx:idx + length])
                idx += length
                print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}")
        return (y_pred, labels, label_lengths)

    train_sampler_labbeled = torch.utils.data.distributed.DistributedSampler(
        train_clean, num_replicas=3, rank=args.local_rank)
    train_sampler_unlabbeled = torch.utils.data.distributed.DistributedSampler(
        train_other, num_replicas=3, rank=args.local_rank)
    test_sampler_clean = torch.utils.data.distributed.DistributedSampler(
        test_clean, num_replicas=3, rank=args.local_rank, shuffle=False)
    test_sampler_other = torch.utils.data.distributed.DistributedSampler(
        test_other, num_replicas=3, rank=args.local_rank, shuffle=False)
    test_sampler_airtel = torch.utils.data.distributed.DistributedSampler(
        test_airtel, num_replicas=3, rank=args.local_rank, shuffle=False)
    test_sampler_airtel_payments = torch.utils.data.distributed.DistributedSampler(
        test_payments_airtel,
        num_replicas=3,
        rank=args.local_rank,
        shuffle=False)

    train_loader_labbeled_loader = torch.utils.data.DataLoader(
        train_clean,
        batch_size=train_batch_size // 3,
        sampler=train_sampler_labbeled,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    train_loader_unlabbeled_loader = torch.utils.data.DataLoader(
        train_other,
        batch_size=train_batch_size * 4,
        sampler=train_sampler_unlabbeled,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    test_loader_clean = torch.utils.data.DataLoader(
        test_clean,
        batch_size=1,
        sampler=test_sampler_clean,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    test_loader_other = torch.utils.data.DataLoader(
        test_other,
        batch_size=1,
        sampler=test_sampler_other,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    test_loader_airtel = torch.utils.data.DataLoader(
        test_airtel,
        batch_size=1,
        sampler=test_sampler_airtel,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    test_loader_airtel_payments = torch.utils.data.DataLoader(
        test_payments_airtel,
        batch_size=1,
        sampler=test_sampler_airtel_payments,
        num_workers=config.workers // 3,
        pin_memory=True,
        collate_fn=allign_collate)
    trainer = Engine(train_update_function)
    iteration_log_step = int(0.33 * len(train_loader_labbeled_loader))
    evaluator_clean = Engine(validate_update_function)
    evaluator_other = Engine(validate_update_function)
    evaluator_airtel = Engine(validate_update_function)
    evaluator_airtel_payments = Engine(validate_update_function)
    metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()}
    for name, metric in metrics.items():
        metric.attach(evaluator_clean, name)
        metric.attach(evaluator_other, name)
        metric.attach(evaluator_airtel, name)
        metric.attach(evaluator_airtel_payments, name)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config.lr_gamma,
        patience=int(config.epochs * 0.05),
        verbose=True,
        threshold_mode="abs",
        cooldown=int(config.epochs * 0.025),
        min_lr=1e-5)
    if args.local_rank == 0:
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             output_transform=lambda loss: {'loss': loss}),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(trainer,
                         log_handler=WeightsHistHandler(model),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=WeightsScalarHandler(model),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=GradsScalarHandler(model),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=GradsHistHandler(model),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(evaluator_clean,
                         log_handler=OutputHandler(tag="validation_clean",
                                                   metric_names=["wer", "cer"],
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(evaluator_other,
                         log_handler=OutputHandler(tag="validation_other",
                                                   metric_names=["wer", "cer"],
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(evaluator_airtel,
                         log_handler=OutputHandler(tag="validation_airtel",
                                                   metric_names=["wer", "cer"],
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
        tb_logger.attach(evaluator_airtel_payments,
                         log_handler=OutputHandler(
                             tag="validation_airtel_payments",
                             metric_names=["wer", "cer"],
                             another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)
        pbar.attach(trainer, output_transform=lambda x: {'loss': x})
        pbar_valid.attach(evaluator_clean, ['wer', 'cer'],
                          event_name=Events.EPOCH_COMPLETED,
                          closing_event_name=Events.COMPLETED)
        pbar_valid_other.attach(evaluator_other, ['wer', 'cer'],
                                event_name=Events.EPOCH_COMPLETED,
                                closing_event_name=Events.COMPLETED)
        pbar_valid_airtel.attach(evaluator_airtel, ['wer', 'cer'],
                                 event_name=Events.EPOCH_COMPLETED,
                                 closing_event_name=Events.COMPLETED)
        pbar_valid_airtel_payments.attach(evaluator_airtel_payments,
                                          ['wer', 'cer'],
                                          event_name=Events.EPOCH_COMPLETED,
                                          closing_event_name=Events.COMPLETED)
        timer.attach(trainer)

    @trainer.on(Events.STARTED)
    def set_init_epoch(engine):
        engine.state.epoch = params['start_epoch']
        logger.info(f'Initial epoch for trainer set to {engine.state.epoch}')

    @trainer.on(Events.EPOCH_STARTED)
    def set_model_train(engine):
        if hasattr(engine.state, 'train_loader_labbeled'):
            del engine.state.train_loader_labbeled
        engine.state.train_loader_labbeled = iter(train_loader_labbeled_loader)
        # engine.state.train_loader_unlabbeled = iter(train_loader_unlabbeled_loader)

    @trainer.on(Events.ITERATION_COMPLETED)
    def iteration_completed(engine):
        if (engine.state.iteration % iteration_log_step
                == 0) and (engine.state.iteration > 0):
            engine.state.epoch += 1
            train_clean.set_epochs(engine.state.epoch)
            train_other.set_epochs(engine.state.epoch)
            model.eval()
            logger.info('Model set to eval mode')
            evaluator_clean.run(test_loader_clean)
            evaluator_other.run(test_loader_other)
            evaluator_airtel.run(test_loader_airtel)
            evaluator_airtel_payments.run(test_loader_airtel_payments)
            model.train()
            logger.info('Model set back to train mode')

    if args.local_rank == 0:

        @evaluator_other.on(Events.EPOCH_COMPLETED)
        def save_checkpoints(engine):
            metrics = engine.state.metrics
            wer = metrics['wer']
            cer = metrics['cer']
            epoch = trainer.state.epoch
            scheduler.step(wer)
            save_checkpoint(model, optimizer, best_meter, wer, cer, epoch)
            best_meter.update(wer, cer, epoch)

        @trainer.on(Events.EPOCH_COMPLETED)
        def after_complete(engine):
            logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format(
                engine.state.epoch, timer.value()))
            timer.reset()

    trainer.run(train_loader_labbeled_loader, max_epochs=epochs)
    if args.local_rank == 0:
        tb_logger.close()
Beispiel #15
0
def test_run_with_max_iters_greater_than_epoch_length():
    max_iters = 73
    engine = Engine(lambda e, b: 1)
    engine.run([0] * 20, max_iters=max_iters)
    assert engine.state.iteration == max_iters
Beispiel #16
0
def test_linear_scheduler():
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0)

    scheduler = LinearCyclicalScheduler(optimizer, 'lr', 1, 0, 10)
    lrs = []

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]['lr'])

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
    trainer.run([0] * 10, max_epochs=2)

    assert lrs == list(
        map(
            pytest.approx,
            [
                # Cycle 1
                1.0,
                0.8,
                0.6,
                0.4,
                0.2,
                0.0,
                0.2,
                0.4,
                0.6,
                0.8,
                # Cycle 2
                1.0,
                0.8,
                0.6,
                0.4,
                0.2,
                0.0,
                0.2,
                0.4,
                0.6,
                0.8,
            ]))

    optimizer = torch.optim.SGD([tensor], lr=0)
    scheduler = LinearCyclicalScheduler(optimizer,
                                        'lr',
                                        1,
                                        0,
                                        10,
                                        cycle_mult=2)

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    lrs = []
    trainer.run([0] * 10, max_epochs=3)

    assert lrs == list(
        map(
            pytest.approx,
            [
                # Cycle 1
                1.0,
                0.8,
                0.6,
                0.4,
                0.2,
                0.0,
                0.2,
                0.4,
                0.6,
                0.8,
                # Cycle 2
                1.0,
                0.9,
                0.8,
                0.7,
                0.6,
                0.5,
                0.4,
                0.3,
                0.2,
                0.1,
                0.0,
                0.1,
                0.2,
                0.3,
                0.4,
                0.5,
                0.6,
                0.7,
                0.8,
                0.9,
            ]))

    # With float cycle_size
    optimizer = torch.optim.SGD([tensor], lr=0)
    scheduler = LinearCyclicalScheduler(optimizer,
                                        'lr',
                                        start_value=1.2,
                                        end_value=0.2,
                                        cycle_size=10.00000012,
                                        cycle_mult=1.0)

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    lrs = []
    trainer.run([0] * 10, max_epochs=2)
    assert lrs == list(
        map(
            pytest.approx,
            [
                # Cycle 1
                1.2,
                1.0,
                0.8,
                0.6,
                0.4,
                0.2,
                0.4,
                0.6,
                0.8,
                1.0,
                # Cycle 2
                1.2,
                1.0,
                0.8,
                0.6,
                0.4,
                0.2,
                0.4,
                0.6,
                0.8,
                1.0,
            ]))
Beispiel #17
0
def main(
    architecture,
    batch_size,
    length_scale,
    centroid_size,
    learning_rate,
    l_gradient_penalty,
    gamma,
    weight_decay,
    final_model,
    output_dir,
):
    writer = SummaryWriter(log_dir=f"runs/{output_dir}")

    ds = all_datasets["CIFAR10"]()
    input_size, num_classes, dataset, test_dataset = ds

    # Split up training set
    idx = list(range(len(dataset)))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        val_size = int(len(dataset) * 0.8)
        train_dataset = torch.utils.data.Subset(dataset, idx[:val_size])
        val_dataset = torch.utils.data.Subset(dataset, idx[val_size:])

        val_dataset.transform = (test_dataset.transform
                                 )  # Test time preprocessing for validation

    if architecture == "WRN":
        model_output_size = 640
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = WideResNet()
    elif architecture == "ResNet18":
        model_output_size = 512
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet18()
    elif architecture == "ResNet50":
        model_output_size = 2048
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet50()
    elif architecture == "ResNet110":
        model_output_size = 2048
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = resnet110()
    elif architecture == "DenseNet121":
        model_output_size = 1024
        epochs = 200
        milestones = [60, 120, 160]
        feature_extractor = densenet121()

        # Adapted resnet from:
        # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
        feature_extractor.conv1 = torch.nn.Conv2d(3,
                                                  64,
                                                  kernel_size=3,
                                                  stride=1,
                                                  padding=1,
                                                  bias=False)
        feature_extractor.maxpool = torch.nn.Identity()
        feature_extractor.fc = torch.nn.Identity()

    if centroid_size is None:
        centroid_size = model_output_size

    model = ResNet_DUQ(
        feature_extractor,
        num_classes,
        centroid_size,
        model_output_size,
        length_scale,
        gamma,
    )
    model = model.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=milestones,
                                                     gamma=0.2)

    def calc_gradients_input(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        return gradients

    def calc_gradient_penalty(x, y_pred):
        gradients = calc_gradients_input(x, y_pred)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1)**2).mean()

        return gradient_penalty

    def step(engine, batch):
        model.train()

        optimizer.zero_grad()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        y = F.one_hot(y, num_classes).float()

        loss = F.binary_cross_entropy(y_pred, y, reduction="mean")

        if l_gradient_penalty > 0:
            gp = calc_gradient_penalty(x, y_pred)
            loss += l_gradient_penalty * gp

        loss.backward()
        optimizer.step()

        x.requires_grad_(False)

        with torch.no_grad():
            model.eval()
            model.update_embeddings(x, y)

        return loss.item()

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x, y = x.cuda(), y.cuda()

        x.requires_grad_(True)

        y_pred = model(x)

        return {"x": x, "y": y, "y_pred": y_pred}

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    metric = Average()
    metric.attach(trainer, "loss")

    metric = Accuracy(output_transform=lambda out: (out["y_pred"], out["y"]))
    metric.attach(evaluator, "accuracy")

    def bce_output_transform(out):
        return (out["y_pred"], F.one_hot(out["y"], num_classes).float())

    metric = Loss(F.binary_cross_entropy,
                  output_transform=bce_output_transform)
    metric.attach(evaluator, "bce")

    metric = Loss(calc_gradient_penalty,
                  output_transform=lambda out: (out["x"], out["y_pred"]))
    metric.attach(evaluator, "gradient_penalty")

    pbar = ProgressBar(dynamic_ncols=True)
    pbar.attach(trainer)

    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             **kwargs)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              **kwargs)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_results(trainer):
        metrics = trainer.state.metrics
        loss = metrics["loss"]

        print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f}")

        writer.add_scalar("Loss/train", loss, trainer.state.epoch)

        if trainer.state.epoch > (epochs - 5):
            accuracy, auroc = get_cifar_svhn_ood(model)
            print(f"Test Accuracy: {accuracy}, AUROC: {auroc}")
            writer.add_scalar("OoD/test_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch)

            accuracy, auroc = get_auroc_classification(val_dataset, model)
            print(f"AUROC - uncertainty: {auroc}")
            writer.add_scalar("OoD/val_accuracy", accuracy,
                              trainer.state.epoch)
            writer.add_scalar("OoD/roc_auc_classification", auroc,
                              trainer.state.epoch)

        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        acc = metrics["accuracy"]
        bce = metrics["bce"]
        GP = metrics["gradient_penalty"]
        loss = bce + l_gradient_penalty * GP

        print((f"Valid - Epoch: {trainer.state.epoch} "
               f"Acc: {acc:.4f} "
               f"Loss: {loss:.2f} "
               f"BCE: {bce:.2f} "
               f"GP: {GP:.2f} "))

        writer.add_scalar("Loss/valid", loss, trainer.state.epoch)
        writer.add_scalar("BCE/valid", bce, trainer.state.epoch)
        writer.add_scalar("GP/valid", GP, trainer.state.epoch)
        writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch)

        scheduler.step()

    trainer.run(train_loader, max_epochs=epochs)
    evaluator.run(test_loader)
    acc = evaluator.state.metrics["accuracy"]

    print(f"Test - Accuracy {acc:.4f}")

    torch.save(model.state_dict(), f"runs/{output_dir}/model.pt")
    writer.close()
Beispiel #18
0
def test_concat_scheduler():
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0)

    scheduler_1 = LinearCyclicalScheduler(optimizer,
                                          "lr",
                                          start_value=1.0,
                                          end_value=0.0,
                                          cycle_size=10)
    scheduler_2 = CosineAnnealingScheduler(optimizer,
                                           "lr",
                                           start_value=0.0,
                                           end_value=1.0,
                                           cycle_size=10)
    durations = [
        10,
    ]

    concat_scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2],
                                       durations=durations,
                                       save_history=True)

    data = [0] * 10
    max_epochs = 2
    simulated_values = ConcatScheduler.simulate_values(
        num_events=len(data) * max_epochs,
        schedulers=[scheduler_1, scheduler_2],
        durations=durations)

    lrs = []

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]['lr'])

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, concat_scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
    trainer.run(data, max_epochs=max_epochs)

    assert lrs == list(
        map(
            pytest.approx,
            [
                # Cycle 1 of the LinearCyclicalScheduler
                1.0,
                0.8,
                0.6,
                0.4,
                0.2,
                0.0,
                0.2,
                0.4,
                0.6,
                0.8,
                # Cycle 1 of the CosineAnnealingScheduler
                0.0,
                0.02447174185242318,
                0.09549150281252627,
                0.20610737385376332,
                0.3454915028125263,
                0.5,
                0.6545084971874737,
                0.7938926261462365,
                0.9045084971874737,
                0.9755282581475768,
            ]))

    state_lrs = trainer.state.param_history['lr']
    assert len(state_lrs) == len(lrs)
    # Unpack singleton lists
    assert [group[0] for group in state_lrs] == lrs

    assert lrs == pytest.approx([v for i, v in simulated_values])
Beispiel #19
0
def test_attach_fail_with_string():
    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(TypeError):
        pbar.attach(engine, 'a')
Beispiel #20
0
def test_concat_scheduler_3_schedulers():
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0)

    scheduler_1 = LinearCyclicalScheduler(optimizer,
                                          "lr",
                                          start_value=1.0,
                                          end_value=0.5,
                                          cycle_size=20)
    scheduler_2 = LinearCyclicalScheduler(optimizer,
                                          "lr",
                                          start_value=0.5,
                                          end_value=0.45,
                                          cycle_size=10)
    scheduler_3 = LinearCyclicalScheduler(optimizer,
                                          "lr",
                                          start_value=0.5,
                                          end_value=0.0,
                                          cycle_size=20)
    durations = [10, 5]

    concat_scheduler = ConcatScheduler(
        schedulers=[scheduler_1, scheduler_2, scheduler_3],
        durations=durations,
        save_history=True)

    data = [0] * 10
    max_epochs = 2
    simulated_values = ConcatScheduler.simulate_values(
        num_events=len(data) * max_epochs,
        schedulers=[scheduler_1, scheduler_2, scheduler_3],
        durations=durations)
    lrs = []

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]['lr'])

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, concat_scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
    trainer.run(data, max_epochs=max_epochs)

    assert lrs == list(
        map(
            pytest.approx,
            [
                # Cycle 1 of the first LinearCyclicalScheduler
                1.0,
                0.95,
                0.9,
                0.85,
                0.8,
                0.75,
                0.7,
                0.65,
                0.6,
                0.55,
                # Cycle 1 of the second LinearCyclicalScheduler
                0.5,
                0.49,
                0.48,
                0.47,
                0.46,
                # Cycle 1 of the third LinearCyclicalScheduler
                0.5,
                0.45,
                0.4,
                0.35,
                0.3,
            ]))

    state_lrs = trainer.state.param_history['lr']
    assert len(state_lrs) == len(lrs)
    # Unpack singleton lists
    assert [group[0] for group in state_lrs] == lrs

    assert lrs == pytest.approx([v for i, v in simulated_values])
Beispiel #21
0
def test_integration():

    n_iters = 100
    batch_size = 10
    n_classes = 10
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    loss_values = iter(range(n_iters))

    def update_fn(engine, batch):
        loss_value = next(loss_values)
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return (
            loss_value,
            torch.from_numpy(y_pred_batch),
            torch.from_numpy(y_true_batch),
        )

    trainer = Engine(update_fn)
    alpha = 0.98

    acc_metric = RunningAverage(
        Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha)
    acc_metric.attach(trainer, "running_avg_accuracy")

    avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha)
    avg_output.attach(trainer, "running_avg_output")

    running_avg_acc = [
        None,
    ]

    @trainer.on(Events.ITERATION_COMPLETED)
    def manual_running_avg_acc(engine):
        _, y_pred, y = engine.state.output
        indices = torch.max(y_pred, 1)[1]
        correct = torch.eq(indices, y).view(-1)
        num_correct = torch.sum(correct).item()
        num_examples = correct.shape[0]
        batch_acc = num_correct * 1.0 / num_examples
        if running_avg_acc[0] is None:
            running_avg_acc[0] = batch_acc
        else:
            running_avg_acc[0] = running_avg_acc[0] * alpha + (
                1.0 - alpha) * batch_acc
        engine.state.running_avg_acc = running_avg_acc[0]

    @trainer.on(Events.EPOCH_STARTED)
    def running_avg_output_init(engine):
        engine.state.running_avg_output = None

    @trainer.on(Events.ITERATION_COMPLETED)
    def running_avg_output_update(engine):
        if engine.state.running_avg_output is None:
            engine.state.running_avg_output = engine.state.output[0]
        else:
            engine.state.running_avg_output = (
                engine.state.running_avg_output * alpha +
                (1.0 - alpha) * engine.state.output[0])

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_acc_values(engine):
        assert (engine.state.running_avg_acc == engine.state.
                metrics["running_avg_accuracy"]), "{} vs {}".format(
                    engine.state.running_avg_acc,
                    engine.state.metrics["running_avg_accuracy"])

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_output_values(engine):
        assert (engine.state.running_avg_output ==
                engine.state.metrics["running_avg_output"]), "{} vs {}".format(
                    engine.state.running_avg_output,
                    engine.state.metrics["running_avg_output"])

    np.random.seed(10)
    running_avg_acc = [
        None,
    ]
    n_iters = 10
    batch_size = 10
    n_classes = 10
    data = list(range(n_iters))
    loss_values = iter(range(n_iters))
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    trainer.run(data, max_epochs=1)

    running_avg_acc = [
        None,
    ]
    n_iters = 10
    batch_size = 10
    n_classes = 10
    data = list(range(n_iters))
    loss_values = iter(range(n_iters))
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    trainer.run(data, max_epochs=1)
Beispiel #22
0
def test_terminate():
    engine = Engine(lambda e, b: 1)
    assert not engine.should_terminate
    engine.terminate()
    assert engine.should_terminate
Beispiel #23
0
def create_setops_trainer(base_model,
                          classifier,
                          setops_model,
                          optimizer,
                          criterion1,
                          criterion2,
                          params_object,
                          metrics={},
                          device=None):
    """
    Factory function for creating a trainer for supervised models

    Args:
        model (`torch.nn.Module`): the model to train
        optimizer (`torch.optim.Optimizer`): the optimizer to use
        loss_fn (torch.nn loss function): the loss function to use
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.

    Returns:
        Engine: a trainer engine with supervised update function
    """
    if device:
        base_model.to(device)
        classifier.to(device)
        setops_model.to(device)

    def _update(engine, batch):

        if params_object.train_base:
            base_model.train()
        else:
            base_model.eval()

        classifier.train()
        setops_model.train()

        optimizer.zero_grad()

        input_a, input_b, target_a, target_b = _prepare_batch(batch,
                                                              device=device)

        #
        # Apply the classification model
        #
        with conditional(not params_object.train_base, torch.no_grad()):
            embed_a = base_model(input_a)
            embed_b = base_model(input_b)

        output_a = classifier(embed_a)
        output_b = classifier(embed_b)

        #
        # Apply the setopt model.
        #
        outputs_setopt = setops_model(embed_a, embed_b)
        fake_a, fake_b, a_S_b, b_S_a, a_U_b, b_U_a, a_I_b, b_I_a, \
        a_S_b_b, b_S_a_a, a_I_b_b, b_I_a_a, a_U_b_b, b_U_a_a, \
        a_S_b_I_a, b_S_a_I_b, a_S_a_I_b, b_S_b_I_a = \
                    [classifier(o) for o in outputs_setopt]
        fake_a_em, fake_b_em, a_S_b_em, b_S_a_em, a_U_b_em, b_U_a_em, a_I_b_em, b_I_a_em, \
        a_S_b_b_em, b_S_a_a_em, a_I_b_b_em, b_I_a_a_em, a_U_b_b_em, b_U_a_a_em, \
        a_S_b_I_a_em, b_S_a_I_b_em, a_S_a_I_b_em, b_S_b_I_a_em = outputs_setopt

        loss_class = criterion1(output_a, target_a) + criterion1(
            output_b, target_b)
        loss_class_out = criterion1(fake_a, target_a) + criterion1(
            fake_b, target_b)
        if params_object.mc_toggle:
            loss_recon = criterion2(embed_a, fake_a_em) + criterion2(
                embed_b, fake_b_em)
            return_loss_recon = loss_recon.item()
        else:
            loss_recon = 0
            return_loss_recon = 0

        #
        # Calculate the target setopt operations
        #
        target_a = target_a.type(torch.cuda.ByteTensor)
        target_b = target_b.type(torch.cuda.ByteTensor)

        target_a_I_b = target_a & target_b
        target_a_U_b = target_a | target_b
        target_a_S_b = target_a & ~target_a_I_b
        target_b_S_a = target_b & ~target_a_I_b

        target_a_I_b = target_a_I_b.type(torch.cuda.FloatTensor)
        target_a_U_b = target_a_U_b.type(torch.cuda.FloatTensor)
        target_a_S_b = target_a_S_b.type(torch.cuda.FloatTensor)
        target_b_S_a = target_b_S_a.type(torch.cuda.FloatTensor)

        loss_class_S = criterion1(a_S_b, target_a_S_b) + criterion1(
            b_S_a, target_b_S_a)
        loss_class_U = criterion1(a_U_b, target_a_U_b)
        loss_class_I = criterion1(a_I_b, target_a_I_b)
        if params_object.tautology_class_toggle:
            loss_class_S += criterion1(a_S_b_b, target_a_S_b) + criterion1(
                b_S_a_a, target_b_S_a)
            loss_class_S += criterion1(a_S_a_I_b, target_a_S_b) + criterion1(b_S_a_I_b, target_b_S_a) +\
                            criterion1(b_S_b_I_a, target_b_S_a) + criterion1(a_S_b_I_a, target_a_S_b)
            loss_class_U += criterion1(a_U_b_b, target_a_U_b) + criterion1(
                b_U_a_a, target_a_U_b)
            loss_class_I += criterion1(a_I_b_b, target_a_I_b) + criterion1(
                b_I_a_a, target_a_I_b)

        if params_object.tautology_recon_toggle:
            loss_recon_S = criterion2(a_S_b_em, a_S_b_b_em) + criterion2(a_S_b_em, a_S_a_I_b_em) + \
                           criterion2(a_S_b_em, a_S_b_I_a_em)
            loss_recon_S += criterion2(b_S_a_em, b_S_a_a_em) + criterion2(b_S_a_em, b_S_a_I_b_em) + \
                            criterion2(b_S_a_em, b_S_b_I_a_em)
            return_recon_S = loss_recon_S.item()
        else:
            loss_recon_S = 0
            return_recon_S = 0

        if params_object.sym_class_toggle:
            loss_class_U += criterion1(b_U_a, target_a_U_b)
            loss_class_I += criterion1(b_I_a, target_a_I_b)

        if params_object.sym_recon_toggle:
            loss_recon_U = criterion2(a_U_b_em, b_U_a_em)
            loss_recon_I = criterion2(a_I_b_em, b_I_a_em)
            return_recon_U = loss_recon_U.item()
            return_recon_I = loss_recon_I.item()
        else:
            loss_recon_U = 0
            loss_recon_I = 0
            return_recon_U = 0
            return_recon_I = 0

        loss = loss_class
        loss += 0 if params_object.class_fake_loss_weight == 0 else params_object.class_fake_loss_weight * loss_class_out
        loss += 0 if (params_object.recon_loss_weight == 0) or (
            not loss_recon) else params_object.recon_loss_weight * loss_recon
        loss += 0 if params_object.class_S_loss_weight == 0 else params_object.class_S_loss_weight * loss_class_S
        loss += 0 if (params_object.recon_loss_weight == 0) or (
            not loss_recon_I
        ) else params_object.recon_loss_weight * loss_recon_S
        loss += 0 if params_object.class_U_loss_weight == 0 else params_object.class_U_loss_weight * loss_class_U
        loss += 0 if (params_object.recon_loss_weight == 0) or (
            not loss_recon_U
        ) else params_object.recon_loss_weight * loss_recon_U
        loss += 0 if params_object.class_I_loss_weight == 0 else params_object.class_I_loss_weight * loss_class_I
        loss += 0 if (params_object.recon_loss_weight == 0) or (
            not loss_recon_I
        ) else params_object.recon_loss_weight * loss_recon_I

        loss.backward()
        optimizer.step()

        return {
            "main": loss.item(),
            "real class": loss_class.item(),
            "fake class": loss_class_out.item(),
            "fake MSE": return_loss_recon,
            "S MSE": return_recon_S,
            "U MSE": return_recon_U,
            "I MSE": return_recon_I,
            "S class": loss_class_S.item(),
            "U class": loss_class_U.item(),
            "I class": loss_class_I.item()
        }

    engine = Engine(_update)

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine
Beispiel #24
0
 def get_engine():
     engine = Engine(sum_data)
     average = Average()
     average.attach(engine, "average")
     return engine
Beispiel #25
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    cutmix_beta = config["cutmix_beta"]
    cutmix_prob = config["cutmix_prob"]
    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            r = torch.rand(1).item()
            if cutmix_beta > 0 and r < cutmix_prob:
                output, loss = utils.cutmix_forward(model, x, criterion, y,
                                                    cutmix_beta)
            else:
                output = model(x)
                loss = criterion(output, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()

        if idist.backend() == "horovod":
            optimizer.synchronize()
            with optimizer.skip_synchronize():
                scaler.step(optimizer)
                scaler.update()
        else:
            scaler.step(optimizer)
            scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    if config["with_pbar"] and idist.get_rank() == 0:
        ProgressBar().attach(trainer)

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Beispiel #26
0
def test_stopping_criterion_is_max_epochs():
    engine = Engine(MagicMock(return_value=1))
    max_epochs = 5
    state = engine.run([1], max_epochs=max_epochs)
    assert state.epoch == max_epochs
Beispiel #27
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--train_path",
                        type=str,
                        default="data/train_set4DSTC7-AVSD.json",
                        help="Path of the trainset")
    parser.add_argument("--fea_path",
                        type=str,
                        default="data/",
                        help="Path of the trainset")
    parser.add_argument("--valid_path",
                        type=str,
                        default="data/valid_set4DSTC7-AVSD.json",
                        help="Path of the validset")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="gpt2",
                        help="Path, url or short name of the model")
    parser.add_argument("--max_history",
                        type=int,
                        default=3,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--drop_rate",
                        type=float,
                        default=0.5,
                        help="drop rate for caption")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=8,
                        help="Number of training epochs")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--log_path",
                        type=str,
                        default="log/",
                        help="Log path")
    args = parser.parse_args()

    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)
    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = VideoGPT2LMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders_new(args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, token_type_ids, labels, input_mask, i3d, video_mask, reply_mask = batch
        input_embs = model.transformer.wte(input_ids)
        video_embs = model.video_ff(i3d)
        input_embs = torch.cat([video_embs, input_embs], dim=1)
        token_type_ids = torch.cat([
            torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
            tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids
        ],
                                   dim=1)
        video_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[video_mask, input_mask],
                           mode="video")[0]
        reply_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[reply_mask, input_mask],
                           mode="reply")[0]
        loss = (video_loss + reply_loss) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, token_type_ids, lm_labels, input_mask, i3d, video_mask, reply_mask = batch
            input_embs = model.transformer.wte(input_ids)
            video_embs = model.video_ff(i3d)
            input_embs = torch.cat([video_embs, input_embs], dim=1)
            token_type_ids = torch.cat([
                torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
                tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]),
                token_type_ids
            ],
                                       dim=1)
            model_outputs = model(input_embs,
                                  token_type_ids=token_type_ids,
                                  attention_mask=[reply_mask, input_mask])[0]

            lm_logits = model_outputs  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return lm_logits_flat_shifted, lm_labels_flat_shifted

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0], x[1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir="./tb_logs")
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(args.log_path,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=8,
                                             require_empty=False)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, args.log_path + 'model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(args.log_path, CONFIG_NAME))
        tokenizer.save_vocabulary(args.log_path)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(args.log_path, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #28
0
def test_run_with_max_iters():
    max_iters = 8
    engine = Engine(lambda e, b: 1)
    engine.run([0] * 20, max_iters=max_iters)
    assert engine.state.iteration == max_iters
    assert engine.state.max_iters == max_iters
Beispiel #29
0
def _test_setup_logging(
    setup_logging_fn,
    kwargs_dict,
    output_handler_cls,
    opt_params_handler_cls,
    with_eval=True,
    with_optim=True,
    as_class=False,
    log_every_iters=1,
):
    trainer = Engine(lambda e, b: b)
    evaluators = None
    optimizers = None

    if with_eval:
        evaluator = Engine(lambda e, b: None)
        acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.3, 0.2, 0.1, 0.1, 0.0]

        @trainer.on(Events.EPOCH_COMPLETED)
        def validate(engine):
            evaluator.run([0, 1])

        @evaluator.on(Events.EPOCH_COMPLETED)
        def set_eval_metric(engine):
            engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]}

        evaluators = {"validation": evaluator}
        if as_class:
            evaluators = evaluators["validation"]

    if with_optim:
        t = torch.tensor([
            0,
        ])
        optimizers = {
            "optimizer": torch.optim.SGD([
                t,
            ], lr=0.01)
        }
        if as_class:
            optimizers = optimizers["optimizer"]

    kwargs_dict["trainer"] = trainer
    kwargs_dict["optimizers"] = optimizers
    kwargs_dict["evaluators"] = evaluators
    kwargs_dict["log_every_iters"] = log_every_iters

    x_logger = setup_logging_fn(**kwargs_dict)

    handlers = trainer._event_handlers[Events.ITERATION_COMPLETED]
    for cls in [
            output_handler_cls,
    ]:
        assert any([isinstance(h[0], cls)
                    for h in handlers]), "{}".format(handlers)

    if with_optim:
        handlers = trainer._event_handlers[Events.ITERATION_STARTED]
        for cls in [
                opt_params_handler_cls,
        ]:
            assert any([isinstance(h[0], cls)
                        for h in handlers]), "{}".format(handlers)

    if with_eval:
        handlers = evaluator._event_handlers[Events.COMPLETED]
        for cls in [
                output_handler_cls,
        ]:
            assert any([isinstance(h[0], cls)
                        for h in handlers]), "{}".format(handlers)

    data = [0, 1, 2]
    trainer.run(data, max_epochs=10)

    if "output_path" in kwargs_dict:
        tb_files = list(os.listdir(kwargs_dict["output_path"]))
        assert len(tb_files) == 1
        for v in [
                "events",
        ]:
            assert any([v in c for c in tb_files]), "{}".format(tb_files)

    return x_logger
Beispiel #30
0
def test_default_exception_handler():
    update_function = MagicMock(side_effect=ValueError())
    engine = Engine(update_function)

    with raises(ValueError):
        engine.run([1])