def main(
    opt_alg,
    opt_args,
    clip_args,
    max_epochs,
    batch_size,
    latent_dim,
    _seed,
    _run,
    eval_every,
):
    # pyro.enable_validation(True)

    ds_train, ds_test = get_datasets()
    train_dl = torch.utils.data.DataLoader(ds_train,
                                           batch_size=batch_size,
                                           num_workers=4,
                                           shuffle=True)
    test_dl = torch.utils.data.DataLoader(ds_test,
                                          batch_size=batch_size,
                                          num_workers=4)

    transforms = T.TransformSequence(T.Rotation())

    trs = transformers.TransformerSequence(
        transformers.Rotation(networks.EquivariantPosePredictor, 1, 32))

    encoder = TransformingEncoder(trs, latent_dim=latent_dim)
    encoder = encoder.cuda()
    decoder = VaeResViewDecoder(latent_dim=latent_dim)
    decoder.cuda()

    svi_args = {
        "encoder": encoder,
        "decoder": decoder,
        "instantiate_label": True,
        "transforms": transforms,
        "cond": True,
        "output_size": 128,
        "device": torch.device("cuda"),
    }

    opt_alg = get_opt_alg(opt_alg)
    opt = opt_alg(opt_args, clip_args=clip_args)
    elbo = infer.Trace_ELBO(max_plate_nesting=1)

    svi = infer.SVI(forward_model, backward_model, opt, loss=elbo)

    if _run.unobserved or _run._id is None:
        tb = U.DummyWriter("/tmp/delme")
    else:
        tb = SummaryWriter(
            U.setup_run_directory(
                Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id)))
        _run.info["tensorboard"] = tb.log_dir

    for batch in train_dl:
        x = batch[0]
        x_orig = x.cuda()
        break

    for i in range(10000):
        encoder.train()
        decoder.train()
        x = augmentation.RandomRotation(180.0)(x_orig)
        l = svi.step(x, **svi_args)

        if i % 200 == 0:
            encoder.eval()
            decoder.eval()

            print("EPOCH", i, "LOSS", l)
            ex.log_scalar("train.loss", l, i)
            tb.add_scalar("train/loss", l, i)
            tb.add_image(f"train/originals", torchvision.utils.make_grid(x), i)
            bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args)
            fwd_trace = poutine.trace(
                poutine.replay(forward_model,
                               trace=bwd_trace)).get_trace(x, **svi_args)
            recon = fwd_trace.nodes["pixels"]["fn"].mean
            tb.add_image(f"train/recons", torchvision.utils.make_grid(recon),
                         i)

            canonical_recon = fwd_trace.nodes["canonical_view"]["value"]
            tb.add_image(
                f"train/canonical_recon",
                torchvision.utils.make_grid(canonical_recon),
                i,
            )

            # sample from the prior

            prior_sample_args = {}
            prior_sample_args.update(svi_args)
            prior_sample_args["cond"] = False
            prior_sample_args["cond_label"] = False
            fwd_trace = poutine.trace(forward_model).get_trace(
                x, **prior_sample_args)
            prior_sample = fwd_trace.nodes["pixels"]["fn"].mean
            prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"]
            tb.add_image(f"train/prior_samples",
                         torchvision.utils.make_grid(prior_sample), i)

            tb.add_image(
                f"train/canonical_prior_samples",
                torchvision.utils.make_grid(prior_canonical_sample),
                i,
            )
            tb.add_image(
                f"train/input_view",
                torchvision.utils.make_grid(
                    bwd_trace.nodes["attention_input"]["value"]),
                i,
            )
Exemple #2
0
        transform_output = encoder(data)
        delta_sample_transformer_params(
            encoder.transformers, transform_output["params"]
        )


