Ejemplo n.º 1
0
def make_masks(string, s_dim, mask_type):
  strategy, factors = string.split("=")
  assert strategy in {"s", "c", "r", "cs", "l"}, "Only allow label, share, change, rank-types"

  # mask_type is only here to help sanity-check that I didn't accidentally
  # use an invalid (strategy , mask_type) pair
  if strategy == "r":
    assert mask_type == "rank", "mask_type must match data collection strategy"
    # Use factor indices as mask. Assumes single factor per comma
    return list(map(int, factors.split(",")))
  elif strategy in {"s", "c", "cs"}:
    assert mask_type == "match", "mask_type must match data collection strategy"
  elif strategy in {"l"}:
    assert mask_type == "label", "mask_type must match data collection strategy"

  if strategy == "cs":
    # Pre-process factors to add complement set
    idx = int(factors)
    l = list(range(s_dim))
    del l[idx]
    factors = "{},{}".format(idx, "".join(map(str, l)))

  factors = [list(map(int, l)) for l in map(list, factors.split(","))]
  masks = np.zeros((len(factors), s_dim), dtype=np.float32)
  for (i, f) in enumerate(factors):
    masks[i, f] = 1

  if strategy == "s":
    masks = 1 - masks
  elif strategy == "l":
    assert len(masks) == 1, "Only one mask allowed for label-strategy"

  ut.log("make_masks output:")
  ut.log(masks)
  return masks
Ejemplo n.º 2
0
    def __init__(self, x_shape, z_dim, batch_norm=True):
        super().__init__()
        ch = x_shape[-1]
        self.net = ts.Sequential(
            dense(128),
            ts.ReLU(),
            dense(4 * 4 * 64),
            ts.ReLU(),
            ts.Reshape((-1, 4, 4, 64)),
            deconv(64, 4, 2, "same"),
            ts.LeakyReLU(),
            deconv(32, 4, 2, "same"),
            ts.LeakyReLU(),
            deconv(32, 4, 2, "same"),
            ts.LeakyReLU(),
            deconv(ch, 4, 2, "same"),
            ts.Sigmoid(),
        )

        # Add batchnorm post-activation (attach to activation out_hook)
        if batch_norm:
            self.net.apply(add_bn, targets=(ts.ReLU, ts.LeakyReLU))

        ut.log("Building generator...")
        self.build((1, z_dim))
        self.apply(ut.reset_parameters)
Ejemplo n.º 3
0
  def __init__(self, x_shape, y_dim, width=1, share_dense=False,
               uncond_bias=False):
    super().__init__()
    self.y_dim = y_dim
    self.body = ts.Sequential(
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
        ts.Flatten(),
        )

    self.aux = ts.Sequential(
        dense(128 * width), ts.LeakyReLU(),
        )

    if share_dense:
      self.body.append(dense(128 * width), ts.LeakyReLU())
      self.aux.append(dense(128 * width), ts.LeakyReLU())

    self.head = ts.Sequential(
        dense(128 * width), ts.LeakyReLU(),
        dense(128 * width), ts.LeakyReLU(),
        dense(1, bias=uncond_bias)
        )

    for m in (self.body, self.aux, self.head):
      m.apply(ts.SpectralNorm.add, targets=ts.Affine)

    ut.log("Building label discriminator...")
    x_shape, y_shape = [1] + x_shape, (1, y_dim)
    self.build(x_shape, y_shape)
    self.apply(ut.reset_parameters)
Ejemplo n.º 4
0
def evaluate_enc_on_targets(enc, dset, s_dim, original_file, original_bindings, target_metrics):
  # Disentanglement Lib Metrics
  evals = {}
  eval_bindings_list = get_eval_bindings_list()
  metrics = ("factor", "mig", "beta", "dci", "modularity", "sap")

  for metric, eval_bindings in zip(metrics, eval_bindings_list):
    if metric in target_metrics:
      gin.parse_config_files_and_bindings([], eval_bindings, finalize_config=False)
      evaluation_fn = get_evaluation()
      tf.logging.info("Reset eval func to {}".format(evaluation_fn.__name__))
      result = evaluation_fn(dset, enc, np.random.RandomState(0))
      ut.log(result)

      if metric == "factor":
        evals[metric] = result["eval_accuracy"]
      elif metric == "mig":
        evals[metric] = result["discrete_mig"]
      elif metric == "beta":
        evals[metric] = result["eval_accuracy"]
      elif metric == "dci":
        evals[metric] = result["disentanglement"]
      elif metric == "modularity":
        evals[metric] = result["modularity_score"]
      elif metric == "sap":
        evals[metric] = result["SAP_score"]

  # Clean up: resetting gin configs to original bindings
  gin.parse_config_files_and_bindings([original_file],
                                      original_bindings,
                                      finalize_config=False)

  return evals
