Beispiel #1
0
 def __init__(self,
              model: Type[torch.nn.Module],
              optimizer: Type[optim.PyroOptim] = None,
              loss: Type[infer.ELBO] = None,
              enumerate_parallel: bool = False,
              seed: int = 1,
              **kwargs: Union[str, float]) -> None:
     """
     Initializes the trainer's parameters
     """
     pyro.clear_param_store()
     set_deterministic_mode(seed)
     self.device = kwargs.get(
         "device", 'cuda' if torch.cuda.is_available() else 'cpu')
     if optimizer is None:
         lr = kwargs.get("lr", 1e-3)
         optimizer = optim.Adam({"lr": lr})
     if loss is None:
         if enumerate_parallel:
             loss = infer.TraceEnum_ELBO(max_plate_nesting=1,
                                         strict_enumeration_warning=False)
         else:
             loss = infer.Trace_ELBO()
     guide = model.guide
     if enumerate_parallel:
         guide = infer.config_enumerate(guide, "parallel", expand=True)
     self.svi = infer.SVI(model.model, guide, optimizer, loss=loss)
     self.loss_history = {"training_loss": [], "test_loss": []}
     self.current_epoch = 0
Beispiel #2
0
    def optimize(self, optimizer=optim.Adam({}), num_steps=1000):
        """
        A convenient method to optimize parameters for GPLVM model using
        :class:`~pyro.infer.svi.SVI`.

        :param ~optim.PyroOptim optimizer: A Pyro optimizer.
        :param int num_steps: Number of steps to run SVI.
        :returns: a list of losses during the training procedure
        :rtype: list
        """
        if not isinstance(optimizer, optim.PyroOptim):
            raise ValueError("Optimizer should be an instance of "
                             "pyro.optim.PyroOptim class.")
        svi = infer.SVI(self.model,
                        self.guide,
                        optimizer,
                        loss=infer.Trace_ELBO())
        losses = []
        for i in range(num_steps):
            losses.append(svi.step())
        return losses
Beispiel #3
0
def train_gp(args, dataset, gp_class):
    u, y = dataset.get_train_data(
        0, gp_class.name) if args.nclt else dataset.get_test_data(
            1, gp_class.name)  # this is only to have a correct dimension

    if gp_class.name == 'GpOdoFog':
        fnet = FNET(args, u.shape[2], args.kernel_dim)

        def fnet_fn(x):
            return pyro.module("FNET", fnet)(x)

        lik = gp.likelihoods.Gaussian(name='lik_f',
                                      variance=0.1 * torch.ones(6, 1))
        # lik = MultiVariateGaussian(name='lik_f', dim=6) # if lower_triangular_constraint is implemented
        kernel = gp.kernels.Matern52(
            input_dim=args.kernel_dim,
            lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=fnet_fn)
        Xu = u[torch.arange(0,
                            u.shape[0],
                            step=int(u.shape[0] /
                                     args.num_inducing_point)).long()]
        gp_model = gp.models.VariationalSparseGP(u,
                                                 torch.zeros(6, u.shape[0]),
                                                 kernel,
                                                 Xu,
                                                 num_data=dataset.num_data,
                                                 likelihood=lik,
                                                 mean_function=None,
                                                 name=gp_class.name,
                                                 whiten=True,
                                                 jitter=1e-3)
    else:
        hnet = HNET(args, u.shape[2], args.kernel_dim)

        def hnet_fn(x):
            return pyro.module("HNET", hnet)(x)

        lik = gp.likelihoods.Gaussian(name='lik_h',
                                      variance=0.1 * torch.ones(9, 1))
        # lik = MultiVariateGaussian(name='lik_h', dim=9) # if lower_triangular_constraint is implemented
        kernel = gp.kernels.Matern52(
            input_dim=args.kernel_dim,
            lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=hnet_fn)
        Xu = u[torch.arange(0,
                            u.shape[0],
                            step=int(u.shape[0] /
                                     args.num_inducing_point)).long()]
        gp_model = gp.models.VariationalSparseGP(u,
                                                 torch.zeros(9, u.shape[0]),
                                                 kernel,
                                                 Xu,
                                                 num_data=dataset.num_data,
                                                 likelihood=lik,
                                                 mean_function=None,
                                                 name=gp_class.name,
                                                 whiten=True,
                                                 jitter=1e-4)

    gp_instante = gp_class(args, gp_model, dataset)
    args.mate = preprocessing(args, dataset, gp_instante)

    optimizer = optim.ClippedAdam({"lr": args.lr, "lrd": args.lr_decay})
    svi = infer.SVI(gp_instante.model, gp_instante.guide, optimizer,
                    infer.Trace_ELBO())

    print("Start of training " + dataset.name + ", " + gp_class.name)
    start_time = time.time()
    for epoch in range(1, args.epochs + 1):
        train_loop(dataset, gp_instante, svi, epoch)
        if epoch == 10:
            if gp_class.name == 'GpOdoFog':
                gp_instante.gp_f.jitter = 1e-4
            else:
                gp_instante.gp_h.jitter = 1e-4

    save_gp(args, gp_instante,
            fnet) if gp_class.name == 'GpOdoFog' else save_gp(
                args, gp_instante, hnet)