if __name__ == "__main__":

    mnist = MNIST("./data", download=True)
    x_train = mnist.data[mnist.targets == 2][0].float().cuda() / 255
    x_train = x_train[None, None, ...]

    transforms = T.TransformSequence(T.Rotation())

    encoder = transformers.TransformerSequence(
        transformers.Rotation(networks.EquivariantPosePredictor, 1, 32)
    )
    encoder = encoder.cuda()

    opt = optim.Adam({}, clip_args={"clip_norm": 10.0})
    elbo = infer.Trace_ELBO()

    svi = infer.SVI(
        transforming_template_mnist, transforming_template_encoder, opt, loss=elbo
    )
    files = glob.glob("runs/*")
    for f in files:
        os.remove(f)
    tb = SummaryWriter("runs/")

    x_train = x_train.expand(516, 1, 28, 28)
    def __init__(self,
                 tfs=[],
                 coords=coordinates.identity_grid,
                 net=None,
                 equivariant=True,
                 downsample=1,
                 tf_opts={},
                 net_opts={},
                 seed=None,
                 load_path=None,
                 loglevel='INFO'):
        """
        Model base class.
        """
        # configure logging
        numeric_level = getattr(logging, loglevel.upper(), None)
        if not isinstance(numeric_level, int):
            raise ValueError('Invalid log level: %s' % loglevel)
        logging.basicConfig(level=numeric_level)

        logging.info(str(self))

        if load_path is not None:
            logging.info(
                'Loading model from file: %s -- using saved model configuration'
                % load_path)
            spec = torch.load(load_path)
            tfs = spec['tfs']
            coords = spec['coords']
            net = spec['net']
            equivariant = spec['equivariant']
            downsample = spec['downsample']
            tf_opts = spec['tf_opts']
            net_opts = spec['net_opts']
            seed = spec['seed']

        if net is None:
            raise ValueError('net parameter must be specified')

        if seed is not None:
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            np.random.seed(seed)

        # build transformer sequence
        if len(tfs) > 0:
            pose_module = networks.EquivariantPosePredictor if equivariant else networks.DirectPosePredictor
            tfs = [
                getattr(transformers, tf) if type(tf) is str else tf
                for tf in tfs
            ]
            seq = transformers.TransformerSequence(
                *[tf(pose_module, **tf_opts) for tf in tfs])
            #seq = transformers.TransformerParallel(*[tf(pose_module, **tf_opts) for tf in tfs])
            logging.info('Transformers: %s' %
                         ' -> '.join([tf.__name__ for tf in tfs]))
            logging.info('Pose module: %s' % pose_module.__name__)
        else:
            seq = None

        # get coordinate function if given as a string
        if type(coords) is str:
            if hasattr(coordinates, coords):
                coords = getattr(coordinates, coords)
            elif hasattr(coordinates, coords + '_grid'):
                coords = getattr(coordinates, coords + '_grid')
            else:
                raise ValueError('Invalid coordinate system: ' + coords)
        logging.info('Coordinate transformation before classification: %s' %
                     coords.__name__)

        # define network
        if type(net) is str:
            net = getattr(networks, net)
        network = net(**net_opts)
        logging.info('Classifier architecture: %s' % net.__name__)

        self.tfs = tfs
        self.coords = coords
        self.downsample = downsample
        self.net = net
        self.equivariant = equivariant
        self.tf_opts = tf_opts
        self.net_opts = net_opts
        self.seed = seed
        self.model = self._build_model(net=network,
                                       transformer=seq,
                                       coords=coords,
                                       downsample=downsample)

        logging.info('Net opts: %s' % str(net_opts))
        logging.info('Transformer opts: %s' % str(tf_opts))
        if load_path is not None:
            self.model.load_state_dict(spec['state_dict'])
Exemple #4
0
        # target_transform=target_transform
    )
    mnist_test = MNIST(
        "./data",
        download=True,
        train=False,
        transform=augmentation,
        # target_transform=target_transform
    )
    train_dl = torch.utils.data.DataLoader(mnist, batch_size=128, shuffle=True)
    test_dl = torch.utils.data.DataLoader(mnist_test, batch_size=1000)

    transforms = T.TransformSequence(T.Translation(), T.RotationScale())

    transformers = transformers.TransformerSequence(
        transformers.Translation(networks.EquivariantPosePredictor, 1, 32),
        transformers.RotationScale(networks.EquivariantPosePredictor, 1, 32),
    )
    latent_dim = 64
    encoder = TransformingEncoder(transformers, latent_dim=latent_dim)
    encoder = encoder.cuda()
    decoder = ViewDecoder(grid_size=32, latent_dim=latent_dim)
    decoder.cuda()
    latent_decoder = nn.Sequential(
        nn.Linear(decoder.latent_dim, 128),
        nn.ReLU(),
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, 10),
    )

    latent_decoder = latent_decoder.cuda()
