Example #1
0
    def _build(self, unused_input=None):
        # pylint:disable=unused-variable
        # Reason:
        #   All endpoints are stored as attribute at the end of `_build`.
        #   Pylint cannot infer this case so it emits false alarm of
        #   unused-variable if we do not disable this warning.

        config = self.config

        # Constants
        batch_size = config['batch_size']
        n_latent = config['n_latent']
        img_width = config['img_width']

        # ---------------------------------------------------------------------
        # ## Placeholders
        # ---------------------------------------------------------------------
        # Image data
        if dataset_is_mnist_family(config['dataset']):
            n_labels = 10
            x = tf.placeholder(tf.float32,
                               shape=(None, img_width * img_width),
                               name='x')
            attr_loss_fn = tf.losses.softmax_cross_entropy
            attr_pred_fn = tf.nn.softmax
            attr_weights = tf.constant(np.ones([1]).astype(np.float32))
            # p_x_fn = lambda logits: ds.Bernoulli(logits=logits)
            x_sigma = tf.constant(config['x_sigma'])
            p_x_fn = (
                lambda logs: ds.Normal(loc=tf.nn.sigmoid(logs), scale=x_sigma)
            )  # noqa

        elif config['dataset'] == 'CELEBA':
            n_labels = 10
            x = tf.placeholder(tf.float32,
                               shape=(None, img_width, img_width, 3),
                               name='x')
            attr_loss_fn = tf.losses.sigmoid_cross_entropy
            attr_pred_fn = tf.nn.sigmoid
            attr_weights = tf.constant(
                np.ones([1, n_labels]).astype(np.float32))
            x_sigma = tf.constant(config['x_sigma'])
            p_x_fn = (
                lambda logs: ds.Normal(loc=tf.nn.sigmoid(logs), scale=x_sigma)
            )  # noqa

        # Attributes
        labels = tf.placeholder(tf.int32,
                                shape=(None, n_labels),
                                name='labels')
        # Real / fake label reward
        r = tf.placeholder(tf.float32, shape=(None, 1), name='D_label')
        # Transform through optimization
        z0 = tf.placeholder(tf.float32, shape=(None, n_latent), name='z0')

        # ---------------------------------------------------------------------
        # ## Modules with parameters
        # ---------------------------------------------------------------------
        # Abstract Modules.
        # Variable that is class has name consider to be invalid by pylint so we
        # disable the warning.
        # pylint:disable=invalid-name
        Encoder = config['Encoder']
        Decoder = config['Decoder']
        Classifier = config['Classifier']
        # pylint:enable=invalid-name

        encoder = Encoder(name='encoder')
        decoder = Decoder(name='decoder')
        classifier = Classifier(output_size=n_labels, name='classifier')

        # ---------------------------------------------------------------------
        # ## Classify Attributes from pixels
        # ---------------------------------------------------------------------
        logits_classifier = classifier(x)
        pred_classifier = attr_pred_fn(logits_classifier)
        classifier_loss = attr_loss_fn(labels, logits=logits_classifier)

        # ---------------------------------------------------------------------
        # ## VAE
        # ---------------------------------------------------------------------
        # Encode
        mu, sigma = encoder(x)
        q_z = ds.Normal(loc=mu, scale=sigma)

        # Optimize / Amortize or feedthrough
        q_z_sample = q_z.sample()

        z = q_z_sample

        # Decode
        logits = decoder(z)
        p_x = p_x_fn(logits)
        x_mean = p_x.mean()

        # Reconstruction Loss
        if config['dataset'] == 'CELEBA':
            recons = tf.reduce_sum(p_x.log_prob(x), axis=[1, 2, 3])
        else:
            recons = tf.reduce_sum(p_x.log_prob(x), axis=[-1])

        mean_recons = tf.reduce_mean(recons)

        # Prior
        p_z = ds.Normal(loc=0., scale=1.)
        prior_sample = p_z.sample(sample_shape=[batch_size, n_latent])

        # KL Loss.
        # We use `KL` in variable name for naming consistency with math.
        # pylint:disable=invalid-name
        if config['beta'] == 0:
            mean_KL = tf.constant(0.0)
        else:
            KL_qp = ds.kl_divergence(q_z, p_z)
            KL = tf.reduce_sum(KL_qp, axis=-1)
            mean_KL = tf.reduce_mean(KL)
        # pylint:enable=invalid-name

        # VAE Loss
        beta = tf.constant(config['beta'])
        vae_loss = -mean_recons + mean_KL * beta

        # ---------------------------------------------------------------------
        # ## Training
        # ---------------------------------------------------------------------
        # Learning rates
        vae_lr = tf.constant(3e-4)
        classifier_lr = tf.constant(3e-4)

        # Training Ops
        vae_vars = list(encoder.get_variables())
        vae_vars.extend(decoder.get_variables())
        train_vae = tf.train.AdamOptimizer(learning_rate=vae_lr).minimize(
            vae_loss, var_list=vae_vars)

        classifier_vars = classifier.get_variables()
        train_classifier = tf.train.AdamOptimizer(
            learning_rate=classifier_lr).minimize(classifier_loss,
                                                  var_list=classifier_vars)

        # Savers
        vae_saver = tf.train.Saver(vae_vars, max_to_keep=100)
        classifier_saver = tf.train.Saver(classifier_vars, max_to_keep=1000)

        # Add all endpoints as object attributes
        for k, v in iteritems(locals()):
            self.__dict__[k] = v