def main(
    opt_alg,
    opt_args,
    clip_args,
    max_epochs,
    batch_size,
    latent_dim,
    _seed,
    _run,
    eval_every,
):
    # pyro.enable_validation(True)

    ds_train, ds_test = get_datasets()
    train_dl = torch.utils.data.DataLoader(ds_train,
                                           batch_size=batch_size,
                                           num_workers=4,
                                           shuffle=True)
    test_dl = torch.utils.data.DataLoader(ds_test,
                                          batch_size=batch_size,
                                          num_workers=4)

    transforms = T.TransformSequence(T.Rotation())

    trs = transformers.TransformerSequence(
        transformers.Rotation(networks.EquivariantPosePredictor, 1, 32))

    encoder = TransformingEncoder(trs, latent_dim=latent_dim)
    encoder = encoder.cuda()
    decoder = VaeResViewDecoder(latent_dim=latent_dim)
    decoder.cuda()

    svi_args = {
        "encoder": encoder,
        "decoder": decoder,
        "instantiate_label": True,
        "transforms": transforms,
        "cond": True,
        "output_size": 128,
        "device": torch.device("cuda"),
    }

    opt_alg = get_opt_alg(opt_alg)
    opt = opt_alg(opt_args, clip_args=clip_args)
    elbo = infer.Trace_ELBO(max_plate_nesting=1)

    svi = infer.SVI(forward_model, backward_model, opt, loss=elbo)

    if _run.unobserved or _run._id is None:
        tb = U.DummyWriter("/tmp/delme")
    else:
        tb = SummaryWriter(
            U.setup_run_directory(
                Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id)))
        _run.info["tensorboard"] = tb.log_dir

    for batch in train_dl:
        x = batch[0]
        x_orig = x.cuda()
        break

    for i in range(10000):
        encoder.train()
        decoder.train()
        x = augmentation.RandomRotation(180.0)(x_orig)
        l = svi.step(x, **svi_args)

        if i % 200 == 0:
            encoder.eval()
            decoder.eval()

            print("EPOCH", i, "LOSS", l)
            ex.log_scalar("train.loss", l, i)
            tb.add_scalar("train/loss", l, i)
            tb.add_image(f"train/originals", torchvision.utils.make_grid(x), i)
            bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args)
            fwd_trace = poutine.trace(
                poutine.replay(forward_model,
                               trace=bwd_trace)).get_trace(x, **svi_args)
            recon = fwd_trace.nodes["pixels"]["fn"].mean
            tb.add_image(f"train/recons", torchvision.utils.make_grid(recon),
                         i)

            canonical_recon = fwd_trace.nodes["canonical_view"]["value"]
            tb.add_image(
                f"train/canonical_recon",
                torchvision.utils.make_grid(canonical_recon),
                i,
            )

            # sample from the prior

            prior_sample_args = {}
            prior_sample_args.update(svi_args)
            prior_sample_args["cond"] = False
            prior_sample_args["cond_label"] = False
            fwd_trace = poutine.trace(forward_model).get_trace(
                x, **prior_sample_args)
            prior_sample = fwd_trace.nodes["pixels"]["fn"].mean
            prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"]
            tb.add_image(f"train/prior_samples",
                         torchvision.utils.make_grid(prior_sample), i)

            tb.add_image(
                f"train/canonical_prior_samples",
                torchvision.utils.make_grid(prior_canonical_sample),
                i,
            )
            tb.add_image(
                f"train/input_view",
                torchvision.utils.make_grid(
                    bwd_trace.nodes["attention_input"]["value"]),
                i,
            )
