Пример #1
0
def create_baseline_trainer(model, optimizer=None, name='train', device=None):

    if device is not None:
        model.to(device)

    is_train = optimizer is not None

    def _update(engine, batch):
        model.train(is_train)

        with torch.set_grad_enabled(is_train):
            images, labels = convert_tensor(batch, device=device)
            preds = model(images)
            loss = F.cross_entropy(preds, labels)

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return {'loss': loss.item(), 'y_pred': preds, 'y': labels}

    engine = Engine(_update)
    engine.name = name
    metrics.Average(lambda o: o['loss']).attach(engine, 'single_loss')
    metrics.Accuracy(lambda o: (o['y_pred'], o['y'])).attach(
        engine, 'single_acc')
    return engine
Пример #2
0
 def __init__(
     self,
     job_dir,
     architecture,
     dim_input,
     dim_hidden,
     dim_output,
     depth,
     negative_slope,
     batch_norm,
     spectral_norm,
     dropout_rate,
     num_examples,
     learning_rate,
     batch_size,
     epochs,
     patience,
     num_workers,
     seed,
 ):
     super(DragonNet, self).__init__(
         job_dir=job_dir,
         num_examples=num_examples,
         learning_rate=learning_rate,
         batch_size=batch_size,
         epochs=epochs,
         seed=seed,
         num_workers=num_workers,
     )
     self.network = dragonnet.DragonNet(
         architecture=architecture,
         dim_input=dim_input,
         dim_hidden=dim_hidden,
         dim_output=dim_output,
         depth=depth,
         negative_slope=negative_slope,
         batch_norm=batch_norm,
         dropout_rate=dropout_rate,
         spectral_norm=spectral_norm,
     )
     self.metrics = {
         "loss":
         metrics.Average(
             output_transform=lambda x: -x["outputs"][0].log_prob(x[
                 "targets"]).mean() - x["outputs"][1].log_prob(x[
                     "treatments"]).mean(),
             device=self.device,
         )
     }
     self.batch_size = batch_size
     self.best_loss = 1e7
     self.patience = patience
     self.optimizer = optim.Adam(
         params=self.network.parameters(),
         lr=self.learning_rate,
         weight_decay=(0.5 * (1 - dropout_rate)) / num_examples,
     )
     self.network.to(self.device)