Example #2
0
def main(unused_argv):
  del unused_argv

  # Load Config
  config_name = FLAGS.config
  config_module = importlib.import_module(configs_module_prefix +
                                          '.%s' % config_name)
  config = config_module.config
  model_uid = common.get_model_uid(config_name, FLAGS.exp_uid)
  batch_size = config['batch_size']

  # Load dataset
  dataset = common.load_dataset(config)
  save_path = dataset.save_path
  train_data = dataset.train_data
  attr_train = dataset.attr_train
  eval_data = dataset.eval_data
  attr_eval = dataset.attr_eval

  # Make the directory
  save_dir = os.path.join(save_path, model_uid)
  best_dir = os.path.join(save_dir, 'best')
  tf.gfile.MakeDirs(save_dir)
  tf.gfile.MakeDirs(best_dir)
  tf.logging.info('Save Dir: %s', save_dir)

  np.random.seed(FLAGS.random_seed)
  # We use `N` in variable name to emphasis its being the Number of something.
  N_train = train_data.shape[0]  # pylint:disable=invalid-name
  N_eval = eval_data.shape[0]  # pylint:disable=invalid-name

  # Load Model
  tf.reset_default_graph()
  sess = tf.Session()

  m = model_dataspace.Model(config, name=model_uid)
  _ = m()  # noqa

  # Create summaries
  tf.summary.scalar('Train_Loss', m.vae_loss)
  tf.summary.scalar('Mean_Recon_LL', m.mean_recons)
  tf.summary.scalar('Mean_KL', m.mean_KL)
  scalar_summaries = tf.summary.merge_all()

  x_mean_, x_ = m.x_mean, m.x
  if common.dataset_is_mnist_family(config['dataset']):
    x_mean_ = tf.reshape(x_mean_, [-1, MNIST_SIZE, MNIST_SIZE, 1])
    x_ = tf.reshape(x_, [-1, MNIST_SIZE, MNIST_SIZE, 1])

  x_mean_summary = tf.summary.image(
      'Reconstruction', nn.tf_batch_image(x_mean_), max_outputs=1)
  x_summary = tf.summary.image('Original', nn.tf_batch_image(x_), max_outputs=1)
  sample_summary = tf.summary.image(
      'Sample', nn.tf_batch_image(x_mean_), max_outputs=1)
  # Summary writers
  train_writer = tf.summary.FileWriter(save_dir + '/vae_train', sess.graph)
  eval_writer = tf.summary.FileWriter(save_dir + '/vae_eval', sess.graph)

  # Initialize
  sess.run(tf.global_variables_initializer())

  i_start = 0
  running_N_eval = 30  # pylint:disable=invalid-name
  traces = {
      'i': [],
      'i_pred': [],
      'loss': [],
      'loss_eval': [],
  }

  best_eval_loss = np.inf
  vae_lr_ = np.logspace(np.log10(FLAGS.lr), np.log10(1e-6), FLAGS.n_iters)

  # Train the VAE
  for i in range(i_start, FLAGS.n_iters):
    start = (i * batch_size) % N_train
    end = start + batch_size
    batch = train_data[start:end]
    labels = attr_train[start:end]

    # train op
    res = sess.run(
        [m.train_vae, m.vae_loss, m.mean_recons, m.mean_KL, scalar_summaries], {
            m.x: batch,
            m.vae_lr: vae_lr_[i],
            m.labels: labels,
        })
    tf.logging.info('Iter: %d, Loss: %d', i, res[1])
    train_writer.add_summary(res[-1], i)

    if i % FLAGS.n_iters_per_eval == 0:
      # write training reconstructions
      if batch.shape[0] == batch_size:
        res = sess.run([x_summary, x_mean_summary], {
            m.x: batch,
            m.labels: labels,
        })
        train_writer.add_summary(res[0], i)
        train_writer.add_summary(res[1], i)

      # write sample reconstructions
      prior_sample = sess.run(m.prior_sample)
      res = sess.run([sample_summary], {
          m.q_z_sample: prior_sample,
          m.labels: labels,
      })
      train_writer.add_summary(res[0], i)

      # write eval summaries
      start = (i * batch_size) % N_eval
      end = start + batch_size
      batch = eval_data[start:end]
      labels = attr_eval[start:end]
      if batch.shape[0] == batch_size:
        res_eval = sess.run([
            m.vae_loss, m.mean_recons, m.mean_KL, scalar_summaries, x_summary,
            x_mean_summary
        ], {
            m.x: batch,
            m.labels: labels,
        })
        traces['loss_eval'].append(res_eval[0])
        eval_writer.add_summary(res_eval[-3], i)
        eval_writer.add_summary(res_eval[-2], i)
        eval_writer.add_summary(res_eval[-1], i)

    if i % FLAGS.n_iters_per_save == 0:
      smoothed_eval_loss = np.mean(traces['loss_eval'][-running_N_eval:])
      if smoothed_eval_loss < best_eval_loss:
        # Save the best model
        best_eval_loss = smoothed_eval_loss
        save_name = os.path.join(best_dir, 'vae_best_%s.ckpt' % model_uid)
        tf.logging.info('SAVING BEST! %s Iter: %d', save_name, i)
        m.vae_saver.save(sess, save_name)
        with tf.gfile.Open(os.path.join(best_dir, 'best_ckpt_iters.txt'),
                           'w') as f:
          f.write('%d' % i)
