Exemple #1
0
class MocapSingleJointGroupsFaceback(object):
  """Experiment with mocap data running each joint as its own faceback group."""

  PARAMS = [
    'subject',
    'train_trials',
    'test_trials',
    'dim_z',
    'batch_size',
    'lam',
    'sparsity_matrix_lr',
    'inference_net_output_dim',
    'generative_net_input_dim',
    'initial_baseline_precision',
    'prior_theta_sigma',
    'group_available_prob'
  ]

  def __init__(
      self,
      subject,
      train_trials,
      test_trials,
      dim_z,
      batch_size,
      lam,
      sparsity_matrix_lr,
      inference_net_output_dim,
      generative_net_input_dim,
      initial_baseline_precision,
      prior_theta_sigma,
      group_available_prob,
      base_results_dir=None
  ):
    self.subject = subject
    self.train_trials = train_trials
    self.test_trials = test_trials
    self.dim_z = dim_z
    self.batch_size = batch_size
    self.lam = lam
    self.sparsity_matrix_lr = sparsity_matrix_lr
    self.inference_net_output_dim = inference_net_output_dim
    self.generative_net_input_dim = generative_net_input_dim
    self.initial_baseline_precision = initial_baseline_precision
    self.prior_theta_sigma = prior_theta_sigma
    self.group_available_prob = group_available_prob
    self.base_results_dir = base_results_dir

    self.epoch_counter = itertools.count()
    self.epoch = None
    self.elbo_per_iter = []
    self.test_loglik_per_iter = []

    self.load_data()

    self.prior_z = Normal(
      Variable(torch.zeros(1, dim_z)),
      Variable(torch.ones(1, dim_z))
    )

    self.inference_net = FacebackInferenceNet(
      almost_inference_nets=[self.make_almost_inference_net(self.joint_dims[j]) for j in joint_order],
      net_output_dim=self.inference_net_output_dim,
      prior_z=self.prior_z,
      initial_baseline_precision=self.initial_baseline_precision
    )
    self.generative_net = FacebackGenerativeNet(
      almost_generative_nets=[self.make_almost_generative_net(self.joint_dims[j]) for j in joint_order],
      net_input_dim=self.generative_net_input_dim,
      dim_z=self.dim_z
    )
    self.vae = FacebackVAE(
      inference_net=self.inference_net,
      generative_net=self.generative_net,
      prior_z=self.prior_z,
      prior_theta=NormalPriorTheta(sigma=self.prior_theta_sigma),
      lam=self.lam
    )

    self.optimizer = MetaOptimizer([
      # Inference parameters
      torch.optim.Adam(
        set(p for net in self.inference_net.almost_inference_nets for p in net.parameters()),
        lr=1e-3
      ),
      torch.optim.Adam([self.inference_net.mu_layers], lr=1e-3),
      torch.optim.SGD([self.inference_net.precision_layers], lr=self.sparsity_matrix_lr),
      torch.optim.Adam([self.inference_net.baseline_precision], lr=1e-3),

      # Generative parameters
      torch.optim.Adam(
        set(p for net in self.generative_net.almost_generative_nets for p in net.parameters()),
        lr=1e-3
      ),
      torch.optim.SGD([self.generative_net.connectivity_matrices], lr=self.sparsity_matrix_lr)
    ])

    if self.base_results_dir is not None:
      # https://stackoverflow.com/questions/2257441/random-string-generation-with-upper-case-letters-and-digits-in-python?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
      self.results_folder_name = 'mocap_subject55_' + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(16))
      self.results_dir = self.base_results_dir / self.results_folder_name
      self._init_results_dir()

  def train(self, num_epochs):
    for self.epoch in itertools.islice(self.epoch_counter, num_epochs):
      for batch_idx, (data, _) in enumerate(self.train_loader):
        # The final batch may not have the same size as `batch_size`.
        actual_batch_size = data.size(0)

        mask = sample_random_mask(actual_batch_size, num_groups, self.group_available_prob)
        info = self.vae.elbo(
          Xs=[Variable(x) for x in self.split_into_groups(data)],
          # group_mask=Variable(torch.ones(actual_batch_size, num_groups)),
          group_mask=Variable(mask),
          inference_group_mask=Variable(mask)
        )
        elbo = info['elbo']
        loss = info['loss']
        z_kl = info['z_kl']
        reconstruction_log_likelihood = info['reconstruction_log_likelihood']
        logprob_theta = info['logprob_theta']
        logprob_L1 = info['logprob_L1']
        # test_ll = self.test_loglik().data[0]

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.vae.proximal_step(self.sparsity_matrix_lr * self.lam)

        self.elbo_per_iter.append(elbo.data[0])
        # self.test_loglik_per_iter.append(test_ll)
        print(f'Epoch {self.epoch}, {batch_idx} / {len(self.train_loader)}')
        print(f'  ELBO: {elbo.data[0]}')
        print(f'    -KL(q(z) || p(z)): {-z_kl.data[0]}')
        print(f'    loglik_term      : {reconstruction_log_likelihood.data[0]}')
        print(f'    log p(theta)     : {logprob_theta.data[0]}')
        print(f'    L1               : {logprob_L1.data[0]}')
        # print(f'  test log lik.      : {test_ll}', flush=True)

      # Checkpoint every once in a while
      if self.epoch % 50 == 0:
        self.checkpoint()

    # Checkpoint at the very end as well
    self.checkpoint()

  def make_almost_generative_net(self, dim_x):
    return NormalNet(
      torch.nn.Sequential(
        torch.nn.ReLU(),
        torch.nn.Linear(self.generative_net_input_dim, dim_x)
      ),
      torch.nn.Sequential(
        torch.nn.ReLU(),
        torch.nn.Linear(self.generative_net_input_dim, dim_x),
        # Learn the log variance
        Lambda(lambda x: torch.exp(0.5 * x))
      )
    )

  def make_almost_inference_net(self, dim_x):
    hidden_size = 32
    return torch.nn.Sequential(
      torch.nn.Linear(dim_x, hidden_size),
      torch.nn.ReLU(),
      torch.nn.Linear(hidden_size, self.inference_net_output_dim),
      torch.nn.ReLU()
    )

  def viz_elbo(self):
    """ELBO per iteration"""
    fig = plt.figure()
    plt.plot(self.elbo_per_iter)
    plt.xlabel('iteration')
    plt.ylabel('ELBO')
    return fig

  def viz_sparsity(self):
    """Visualize the sparisty matrix associating latent components with
    groups."""
    fig = plt.figure()
    plt.imshow(self.vae.sparsity_matrix().data.numpy())
    plt.colorbar()
    plt.xlabel('latent components')
    plt.ylabel('groups')
    return fig

  def viz_reconstruction(self, plot_seed, num_examples):
    pytorch_rng_state = torch.get_rng_state()
    torch.manual_seed(plot_seed)

    # grab random sample from train_loader
    train_sample, _ = iter(self.train_loader).next()
    inference_group_mask = Variable(
      # sample_random_mask(batch_size, num_groups)
      torch.ones(self.batch_size, num_groups)
    )
    info = self.vae.reconstruct(
      [Variable(x) for x in self.split_into_groups(train_sample)],
      inference_group_mask
    )
    reconstr = info['reconstructed']

    true_angles = self.split_into_groups(self.preprocess_inverse(train_sample))

    # Take the mean of p(x | z)
    reconstr_tensor = torch.cat([reconstr[i].mu.data for i in range(num_groups)], dim=1)
    reconstr_angles = self.split_into_groups(self.preprocess_inverse(reconstr_tensor))

    def frame(ix, angles):
      stuff = {j: list(x[ix].numpy()) for j, x in zip(joint_order, angles)}
      # We can't forget the (unlearned) translation dofs
      stuff['root'] = [0, 0, 0] + stuff['root']
      return stuff

    fig = plt.figure(figsize=(12, 4))
    for i in range(num_examples):
      ax1 = fig.add_subplot(2, num_examples, i + 1, projection='3d')
      ax2 = fig.add_subplot(2, num_examples, i + num_examples + 1, projection='3d')
      mocap_data.plot_skeleton(
        self.skeleton,
        mocap_data.frame_to_xyz(self.skeleton, frame(i, true_angles)),
        axes=ax1
      )
      mocap_data.plot_skeleton(
        self.skeleton,
        mocap_data.frame_to_xyz(self.skeleton, frame(i, reconstr_angles)),
        axes=ax2
      )

    plt.tight_layout()
    plt.suptitle(f'Epoch {self.epoch}')

    torch.set_rng_state(pytorch_rng_state)
    return fig

  def load_data(self):
    self.skeleton = mocap_data.load_skeleton(self.subject)
    train_trials_data = [
      mocap_data.load_trial(self.subject, trial, joint_order=joint_order)
      for trial in self.train_trials
    ]
    test_trials_data = [
      mocap_data.load_trial(self.subject, trial, joint_order=joint_order)
      for trial in self.test_trials
    ]
    _, self.joint_dims, _ = train_trials_data[0]

    # We remove the first three components since those correspond to root
    # position in 3d space.
    self.joint_dims['root'] = self.joint_dims['root'] - 3

    Xtrain_raw = torch.FloatTensor(
      # Chain all of the different lists together across the trials
      list(itertools.chain(*[arr for _, _, arr in train_trials_data]))
    )[:, 3:]
    Xtest_raw = torch.FloatTensor(
      # Chain all of the different lists together across the trials
      list(itertools.chain(*[arr for _, _, arr in test_trials_data]))
    )[:, 3:]

    # Normalize each of the channels to be within [0, 1].
    self.angular_mins, _ = torch.min(Xtrain_raw, dim=0)
    self.angular_maxs, _ = torch.max(Xtrain_raw, dim=0)

    self.Xtrain = self.preprocess(Xtrain_raw)
    self.Xtest = self.preprocess(Xtest_raw)

    self.train_loader = torch.utils.data.DataLoader(
      # TensorDataset is stupid. We have to provide two tensors.
      torch.utils.data.TensorDataset(self.Xtrain, torch.zeros(self.Xtrain.size(0))),
      batch_size=self.batch_size,
      shuffle=True
    )

  def test_loglik(self):
    Xs = [Variable(x) for x in self.split_into_groups(self.Xtest)]
    group_mask = Variable(torch.ones(self.Xtest.size(0), num_groups))
    q_z = self.inference_net(Xs, group_mask)
    return self.vae.log_likelihood(Xs, group_mask, q_z.sample())

  def preprocess(self, x):
    """Preprocess the angular data to lie between 0 and 1."""
    # Some of these things aren't used, and we don't want to divide by zero
    return (x - self.angular_mins) / torch.clamp(self.angular_maxs - self.angular_mins, min=0.1)

  def preprocess_inverse(self, y):
    """Inverse of `preprocess`."""
    return y * torch.clamp(self.angular_maxs - self.angular_mins, min=0.1) + self.angular_mins

  def split_into_groups(self, data):
    poop = np.cumsum([0] + [self.joint_dims[j] for j in joint_order])
    return [data[:, poop[i]:poop[i + 1]] for i in range(num_groups)]

  def _init_results_dir(self):
    self.results_dir_params = self.results_dir / 'params.json'
    self.results_dir_elbo = self.results_dir / 'elbo_plot'
    self.results_dir_sparsity_matrix = self.results_dir / 'sparsity_matrix'
    self.results_dir_reconstructions = self.results_dir / 'reconstructions'
    self.results_dir_pickles = self.results_dir / 'pickles'

    # The results_dir should be unique
    self.results_dir.mkdir(exist_ok=False)
    self.results_dir_elbo.mkdir(exist_ok=False)
    self.results_dir_sparsity_matrix.mkdir(exist_ok=False)
    self.results_dir_reconstructions.mkdir(exist_ok=False)
    self.results_dir_pickles.mkdir(exist_ok=False)
    json.dump(
      {p: getattr(self, p) for p in MocapSingleJointGroupsFaceback.PARAMS},
      open(self.results_dir_params, 'w'),
      sort_keys=True,
      indent=2,
      separators=(',', ': ')
    )

  def checkpoint(self):
    if self.base_results_dir is not None:
      fig = self.viz_reconstruction(0, num_examples=6)
      plt.savefig(self.results_dir_reconstructions / f'epoch{self.epoch}.pdf')
      plt.close(fig)

      fig = self.viz_elbo()
      plt.savefig(self.results_dir_elbo / f'epoch{self.epoch}.pdf')
      plt.close(fig)

      fig = self.viz_sparsity()
      plt.savefig(self.results_dir_sparsity_matrix / f'epoch{self.epoch}.pdf')
      plt.close(fig)

      dill.dump(self, open(self.results_dir_pickles / f'epoch{self.epoch}.p', 'wb'))
    else:
      self.viz_reconstruction(0, num_examples=6)
      self.viz_elbo()
      self.viz_sparsity()
      plt.show()