Exemple #5
0
    def __init__(self,
                 tfs=[
                     transformers.ShearX, transformers.HyperbolicRotation,
                     transformers.PerspectiveX, transformers.PerspectiveY
                 ],
                 coords=coordinates.logpolar_grid,
                 net=networks.make_basic_cnn,
                 equivariant=True,
                 downsample=1,
                 tf_opts=tf_default_opts,
                 net_opts=net_default_opts,
                 seed=None,
                 load_path=None,
                 loglevel='INFO'):
        """MNIST model"""
        tf_opts_copy = dict(self.tf_default_opts)
        tf_opts_copy.update(tf_opts)

        net_opts_copy = dict(self.net_default_opts)
        net_opts_copy.update(net_opts)

        # configure logging
        numeric_level = getattr(logging, loglevel.upper(), None)
        if not isinstance(numeric_level, int):
            raise ValueError('Invalid log level: %s' % loglevel)
        logging.basicConfig(level=numeric_level)

        logging.info(str(self))

        if net is None:
            raise ValueError('net parameter must be specified')

        if seed is not None:
            tf.random.set_seed(seed)
            np.random.seed(seed)

            # build transformer sequence
        if len(tfs) > 0:
            pose_module = networks.make_equivariant_pose_predictor if equivariant \
                else networks.make_direct_pose_predictor
            tfs = [
                getattr(transformers, tfr) if type(tfr) is str else tfr
                for tfr in tfs
            ]
            seq = transformers.TransformerSequence(
                *[tfr(pose_module, **tf_opts) for tfr in tfs])
            # seq = transformers.TransformerParallel(*[tfr(pose_module, **tf_opts) for tfr in tfs])
            logging.info('Transformers: %s' %
                         ' -> '.join([tfr.__name__ for tfr in tfs]))
            logging.info('Pose module: %s' % pose_module.__name__)
        else:
            seq = None

        # get coordinate function if given as a string
        if type(coords) is str:
            if hasattr(coordinates, coords):
                coords = getattr(coordinates, coords)
            elif hasattr(coordinates, coords + '_grid'):
                coords = getattr(coordinates, coords + '_grid')
            else:
                raise ValueError('Invalid coordinate system: ' + coords)
        logging.info('Coordinate transformation before classification: %s' %
                     coords.__name__)

        # define network
        if type(net) is str:
            net = getattr(networks, net)
        network = net(**net_opts)
        logging.info('Classifier architecture: %s' % net.__name__)

        self.tfs = tfs
        self.coords = coords
        self.downsample = downsample
        self.net = net
        self.equivariant = equivariant
        self.tf_opts = tf_opts
        self.net_opts = net_opts
        self.seed = seed
        self.model = self._build_model(net=network,
                                       transformer=seq,
                                       coords=coords,
                                       downsample=downsample)

        if load_path is not None:
            ckpt = tf.train.Checkpoint(model=self.model)
            ckpt_manager = tf.train.CheckpointManager(ckpt, load_path, 1)
            ckpt_manager.restore_or_initialize()
            logging.info('Model loaded at {}'.format(load_path))

        logging.info('Net opts: %s' % str(net_opts))
        logging.info('Transformer opts: %s' % str(tf_opts))