Example #3
0
def main(unused_argv):
    del unused_argv

    # Load Config
    config_name = FLAGS.config
    config_module = importlib.import_module(configs_module_prefix +
                                            '.%s' % config_name)
    config = config_module.config
    model_uid = common.get_model_uid(config_name, FLAGS.exp_uid)
    batch_size = config['batch_size']

    # Load dataset
    dataset = common.load_dataset(config)
    save_path = dataset.save_path
    train_data = dataset.train_data
    attr_train = dataset.attr_train
    eval_data = dataset.eval_data
    attr_eval = dataset.attr_eval

    # Make the directory
    save_dir = os.path.join(save_path, model_uid)
    best_dir = os.path.join(save_dir, 'best')
    tf.gfile.MakeDirs(save_dir)
    tf.gfile.MakeDirs(best_dir)
    tf.logging.info('Save Dir: %s', save_dir)

    np.random.seed(FLAGS.random_seed)
    # We use `N` in variable name to emphasis its being the Number of something.
    N_train = train_data.shape[0]  # pylint:disable=invalid-name
    N_eval = eval_data.shape[0]  # pylint:disable=invalid-name

    # Load Model
    tf.reset_default_graph()
    sess = tf.Session()

    m = model_dataspace.Model(config, name=model_uid)
    _ = m()  # noqa

    # Create summaries
    tf.summary.scalar('Train_Loss', m.vae_loss)
    tf.summary.scalar('Mean_Recon_LL', m.mean_recons)
    tf.summary.scalar('Mean_KL', m.mean_KL)
    scalar_summaries = tf.summary.merge_all()

    x_mean_, x_ = m.x_mean, m.x
    if common.dataset_is_mnist_family(config['dataset']):
        x_mean_ = tf.reshape(x_mean_, [-1, MNIST_SIZE, MNIST_SIZE, 1])
        x_ = tf.reshape(x_, [-1, MNIST_SIZE, MNIST_SIZE, 1])

    x_mean_summary = tf.summary.image('Reconstruction',
                                      nn.tf_batch_image(x_mean_),
                                      max_outputs=1)
    x_summary = tf.summary.image('Original',
                                 nn.tf_batch_image(x_),
                                 max_outputs=1)
    sample_summary = tf.summary.image('Sample',
                                      nn.tf_batch_image(x_mean_),
                                      max_outputs=1)
    # Summary writers
    train_writer = tf.summary.FileWriter(save_dir + '/vae_train', sess.graph)
    eval_writer = tf.summary.FileWriter(save_dir + '/vae_eval', sess.graph)

    # Initialize
    sess.run(tf.global_variables_initializer())

    i_start = 0
    running_N_eval = 30  # pylint:disable=invalid-name
    traces = {
        'i': [],
        'i_pred': [],
        'loss': [],
        'loss_eval': [],
    }

    best_eval_loss = np.inf
    vae_lr_ = np.logspace(np.log10(FLAGS.lr), np.log10(1e-6), FLAGS.n_iters)

    # Train the VAE
    for i in range(i_start, FLAGS.n_iters):
        start = (i * batch_size) % N_train
        end = start + batch_size
        batch = train_data[start:end]
        labels = attr_train[start:end]

        # train op
        res = sess.run([
            m.train_vae, m.vae_loss, m.mean_recons, m.mean_KL, scalar_summaries
        ], {
            m.x: batch,
            m.vae_lr: vae_lr_[i],
            m.labels: labels,
        })
        tf.logging.info('Iter: %d, Loss: %d', i, res[1])
        train_writer.add_summary(res[-1], i)

        if i % FLAGS.n_iters_per_eval == 0:
            # write training reconstructions
            if batch.shape[0] == batch_size:
                res = sess.run([x_summary, x_mean_summary], {
                    m.x: batch,
                    m.labels: labels,
                })
                train_writer.add_summary(res[0], i)
                train_writer.add_summary(res[1], i)

            # write sample reconstructions
            prior_sample = sess.run(m.prior_sample)
            res = sess.run([sample_summary], {
                m.q_z_sample: prior_sample,
                m.labels: labels,
            })
            train_writer.add_summary(res[0], i)

            # write eval summaries
            start = (i * batch_size) % N_eval
            end = start + batch_size
            batch = eval_data[start:end]
            labels = attr_eval[start:end]
            if batch.shape[0] == batch_size:
                res_eval = sess.run([
                    m.vae_loss, m.mean_recons, m.mean_KL, scalar_summaries,
                    x_summary, x_mean_summary
                ], {
                    m.x: batch,
                    m.labels: labels,
                })
                traces['loss_eval'].append(res_eval[0])
                eval_writer.add_summary(res_eval[-3], i)
                eval_writer.add_summary(res_eval[-2], i)
                eval_writer.add_summary(res_eval[-1], i)

        if i % FLAGS.n_iters_per_save == 0:
            smoothed_eval_loss = np.mean(traces['loss_eval'][-running_N_eval:])
            if smoothed_eval_loss < best_eval_loss:
                # Save the best model
                best_eval_loss = smoothed_eval_loss
                save_name = os.path.join(best_dir,
                                         'vae_best_%s.ckpt' % model_uid)
                tf.logging.info('SAVING BEST! %s Iter: %d', save_name, i)
                m.vae_saver.save(sess, save_name)
                with tf.gfile.Open(
                        os.path.join(best_dir, 'best_ckpt_iters.txt'),
                        'w') as f:
                    f.write('%d' % i)