class BarsQuadrantsFaceback(object):
    """Runs the sparse PoE faceback framework on the bars data with the split group
  sparsity prior. This time the groups are the 4 quadrants of the image."""

    PARAMS = [
        'img_size',
        'num_samples',
        'batch_size',
        'dim_z',
        'lam',
        'sparsity_matrix_lr',
        'initial_baseline_precision',
        'inference_net_output_dim',
        'generative_net_input_dim',
        'noise_stddev',
        'group_available_prob',
        'initial_sigma_adjustment',
        'prior_theta_sigma',
    ]

    def __init__(self,
                 img_size,
                 num_samples,
                 batch_size,
                 dim_z,
                 lam,
                 sparsity_matrix_lr,
                 initial_baseline_precision,
                 inference_net_output_dim,
                 generative_net_input_dim,
                 noise_stddev,
                 group_available_prob,
                 initial_sigma_adjustment,
                 prior_theta_sigma,
                 base_results_dir=None,
                 prefix='quad_bars_'):
        self.img_size = img_size
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.dim_z = dim_z
        self.lam = lam
        self.sparsity_matrix_lr = sparsity_matrix_lr
        self.initial_baseline_precision = initial_baseline_precision
        self.inference_net_output_dim = inference_net_output_dim
        self.generative_net_input_dim = generative_net_input_dim
        self.noise_stddev = noise_stddev
        self.group_available_prob = group_available_prob
        self.initial_sigma_adjustment = initial_sigma_adjustment
        self.prior_theta_sigma = prior_theta_sigma
        self.base_results_dir = base_results_dir
        self.prefix = prefix

        # Sample the training data and set up a DataLoader
        self.train_data = self.sample_data(self.num_samples)
        self.test_data = self.sample_data(1000)
        self.train_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(self.train_data,
                                           torch.zeros(self.num_samples)),
            batch_size=batch_size,
            shuffle=True)

        self.generative_sigma_adjustment = Variable(
            self.initial_sigma_adjustment * torch.ones(1), requires_grad=True)

        self.epoch_counter = itertools.count()
        self.epoch = None
        self.elbo_per_iter = []
        self.test_loglik_per_iter = []

        self.prior_z = Normal(Variable(torch.zeros(1, dim_z)),
                              Variable(torch.ones(1, dim_z)))

        half_size = self.img_size // 2
        dim_xs = [half_size * half_size] * 4
        self.inference_net = FacebackInferenceNet(
            almost_inference_nets=[
                self.make_almost_inference_net(dim_x) for dim_x in dim_xs
            ],
            net_output_dim=self.inference_net_output_dim,
            prior_z=self.prior_z,
            initial_baseline_precision=self.initial_baseline_precision)
        self.generative_net = FacebackGenerativeNet(
            almost_generative_nets=[
                self.make_almost_generative_net(dim_x) for dim_x in dim_xs
            ],
            net_input_dim=self.generative_net_input_dim,
            dim_z=self.dim_z)
        self.vae = FacebackVAE(
            # self.vae = FacebackDecoderSparseOnly(
            inference_net=self.inference_net,
            generative_net=self.generative_net,
            prior_z=self.prior_z,
            prior_theta=NormalPriorTheta(sigma=self.prior_theta_sigma),
            lam=self.lam)

        self.optimizer = MetaOptimizer([
            # Inference parameters
            torch.optim.Adam(set(
                p for net in self.inference_net.almost_inference_nets
                for p in net.parameters()),
                             lr=1e-3),
            torch.optim.Adam([self.inference_net.mu_layers], lr=1e-3),
            torch.optim.SGD([self.inference_net.precision_layers],
                            lr=self.sparsity_matrix_lr),
            torch.optim.Adam([self.inference_net.baseline_precision], lr=1e-3),

            # Generative parameters
            torch.optim.Adam(set(
                p for net in self.generative_net.almost_generative_nets
                for p in net.parameters()),
                             lr=1e-3),
            torch.optim.SGD([self.generative_net.connectivity_matrices],
                            lr=self.sparsity_matrix_lr),
            torch.optim.Adam([self.generative_sigma_adjustment], lr=1e-3)
        ])

        if self.base_results_dir is not None:
            # https://stackoverflow.com/questions/2257441/random-string-generation-with-upper-case-letters-and-digits-in-python?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
            self.results_folder_name = self.prefix + ''.join(
                random.choice(string.ascii_uppercase + string.digits)
                for _ in range(16))
            self.results_dir = self.base_results_dir / self.results_folder_name
            self._init_results_dir()

    def sample_data(self, num_samples):
        """Sample a bars image. Produces a Tensor of shape [num_samples,
    self.img_size, self.img_size]."""
        # return (
        #   sample_many_one_bar_images(num_samples, self.img_size) +
        #   noise_stddev * torch.randn(num_samples, self.img_size, self.img_size)
        # )
        return (sample_many_bars_images(
            num_samples, self.img_size, 0.75 * torch.ones(self.img_size),
            torch.zeros(self.img_size)) + self.noise_stddev *
                torch.randn(num_samples, self.img_size, self.img_size))

    def make_almost_generative_net(self, dim_x):
        return NormalNet(
            torch.nn.Linear(self.generative_net_input_dim, dim_x),
            torch.nn.Sequential(
                torch.nn.Linear(self.generative_net_input_dim, dim_x),
                Lambda(lambda x: torch.exp(0.5 * x + self.
                                           generative_sigma_adjustment))))

    def make_almost_inference_net(self, dim_x):
        return torch.nn.Sequential(
            torch.nn.Linear(dim_x, self.inference_net_output_dim),
            torch.nn.ReLU())

    def train(self, num_epochs):
        for self.epoch in itertools.islice(self.epoch_counter, num_epochs):
            for batch_idx, (data, _) in enumerate(self.train_loader):
                # The final batch may not have the same size as `batch_size`.
                actual_batch_size = data.size(0)

                mask = sample_random_mask(actual_batch_size, self.img_size,
                                          self.group_available_prob)
                info = self.vae.elbo(
                    Xs=[Variable(x) for x in self.data_transform(data)],
                    group_mask=Variable(
                        torch.ones(actual_batch_size, self.img_size)
                        # mask
                    ),
                    inference_group_mask=Variable(
                        # sample_random_mask(actual_batch_size * self., self.img_size, self.group_available_prob)
                        mask))
                elbo = info['elbo']
                loss = info['loss']
                z_kl = info['z_kl']
                reconstruction_log_likelihood = info[
                    'reconstruction_log_likelihood']
                logprob_theta = info['logprob_theta']
                logprob_L1 = info['logprob_L1']
                test_ll = self.test_loglik().data[0]

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.vae.proximal_step(self.sparsity_matrix_lr * self.lam)

                self.elbo_per_iter.append(elbo.data[0])
                self.test_loglik_per_iter.append(test_ll)
                print(
                    f'Epoch {self.epoch}, {batch_idx} / {len(self.train_loader)}'
                )
                print(f'  ELBO: {elbo.data[0]}')
                print(f'    -KL(q(z) || p(z)): {-z_kl.data[0]}')
                print(
                    f'    loglik_term      : {reconstruction_log_likelihood.data[0]}'
                )
                print(f'    log p(theta)     : {logprob_theta.data[0]}')
                print(f'    L1               : {logprob_L1.data[0]}')
                print(f'  test log lik.      : {test_ll}')
                print(self.inference_net.baseline_precision.data[0])
                print(self.generative_sigma_adjustment.data[0])
                # print(self.vae.generative_nets[0].param_nets[1][0].bias.data)

            # Checkpoint every 10 epochs
            if self.epoch % 10 == 0:
                self.checkpoint()

        # Checkpoint at the very end as well
        self.checkpoint()

    def test_loglik(self):
        Xs = [Variable(x) for x in self.data_transform(self.test_data)]
        group_mask = Variable(torch.ones(self.test_data.size(0),
                                         self.img_size))
        q_z = self.inference_net(Xs, group_mask)
        return self.vae.log_likelihood(Xs, group_mask, q_z.sample())

    def viz_elbo(self):
        fig = plt.figure()
        plt.plot(self.elbo_per_iter)
        plt.xlabel('iteration')
        plt.ylabel('ELBO')
        return fig

    def _init_results_dir(self):
        self.results_dir_params = self.results_dir / 'params.json'
        self.results_dir_elbo = self.results_dir / 'elbo_plot'
        self.results_dir_sparsity_matrix = self.results_dir / 'sparsity_matrix'
        self.results_dir_reconstructions = self.results_dir / 'reconstructions'
        self.results_dir_pickles = self.results_dir / 'pickles'

        # The results_dir should be unique
        self.results_dir.mkdir(exist_ok=False)
        self.results_dir_elbo.mkdir(exist_ok=False)
        self.results_dir_sparsity_matrix.mkdir(exist_ok=False)
        self.results_dir_reconstructions.mkdir(exist_ok=False)
        self.results_dir_pickles.mkdir(exist_ok=False)
        json.dump({p: getattr(self, p)
                   for p in BarsQuadrantsFaceback.PARAMS},
                  open(self.results_dir_params, 'w'),
                  sort_keys=True,
                  indent=2,
                  separators=(',', ': '))

    def checkpoint(self):
        if self.base_results_dir is not None:
            fig = viz_reconstruction(self, 12345)
            plt.savefig(self.results_dir_reconstructions /
                        f'epoch{self.epoch}.pdf')
            plt.close(fig)

            fig = self.viz_elbo()
            plt.savefig(self.results_dir_elbo / f'epoch{self.epoch}.pdf')
            plt.close(fig)

            fig, ax = viz_sparsity(self.vae,
                                   group_names=[
                                       'top left', 'top right', 'bottom left',
                                       'bottom right'
                                   ])
            plt.savefig(self.results_dir_sparsity_matrix /
                        f'epoch{self.epoch}.pdf')
            plt.close(fig)

            # dill.dump(self, open(self.results_dir_pickles / f'epoch{self.epoch}.p', 'wb'))
        else:
            viz_reconstruction(self, 12345)
            self.viz_elbo()
            viz_sparsity(self.vae,
                         group_names=[
                             'top left', 'top right', 'bottom left',
                             'bottom right'
                         ])
            plt.show()

    def data_transform(self, x):
        half_size = self.img_size // 2
        return [
            x[:, :half_size, :half_size].contiguous().view(
                -1, half_size * half_size),
            x[:, :half_size,
              half_size:].contiguous().view(-1, half_size * half_size),
            x[:, half_size:, :half_size].contiguous().view(
                -1, half_size * half_size),
            x[:, half_size:,
              half_size:].contiguous().view(-1, half_size * half_size)
        ]

    def data_untransform(self, Xs):
        reshaped = [
            x.view(-1, self.img_size // 2, self.img_size // 2) for x in Xs
        ]
        return torch.cat([
            torch.cat([reshaped[0], reshaped[1]], dim=2),
            torch.cat([reshaped[2], reshaped[3]], dim=2),
        ],
                         dim=1)
Exemple #3
0
class CVLFacebackExperiment(object):

  PARAMS = [
    # 'img_size',
    'dim_z',
    'batch_size',
    'lam',
    'sparsity_matrix_lr',
    'inference_net_output_dim',
    'generative_net_input_dim',
    'initial_baseline_precision',
    'prior_theta_sigma',
    'group_available_prob',
    'inference_net_num_filters',
    'generative_net_num_filters',
    'use_gpu',
  ]

  def __init__(
      self,
      # img_size,
      dim_z,
      batch_size,
      lam,
      sparsity_matrix_lr,
      inference_net_output_dim,
      generative_net_input_dim,
      initial_baseline_precision,
      prior_theta_sigma,
      group_available_prob,
      inference_net_num_filters,
      generative_net_num_filters,
      use_gpu,
      base_results_dir=None,
      prefix='faces_'
  ):
    # self.img_size = img_size
    self.dim_z = dim_z
    self.batch_size = batch_size
    self.lam = lam
    self.sparsity_matrix_lr = sparsity_matrix_lr
    self.inference_net_output_dim = inference_net_output_dim
    self.generative_net_input_dim = generative_net_input_dim
    self.initial_baseline_precision = initial_baseline_precision
    self.prior_theta_sigma = prior_theta_sigma
    self.group_available_prob = group_available_prob
    self.inference_net_num_filters = inference_net_num_filters
    self.generative_net_num_filters = generative_net_num_filters
    self.use_gpu = use_gpu
    self.base_results_dir = base_results_dir
    self.prefix = prefix

    self.epoch_counter = itertools.count()
    self.epoch = None
    self.elbo_per_iter = []

    self.load_data()

    if self.use_gpu:
      self.prior_z = Normal(
        Variable(torch.zeros(1, dim_z).cuda()),
        Variable(torch.ones(1, dim_z).cuda())
      )
    else:
      self.prior_z = Normal(
        Variable(torch.zeros(1, dim_z)),
        Variable(torch.ones(1, dim_z))
      )

    self.inference_net = FacebackInferenceNet(
      almost_inference_nets=[
        self.make_almost_inference_net(64 * 64)
        for _ in range(num_groups)
      ],
      net_output_dim=self.inference_net_output_dim,
      prior_z=self.prior_z,
      initial_baseline_precision=self.initial_baseline_precision,
      use_gpu=self.use_gpu
    )
    self.generative_net = FacebackGenerativeNet(
      almost_generative_nets=[
        self.make_almost_generative_net(64 * 64)
        for _ in range(num_groups)
      ],
      net_input_dim=self.generative_net_input_dim,
      dim_z=self.dim_z,
      use_gpu=self.use_gpu
    )
    self.vae = FacebackVAE(
      inference_net=self.inference_net,
      generative_net=self.generative_net,
      prior_z=self.prior_z,
      prior_theta=NormalPriorTheta(sigma=self.prior_theta_sigma),
      lam=self.lam
    )

    self.optimizer = MetaOptimizer([
      # Inference parameters
      torch.optim.Adam(
        set(p for net in self.inference_net.almost_inference_nets for p in net.parameters()),
        lr=1e-3
      ),
      torch.optim.Adam([self.inference_net.mu_layers], lr=1e-3),
      torch.optim.SGD([self.inference_net.precision_layers], lr=self.sparsity_matrix_lr),
      torch.optim.Adam([self.inference_net.baseline_precision], lr=1e-3),

      # Generative parameters
      torch.optim.Adam(
        set(p for net in self.generative_net.almost_generative_nets for p in net.parameters()),
        lr=1e-3
      ),
      torch.optim.Adam(
        [net.sigma_net.extra_args[0] for net in self.generative_net.almost_generative_nets],
        lr=1e-3
      ),
      torch.optim.SGD([self.generative_net.connectivity_matrices], lr=self.sparsity_matrix_lr)
    ])

    if self.base_results_dir is not None:
      # https://stackoverflow.com/questions/2257441/random-string-generation-with-upper-case-letters-and-digits-in-python?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
      self.results_folder_name = self.prefix + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(16))
      self.results_dir = self.base_results_dir / self.results_folder_name
      self._init_results_dir()

  def train(self, num_epochs):
    for self.epoch in itertools.islice(self.epoch_counter, num_epochs):
      for batch_idx, batch in enumerate(self.train_loader):
        group_mask = batch[0]
        views = batch[1:]

        # The final batch may not have the same size as `batch_size`.
        actual_batch_size = group_mask.size(0)

        # Multiply in the group mask because we don't want to use things for
        # inference that aren't available in the data at all.
        inference_group_mask = (
          sample_random_mask(actual_batch_size, num_groups, self.group_available_prob) *
          group_mask
        )
        if self.use_gpu:
          group_mask = group_mask.cuda()
          inference_group_mask = inference_group_mask.cuda()
          views = [x.cuda() for x in views]

        info = self.vae.elbo(
          Xs=[Variable(x) for x in views],
          group_mask=Variable(group_mask),
          inference_group_mask=Variable(inference_group_mask)
        )
        elbo = info['elbo']
        loss = info['loss']
        z_kl = info['z_kl']
        reconstruction_log_likelihood = info['reconstruction_log_likelihood']
        logprob_theta = info['logprob_theta']
        logprob_L1 = info['logprob_L1']

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.vae.proximal_step(self.sparsity_matrix_lr * self.lam)

        self.elbo_per_iter.append(elbo.data[0])
        print(f'Epoch {self.epoch}, {batch_idx} / {len(self.train_loader)}')
        print(f'  ELBO: {elbo.data[0]}')
        print(f'    -KL(q(z) || p(z)): {-z_kl.data[0]}')
        print(f'    loglik_term      : {reconstruction_log_likelihood.data[0]}')
        print(f'    log p(theta)     : {logprob_theta.data[0]}')
        print(f'    L1               : {logprob_L1.data[0]}')

      # Checkpoint every once in a while
      if self.epoch % 50 == 0:
        self.checkpoint()

    # Checkpoint at the very end as well
    self.checkpoint()

  def make_almost_generative_net(self, dim_x):
    # We learn a std dev for each pixel which is not a function of the input.
    # Note that this Variable is NOT going to show up in `net.parameters()` and
    # therefore it is implicitly free from the ridge penalty/p(theta) prior.
    init_log_sigma = torch.log(1e-2 * torch.ones(1, 1, 64, 64))

    # See https://github.com/pytorch/examples/blob/master/dcgan/main.py#L107
    dim_in = self.generative_net_input_dim
    ngf = self.generative_net_num_filters
    model = torch.nn.Sequential(
      Lambda(lambda x: x.view(-1, dim_in, 1, 1)),
      torch.nn.ConvTranspose2d( dim_in, ngf * 8, 4, 1, 0, bias=False),
      torch.nn.BatchNorm2d(ngf * 8),
      torch.nn.ReLU(inplace=True),
      # state size. (ngf*8) x 4 x 4
      torch.nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
      torch.nn.BatchNorm2d(ngf * 4),
      torch.nn.ReLU(inplace=True),
      # state size. (ngf*4) x 8 x 8
      torch.nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
      torch.nn.BatchNorm2d(ngf * 2),
      torch.nn.ReLU(inplace=True),
      # state size. (ngf*2) x 16 x 16
      torch.nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
      torch.nn.BatchNorm2d(ngf),
      torch.nn.ReLU(inplace=True),
      # state size. (ngf) x 32 x 32
      torch.nn.ConvTranspose2d(    ngf,       1, 4, 2, 1, bias=False),
      # state size. 1 x 64 x 64
      torch.nn.Tanh()
    )

    if self.use_gpu:
      model = model.cuda()
      init_log_sigma = init_log_sigma.cuda()

    log_sigma = Variable(init_log_sigma, requires_grad=True)
    return NormalNet(
      model,
      Lambda(
        lambda x, log_sigma: torch.exp(log_sigma.expand(x.size(0), -1, -1, -1)) + 1e-3,
        extra_args=(log_sigma,)
      )
    )

  def make_almost_inference_net(self, dim_x):
    ndf = self.inference_net_num_filters
    model = torch.nn.Sequential(
      # input is (nc) x 64 x 64
      torch.nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
      torch.nn.LeakyReLU(0.2, inplace=True),
      # state size. (ndf) x 32 x 32
      torch.nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
      torch.nn.BatchNorm2d(ndf * 2),
      torch.nn.LeakyReLU(0.2, inplace=True),
      # state size. (ndf*2) x 16 x 16
      torch.nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
      torch.nn.BatchNorm2d(ndf * 4),
      torch.nn.LeakyReLU(0.2, inplace=True),
      # state size. (ndf*4) x 8 x 8

      # Flatten the filter channels
      Lambda(lambda x: x.view(x.size(0), -1)),

      torch.nn.Linear((ndf * 4) * 8 * 8, self.inference_net_output_dim),
      torch.nn.LeakyReLU(0.2, inplace=True)

      # This is the rest of the DCGAN discriminator
      # torch.nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
      # torch.nn.BatchNorm2d(ndf * 8),
      # torch.nn.LeakyReLU(0.2, inplace=True),
      # # state size. (ndf*8) x 4 x 4
      # torch.nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
      # torch.nn.Sigmoid()
    )

    return (model.cuda() if self.use_gpu else model)

  def load_data(self):
    path = Path('data/cvl_faces/')
    # all_subjects = range(1, 114 + 1)
    train_subjects = range(1, 100 + 1)
    test_subjects = range(101, 114 + 1)

    self.train_dataset = CVLDataset(path, subjects=train_subjects, transform=transform)
    self.train_loader = torch.utils.data.DataLoader(
      self.train_dataset,
      batch_size=self.batch_size,
      shuffle=True
    )

    self.test_dataset = CVLDataset(path, subjects=test_subjects, transform=transform)
    self.test_loader = torch.utils.data.DataLoader(
      self.test_dataset,
      batch_size=self.batch_size,
      shuffle=True
    )

  def checkpoint(self):
    if self.base_results_dir is not None:
      fig = viz_reconstruction_all_views(self, self.train_dataset, range(8))
      plt.savefig(self.results_dir_train_reconstructions / f'epoch{self.epoch}.pdf')
      plt.close(fig)

      fig = viz_reconstruction_all_views(self, self.test_dataset, range(8))
      plt.savefig(self.results_dir_test_reconstructions / f'epoch{self.epoch}.pdf')
      plt.close(fig)

      # 35, 44, 93 are all missing the last smiling with teeth face. Don't
      # forget the off-by-one error though!
      fig = viz_reconstruction_all_views(self, self.train_dataset, [34, 43, 92])
      plt.savefig(self.results_dir_train_missing_reconstructions / f'epoch{self.epoch}.pdf')
      plt.close(fig)

      fig = viz_elbo(self)
      plt.savefig(self.results_dir_elbo / f'epoch{self.epoch}.pdf')
      plt.close(fig)

      fig, _ = viz_sparsity(self.vae)
      plt.savefig(self.results_dir_sparsity_matrix / f'epoch{self.epoch}.pdf')
      plt.close(fig)

      # t0 = time.time()
      # dill.dump(self, open(self.results_dir_pickles / f'epoch{self.epoch}.p', 'wb'))
      # print(f'(dilling took {time.time() - t0} seconds.)')
    else:
      # viz_reconstruction_all_views(self, range(8))
      # viz_elbo(self)
      # viz_sparsity(self.vae)
      # plt.show()
      pass

  def _init_results_dir(self):
    self.results_dir_params = self.results_dir / 'params.json'
    self.results_dir_elbo = self.results_dir / 'elbo_plot'
    self.results_dir_sparsity_matrix = self.results_dir / 'sparsity_matrix'
    self.results_dir_test_reconstructions = self.results_dir / 'test_reconstructions'
    self.results_dir_train_reconstructions = self.results_dir / 'train_reconstructions'
    self.results_dir_train_missing_reconstructions = self.results_dir / 'train_missing_reconstructions'
    self.results_dir_pickles = self.results_dir / 'pickles'

    # The results_dir should be unique
    self.results_dir.mkdir(exist_ok=False)
    self.results_dir_elbo.mkdir(exist_ok=False)
    self.results_dir_sparsity_matrix.mkdir(exist_ok=False)
    self.results_dir_test_reconstructions.mkdir(exist_ok=False)
    self.results_dir_train_reconstructions.mkdir(exist_ok=False)
    self.results_dir_train_missing_reconstructions.mkdir(exist_ok=False)
    self.results_dir_pickles.mkdir(exist_ok=False)
    json.dump(
      {p: getattr(self, p) for p in CVLFacebackExperiment.PARAMS},
      open(self.results_dir_params, 'w'),
      sort_keys=True,
      indent=2,
      separators=(',', ': ')
    )