Ejemplo n.º 5
0
    def __init__(self,
                 x_shape,
                 y_dim,
                 width=1,
                 share_dense=False,
                 uncond_bias=False,
                 cond_bias=False,
                 mask_type="match"):
        super().__init__()
        self.y_dim = y_dim
        self.mask_type = mask_type
        self.body = ts.Sequential(
            conv(32 * width, 4, 2, "same"),
            ts.LeakyReLU(),
            conv(32 * width, 4, 2, "same"),
            ts.LeakyReLU(),
            conv(64 * width, 4, 2, "same"),
            ts.LeakyReLU(),
            conv(64 * width, 4, 2, "same"),
            ts.LeakyReLU(),
            ts.Flatten(),
        )

        if share_dense:
            self.body.append(dense(128 * width), ts.LeakyReLU())

        if mask_type == "match":
            self.neck = ts.Sequential(
                dense(128 * width),
                ts.LeakyReLU(),
                dense(128 * width),
                ts.LeakyReLU(),
            )

            self.head_uncond = dense(1, bias=uncond_bias)
            self.head_cond = dense(128 * width, bias=cond_bias)

            for m in (self.body, self.neck, self.head_uncond):
                m.apply(ts.SpectralNorm.add, targets=ts.Affine)
            add_wn(self.head_cond)
            x_shape, y_shape = [1] + x_shape, ((1, ), tf.int32)

        elif mask_type == "rank":
            self.body.append(dense(128 * width), ts.LeakyReLU(),
                             dense(128 * width), ts.LeakyReLU(),
                             dense(1 + y_dim, bias=uncond_bias))

            self.body.apply(ts.SpectralNorm.add, targets=ts.Affine)
            x_shape, y_shape = [1] + x_shape, (1, y_dim)

        ut.log("Building {} discriminator...".format(mask_type))
        self.build(x_shape, x_shape, y_shape)
        self.apply(ut.reset_parameters)
Ejemplo n.º 6
0
def get_dlib_data(task):
  ut.log("Loading {}".format(task))
  if task == "dsprites":
    # 5 factors
    return dsprites.DSprites(list(range(1, 6)))
  elif task == "shapes3d":
    # 6 factors
    return shapes3d.Shapes3D()
  elif task == "norb":
    # 4 factors + 1 nuisance (which we'll handle via n_dim=2)
    return norb.SmallNORB()
  elif task == "cars3d":
    # 3 factors
    return cars3d.Cars3D()
  elif task == "scream":
    # 5 factors + 2 nuisance (handled as n_dim=2)
    return dsprites.ScreamDSprites(list(range(1, 6)))
Ejemplo n.º 7
0
  def __init__(self, x_shape, z_dim, width=1, spectral_norm=True):
    super().__init__()
    self.net = ts.Sequential(
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
        conv(32 * width, 4, 2, "same"), ts.LeakyReLU(),
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
        conv(64 * width, 4, 2, "same"), ts.LeakyReLU(),
        ts.Flatten(),
        dense(128 * width), ts.LeakyReLU(),
        dense(2 * z_dim)
        )

    if spectral_norm:
      self.net.apply(ts.SpectralNorm.add, targets=ts.Affine)

    ut.log("Building encoder...")
    self.build([1] + x_shape)
    self.apply(ut.reset_parameters)
Ejemplo n.º 8
0
def main(_):
    if FLAGS.debug:
        FLAGS.gin_bindings += ["log.debug = True"]
    gin.parse_config_files_and_bindings([FLAGS.gin_file],
                                        FLAGS.gin_bindings,
                                        finalize_config=False)
    ut.log("\n" + "*" * 80 + "\nBegin program\n" + "*" * 80)
    ut.log("In main")
    train()
    ut.log("\n" + "*" * 80 + "\nEnd program\n" + "*" * 80)
Ejemplo n.º 9
0
def evaluate_enc(enc,
                 dset,
                 s_dim,
                 original_file,
                 original_bindings,
                 pida_sample_size=10000,
                 dlib_metrics=True):
    """Evaluates an encoder on multiple disentanglement metrics.

  Given an encoder and an oracle generator, this function computes encoder-
  based metrics on consistency, restrictiveness, ranking accuracy, and six
  additional disentanglement metrics used in disentanglement_lib. The
  disentanglement_lib metrics are set by modifying the global gin-config. We
  require the user provide the original gin-file and gin-bindings so that they
  can be re-established at the end of this call.

  Args:
    enc: an encoding function that takes in and returns numpy arrays.
    dset: a disentanglement_lib dataset.
    s_dim: number of non-nuisance dimensions of the latent space.
    original_file: path to original gin file
    original_bindings: list of original gin bindings
    pida_sample_size: number of samples for monte carlo estimate pida metrics.
    dlib_metrics: flag for using disentanglement_lib metrics

  Returns:
    A dictionary of scores.
  """
    # enc takes in and outputs numpy arrays
    # Consistency/Restrictiveness Metrics
    itypes = [
        "{}={}".format(t, i)
        for t, i in itertools.product(("s", "c", "r"), range(s_dim))
    ]

    evals = {}
    for it in itypes:
        scores = negative_pida(it,
                               s_dim,
                               enc,
                               dset,
                               sample_size=pida_sample_size,
                               random_state=np.random.RandomState(0))
        evals.update(scores)

        for k in scores:
            ut.log(k, ":", scores[k])

    evals["s_mean"] = np.mean([evals[k] for k in evals if "s=" == k[:2]])
    evals["c_mean"] = np.mean([evals[k] for k in evals if "c=" == k[:2]])

    if not dlib_metrics:
        return evals

    # Disentanglement Lib Metrics
    eval_bindings_list = get_eval_bindings_list()
    metrics = ("factor", "mig", "beta", "dci", "modularity", "sap")

    for metric, eval_bindings in zip(metrics, eval_bindings_list):
        gin.parse_config_files_and_bindings([],
                                            eval_bindings,
                                            finalize_config=False)
        evaluation_fn = get_evaluation()
        tf.logging.info("Reset eval func to {}".format(evaluation_fn.__name__))
        result = evaluation_fn(dset, enc, np.random.RandomState(0))
        ut.log(result)

        if metric == "factor":
            evals[metric] = result["eval_accuracy"]
        elif metric == "mig":
            evals[metric] = result["discrete_mig"]
        elif metric == "beta":
            evals[metric] = result["eval_accuracy"]
        elif metric == "dci":
            evals[metric] = result["disentanglement"]
        elif metric == "modularity":
            evals[metric] = result["modularity_score"]
        elif metric == "sap":
            evals[metric] = result["SAP_score"]

    # Clean up: resetting gin configs to original bindings
    gin.parse_config_files_and_bindings([original_file],
                                        original_bindings,
                                        finalize_config=False)

    return evals