Example #4
0
  def _build(self, unused_input=None):
    # pylint:disable=unused-variable
    # Reason:
    #   All endpoints are stored as attribute at the end of `_build`.
    #   Pylint cannot infer this case so it emits false alarm of
    #   unused-variable if we do not disable this warning.

    config = self.config

    # Constants
    batch_size = config['batch_size']
    n_latent = config['n_latent']
    img_width = config['img_width']

    # ---------------------------------------------------------------------
    # ## Placeholders
    # ---------------------------------------------------------------------
    # Image data
    if dataset_is_mnist_family(config['dataset']):
      n_labels = 10
      x = tf.placeholder(
          tf.float32, shape=(None, img_width * img_width), name='x')
      attr_loss_fn = tf.losses.softmax_cross_entropy
      attr_pred_fn = tf.nn.softmax
      attr_weights = tf.constant(np.ones([1]).astype(np.float32))
      # p_x_fn = lambda logits: ds.Bernoulli(logits=logits)
      x_sigma = tf.constant(config['x_sigma'])
      p_x_fn = (lambda logs: ds.Normal(loc=tf.nn.sigmoid(logs), scale=x_sigma)
               )  # noqa

    elif config['dataset'] == 'CELEBA':
      n_labels = 10
      x = tf.placeholder(
          tf.float32, shape=(None, img_width, img_width, 3), name='x')
      attr_loss_fn = tf.losses.sigmoid_cross_entropy
      attr_pred_fn = tf.nn.sigmoid
      attr_weights = tf.constant(np.ones([1, n_labels]).astype(np.float32))
      x_sigma = tf.constant(config['x_sigma'])
      p_x_fn = (lambda logs: ds.Normal(loc=tf.nn.sigmoid(logs), scale=x_sigma)
               )  # noqa

    # Attributes
    labels = tf.placeholder(tf.int32, shape=(None, n_labels), name='labels')
    # Real / fake label reward
    r = tf.placeholder(tf.float32, shape=(None, 1), name='D_label')
    # Transform through optimization
    z0 = tf.placeholder(tf.float32, shape=(None, n_latent), name='z0')

    # ---------------------------------------------------------------------
    # ## Modules with parameters
    # ---------------------------------------------------------------------
    # Abstract Modules.
    # Variable that is class has name consider to be invalid by pylint so we
    # disable the warning.
    # pylint:disable=invalid-name
    Encoder = config['Encoder']
    Decoder = config['Decoder']
    Classifier = config['Classifier']
    # pylint:enable=invalid-name

    encoder = Encoder(name='encoder')
    decoder = Decoder(name='decoder')
    classifier = Classifier(output_size=n_labels, name='classifier')

    # ---------------------------------------------------------------------
    # ## Classify Attributes from pixels
    # ---------------------------------------------------------------------
    logits_classifier = classifier(x)
    pred_classifier = attr_pred_fn(logits_classifier)
    classifier_loss = attr_loss_fn(labels, logits=logits_classifier)

    # ---------------------------------------------------------------------
    # ## VAE
    # ---------------------------------------------------------------------
    # Encode
    mu, sigma = encoder(x)
    q_z = ds.Normal(loc=mu, scale=sigma)

    # Optimize / Amortize or feedthrough
    q_z_sample = q_z.sample()

    z = q_z_sample

    # Decode
    logits = decoder(z)
    p_x = p_x_fn(logits)
    x_mean = p_x.mean()

    # Reconstruction Loss
    if config['dataset'] == 'CELEBA':
      recons = tf.reduce_sum(p_x.log_prob(x), axis=[1, 2, 3])
    else:
      recons = tf.reduce_sum(p_x.log_prob(x), axis=[-1])

    mean_recons = tf.reduce_mean(recons)

    # Prior
    p_z = ds.Normal(loc=0., scale=1.)
    prior_sample = p_z.sample(sample_shape=[batch_size, n_latent])

    # KL Loss.
    # We use `KL` in variable name for naming consistency with math.
    # pylint:disable=invalid-name
    if config['beta'] == 0:
      mean_KL = tf.constant(0.0)
    else:
      KL_qp = ds.kl_divergence(q_z, p_z)
      KL = tf.reduce_sum(KL_qp, axis=-1)
      mean_KL = tf.reduce_mean(KL)
    # pylint:enable=invalid-name

    # VAE Loss
    beta = tf.constant(config['beta'])
    vae_loss = -mean_recons + mean_KL * beta

    # ---------------------------------------------------------------------
    # ## Training
    # ---------------------------------------------------------------------
    # Learning rates
    vae_lr = tf.constant(3e-4)
    classifier_lr = tf.constant(3e-4)

    # Training Ops
    vae_vars = list(encoder.get_variables())
    vae_vars.extend(decoder.get_variables())
    train_vae = tf.train.AdamOptimizer(learning_rate=vae_lr).minimize(
        vae_loss, var_list=vae_vars)

    classifier_vars = classifier.get_variables()
    train_classifier = tf.train.AdamOptimizer(
        learning_rate=classifier_lr).minimize(
            classifier_loss, var_list=classifier_vars)

    # Savers
    vae_saver = tf.train.Saver(vae_vars, max_to_keep=100)
    classifier_saver = tf.train.Saver(classifier_vars, max_to_keep=1000)

    # Add all endpoints as object attributes
    for k, v in iteritems(locals()):
      self.__dict__[k] = v