Пример #3
0
        lr = Pipeline([("scaler", StandardScaler()),
                       ("lr", LogisticRegression(max_iter=10000))])
        lr.fit(z_tr, y_tr)
        acc = lr.score(z_ts, y_ts)
        return {
            "y": y,
            "loss": l,
            "y_pred": y_probs,
            "y_probs": y_probs,
            "lr_acc": acc
        }

    eval_engine = Engine(batch_eval)

    metrics.Accuracy().attach(eval_engine, "accuracy")
    metrics.Average().attach(train_engine, "average_loss")
    metrics.Average(output_transform=lambda x: x["lr_acc"]).attach(
        eval_engine, "lr_acc")
    metrics.Average(output_transform=lambda x: x["loss"]).attach(
        eval_engine, "average_loss")

    @eval_engine.on(Events.EPOCH_COMPLETED)
    def log_tboard(engine):
        tb.add_scalar(
            "train/loss",
            train_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        tb.add_scalar(
            "eval/loss",
            eval_engine.state.metrics["average_loss"],
Пример #4
0
def create_sla_trainer(model,
                       transform,
                       optimizer=None,
                       with_large_loss=False,
                       name='train',
                       device=None):

    if device is not None:
        model.to(device)

    is_train = optimizer is not None

    def _update(engine, batch):
        model.train(is_train)

        with torch.set_grad_enabled(is_train):
            images, labels = convert_tensor(batch, device=device)
            batch_size = images.shape[0]
            images = transform(model, images, labels)
            n = images.shape[0] // batch_size

            preds = model(images)
            labels = torch.stack([labels * n + i for i in range(n)],
                                 1).view(-1)
            loss = F.cross_entropy(preds, labels)
            if with_large_loss:
                loss = loss * n

            single_preds = preds[::n, ::n]
            single_labels = labels[::n] // n

            agg_preds = 0
            for i in range(n):
                agg_preds = agg_preds + preds[i::n, i::n] / n

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return {
            'loss': loss.item(),
            'preds': preds,
            'labels': labels,
            'single_preds': single_preds,
            'single_labels': single_labels,
            'agg_preds': agg_preds,
        }

    engine = Engine(_update)
    engine.name = name

    metrics.Average(lambda o: o['loss']).attach(engine, 'total_loss')
    metrics.Accuracy(lambda o: (o['preds'], o['labels'])).attach(
        engine, 'total_acc')

    metrics.Average(lambda o: F.cross_entropy(o['single_preds'], o[
        'single_labels'])).attach(engine, 'single_loss')
    metrics.Accuracy(lambda o: (o['single_preds'], o['single_labels'])).attach(
        engine, 'single_acc')

    metrics.Average(
        lambda o: F.cross_entropy(o['agg_preds'], o['single_labels'])).attach(
            engine, 'agg_loss')
    metrics.Accuracy(lambda o: (o['agg_preds'], o['single_labels'])).attach(
        engine, 'agg_acc')

    return engine
Пример #5
0
def create_sla_sd_trainer(model,
                          transform,
                          optimizer=None,
                          T=1.0,
                          with_large_loss=False,
                          name='train',
                          device=None):

    if device is not None:
        model.to(device)

    is_train = optimizer is not None

    def _update(engine, batch):
        model.train(is_train)

        with torch.set_grad_enabled(is_train):
            images, single_labels = convert_tensor(batch, device=device)
            batch_size = images.shape[0]
            images = transform(model, images, single_labels)
            n = images.shape[0] // batch_size

            joint_preds, single_preds = model(images, None)
            single_preds = single_preds[::n]
            joint_labels = torch.stack(
                [single_labels * n + i for i in range(n)], 1).view(-1)

            joint_loss = F.cross_entropy(joint_preds, joint_labels)
            single_loss = F.cross_entropy(single_preds, single_labels)
            if with_large_loss:
                joint_loss = joint_loss * n

            agg_preds = 0
            for i in range(n):
                agg_preds = agg_preds + joint_preds[i::n, i::n] / n

            distillation_loss = F.kl_div(F.log_softmax(single_preds / T, 1),
                                         F.softmax(agg_preds.detach() / T, 1),
                                         reduction='batchmean')

            loss = joint_loss + single_loss + distillation_loss.mul(T**2)

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        return {
            'loss': loss.item(),
            'preds': joint_preds,
            'labels': joint_labels,
            'single_preds': single_preds,
            'single_labels': single_labels,
            'agg_preds': agg_preds,
        }

    engine = Engine(_update)
    engine.name = name

    metrics.Average(lambda o: o['loss']).attach(engine, 'total_loss')
    metrics.Accuracy(lambda o: (o['preds'], o['labels'])).attach(
        engine, 'total_acc')

    metrics.Average(lambda o: F.cross_entropy(o['single_preds'], o[
        'single_labels'])).attach(engine, 'single_loss')
    metrics.Accuracy(lambda o: (o['single_preds'], o['single_labels'])).attach(
        engine, 'single_acc')

    metrics.Average(
        lambda o: F.cross_entropy(o['agg_preds'], o['single_labels'])).attach(
            engine, 'agg_loss')
    metrics.Accuracy(lambda o: (o['agg_preds'], o['single_labels'])).attach(
        engine, 'agg_acc')

    return engine
Пример #6
0
def main(opt_alg, opt_args, clip_args, max_epochs, batch_size, latent_dim,
         _seed, _run, eval_every, kl_beta, oversize_view):
    # pyro.enable_validation(True)

    ds_train, ds_test = get_datasets()
    train_dl = torch.utils.data.DataLoader(ds_train,
                                           batch_size=batch_size,
                                           num_workers=4,
                                           shuffle=True)
    test_dl = torch.utils.data.DataLoader(ds_test,
                                          batch_size=batch_size,
                                          num_workers=4)

    transforms = T.TransformSequence(*transform_sequence())

    trs = transformers.TransformerSequence(*transformer_sequence())

    encoder = TransformingEncoder(trs, latent_dim=latent_dim)
    encoder = encoder.cuda()
    decoder = VaeResViewDecoder(latent_dim=latent_dim,
                                oversize_output=oversize_view)
    decoder.cuda()

    svi_args = {
        "encoder": encoder,
        "decoder": decoder,
        "instantiate_label": True,
        "transforms": transforms,
        "cond": True,
        "output_size": 128,
        "device": torch.device("cuda"),
        "kl_beta": kl_beta,
    }

    opt_alg = get_opt_alg(opt_alg)
    opt = opt_alg(opt_args, clip_args=clip_args)
    elbo = infer.Trace_ELBO(max_plate_nesting=1)

    svi = infer.SVI(forward_model, backward_model, opt, loss=elbo)
    if _run.unobserved or _run._id is None:
        tb = U.DummyWriter("/tmp/delme")
    else:
        tb = SummaryWriter(
            U.setup_run_directory(
                Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id)))
        _run.info["tensorboard"] = tb.log_dir

    def batch_train(engine, batch):
        x = batch[0]

        x = x.cuda()

        l = svi.step(x, **svi_args)
        return l

    train_engine = Engine(batch_train)

    @torch.no_grad()
    def batch_eval(engine, batch):
        x = batch[0]

        x = x.cuda()

        l = svi.evaluate_loss(x, **svi_args)

        # get predictive distribution over y.
        return {"loss": l}

    eval_engine = Engine(batch_eval)

    @eval_engine.on(Events.EPOCH_STARTED)
    def switch_eval_mode(*args):
        print("MODELS IN EVAL MODE")
        encoder.eval()
        decoder.eval()

    @train_engine.on(Events.EPOCH_STARTED)
    def switch_train_mode(*args):
        print("MODELS IN TRAIN MODE")
        encoder.train()
        decoder.train()

    metrics.Average().attach(train_engine, "average_loss")
    metrics.Average(output_transform=lambda x: x["loss"]).attach(
        eval_engine, "average_loss")

    @eval_engine.on(Events.EPOCH_COMPLETED)
    def log_tboard(engine):
        ex.log_scalar(
            "train.loss",
            train_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        ex.log_scalar(
            "eval.loss",
            eval_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        tb.add_scalar(
            "train/loss",
            train_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        tb.add_scalar(
            "eval/loss",
            eval_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )

        print(
            "EPOCH",
            train_engine.state.epoch,
            "train.loss",
            train_engine.state.metrics["average_loss"],
            "eval.loss",
            eval_engine.state.metrics["average_loss"],
        )

    def plot_recons(dataloader, mode):
        epoch = train_engine.state.epoch
        for batch in dataloader:
            x = batch[0]
            x = x.cuda()
            break
        x = x[:64]
        tb.add_image(f"{mode}/originals", torchvision.utils.make_grid(x),
                     epoch)
        bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args)
        fwd_trace = poutine.trace(
            poutine.replay(forward_model,
                           trace=bwd_trace)).get_trace(x, **svi_args)
        recon = fwd_trace.nodes["pixels"]["fn"].mean
        tb.add_image(f"{mode}/recons", torchvision.utils.make_grid(recon),
                     epoch)

        canonical_recon = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image(
            f"{mode}/canonical_recon",
            torchvision.utils.make_grid(canonical_recon),
            epoch,
        )

        # sample from the prior

        prior_sample_args = {}
        prior_sample_args.update(svi_args)
        prior_sample_args["cond"] = False
        prior_sample_args["cond_label"] = False
        fwd_trace = poutine.trace(forward_model).get_trace(
            x, **prior_sample_args)
        prior_sample = fwd_trace.nodes["pixels"]["fn"].mean
        prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image(f"{mode}/prior_samples",
                     torchvision.utils.make_grid(prior_sample), epoch)

        tb.add_image(
            f"{mode}/canonical_prior_samples",
            torchvision.utils.make_grid(prior_canonical_sample),
            epoch,
        )
        tb.add_image(
            f"{mode}/input_view",
            torchvision.utils.make_grid(
                bwd_trace.nodes["attention_input"]["value"]),
            epoch,
        )

    @eval_engine.on(Events.EPOCH_COMPLETED)
    def plot_images(engine):
        plot_recons(train_dl, "train")
        plot_recons(test_dl, "eval")

    @train_engine.on(Events.EPOCH_COMPLETED(every=eval_every))
    def eval(engine):
        eval_engine.run(test_dl, seed=_seed + engine.state.epoch)

    train_engine.run(train_dl, max_epochs=max_epochs, seed=_seed)
Пример #7
0
 def __init__(
     self,
     job_dir,
     architecture,
     dim_input,
     dim_hidden,
     dim_output,
     depth,
     negative_slope,
     batch_norm,
     spectral_norm,
     dropout_rate,
     num_examples,
     learning_rate,
     batch_size,
     epochs,
     patience,
     num_workers,
     seed,
 ):
     super(NeuralNetwork, self).__init__(
         job_dir=job_dir,
         num_examples=num_examples,
         learning_rate=learning_rate,
         batch_size=batch_size,
         epochs=epochs,
         seed=seed,
         num_workers=num_workers,
     )
     encoder = (
         convolution.ResNet(
             dim_input=dim_input,
             layers=[2] * depth,
             base_width=dim_hidden // 8,
             negative_slope=negative_slope,
             dropout_rate=dropout_rate,
             batch_norm=batch_norm,
             spectral_norm=spectral_norm,
             stem_kernel_size=5,
             stem_kernel_stride=1,
             stem_kernel_padding=2,
             stem_pool=False,
             activate_output=True,
         )
         if isinstance(dim_input, list)
         else dense.NeuralNetwork(
             architecture=architecture,
             dim_input=dim_input,
             dim_hidden=dim_hidden,
             depth=depth,
             negative_slope=negative_slope,
             batch_norm=batch_norm,
             dropout_rate=dropout_rate,
             spectral_norm=spectral_norm,
             activate_output=True,
         )
     )
     self.network = nn.Sequential(
         encoder,
         variational.Categorical(
             dim_input=encoder.dim_output,
             dim_output=dim_output,
         ),
     )
     self.metrics = {
         "loss": metrics.Average(
             output_transform=lambda x: -x["outputs"].log_prob(x["targets"]).mean(),
             device=self.device,
         )
     }
     self.batch_size = batch_size
     self.best_loss = 1e7
     self.patience = patience
     self.optimizer = optim.Adam(
         params=self.network.parameters(),
         lr=self.learning_rate,
         weight_decay=(0.5 * (1 - dropout_rate)) / num_examples,
     )
     self.network.to(self.device)