def __init__(self, model: Type[torch.nn.Module], optimizer: Type[optim.PyroOptim] = None, loss: Type[infer.ELBO] = None, enumerate_parallel: bool = False, seed: int = 1, **kwargs: Union[str, float]) -> None: """ Initializes the trainer's parameters """ pyro.clear_param_store() set_deterministic_mode(seed) self.device = kwargs.get( "device", 'cuda' if torch.cuda.is_available() else 'cpu') if optimizer is None: lr = kwargs.get("lr", 1e-3) optimizer = optim.Adam({"lr": lr}) if loss is None: if enumerate_parallel: loss = infer.TraceEnum_ELBO(max_plate_nesting=1, strict_enumeration_warning=False) else: loss = infer.Trace_ELBO() guide = model.guide if enumerate_parallel: guide = infer.config_enumerate(guide, "parallel", expand=True) self.svi = infer.SVI(model.model, guide, optimizer, loss=loss) self.loss_history = {"training_loss": [], "test_loss": []} self.current_epoch = 0
def optimize(self, optimizer=optim.Adam({}), num_steps=1000): """ A convenient method to optimize parameters for GPLVM model using :class:`~pyro.infer.svi.SVI`. :param ~optim.PyroOptim optimizer: A Pyro optimizer. :param int num_steps: Number of steps to run SVI. :returns: a list of losses during the training procedure :rtype: list """ if not isinstance(optimizer, optim.PyroOptim): raise ValueError("Optimizer should be an instance of " "pyro.optim.PyroOptim class.") svi = infer.SVI(self.model, self.guide, optimizer, loss=infer.Trace_ELBO()) losses = [] for i in range(num_steps): losses.append(svi.step()) return losses
def train_gp(args, dataset, gp_class): u, y = dataset.get_train_data( 0, gp_class.name) if args.nclt else dataset.get_test_data( 1, gp_class.name) # this is only to have a correct dimension if gp_class.name == 'GpOdoFog': fnet = FNET(args, u.shape[2], args.kernel_dim) def fnet_fn(x): return pyro.module("FNET", fnet)(x) lik = gp.likelihoods.Gaussian(name='lik_f', variance=0.1 * torch.ones(6, 1)) # lik = MultiVariateGaussian(name='lik_f', dim=6) # if lower_triangular_constraint is implemented kernel = gp.kernels.Matern52( input_dim=args.kernel_dim, lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=fnet_fn) Xu = u[torch.arange(0, u.shape[0], step=int(u.shape[0] / args.num_inducing_point)).long()] gp_model = gp.models.VariationalSparseGP(u, torch.zeros(6, u.shape[0]), kernel, Xu, num_data=dataset.num_data, likelihood=lik, mean_function=None, name=gp_class.name, whiten=True, jitter=1e-3) else: hnet = HNET(args, u.shape[2], args.kernel_dim) def hnet_fn(x): return pyro.module("HNET", hnet)(x) lik = gp.likelihoods.Gaussian(name='lik_h', variance=0.1 * torch.ones(9, 1)) # lik = MultiVariateGaussian(name='lik_h', dim=9) # if lower_triangular_constraint is implemented kernel = gp.kernels.Matern52( input_dim=args.kernel_dim, lengthscale=torch.ones(args.kernel_dim)).warp(iwarping_fn=hnet_fn) Xu = u[torch.arange(0, u.shape[0], step=int(u.shape[0] / args.num_inducing_point)).long()] gp_model = gp.models.VariationalSparseGP(u, torch.zeros(9, u.shape[0]), kernel, Xu, num_data=dataset.num_data, likelihood=lik, mean_function=None, name=gp_class.name, whiten=True, jitter=1e-4) gp_instante = gp_class(args, gp_model, dataset) args.mate = preprocessing(args, dataset, gp_instante) optimizer = optim.ClippedAdam({"lr": args.lr, "lrd": args.lr_decay}) svi = infer.SVI(gp_instante.model, gp_instante.guide, optimizer, infer.Trace_ELBO()) print("Start of training " + dataset.name + ", " + gp_class.name) start_time = time.time() for epoch in range(1, args.epochs + 1): train_loop(dataset, gp_instante, svi, epoch) if epoch == 10: if gp_class.name == 'GpOdoFog': gp_instante.gp_f.jitter = 1e-4 else: gp_instante.gp_h.jitter = 1e-4 save_gp(args, gp_instante, fnet) if gp_class.name == 'GpOdoFog' else save_gp( args, gp_instante, hnet)
def main( opt_alg, opt_args, clip_args, max_epochs, batch_size, latent_dim, _seed, _run, eval_every, ): # pyro.enable_validation(True) ds_train, ds_test = get_datasets() train_dl = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, num_workers=4, shuffle=True) test_dl = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, num_workers=4) transforms = T.TransformSequence(T.Rotation()) trs = transformers.TransformerSequence( transformers.Rotation(networks.EquivariantPosePredictor, 1, 32)) encoder = TransformingEncoder(trs, latent_dim=latent_dim) encoder = encoder.cuda() decoder = VaeResViewDecoder(latent_dim=latent_dim) decoder.cuda() svi_args = { "encoder": encoder, "decoder": decoder, "instantiate_label": True, "transforms": transforms, "cond": True, "output_size": 128, "device": torch.device("cuda"), } opt_alg = get_opt_alg(opt_alg) opt = opt_alg(opt_args, clip_args=clip_args) elbo = infer.Trace_ELBO(max_plate_nesting=1) svi = infer.SVI(forward_model, backward_model, opt, loss=elbo) if _run.unobserved or _run._id is None: tb = U.DummyWriter("/tmp/delme") else: tb = SummaryWriter( U.setup_run_directory( Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id))) _run.info["tensorboard"] = tb.log_dir for batch in train_dl: x = batch[0] x_orig = x.cuda() break for i in range(10000): encoder.train() decoder.train() x = augmentation.RandomRotation(180.0)(x_orig) l = svi.step(x, **svi_args) if i % 200 == 0: encoder.eval() decoder.eval() print("EPOCH", i, "LOSS", l) ex.log_scalar("train.loss", l, i) tb.add_scalar("train/loss", l, i) tb.add_image(f"train/originals", torchvision.utils.make_grid(x), i) bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args) fwd_trace = poutine.trace( poutine.replay(forward_model, trace=bwd_trace)).get_trace(x, **svi_args) recon = fwd_trace.nodes["pixels"]["fn"].mean tb.add_image(f"train/recons", torchvision.utils.make_grid(recon), i) canonical_recon = fwd_trace.nodes["canonical_view"]["value"] tb.add_image( f"train/canonical_recon", torchvision.utils.make_grid(canonical_recon), i, ) # sample from the prior prior_sample_args = {} prior_sample_args.update(svi_args) prior_sample_args["cond"] = False prior_sample_args["cond_label"] = False fwd_trace = poutine.trace(forward_model).get_trace( x, **prior_sample_args) prior_sample = fwd_trace.nodes["pixels"]["fn"].mean prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"] tb.add_image(f"train/prior_samples", torchvision.utils.make_grid(prior_sample), i) tb.add_image( f"train/canonical_prior_samples", torchvision.utils.make_grid(prior_canonical_sample), i, ) tb.add_image( f"train/input_view", torchvision.utils.make_grid( bwd_trace.nodes["attention_input"]["value"]), i, )
if __name__ == "__main__": mnist = MNIST("./data", download=True) x_train = mnist.data[mnist.targets == 2][0].float().cuda() / 255 x_train = x_train[None, None, ...] transforms = T.TransformSequence(T.Rotation()) encoder = transformers.TransformerSequence( transformers.Rotation(networks.EquivariantPosePredictor, 1, 32) ) encoder = encoder.cuda() opt = optim.Adam({}, clip_args={"clip_norm": 10.0}) elbo = infer.Trace_ELBO() svi = infer.SVI( transforming_template_mnist, transforming_template_encoder, opt, loss=elbo ) files = glob.glob("runs/*") for f in files: os.remove(f) tb = SummaryWriter("runs/") x_train = x_train.expand(516, 1, 28, 28) x_train = F.pad(x_train, (6, 6, 6, 6)) for i in range(20000): x_rot = augmentation.RandomAffine(40.0)( x_train ) # randomly translate the image by up to 40 degrees
latent_decoder = latent_decoder.cuda() svi_args = { "encoder": encoder, "decoder": decoder, "instantiate_label": True, "cond_label": True, "latent_decoder": latent_decoder, "transforms": transforms, "cond": True, "output_size": 40, "device": torch.device("cuda"), } opt = optim.Adam({"amsgrad": True}, clip_args={"clip_norm": 10.0}) elbo = infer.Trace_ELBO(max_plate_nesting=1) svi = infer.SVI(forward_model, backward_model, opt, loss=elbo) files = glob.glob("runs/*") for f in files: os.remove(f) tb = SummaryWriter("runs/") def batch_train(engine, batch): x, y = batch x = x.cuda() y = y.cuda() l = svi.step(x, y, N=x.shape[0], **svi_args) return l
Here we make inference using Stochastic Variational Inference. However here we have to define a guide function. """ from pyro.contrib.autoguide import AutoMultivariateNormal guide = AutoMultivariateNormal(model) pyro.clear_param_store() adam_params = {"lr": 0.01, "betas": (0.90, 0.999)} optimizer = optim.Adam(adam_params) svi = infer.SVI(model, guide, optimizer, loss=infer.Trace_ELBO()) losses = [] for i in range(5000): loss = svi.step(x, y_obs, truncation_label) losses.append(loss) if i % 1000 == 0: print(', '.join(['{} = {}'.format(*kv) for kv in guide.median().items()])) print('final result:') for kv in sorted(guide.median().items()): print('median {} = {}'.format(*kv)) """Let's check that the model has converged by plotting losses"""
def main(args): normalize = transforms.Normalize((0.1307, ), (0.3081, )) train_loader = get_data_loader(dataset_name='MNIST', data_dir=args.data_dir, batch_size=args.batch_size, dataset_transforms=[normalize], is_training_set=True, shuffle=True) test_loader = get_data_loader(dataset_name='MNIST', data_dir=args.data_dir, batch_size=args.batch_size, dataset_transforms=[normalize], is_training_set=False, shuffle=True) cnn = CNN().cuda() if args.cuda else CNN() # optimizer in SVI just works with params which are active inside # its model/guide scope; so we need this helper to # mark cnn's parameters active for each `svi.step()` call. def cnn_fn(x): return pyro.module("CNN", cnn)(x) # Create deep kernel by warping RBF with CNN. # CNN will transform a high dimension image into a low dimension 2D # tensors for RBF kernel. # This kernel accepts inputs are inputs of CNN and gives outputs are # covariance matrix of RBF on outputs of CNN. kernel = gp.kernels.RBF( input_dim=10, lengthscale=torch.ones(10)).warp(iwarping_fn=cnn_fn) # init inducing points (taken randomly from dataset) Xu = next(iter(train_loader))[0][:args.num_inducing] # use MultiClass likelihood for 10-class classification problem likelihood = gp.likelihoods.MultiClass(num_classes=10) # Because we use Categorical distribution in MultiClass likelihood, # we need GP model returns a list of probabilities of each class. # Hence it is required to use latent_shape = 10. # Turns on "whiten" flag will help optimization for variational models. gpmodel = gp.models.VariationalSparseGP(X=Xu, y=None, kernel=kernel, Xu=Xu, likelihood=likelihood, latent_shape=torch.Size([10]), num_data=60000, whiten=True) if args.cuda: gpmodel.cuda() # optimizer = optim.adam({"lr": args.lr}) optimizer = optim.Adam(lr=args.lr) # optimizer = get_optimizer(args) svi = infer.SVI(gpmodel.model, gpmodel.guide, optimizer, infer.Trace_ELBO()) for epoch in range(1, args.epochs + 1): start_time = time.time() train(args, train_loader, gpmodel, svi, epoch) with torch.no_grad(): test(args, test_loader, gpmodel) print("Amount of time spent for epoch {}: {}s\n".format( epoch, int(time.time() - start_time)))
test_loader = torch.utils.data.DataLoader(dset.MNIST( 'mnist-data/', train=False, transform=transforms.Compose([ transforms.ToTensor(), ])), batch_size=128, shuffle=True) model2 = BNN(28 * 28, 1024, 10) from pyro.infer.autoguide import AutoDiagonalNormal guide2 = AutoDiagonalNormal(model2) optima = optim.Adam({"lr": 0.01}) svi = infer.SVI(model2, guide2, optima, loss=infer.Trace_ELBO()) num_iterations = 5 loss = 0 for j in range(num_iterations): loss = 0 for batch_id, data in enumerate(train_loader): # calculate the loss and take a gradient step data, label = data[0].view(-1, 28 * 28), data[1] loss += svi.step(data, label) normalizer_train = len(train_loader.dataset) total_epoch_loss_train = loss / normalizer_train print("Epoch ", j, " Loss ", total_epoch_loss_train)
def main(opt_alg, opt_args, clip_args, max_epochs, batch_size, latent_dim, _seed, _run, eval_every, kl_beta, oversize_view): # pyro.enable_validation(True) ds_train, ds_test = get_datasets() train_dl = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, num_workers=4, shuffle=True) test_dl = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, num_workers=4) transforms = T.TransformSequence(*transform_sequence()) trs = transformers.TransformerSequence(*transformer_sequence()) encoder = TransformingEncoder(trs, latent_dim=latent_dim) encoder = encoder.cuda() decoder = VaeResViewDecoder(latent_dim=latent_dim, oversize_output=oversize_view) decoder.cuda() svi_args = { "encoder": encoder, "decoder": decoder, "instantiate_label": True, "transforms": transforms, "cond": True, "output_size": 128, "device": torch.device("cuda"), "kl_beta": kl_beta, } opt_alg = get_opt_alg(opt_alg) opt = opt_alg(opt_args, clip_args=clip_args) elbo = infer.Trace_ELBO(max_plate_nesting=1) svi = infer.SVI(forward_model, backward_model, opt, loss=elbo) if _run.unobserved or _run._id is None: tb = U.DummyWriter("/tmp/delme") else: tb = SummaryWriter( U.setup_run_directory( Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id))) _run.info["tensorboard"] = tb.log_dir def batch_train(engine, batch): x = batch[0] x = x.cuda() l = svi.step(x, **svi_args) return l train_engine = Engine(batch_train) @torch.no_grad() def batch_eval(engine, batch): x = batch[0] x = x.cuda() l = svi.evaluate_loss(x, **svi_args) # get predictive distribution over y. return {"loss": l} eval_engine = Engine(batch_eval) @eval_engine.on(Events.EPOCH_STARTED) def switch_eval_mode(*args): print("MODELS IN EVAL MODE") encoder.eval() decoder.eval() @train_engine.on(Events.EPOCH_STARTED) def switch_train_mode(*args): print("MODELS IN TRAIN MODE") encoder.train() decoder.train() metrics.Average().attach(train_engine, "average_loss") metrics.Average(output_transform=lambda x: x["loss"]).attach( eval_engine, "average_loss") @eval_engine.on(Events.EPOCH_COMPLETED) def log_tboard(engine): ex.log_scalar( "train.loss", train_engine.state.metrics["average_loss"], train_engine.state.epoch, ) ex.log_scalar( "eval.loss", eval_engine.state.metrics["average_loss"], train_engine.state.epoch, ) tb.add_scalar( "train/loss", train_engine.state.metrics["average_loss"], train_engine.state.epoch, ) tb.add_scalar( "eval/loss", eval_engine.state.metrics["average_loss"], train_engine.state.epoch, ) print( "EPOCH", train_engine.state.epoch, "train.loss", train_engine.state.metrics["average_loss"], "eval.loss", eval_engine.state.metrics["average_loss"], ) def plot_recons(dataloader, mode): epoch = train_engine.state.epoch for batch in dataloader: x = batch[0] x = x.cuda() break x = x[:64] tb.add_image(f"{mode}/originals", torchvision.utils.make_grid(x), epoch) bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args) fwd_trace = poutine.trace( poutine.replay(forward_model, trace=bwd_trace)).get_trace(x, **svi_args) recon = fwd_trace.nodes["pixels"]["fn"].mean tb.add_image(f"{mode}/recons", torchvision.utils.make_grid(recon), epoch) canonical_recon = fwd_trace.nodes["canonical_view"]["value"] tb.add_image( f"{mode}/canonical_recon", torchvision.utils.make_grid(canonical_recon), epoch, ) # sample from the prior prior_sample_args = {} prior_sample_args.update(svi_args) prior_sample_args["cond"] = False prior_sample_args["cond_label"] = False fwd_trace = poutine.trace(forward_model).get_trace( x, **prior_sample_args) prior_sample = fwd_trace.nodes["pixels"]["fn"].mean prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"] tb.add_image(f"{mode}/prior_samples", torchvision.utils.make_grid(prior_sample), epoch) tb.add_image( f"{mode}/canonical_prior_samples", torchvision.utils.make_grid(prior_canonical_sample), epoch, ) tb.add_image( f"{mode}/input_view", torchvision.utils.make_grid( bwd_trace.nodes["attention_input"]["value"]), epoch, ) @eval_engine.on(Events.EPOCH_COMPLETED) def plot_images(engine): plot_recons(train_dl, "train") plot_recons(test_dl, "eval") @train_engine.on(Events.EPOCH_COMPLETED(every=eval_every)) def eval(engine): eval_engine.run(test_dl, seed=_seed + engine.state.epoch) train_engine.run(train_dl, max_epochs=max_epochs, seed=_seed)