Beispiel #5
0
if __name__ == "__main__":

    mnist = MNIST("./data", download=True)
    x_train = mnist.data[mnist.targets == 2][0].float().cuda() / 255
    x_train = x_train[None, None, ...]

    transforms = T.TransformSequence(T.Rotation())

    encoder = transformers.TransformerSequence(
        transformers.Rotation(networks.EquivariantPosePredictor, 1, 32)
    )
    encoder = encoder.cuda()

    opt = optim.Adam({}, clip_args={"clip_norm": 10.0})
    elbo = infer.Trace_ELBO()

    svi = infer.SVI(
        transforming_template_mnist, transforming_template_encoder, opt, loss=elbo
    )
    files = glob.glob("runs/*")
    for f in files:
        os.remove(f)
    tb = SummaryWriter("runs/")

    x_train = x_train.expand(516, 1, 28, 28)
    x_train = F.pad(x_train, (6, 6, 6, 6))
    for i in range(20000):
        x_rot = augmentation.RandomAffine(40.0)(
            x_train
        )  # randomly translate the image by up to 40 degrees
Beispiel #6
0
    latent_decoder = latent_decoder.cuda()

    svi_args = {
        "encoder": encoder,
        "decoder": decoder,
        "instantiate_label": True,
        "cond_label": True,
        "latent_decoder": latent_decoder,
        "transforms": transforms,
        "cond": True,
        "output_size": 40,
        "device": torch.device("cuda"),
    }

    opt = optim.Adam({"amsgrad": True}, clip_args={"clip_norm": 10.0})
    elbo = infer.Trace_ELBO(max_plate_nesting=1)

    svi = infer.SVI(forward_model, backward_model, opt, loss=elbo)
    files = glob.glob("runs/*")
    for f in files:
        os.remove(f)
    tb = SummaryWriter("runs/")

    def batch_train(engine, batch):
        x, y = batch

        x = x.cuda()
        y = y.cuda()

        l = svi.step(x, y, N=x.shape[0], **svi_args)
        return l