Ejemplo n.º 10
0
def train(dset_name,
          s_dim,
          n_dim,
          factors,
          batch_size,
          dec_lr,
          enc_lr_mul,
          iterations,
          model_type="gen"):
    ut.log("In train")
    masks = datasets.make_masks(factors, s_dim)
    z_dim = s_dim + n_dim
    enc_lr = enc_lr_mul * dec_lr

    # Load data
    dset = datasets.get_dlib_data(dset_name)
    if dset is None:
        x_shape = [64, 64, 1]
    else:
        x_shape = dset.observation_shape
        targets_real = tf.ones((batch_size, 1))
        targets_fake = tf.zeros((batch_size, 1))
        targets = tf.concat((targets_real, targets_fake), axis=0)

    # Networks
    if model_type == "gen":
        assert factors.split("=")[0] in {"c", "s", "cs", "r"}
        y_dim = len(masks)
        dis = networks.Discriminator(x_shape, y_dim)
        gen = networks.Generator(x_shape, z_dim)
        enc = networks.Encoder(x_shape,
                               s_dim)  # Encoder ignores nuisance param
        ut.log(dis.read(dis.WITH_VARS))
        ut.log(gen.read(gen.WITH_VARS))
        ut.log(enc.read(enc.WITH_VARS))
    elif model_type == "enc":
        assert factors.split("=")[0] in {"r"}
        enc = networks.Encoder(x_shape,
                               s_dim)  # Encoder ignores nuisance param
        ut.log(enc.read(enc.WITH_VARS))
    elif model_type == "van":
        assert factors.split("=")[0] in {"l"}
        dis = networks.LabelDiscriminator(x_shape, s_dim)  # Uses s_dim
        gen = networks.Generator(x_shape, z_dim)
        enc = networks.Encoder(x_shape,
                               s_dim)  # Encoder ignores nuisance param
        ut.log(dis.read(dis.WITH_VARS))
        ut.log(gen.read(gen.WITH_VARS))
        ut.log(enc.read(enc.WITH_VARS))

    # Create optimizers
    if model_type in {"gen", "van"}:
        gen_opt = tfk.optimizers.Adam(learning_rate=dec_lr,
                                      beta_1=0.5,
                                      beta_2=0.999,
                                      epsilon=1e-8)
        dis_opt = tfk.optimizers.Adam(learning_rate=enc_lr,
                                      beta_1=0.5,
                                      beta_2=0.999,
                                      epsilon=1e-8)
        enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr,
                                      beta_1=0.5,
                                      beta_2=0.999,
                                      epsilon=1e-8)
    elif model_type == "enc":
        enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr,
                                      beta_1=0.5,
                                      beta_2=0.999,
                                      epsilon=1e-8)

    @tf.function
    def train_gen_step(x1_real, x2_real, y_real):
        gen.train()
        dis.train()
        enc.train()
        # Alternate discriminator step and generator step
        with tf.GradientTape(persistent=True) as tape:
            # Generate
            z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks)
            x1_fake = tf.stop_gradient(gen(z1))
            x2_fake = tf.stop_gradient(gen(z2))

            # Discriminate
            x1 = tf.concat((x1_real, x1_fake), 0)
            x2 = tf.concat((x2_real, x2_fake), 0)
            y = tf.concat((y_real, y_fake), 0)
            logits = dis(x1, x2, y)

            # Encode
            p_z = enc(x1_fake)

            dis_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                        labels=targets))
            # Encoder ignores nuisance parameters (if they exist)
            enc_loss = -tf.reduce_mean(p_z.log_prob(z1[:, :s_dim]))

        dis_grads = tape.gradient(dis_loss, dis.trainable_variables)
        enc_grads = tape.gradient(enc_loss, enc.trainable_variables)

        dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables))
        enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))

        with tf.GradientTape(persistent=False) as tape:
            # Generate
            z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks)
            x1_fake = gen(z1)
            x2_fake = gen(z2)

            # Discriminate
            logits_fake = dis(x1_fake, x2_fake, y_fake)

            gen_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake,
                                                        labels=targets_real))

        gen_grads = tape.gradient(gen_loss, gen.trainable_variables)
        gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))

        return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss)

    @tf.function
    def train_van_step(x_real, y_real):
        gen.train()
        dis.train()
        enc.train()

        if n_dim > 0:
            padding = tf.zeros((y_real.shape[0], n_dim))
            y_real_pad = tf.concat((y_real, padding), axis=-1)
        else:
            y_real_pad = y_real

        # Alternate discriminator step and generator step
        with tf.GradientTape(persistent=False) as tape:
            # Generate
            z_fake = datasets.paired_randn(batch_size, z_dim, masks)
            z_fake = z_fake + y_real_pad
            x_fake = gen(z_fake)

            # Discriminate
            logits_fake = dis(x_fake, y_real)

            gen_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake,
                                                        labels=targets_real))

        gen_grads = tape.gradient(gen_loss, gen.trainable_variables)
        gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))

        with tf.GradientTape(persistent=True) as tape:
            # Generate
            z_fake = datasets.paired_randn(batch_size, z_dim, masks)
            z_fake = z_fake + y_real_pad
            x_fake = tf.stop_gradient(gen(z_fake))

            # Discriminate
            x = tf.concat((x_real, x_fake), 0)
            y = tf.concat((y_real, y_real), 0)
            logits = dis(x, y)

            # Encode
            p_z = enc(x_fake)

            dis_loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                        labels=targets))
            # Encoder ignores nuisance parameters (if they exist)
            enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim]))

        dis_grads = tape.gradient(dis_loss, dis.trainable_variables)
        enc_grads = tape.gradient(enc_loss, enc.trainable_variables)

        dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables))
        enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))

        return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss)

    @tf.function
    def train_enc_step(x1_real, x2_real, y_real):
        with tf.GradientTape() as tape:
            z1 = enc(x1_real).mean()
            z2 = enc(x2_real).mean()
            logits = tf.gather(z1 - z2, masks, axis=-1)
            loss = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                        labels=y_real))

        enc_grads = tape.gradient(loss, enc.trainable_variables)
        enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))
        return dict(gen_loss=0, dis_loss=0, enc_loss=loss)

    @tf.function
    def gen_eval(z):
        gen.eval()
        return gen(z)

    @tf.function
    def enc_eval(x):
        enc.eval()
        return enc(x).mean()

    enc_np = lambda x: enc_eval(x).numpy()

    # Initial preparation
    if FLAGS.debug:
        iter_log = 100
        iter_save = 2000
        train_range = range(iterations)
        basedir = FLAGS.basedir
        vizdir = FLAGS.basedir
        ckptdir = FLAGS.basedir
        new_run = True
    else:
        iter_log = 5000
        iter_save = 50000
        iter_metric = iter_save * 5  # Make sure this is a factor of 500k
        basedir = os.path.join(FLAGS.basedir, "exp")
        ckptdir = os.path.join(basedir, "ckptdir")
        vizdir = os.path.join(basedir, "vizdir")
        gfile.MakeDirs(basedir)
        gfile.MakeDirs(ckptdir)
        gfile.MakeDirs(vizdir)  # train_range will be specified below

    ckpt_prefix = os.path.join(ckptdir, "model")
    if model_type in {"gen", "van"}:
        ckpt_root = tf.train.Checkpoint(dis=dis,
                                        dis_opt=dis_opt,
                                        gen=gen,
                                        gen_opt=gen_opt,
                                        enc=enc,
                                        enc_opt=enc_opt)
    elif model_type == "enc":
        ckpt_root = tf.train.Checkpoint(enc=enc, enc_opt=enc_opt)

    # Check if we're resuming training if not in debugging mode
    if not FLAGS.debug:
        latest_ckpt = tf.train.latest_checkpoint(ckptdir)
        if latest_ckpt is None:
            new_run = True
            ut.log("Starting a completely new model")
            train_range = range(iterations)

        else:
            new_run = False
            ut.log("Restarting from {}".format(latest_ckpt))
            ckpt_root.restore(latest_ckpt)
            resuming_iteration = iter_save * (int(ckpt_root.save_counter) - 1)
            train_range = range(resuming_iteration, iterations)

    # Training
    if dset is None:
        ut.log("Dataset {} is not available".format(dset_name))
        ut.log("Ending program having checked that the networks can be built.")
        return

    batches = datasets.paired_data_generator(
        dset, masks).repeat().batch(batch_size).prefetch(1000)
    batches = iter(batches)
    start_time = time.time()
    train_time = 0

    if FLAGS.debug:
        train_range = tqdm(train_range)

    for global_step in train_range:
        stopwatch = time.time()
        if model_type == "gen":
            x1, x2, y = next(batches)
            vals = train_gen_step(x1, x2, y)
        elif model_type == "enc":
            x1, x2, y = next(batches)
            vals = train_enc_step(x1, x2, y)
        elif model_type == "van":
            x, y = next(batches)
            vals = train_van_step(x, y)
        train_time += time.time() - stopwatch

        # Generic bookkeeping
        if (global_step + 1) % iter_log == 0 or global_step == 0:
            elapsed_time = time.time() - start_time
            string = ", ".join((
                "Iter: {:07d}, Elapsed: {:.3e}, (Elapsed) Iter/s: {:.3e}, (Train Step) Iter/s: {:.3e}"
                .format(global_step, elapsed_time, global_step / elapsed_time,
                        global_step / train_time),
                "Gen: {gen_loss:.4f}, Dis: {dis_loss:.4f}, Enc: {enc_loss:.4f}"
                .format(**vals))) + "."
            ut.log(string)

        # Log visualizations and evaluations
        if (global_step + 1) % iter_save == 0 or global_step == 0:
            if model_type == "gen":
                viz.ablation_visualization(x1, x2, gen_eval, z_dim, vizdir,
                                           global_step + 1)
            elif model_type == "van":
                viz.ablation_visualization(x, x, gen_eval, z_dim, vizdir,
                                           global_step + 1)

            if FLAGS.debug:
                evaluate.evaluate_enc(enc_np,
                                      dset,
                                      s_dim,
                                      FLAGS.gin_file,
                                      FLAGS.gin_bindings,
                                      pida_sample_size=1000,
                                      dlib_metrics=FLAGS.debug_dlib_metrics)

            else:
                dlib_metrics = (global_step + 1) % iter_metric == 0
                evaluate.evaluate_enc(enc_np,
                                      dset,
                                      s_dim,
                                      FLAGS.gin_file,
                                      FLAGS.gin_bindings,
                                      pida_sample_size=10000,
                                      dlib_metrics=dlib_metrics)

        # Save model
        if (global_step + 1) % iter_save == 0 or (global_step == 0
                                                  and new_run):
            # Save model only after ensuring all measurements are taken.
            # This ensures that restarts always computes the evals
            ut.log("Saved to", ckpt_root.save(ckpt_prefix))