def main(opt_alg, opt_args, clip_args, max_epochs, batch_size, latent_dim,
         _seed, _run, eval_every, kl_beta, oversize_view):
    # pyro.enable_validation(True)

    ds_train, ds_test = get_datasets()
    train_dl = torch.utils.data.DataLoader(ds_train,
                                           batch_size=batch_size,
                                           num_workers=4,
                                           shuffle=True)
    test_dl = torch.utils.data.DataLoader(ds_test,
                                          batch_size=batch_size,
                                          num_workers=4)

    transforms = T.TransformSequence(*transform_sequence())

    trs = transformers.TransformerSequence(*transformer_sequence())

    encoder = TransformingEncoder(trs, latent_dim=latent_dim)
    encoder = encoder.cuda()
    decoder = VaeResViewDecoder(latent_dim=latent_dim,
                                oversize_output=oversize_view)
    decoder.cuda()

    svi_args = {
        "encoder": encoder,
        "decoder": decoder,
        "instantiate_label": True,
        "transforms": transforms,
        "cond": True,
        "output_size": 128,
        "device": torch.device("cuda"),
        "kl_beta": kl_beta,
    }

    opt_alg = get_opt_alg(opt_alg)
    opt = opt_alg(opt_args, clip_args=clip_args)
    elbo = infer.Trace_ELBO(max_plate_nesting=1)

    svi = infer.SVI(forward_model, backward_model, opt, loss=elbo)
    if _run.unobserved or _run._id is None:
        tb = U.DummyWriter("/tmp/delme")
    else:
        tb = SummaryWriter(
            U.setup_run_directory(
                Path(TENSORBOARD_OBSERVER_PATH) / repr(_run._id)))
        _run.info["tensorboard"] = tb.log_dir

    def batch_train(engine, batch):
        x = batch[0]

        x = x.cuda()

        l = svi.step(x, **svi_args)
        return l

    train_engine = Engine(batch_train)

    @torch.no_grad()
    def batch_eval(engine, batch):
        x = batch[0]

        x = x.cuda()

        l = svi.evaluate_loss(x, **svi_args)

        # get predictive distribution over y.
        return {"loss": l}

    eval_engine = Engine(batch_eval)

    @eval_engine.on(Events.EPOCH_STARTED)
    def switch_eval_mode(*args):
        print("MODELS IN EVAL MODE")
        encoder.eval()
        decoder.eval()

    @train_engine.on(Events.EPOCH_STARTED)
    def switch_train_mode(*args):
        print("MODELS IN TRAIN MODE")
        encoder.train()
        decoder.train()

    metrics.Average().attach(train_engine, "average_loss")
    metrics.Average(output_transform=lambda x: x["loss"]).attach(
        eval_engine, "average_loss")

    @eval_engine.on(Events.EPOCH_COMPLETED)
    def log_tboard(engine):
        ex.log_scalar(
            "train.loss",
            train_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        ex.log_scalar(
            "eval.loss",
            eval_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        tb.add_scalar(
            "train/loss",
            train_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )
        tb.add_scalar(
            "eval/loss",
            eval_engine.state.metrics["average_loss"],
            train_engine.state.epoch,
        )

        print(
            "EPOCH",
            train_engine.state.epoch,
            "train.loss",
            train_engine.state.metrics["average_loss"],
            "eval.loss",
            eval_engine.state.metrics["average_loss"],
        )

    def plot_recons(dataloader, mode):
        epoch = train_engine.state.epoch
        for batch in dataloader:
            x = batch[0]
            x = x.cuda()
            break
        x = x[:64]
        tb.add_image(f"{mode}/originals", torchvision.utils.make_grid(x),
                     epoch)
        bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args)
        fwd_trace = poutine.trace(
            poutine.replay(forward_model,
                           trace=bwd_trace)).get_trace(x, **svi_args)
        recon = fwd_trace.nodes["pixels"]["fn"].mean
        tb.add_image(f"{mode}/recons", torchvision.utils.make_grid(recon),
                     epoch)

        canonical_recon = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image(
            f"{mode}/canonical_recon",
            torchvision.utils.make_grid(canonical_recon),
            epoch,
        )

        # sample from the prior

        prior_sample_args = {}
        prior_sample_args.update(svi_args)
        prior_sample_args["cond"] = False
        prior_sample_args["cond_label"] = False
        fwd_trace = poutine.trace(forward_model).get_trace(
            x, **prior_sample_args)
        prior_sample = fwd_trace.nodes["pixels"]["fn"].mean
        prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image(f"{mode}/prior_samples",
                     torchvision.utils.make_grid(prior_sample), epoch)

        tb.add_image(
            f"{mode}/canonical_prior_samples",
            torchvision.utils.make_grid(prior_canonical_sample),
            epoch,
        )
        tb.add_image(
            f"{mode}/input_view",
            torchvision.utils.make_grid(
                bwd_trace.nodes["attention_input"]["value"]),
            epoch,
        )

    @eval_engine.on(Events.EPOCH_COMPLETED)
    def plot_images(engine):
        plot_recons(train_dl, "train")
        plot_recons(test_dl, "eval")

    @train_engine.on(Events.EPOCH_COMPLETED(every=eval_every))
    def eval(engine):
        eval_engine.run(test_dl, seed=_seed + engine.state.epoch)

    train_engine.run(train_dl, max_epochs=max_epochs, seed=_seed)