Beispiel #7
0
Here we make inference using Stochastic Variational Inference. However here we have to define a guide function.
"""

from pyro.contrib.autoguide import AutoMultivariateNormal

guide = AutoMultivariateNormal(model)

pyro.clear_param_store()
  
adam_params = {"lr": 0.01, "betas": (0.90, 0.999)}
optimizer = optim.Adam(adam_params)

svi = infer.SVI(model, 
                guide, 
                optimizer, 
                loss=infer.Trace_ELBO())

losses = []
for i in range(5000):
  loss = svi.step(x, y_obs, truncation_label)
  losses.append(loss)

  if i % 1000 == 0:
    print(', '.join(['{} = {}'.format(*kv) for kv in guide.median().items()]))

print('final result:')
for kv in sorted(guide.median().items()):
  print('median {} = {}'.format(*kv))

"""Let's check that the model has converged by plotting losses"""
def main(args):
    normalize = transforms.Normalize((0.1307, ), (0.3081, ))
    train_loader = get_data_loader(dataset_name='MNIST',
                                   data_dir=args.data_dir,
                                   batch_size=args.batch_size,
                                   dataset_transforms=[normalize],
                                   is_training_set=True,
                                   shuffle=True)
    test_loader = get_data_loader(dataset_name='MNIST',
                                  data_dir=args.data_dir,
                                  batch_size=args.batch_size,
                                  dataset_transforms=[normalize],
                                  is_training_set=False,
                                  shuffle=True)

    cnn = CNN().cuda() if args.cuda else CNN()

    # optimizer in SVI just works with params which are active inside
    # its model/guide scope; so we need this helper to
    # mark cnn's parameters active for each `svi.step()` call.
    def cnn_fn(x):
        return pyro.module("CNN", cnn)(x)

    # Create deep kernel by warping RBF with CNN.
    # CNN will transform a high dimension image into a low dimension 2D
    # tensors for RBF kernel.
    # This kernel accepts inputs are inputs of CNN and gives outputs are
    # covariance matrix of RBF on outputs of CNN.
    kernel = gp.kernels.RBF(
        input_dim=10, lengthscale=torch.ones(10)).warp(iwarping_fn=cnn_fn)

    # init inducing points (taken randomly from dataset)
    Xu = next(iter(train_loader))[0][:args.num_inducing]
    # use MultiClass likelihood for 10-class classification problem
    likelihood = gp.likelihoods.MultiClass(num_classes=10)
    # Because we use Categorical distribution in MultiClass likelihood,
    # we need GP model returns a list of probabilities of each class.
    # Hence it is required to use latent_shape = 10.
    # Turns on "whiten" flag will help optimization for variational models.
    gpmodel = gp.models.VariationalSparseGP(X=Xu,
                                            y=None,
                                            kernel=kernel,
                                            Xu=Xu,
                                            likelihood=likelihood,
                                            latent_shape=torch.Size([10]),
                                            num_data=60000,
                                            whiten=True)
    if args.cuda:
        gpmodel.cuda()

    # optimizer = optim.adam({"lr": args.lr})
    optimizer = optim.Adam(lr=args.lr)
    # optimizer = get_optimizer(args)

    svi = infer.SVI(gpmodel.model, gpmodel.guide, optimizer,
                    infer.Trace_ELBO())

    for epoch in range(1, args.epochs + 1):
        start_time = time.time()
        train(args, train_loader, gpmodel, svi, epoch)
        with torch.no_grad():
            test(args, test_loader, gpmodel)
        print("Amount of time spent for epoch {}: {}s\n".format(
            epoch, int(time.time() - start_time)))
Beispiel #9
0
test_loader = torch.utils.data.DataLoader(dset.MNIST(
    'mnist-data/',
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
    ])),
                                          batch_size=128,
                                          shuffle=True)

model2 = BNN(28 * 28, 1024, 10)
from pyro.infer.autoguide import AutoDiagonalNormal

guide2 = AutoDiagonalNormal(model2)
optima = optim.Adam({"lr": 0.01})
svi = infer.SVI(model2, guide2, optima, loss=infer.Trace_ELBO())

num_iterations = 5
loss = 0

for j in range(num_iterations):
    loss = 0
    for batch_id, data in enumerate(train_loader):
        # calculate the loss and take a gradient step
        data, label = data[0].view(-1, 28 * 28), data[1]
        loss += svi.step(data, label)
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train

    print("Epoch ", j, " Loss ", total_epoch_loss_train)