Ejemplo n.º 11
0
def train(dset_name, s_dim, n_dim, factors, s_I_dim, batch_size, clas_lr,
          iterations):
    ut.log("In train classifier")
    # Load data
    dset = datasets.get_dlib_data(dset_name)
    if dset is None:
        x_shape = [64, 64, 1]
    else:
        x_shape = dset.observation_shape
        targets_real = tf.ones((batch_size, 1))
        targets_fake = tf.zeros((batch_size, 1))
        targets = tf.concat((targets_real, targets_fake), axis=0)

    # Networks
    clas = networks.Classifier(x_shape, s_I_dim)
    ut.log(clas.read(clas.WITH_VARS))

    # Create optimizers
    clas_opt = tfk.optimizers.Adam(learning_rate=clas_lr,
                                   beta_1=0.5,
                                   beta_2=0.999,
                                   epsilon=1e-8)

    @tf.function
    def train_clas_step(x_real, y_real):
        clas.train()

        with tf.GradientTape(persistent=True) as tape:
            # Generate
            p_s = clas(x_real)
            clas_loss = -tf.reduce_mean(p_s.log_prob(y_real[:, :s_I_dim]))

        clas_grads = tape.gradient(clas_loss, clas.trainable_variables)
        clas_opt.apply_gradients(zip(clas_grads, clas.trainable_variables))

        return dict(clas_loss=clas_loss)

    # Initial preparation
    if FLAGS.debug:
        iter_log = 100
        iter_save = 2000
        train_range = range(iterations)
        basedir = FLAGS.basedir
        vizdir = FLAGS.basedir
        ckptdir = FLAGS.basedir
        new_run = True
    else:
        iter_log = 5000
        iter_save = 5000
        iter_metric = iter_save * 5  # Make sure this is a factor of 500k
        basedir = os.path.join(FLAGS.basedir, "clas")
        ckptdir = os.path.join(basedir, "ckptdir")
        vizdir = os.path.join(basedir, "vizdir")
        gfile.MakeDirs(basedir)
        gfile.MakeDirs(ckptdir)
        gfile.MakeDirs(vizdir)  # train_range will be specified below

    ckpt_prefix = os.path.join(ckptdir, "model")
    ckpt_root = tf.train.Checkpoint(clas=clas, clas_opt=clas_opt)

    # Check if we're resuming training if not in debugging mode
    if not FLAGS.debug:
        latest_ckpt = tf.train.latest_checkpoint(ckptdir)
        if latest_ckpt is None:
            new_run = True
            ut.log("Starting a completely new model")
            train_range = range(iterations)

        else:
            new_run = False
            ut.log("Restarting from {}".format(latest_ckpt))
            ckpt_root.restore(latest_ckpt)
            resuming_iteration = iter_save * (int(ckpt_root.save_counter) - 1)
            train_range = range(resuming_iteration, iterations)

    # Training
    if dset is None:
        ut.log("Dataset {} is not available".format(dset_name))
        ut.log("Ending program having checked that the networks can be built.")
        return

    batches = datasets.unmasked_label_data_generator(
        dset, s_dim).repeat().batch(batch_size).prefetch(1000)
    batches = iter(batches)
    start_time = time.time()
    train_time = 0

    if FLAGS.debug:
        train_range = tqdm(train_range)

    for global_step in train_range:
        stopwatch = time.time()
        x, y = next(batches)
        vals = train_clas_step(x, y)
        train_time += time.time() - stopwatch

        # Generic bookkeeping
        if (global_step + 1) % iter_log == 0 or global_step == 0:
            elapsed_time = time.time() - start_time
            string = ", ".join((
                "Iter: {:07d}, Elapsed: {:.3e}, (Elapsed) Iter/s: {:.3e}, (Train Step) Iter/s: {:.3e}"
                .format(global_step, elapsed_time, global_step / elapsed_time,
                        global_step / train_time),
                "Clas: {clas_loss:.4f}".format(**vals))) + "."
            ut.log(string)

        # Log visualizations and evaluations
        if (global_step + 1) % iter_save == 0 or global_step == 0:
            ut.log("Beginning evaluation.")
            sample_size = 10000
            random_state = np.random.RandomState(1)

            factors = dset.sample_factors(sample_size, random_state)
            obs = dset.sample_observations_from_factors(factors, random_state)

            eval_y = tf.convert_to_tensor(factors, dtype=tf.float32)
            eval_x = tf.convert_to_tensor(obs, dtype=tf.float32)

            p_s_eval = clas(eval_x)
            eval_loss = -tf.reduce_mean(p_s_eval.log_prob(eval_y[:, :s_I_dim]))
            ut.log("Eval loss: {}".format(eval_loss))

        # Save model
        if (global_step + 1) % iter_save == 0 or (global_step == 0
                                                  and new_run):
            # Save model only after ensuring all measurements are taken.
            # This ensures that restarts always computes the evals
            ut.log("Saved to", ckpt_root.save(ckpt_prefix))
