def main( opt_alg, opt_args, clip_args, max_epochs, batch_size, latent_dim, _seed, _run, eval_every, ): # pyro.enable_validation(True) ds_train, ds_test = get_datasets() train_dl = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, num_workers=4, shuffle=True) test_dl = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, num_workers=4) transforms = T.TransformSequence(T.Rotation()) trs = transformers.TransformerSequence( transformers.Rotation(networks.EquivariantPosePredictor, 1, 32)) encoder = TransformingEncoder(trs, latent_dim=latent_dim) encoder = encoder.cuda() decoder = VaeResViewDecoder(latent_dim=latent_dim) decoder.cuda() svi_args = { "encoder": encoder, "decoder": decoder, "instantiate_label": True, "transforms": transforms, "cond": True, "output_size": 128, "device": torch.device("cuda"), } opt_alg = get_opt_alg(opt_alg) opt = opt_alg(opt_args, clip_args=clip_args) elbo = infer.Trace_ELBO(max_plate_nesting=1) svi = infer.SVI(forward_model, backward_model, opt, loss=elbo) if _run.unobserved or _run._id is None: tb = U.DummyWriter("/tmp/delme") else: tb = SummaryWriter( U.setup_run_directory( Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id))) _run.info["tensorboard"] = tb.log_dir for batch in train_dl: x = batch[0] x_orig = x.cuda() break for i in range(10000): encoder.train() decoder.train() x = augmentation.RandomRotation(180.0)(x_orig) l = svi.step(x, **svi_args) if i % 200 == 0: encoder.eval() decoder.eval() print("EPOCH", i, "LOSS", l) ex.log_scalar("train.loss", l, i) tb.add_scalar("train/loss", l, i) tb.add_image(f"train/originals", torchvision.utils.make_grid(x), i) bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args) fwd_trace = poutine.trace( poutine.replay(forward_model, trace=bwd_trace)).get_trace(x, **svi_args) recon = fwd_trace.nodes["pixels"]["fn"].mean tb.add_image(f"train/recons", torchvision.utils.make_grid(recon), i) canonical_recon = fwd_trace.nodes["canonical_view"]["value"] tb.add_image( f"train/canonical_recon", torchvision.utils.make_grid(canonical_recon), i, ) # sample from the prior prior_sample_args = {} prior_sample_args.update(svi_args) prior_sample_args["cond"] = False prior_sample_args["cond_label"] = False fwd_trace = poutine.trace(forward_model).get_trace( x, **prior_sample_args) prior_sample = fwd_trace.nodes["pixels"]["fn"].mean prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"] tb.add_image(f"train/prior_samples", torchvision.utils.make_grid(prior_sample), i) tb.add_image( f"train/canonical_prior_samples", torchvision.utils.make_grid(prior_canonical_sample), i, ) tb.add_image( f"train/input_view", torchvision.utils.make_grid( bwd_trace.nodes["attention_input"]["value"]), i, )
transform_output = encoder(data) delta_sample_transformer_params( encoder.transformers, transform_output["params"] ) if __name__ == "__main__": mnist = MNIST("./data", download=True) x_train = mnist.data[mnist.targets == 2][0].float().cuda() / 255 x_train = x_train[None, None, ...] transforms = T.TransformSequence(T.Rotation()) encoder = transformers.TransformerSequence( transformers.Rotation(networks.EquivariantPosePredictor, 1, 32) ) encoder = encoder.cuda() opt = optim.Adam({}, clip_args={"clip_norm": 10.0}) elbo = infer.Trace_ELBO() svi = infer.SVI( transforming_template_mnist, transforming_template_encoder, opt, loss=elbo ) files = glob.glob("runs/*") for f in files: os.remove(f) tb = SummaryWriter("runs/") x_train = x_train.expand(516, 1, 28, 28)
def __init__(self, tfs=[], coords=coordinates.identity_grid, net=None, equivariant=True, downsample=1, tf_opts={}, net_opts={}, seed=None, load_path=None, loglevel='INFO'): """ Model base class. """ # configure logging numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): raise ValueError('Invalid log level: %s' % loglevel) logging.basicConfig(level=numeric_level) logging.info(str(self)) if load_path is not None: logging.info( 'Loading model from file: %s -- using saved model configuration' % load_path) spec = torch.load(load_path) tfs = spec['tfs'] coords = spec['coords'] net = spec['net'] equivariant = spec['equivariant'] downsample = spec['downsample'] tf_opts = spec['tf_opts'] net_opts = spec['net_opts'] seed = spec['seed'] if net is None: raise ValueError('net parameter must be specified') if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) # build transformer sequence if len(tfs) > 0: pose_module = networks.EquivariantPosePredictor if equivariant else networks.DirectPosePredictor tfs = [ getattr(transformers, tf) if type(tf) is str else tf for tf in tfs ] seq = transformers.TransformerSequence( *[tf(pose_module, **tf_opts) for tf in tfs]) #seq = transformers.TransformerParallel(*[tf(pose_module, **tf_opts) for tf in tfs]) logging.info('Transformers: %s' % ' -> '.join([tf.__name__ for tf in tfs])) logging.info('Pose module: %s' % pose_module.__name__) else: seq = None # get coordinate function if given as a string if type(coords) is str: if hasattr(coordinates, coords): coords = getattr(coordinates, coords) elif hasattr(coordinates, coords + '_grid'): coords = getattr(coordinates, coords + '_grid') else: raise ValueError('Invalid coordinate system: ' + coords) logging.info('Coordinate transformation before classification: %s' % coords.__name__) # define network if type(net) is str: net = getattr(networks, net) network = net(**net_opts) logging.info('Classifier architecture: %s' % net.__name__) self.tfs = tfs self.coords = coords self.downsample = downsample self.net = net self.equivariant = equivariant self.tf_opts = tf_opts self.net_opts = net_opts self.seed = seed self.model = self._build_model(net=network, transformer=seq, coords=coords, downsample=downsample) logging.info('Net opts: %s' % str(net_opts)) logging.info('Transformer opts: %s' % str(tf_opts)) if load_path is not None: self.model.load_state_dict(spec['state_dict'])
# target_transform=target_transform ) mnist_test = MNIST( "./data", download=True, train=False, transform=augmentation, # target_transform=target_transform ) train_dl = torch.utils.data.DataLoader(mnist, batch_size=128, shuffle=True) test_dl = torch.utils.data.DataLoader(mnist_test, batch_size=1000) transforms = T.TransformSequence(T.Translation(), T.RotationScale()) transformers = transformers.TransformerSequence( transformers.Translation(networks.EquivariantPosePredictor, 1, 32), transformers.RotationScale(networks.EquivariantPosePredictor, 1, 32), ) latent_dim = 64 encoder = TransformingEncoder(transformers, latent_dim=latent_dim) encoder = encoder.cuda() decoder = ViewDecoder(grid_size=32, latent_dim=latent_dim) decoder.cuda() latent_decoder = nn.Sequential( nn.Linear(decoder.latent_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 10), ) latent_decoder = latent_decoder.cuda()
def __init__(self, tfs=[ transformers.ShearX, transformers.HyperbolicRotation, transformers.PerspectiveX, transformers.PerspectiveY ], coords=coordinates.logpolar_grid, net=networks.make_basic_cnn, equivariant=True, downsample=1, tf_opts=tf_default_opts, net_opts=net_default_opts, seed=None, load_path=None, loglevel='INFO'): """MNIST model""" tf_opts_copy = dict(self.tf_default_opts) tf_opts_copy.update(tf_opts) net_opts_copy = dict(self.net_default_opts) net_opts_copy.update(net_opts) # configure logging numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): raise ValueError('Invalid log level: %s' % loglevel) logging.basicConfig(level=numeric_level) logging.info(str(self)) if net is None: raise ValueError('net parameter must be specified') if seed is not None: tf.random.set_seed(seed) np.random.seed(seed) # build transformer sequence if len(tfs) > 0: pose_module = networks.make_equivariant_pose_predictor if equivariant \ else networks.make_direct_pose_predictor tfs = [ getattr(transformers, tfr) if type(tfr) is str else tfr for tfr in tfs ] seq = transformers.TransformerSequence( *[tfr(pose_module, **tf_opts) for tfr in tfs]) # seq = transformers.TransformerParallel(*[tfr(pose_module, **tf_opts) for tfr in tfs]) logging.info('Transformers: %s' % ' -> '.join([tfr.__name__ for tfr in tfs])) logging.info('Pose module: %s' % pose_module.__name__) else: seq = None # get coordinate function if given as a string if type(coords) is str: if hasattr(coordinates, coords): coords = getattr(coordinates, coords) elif hasattr(coordinates, coords + '_grid'): coords = getattr(coordinates, coords + '_grid') else: raise ValueError('Invalid coordinate system: ' + coords) logging.info('Coordinate transformation before classification: %s' % coords.__name__) # define network if type(net) is str: net = getattr(networks, net) network = net(**net_opts) logging.info('Classifier architecture: %s' % net.__name__) self.tfs = tfs self.coords = coords self.downsample = downsample self.net = net self.equivariant = equivariant self.tf_opts = tf_opts self.net_opts = net_opts self.seed = seed self.model = self._build_model(net=network, transformer=seq, coords=coords, downsample=downsample) if load_path is not None: ckpt = tf.train.Checkpoint(model=self.model) ckpt_manager = tf.train.CheckpointManager(ckpt, load_path, 1) ckpt_manager.restore_or_initialize() logging.info('Model loaded at {}'.format(load_path)) logging.info('Net opts: %s' % str(net_opts)) logging.info('Transformer opts: %s' % str(tf_opts))
def main(opt_alg, opt_args, clip_args, max_epochs, batch_size, latent_dim, _seed, _run, eval_every, kl_beta, oversize_view): # pyro.enable_validation(True) ds_train, ds_test = get_datasets() train_dl = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, num_workers=4, shuffle=True) test_dl = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, num_workers=4) transforms = T.TransformSequence(*transform_sequence()) trs = transformers.TransformerSequence(*transformer_sequence()) encoder = TransformingEncoder(trs, latent_dim=latent_dim) encoder = encoder.cuda() decoder = VaeResViewDecoder(latent_dim=latent_dim, oversize_output=oversize_view) decoder.cuda() svi_args = { "encoder": encoder, "decoder": decoder, "instantiate_label": True, "transforms": transforms, "cond": True, "output_size": 128, "device": torch.device("cuda"), "kl_beta": kl_beta, } opt_alg = get_opt_alg(opt_alg) opt = opt_alg(opt_args, clip_args=clip_args) elbo = infer.Trace_ELBO(max_plate_nesting=1) svi = infer.SVI(forward_model, backward_model, opt, loss=elbo) if _run.unobserved or _run._id is None: tb = U.DummyWriter("/tmp/delme") else: tb = SummaryWriter( U.setup_run_directory( Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id))) _run.info["tensorboard"] = tb.log_dir def batch_train(engine, batch): x = batch[0] x = x.cuda() l = svi.step(x, **svi_args) return l train_engine = Engine(batch_train) @torch.no_grad() def batch_eval(engine, batch): x = batch[0] x = x.cuda() l = svi.evaluate_loss(x, **svi_args) # get predictive distribution over y. return {"loss": l} eval_engine = Engine(batch_eval) @eval_engine.on(Events.EPOCH_STARTED) def switch_eval_mode(*args): print("MODELS IN EVAL MODE") encoder.eval() decoder.eval() @train_engine.on(Events.EPOCH_STARTED) def switch_train_mode(*args): print("MODELS IN TRAIN MODE") encoder.train() decoder.train() metrics.Average().attach(train_engine, "average_loss") metrics.Average(output_transform=lambda x: x["loss"]).attach( eval_engine, "average_loss") @eval_engine.on(Events.EPOCH_COMPLETED) def log_tboard(engine): ex.log_scalar( "train.loss", train_engine.state.metrics["average_loss"], train_engine.state.epoch, ) ex.log_scalar( "eval.loss", eval_engine.state.metrics["average_loss"], train_engine.state.epoch, ) tb.add_scalar( "train/loss", train_engine.state.metrics["average_loss"], train_engine.state.epoch, ) tb.add_scalar( "eval/loss", eval_engine.state.metrics["average_loss"], train_engine.state.epoch, ) print( "EPOCH", train_engine.state.epoch, "train.loss", train_engine.state.metrics["average_loss"], "eval.loss", eval_engine.state.metrics["average_loss"], ) def plot_recons(dataloader, mode): epoch = train_engine.state.epoch for batch in dataloader: x = batch[0] x = x.cuda() break x = x[:64] tb.add_image(f"{mode}/originals", torchvision.utils.make_grid(x), epoch) bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args) fwd_trace = poutine.trace( poutine.replay(forward_model, trace=bwd_trace)).get_trace(x, **svi_args) recon = fwd_trace.nodes["pixels"]["fn"].mean tb.add_image(f"{mode}/recons", torchvision.utils.make_grid(recon), epoch) canonical_recon = fwd_trace.nodes["canonical_view"]["value"] tb.add_image( f"{mode}/canonical_recon", torchvision.utils.make_grid(canonical_recon), epoch, ) # sample from the prior prior_sample_args = {} prior_sample_args.update(svi_args) prior_sample_args["cond"] = False prior_sample_args["cond_label"] = False fwd_trace = poutine.trace(forward_model).get_trace( x, **prior_sample_args) prior_sample = fwd_trace.nodes["pixels"]["fn"].mean prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"] tb.add_image(f"{mode}/prior_samples", torchvision.utils.make_grid(prior_sample), epoch) tb.add_image( f"{mode}/canonical_prior_samples", torchvision.utils.make_grid(prior_canonical_sample), epoch, ) tb.add_image( f"{mode}/input_view", torchvision.utils.make_grid( bwd_trace.nodes["attention_input"]["value"]), epoch, ) @eval_engine.on(Events.EPOCH_COMPLETED) def plot_images(engine): plot_recons(train_dl, "train") plot_recons(test_dl, "eval") @train_engine.on(Events.EPOCH_COMPLETED(every=eval_every)) def eval(engine): eval_engine.run(test_dl, seed=_seed + engine.state.epoch) train_engine.run(train_dl, max_epochs=max_epochs, seed=_seed)