Beispiel #10
0
def main(opt_alg, opt_args, clip_args, max_epochs, batch_size, latent_dim,
         _seed, _run, eval_every, kl_beta, oversize_view):
    # pyro.enable_validation(True)

    ds_train, ds_test = get_datasets()
    train_dl = torch.utils.data.DataLoader(ds_train,
                                           batch_size=batch_size,
                                           num_workers=4,
                                           shuffle=True)
    test_dl = torch.utils.data.DataLoader(ds_test,
                                          batch_size=batch_size,
                                          num_workers=4)

    transforms = T.TransformSequence(*transform_sequence())

    trs = transformers.TransformerSequence(*transformer_sequence())

    encoder = TransformingEncoder(trs, latent_dim=latent_dim)
    encoder = encoder.cuda()
    decoder = VaeResViewDecoder(latent_dim=latent_dim,
                                oversize_output=oversize_view)
    decoder.cuda()

    svi_args = {
        "encoder": encoder,
        "decoder": decoder,
        "instantiate_label": True,
        "transforms": transforms,
        "cond": True,
        "output_size": 128,
        "device": torch.device("cuda"),
        "kl_beta": kl_beta,
    }

    opt_alg = get_opt_alg(opt_alg)
    opt = opt_alg(opt_args, clip_args=clip_args)
    elbo = infer.Trace_ELBO(max_plate_nesting=1)

    svi = infer.SVI(forward_model, backward_model, opt, loss=elbo)
    if _run.unobserved or _run._id is None:
        tb = U.DummyWriter("/tmp/delme")
    else:
        tb = SummaryWriter(
            U.setup_run_directory(
                Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id)))
        _run.info["tensorboard"] = tb.log_dir

    def batch_train(engine, batch):
        x = batch[0]

        x = x.cuda()

        l = svi.step(x, **svi_args)
        return l

    train_engine = Engine(batch_train)

    @torch.no_grad()
    def batch_eval(engine, batch):
        x = batch[0]

        x = x.cuda()

        l = svi.evaluate_loss(x, **svi_args)

        # get predictive distribution over y.
        return {"loss": l}

    eval_engine = Engine(batch_eval)

    @eval_engine.on(Events.EPOCH_STARTED)
    def switch_eval_mode(*args):
        print("MODELS IN EVAL MODE")
        encoder.eval()
        decoder.eval()

    @train_engine.on(Events.EPOCH_STARTED)
    def switch_train_mode(*args):
        print("MODELS IN TRAIN MODE")
        encoder.train()
        decoder.train()

    metrics.Average().attach(train_engine, "average_loss")
    metrics.Average(output_transform=lambda x: x["loss"]).attach(
        eval_engine, "average_loss")

    @eval_engine.on(Events.EPOCH_COMPLETED)
    def log_tboard(engine):
        ex.log_scalar(
            "train.loss",
            train_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        ex.log_scalar(
            "eval.loss",
            eval_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        tb.add_scalar(
            "train/loss",
            train_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        tb.add_scalar(
            "eval/loss",
            eval_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )

        print(
            "EPOCH",
            train_engine.state.epoch,
            "train.loss",
            train_engine.state.metrics["average_loss"],
            "eval.loss",
            eval_engine.state.metrics["average_loss"],
        )

    def plot_recons(dataloader, mode):
        epoch = train_engine.state.epoch
        for batch in dataloader:
            x = batch[0]
            x = x.cuda()
            break
        x = x[:64]
        tb.add_image(f"{mode}/originals", torchvision.utils.make_grid(x),
                     epoch)
        bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args)
        fwd_trace = poutine.trace(
            poutine.replay(forward_model,
                           trace=bwd_trace)).get_trace(x, **svi_args)
        recon = fwd_trace.nodes["pixels"]["fn"].mean
        tb.add_image(f"{mode}/recons", torchvision.utils.make_grid(recon),
                     epoch)

        canonical_recon = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image(
            f"{mode}/canonical_recon",
            torchvision.utils.make_grid(canonical_recon),
            epoch,
        )

        # sample from the prior

        prior_sample_args = {}
        prior_sample_args.update(svi_args)
        prior_sample_args["cond"] = False
        prior_sample_args["cond_label"] = False
        fwd_trace = poutine.trace(forward_model).get_trace(
            x, **prior_sample_args)
        prior_sample = fwd_trace.nodes["pixels"]["fn"].mean
        prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image(f"{mode}/prior_samples",
                     torchvision.utils.make_grid(prior_sample), epoch)

        tb.add_image(
            f"{mode}/canonical_prior_samples",
            torchvision.utils.make_grid(prior_canonical_sample),
            epoch,
        )
        tb.add_image(
            f"{mode}/input_view",
            torchvision.utils.make_grid(
                bwd_trace.nodes["attention_input"]["value"]),
            epoch,
        )

    @eval_engine.on(Events.EPOCH_COMPLETED)
    def plot_images(engine):
        plot_recons(train_dl, "train")
        plot_recons(test_dl, "eval")

    @train_engine.on(Events.EPOCH_COMPLETED(every=eval_every))
    def eval(engine):
        eval_engine.run(test_dl, seed=_seed + engine.state.epoch)

    train_engine.run(train_dl, max_epochs=max_epochs, seed=_seed)