Ejemplo n.º 12
0
def train(dset_name, s_dim, n_dim, factors, z_transform,
          batch_size, dec_lr, enc_lr_mul, iterations,
          model_type="gen"):

  ut.log("In train")
  masks = datasets.make_masks(factors, s_dim)
  z_dim = s_dim + n_dim
  enc_lr = enc_lr_mul * dec_lr
  z_trans = datasets.get_z_transform(z_transform)
  # Load data
  dset = datasets.get_dlib_data(dset_name)
  # if FLAGS.evaluate:
  #   dset = None
  # else:
  #   dset = datasets.get_dlib_data(dset_name)
  if dset is None:
    x_shape = [64, 64, 1]
  else:
    x_shape = dset.observation_shape
    targets_real = tf.ones((batch_size, 1))
    targets_fake = tf.zeros((batch_size, 1))
    targets = tf.concat((targets_real, targets_fake), axis=0)

  # Networks
  if model_type == "gen":
    assert factors.split("=")[0] in {"c", "s", "cs", "r"}
    y_dim = len(masks)
    dis = networks.Discriminator(x_shape, y_dim)
    gen = networks.Generator(x_shape, z_dim)
    enc = networks.Encoder(x_shape, s_dim)  # Encoder ignores nuisance param
    ut.log(dis.read(dis.WITH_VARS))
    ut.log(gen.read(gen.WITH_VARS))
    ut.log(enc.read(enc.WITH_VARS))
  elif model_type == "van":
    assert factors.split("=")[0] in {"l"}
    dis = networks.LabelDiscriminator(x_shape, s_dim)  # Uses s_dim
    gen = networks.Generator(x_shape, z_dim)
    enc = networks.CovEncoder(x_shape, s_dim)  # Encoder ignores nuisance param
    trans_enc = networks.CovEncoder(x_shape, s_dim)

    clas_path = os.path.join(FLAGS.basedir, "clas")
    clas = networks.Classifier(x_shape, s_dim)
    ckpt_root = tf.train.Checkpoint(clas=clas)
    latest_ckpt = tf.train.latest_checkpoint(clas_path)
    ckpt_root.restore(latest_ckpt)

    ut.log(dis.read(dis.WITH_VARS))
    ut.log(gen.read(gen.WITH_VARS))
    ut.log(enc.read(enc.WITH_VARS))
    ut.log(clas.read(clas.WITH_VARS))

  # Create optimizers
  if model_type in {"gen", "van"}:
    gen_opt = tfk.optimizers.Adam(learning_rate=dec_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
    dis_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
    enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)
    trans_enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8)

  @tf.function
  def train_gen_step(x1_real, x2_real, y_real):
    gen.train()
    dis.train()
    enc.train()
    # Alternate discriminator step and generator step
    with tf.GradientTape(persistent=True) as tape:
      # Generate
      z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks)
      x1_fake = tf.stop_gradient(gen(z1))
      x2_fake = tf.stop_gradient(gen(z2))

      # Discriminate
      x1 = tf.concat((x1_real, x1_fake), 0)
      x2 = tf.concat((x2_real, x2_fake), 0)
      y = tf.concat((y_real, y_fake), 0)
      logits = dis(x1, x2, y)

      # Encode
      p_z = enc(x1_fake)

      dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
          logits=logits, labels=targets))
      # Encoder ignores nuisance parameters (if they exist)
      enc_loss = -tf.reduce_mean(p_z.log_prob(z1[:, :s_dim]))

    dis_grads = tape.gradient(dis_loss, dis.trainable_variables)
    enc_grads = tape.gradient(enc_loss, enc.trainable_variables)

    dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables))
    enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))

    with tf.GradientTape(persistent=False) as tape:
      # Generate
      z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks)
      x1_fake = gen(z1)
      x2_fake = gen(z2)

      # Discriminate
      logits_fake = dis(x1_fake, x2_fake, y_fake)

      gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
          logits=logits_fake, labels=targets_real))

    gen_grads = tape.gradient(gen_loss, gen.trainable_variables)
    gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))

    return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss)

  @tf.function
  def train_van_step(x_real, y_real, entangle=False):
    gen.train()
    dis.train()
    enc.train()
    trans_enc.train()

    if n_dim > 0:
      padding = tf.zeros((y_real.shape[0], n_dim))
      y_real_pad = tf.concat((y_real, padding), axis=-1)
    else:
      y_real_pad = y_real

    if entangle:
        # Alternate discriminator step and generator step
        with tf.GradientTape(persistent=False) as tape:
          # Generate
          dummy_mask = tf.zeros_like(masks)
          z_fake = datasets.paired_randn(batch_size, z_dim, dummy_mask)
          x_fake = gen(z_fake)

          # Discriminate
          logits_fake = dis(x_fake, y_real)

          gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
              logits=logits_fake, labels=targets_real))

        gen_grads = tape.gradient(gen_loss, gen.trainable_variables)
        gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))

        with tf.GradientTape(persistent=True) as tape:
          # Generate
          dummy_mask = tf.zeros_like(masks)
          z_fake = datasets.paired_randn(batch_size, z_dim, dummy_mask)
          x_fake = tf.stop_gradient(gen(z_fake))
          trans_z_fake = z_trans(z_fake)

          # Discriminate
          x = tf.concat((x_real, x_fake), 0)
          y = tf.concat((y_real, y_real), 0)
          logits = dis(x, y)

          # Encode
          p_z = enc(x_fake)
          p_z_trans = trans_enc(x_fake)

          dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
              logits=logits, labels=targets))
          # Encoder ignores nuisance parameters (if they exist)
          enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim]))
          trans_enc_loss = -tf.reduce_mean(p_z_trans.log_prob(
            trans_z_fake[:, :s_dim]))

        dis_grads = tape.gradient(dis_loss, dis.trainable_variables)
        enc_grads = tape.gradient(enc_loss, enc.trainable_variables)
        trans_enc_grads = tape.gradient(trans_enc_loss,
          trans_enc.trainable_variables)

        dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables))
        enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))
        trans_enc_opt.apply_gradients(zip(trans_enc_grads, trans_enc.trainable_variables))

    else:
        if n_dim > 0:
          padding = tf.zeros((y_real.shape[0], n_dim))
          y_real_pad = tf.concat((y_real, padding), axis=-1)
        else:
          y_real_pad = y_real

        # Alternate discriminator step and generator step
        with tf.GradientTape(persistent=False) as tape:
          # Generate
          z_fake = datasets.paired_randn(batch_size, z_dim, masks)
          z_fake = z_fake + y_real_pad
          x_fake = gen(z_fake)

          # Discriminate
          logits_fake = dis(x_fake, y_real)

          gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
              logits=logits_fake, labels=targets_real))

        gen_grads = tape.gradient(gen_loss, gen.trainable_variables)
        gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables))

        with tf.GradientTape(persistent=True) as tape:
          # Generate
          z_fake = datasets.paired_randn(batch_size, z_dim, masks)
          z_fake = z_fake + y_real_pad
          x_fake = tf.stop_gradient(gen(z_fake))
          trans_z_fake = z_trans(z_fake)

          # Discriminate
          x = tf.concat((x_real, x_fake), 0)
          y = tf.concat((y_real, y_real), 0)
          logits = dis(x, y)

          # Encode
          p_z = enc(x_fake)
          p_z_trans = trans_enc(x_fake)

          dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
              logits=logits, labels=targets))
          # Encoder ignores nuisance parameters (if they exist)
          enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim]))
          trans_enc_loss = -tf.reduce_mean(p_z_trans.log_prob(
            trans_z_fake[:, :s_dim]))

        dis_grads = tape.gradient(dis_loss, dis.trainable_variables)
        enc_grads = tape.gradient(enc_loss, enc.trainable_variables)
        trans_enc_grads = tape.gradient(trans_enc_loss,
          trans_enc.trainable_variables)

        dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables))
        enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables))
        trans_enc_opt.apply_gradients(zip(trans_enc_grads, trans_enc.trainable_variables))

    return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss, trans_enc_loss=trans_enc_loss)

  @tf.function
  def gen_eval(z):
    gen.eval()
    return gen(z)

  @tf.function
  def enc_eval(x):
    enc.eval()
    return enc(x).mean()
  enc_np = lambda x: enc_eval(x).numpy()

  @tf.function
  def trans_enc_eval(x):
    trans_enc.eval()
    return trans_enc(x).mean()
  trans_enc_np = lambda x: trans_enc_eval(x).numpy()

  # Initial preparation
  if FLAGS.debug:
    iter_log = 100
    iter_save = 2000
    train_range = range(iterations)
    basedir = FLAGS.basedir
    vizdir = FLAGS.basedir
    ckptdir = FLAGS.basedir
    new_run = True
  else:
    iter_log = 5000
    iter_save = 5000
    iter_metric = iter_save * 5  # Make sure this is a factor of 500k
    basedir = os.path.join(FLAGS.basedir, "exp")
    ckptdir = os.path.join(basedir, "ckptdir")
    vizdir = os.path.join(basedir, "vizdir")
    gfile.MakeDirs(basedir)
    gfile.MakeDirs(ckptdir)
    gfile.MakeDirs(vizdir)  # train_range will be specified below

  ckpt_prefix = os.path.join(ckptdir, "model")
  if model_type in {"gen", "van"}:
    ckpt_root = tf.train.Checkpoint(dis=dis, dis_opt=dis_opt,
                                    gen=gen, gen_opt=gen_opt,
                                    enc=enc, enc_opt=enc_opt,
                                    trans_enc=trans_enc, trans_enc_opt=trans_enc_opt)

  # Check if we're resuming training if not in debugging mode
  if not FLAGS.debug:
    latest_ckpt = tf.train.latest_checkpoint(ckptdir)
    if latest_ckpt is None:
      new_run = True
      ut.log("Starting a completely new model")
      train_range = range(iterations)

    else:
      new_run = False
      ut.log("Restarting from {}".format(latest_ckpt))
      ckpt_root.restore(latest_ckpt)
      resuming_iteration = iter_save * (int(ckpt_root.save_counter) - 1)
      train_range = range(resuming_iteration, iterations)

  samples = FLAGS.val_samples
  if FLAGS.evaluate:
    masks = np.zeros([samples, z_dim])
    masks[:, 0] = 1
    masks = tf.convert_to_tensor(masks, dtype=tf.float32)

    transformed_prior = datasets.transformed_prior(z_trans)
    mi, mi_trans, mi_joint, mi_joint_trans = [], [], [], []

    for i in range(FLAGS.mi_averages):
      mi.append(new_metrics.mi_difference(z_dim, gen, clas, masks, samples))
      mi_trans.append(new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = transformed_prior))

      mi_joint.append(new_metrics.mi_difference(z_dim, gen, clas, masks, samples, draw_from_joint=True))
      mi_joint_trans.append(new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = transformed_prior, draw_from_joint=True))

    mi = np.mean(np.stack(mi), axis=0)
    mi_trans = np.mean(np.stack(mi_trans), axis=0)

    mi_joint = np.mean(mi_joint)
    mi_joint_trans = np.mean(mi_joint_trans)

    ut.log("MI - Normal: {}, {} Trans: {}, {}".format(mi[0], mi[1], mi_trans[0], mi_trans[1]))
    # mi = new_metrics.mi_difference(z_dim, gen, clas, masks, samples)
    # unmixed_prior = datasets.unmixed_prior(FLAGS.shift, FLAGS.scale)
    # mi_unmixed = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = unmixed_prior)
    # mi_mixed = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = datasets.mixed_prior)
    # ut.log("MI - Normal: {}, {} Unmixed: {}, {} Mixed: {}, {}".format(mi[0], mi[1], mi_unmixed[0], mi_unmixed[1], mi_mixed[0], mi_mixed[1]))

    # mi_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, draw_from_joint=True)
    # mi_unmixed_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = unmixed_prior, draw_from_joint=True)
    # mi_mixed_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = datasets.mixed_prior, draw_from_joint=True)
    ut.log("MI Joint - Normal: {} Trans: {}".format(mi_joint, mi_joint_trans))

    ut.log("Encoder Metrics")
    evaluate.evaluate_enc(enc_np, dset, s_dim,
                          FLAGS.gin_file,
                          FLAGS.gin_bindings,
                          pida_sample_size=1000,
                          dlib_metrics=FLAGS.debug_dlib_metrics)
    ut.log("Transformed Encoder Metrics")
    evaluate.evaluate_enc(trans_enc_np, dset, s_dim,
                          FLAGS.gin_file,
                          FLAGS.gin_bindings,
                          pida_sample_size=1000,
                          dlib_metrics=FLAGS.debug_dlib_metrics)

    ut.log("Completed Evaluation")
    return

  # Training
  if dset is None:
    ut.log("Dataset {} is not available".format(dset_name))
    ut.log("Ending program having checked that the networks can be built.")
    return

  batches = datasets.paired_data_generator(dset, masks).repeat().batch(batch_size).prefetch(1000)
  batches = iter(batches)
  start_time = time.time()
  train_time = 0

  if FLAGS.debug:
    train_range = tqdm(train_range)
  if FLAGS.visualize:
    train_range = range(iterations+1)

  for global_step in train_range:
    stopwatch = time.time()
    if model_type == "gen":
      x1, x2, y = next(batches)
      vals = train_gen_step(x1, x2, y)
    elif model_type == "van":
      x, y = next(batches)
      vals = train_van_step(x, y, FLAGS.entangle)
    train_time += time.time() - stopwatch

    # Generic bookkeeping
    if (global_step + 1) % iter_log == 0 or global_step == 0:
      elapsed_time = time.time() - start_time
      string = ", ".join((
          "Iter: {:07d}, Elapsed: {:.3e}, (Elapsed) Iter/s: {:.3e}, (Train Step) Iter/s: {:.3e}".format(
              global_step, elapsed_time, global_step / elapsed_time, global_step / train_time),
          "Gen: {gen_loss:.4f}, Dis: {dis_loss:.4f}, Enc: {enc_loss:.4f}, Trans_Enc: {trans_enc_loss:.4f}".format(
          **vals)
      )) + "."
      ut.log(string)

    # Log visualizations and evaluations
    if (global_step + 1) % iter_save == 0 or global_step == 0:
      if model_type == "gen":
        viz.ablation_visualization(x1, x2, gen_eval, z_dim, vizdir, global_step + 1)
      elif model_type == "van":
        viz.ablation_visualization(x, x, gen_eval, z_dim, vizdir, global_step + 1)

      # num_s_I = 100
      # k = 150
      # y_real = tf.convert_to_tensor(dset.sample_factors(num_s_I, np.random.RandomState(1)), dtype=tf.float32)
      # masks = np.zeros([samples, z_dim])
      # masks[:, 0] = 1
      # masks = tf.convert_to_tensor(masks, dtype=tf.float32)
      # y_real = y_real * masks
      # mi = metrics.mi_estimate(y_real, gen, enc, masks, k, num_s_I, z_dim, s_dim)
      # mi_trans = metrics.mi_estimate(y_real, gen, trans_enc, masks, k, num_s_I, z_dim, s_dim, z_trans)
      # ut.log("Encoder MI: {} Transformed Encoder MI: {}".format(mi, mi_trans))
      # mi = new_metrics.mi_difference(z_dim, gen, clas, masks, samples)
      # mi_joint =  new_metrics.mi_difference(z_dim, gen, clas, masks, samples, draw_from_joint=True)
      # ut.log("MI:{} MI_Joint:{}".format(mi, mi_joint))
      # mi = new_metrics.mi_difference(z_dim, gen, clas, masks, samples)
      # unmixed_prior = datasets.unmixed_prior(FLAGS.shift, FLAGS.scale)
      # mi_unmixed = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = unmixed_prior)
      # mi_mixed = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = datasets.mixed_prior)
      # ut.log("MI - Normal: {}, {} Unmixed: {}, {} Mixed: {}, {}".format(mi[0], mi[1], mi_unmixed[0], mi_unmixed[1], mi_mixed[0], mi_mixed[1]))
      #
      # mi_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, draw_from_joint=True)
      # mi_unmixed_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = unmixed_prior, draw_from_joint=True)
      # mi_mixed_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = datasets.mixed_prior, draw_from_joint=True)
      # ut.log("MI Joint - Normal: {} Unmixed: {} Mixed: {}".format(mi_joint, mi_unmixed_joint, mi_mixed_joint))


      if FLAGS.debug:
        ut.log("Encoder Metrics")
        evaluate.evaluate_enc(enc_np, dset, s_dim,
                              FLAGS.gin_file,
                              FLAGS.gin_bindings,
                              pida_sample_size=1000,
                              dlib_metrics=FLAGS.debug_dlib_metrics)
        ut.log("Transformed Encoder Metrics")
        evaluate.evaluate_enc(trans_enc_np, dset, s_dim,
                              FLAGS.gin_file,
                              FLAGS.gin_bindings,
                              pida_sample_size=1000,
                              dlib_metrics=FLAGS.debug_dlib_metrics)


      else:
        dlib_metrics = (global_step + 1) % iter_metric == 0
        ut.log("Encoder Metrics")
        evaluate.evaluate_enc(enc_np, dset, s_dim,
                              FLAGS.gin_file,
                              FLAGS.gin_bindings,
                              pida_sample_size=10000,
                              dlib_metrics=dlib_metrics)
        ut.log("Transformed Encoder Metrics")
        evaluate.evaluate_enc(trans_enc_np, dset, s_dim,
                              FLAGS.gin_file,
                              FLAGS.gin_bindings,
                              pida_sample_size=10000,
                              dlib_metrics=dlib_metrics)



    # Save model
    if (global_step + 1) % iter_save == 0 or (global_step == 0 and new_run):
      # Save model only after ensuring all measurements are taken.
      # This ensures that restarts always computes the evals
      ut.log("Saved to", ckpt_root.save(ckpt_prefix))