def make_masks(string, s_dim, mask_type): strategy, factors = string.split("=") assert strategy in {"s", "c", "r", "cs", "l"}, "Only allow label, share, change, rank-types" # mask_type is only here to help sanity-check that I didn't accidentally # use an invalid (strategy , mask_type) pair if strategy == "r": assert mask_type == "rank", "mask_type must match data collection strategy" # Use factor indices as mask. Assumes single factor per comma return list(map(int, factors.split(","))) elif strategy in {"s", "c", "cs"}: assert mask_type == "match", "mask_type must match data collection strategy" elif strategy in {"l"}: assert mask_type == "label", "mask_type must match data collection strategy" if strategy == "cs": # Pre-process factors to add complement set idx = int(factors) l = list(range(s_dim)) del l[idx] factors = "{},{}".format(idx, "".join(map(str, l))) factors = [list(map(int, l)) for l in map(list, factors.split(","))] masks = np.zeros((len(factors), s_dim), dtype=np.float32) for (i, f) in enumerate(factors): masks[i, f] = 1 if strategy == "s": masks = 1 - masks elif strategy == "l": assert len(masks) == 1, "Only one mask allowed for label-strategy" ut.log("make_masks output:") ut.log(masks) return masks
def __init__(self, x_shape, z_dim, batch_norm=True): super().__init__() ch = x_shape[-1] self.net = ts.Sequential( dense(128), ts.ReLU(), dense(4 * 4 * 64), ts.ReLU(), ts.Reshape((-1, 4, 4, 64)), deconv(64, 4, 2, "same"), ts.LeakyReLU(), deconv(32, 4, 2, "same"), ts.LeakyReLU(), deconv(32, 4, 2, "same"), ts.LeakyReLU(), deconv(ch, 4, 2, "same"), ts.Sigmoid(), ) # Add batchnorm post-activation (attach to activation out_hook) if batch_norm: self.net.apply(add_bn, targets=(ts.ReLU, ts.LeakyReLU)) ut.log("Building generator...") self.build((1, z_dim)) self.apply(ut.reset_parameters)
def __init__(self, x_shape, y_dim, width=1, share_dense=False, uncond_bias=False): super().__init__() self.y_dim = y_dim self.body = ts.Sequential( conv(32 * width, 4, 2, "same"), ts.LeakyReLU(), conv(32 * width, 4, 2, "same"), ts.LeakyReLU(), conv(64 * width, 4, 2, "same"), ts.LeakyReLU(), conv(64 * width, 4, 2, "same"), ts.LeakyReLU(), ts.Flatten(), ) self.aux = ts.Sequential( dense(128 * width), ts.LeakyReLU(), ) if share_dense: self.body.append(dense(128 * width), ts.LeakyReLU()) self.aux.append(dense(128 * width), ts.LeakyReLU()) self.head = ts.Sequential( dense(128 * width), ts.LeakyReLU(), dense(128 * width), ts.LeakyReLU(), dense(1, bias=uncond_bias) ) for m in (self.body, self.aux, self.head): m.apply(ts.SpectralNorm.add, targets=ts.Affine) ut.log("Building label discriminator...") x_shape, y_shape = [1] + x_shape, (1, y_dim) self.build(x_shape, y_shape) self.apply(ut.reset_parameters)
def evaluate_enc_on_targets(enc, dset, s_dim, original_file, original_bindings, target_metrics): # Disentanglement Lib Metrics evals = {} eval_bindings_list = get_eval_bindings_list() metrics = ("factor", "mig", "beta", "dci", "modularity", "sap") for metric, eval_bindings in zip(metrics, eval_bindings_list): if metric in target_metrics: gin.parse_config_files_and_bindings([], eval_bindings, finalize_config=False) evaluation_fn = get_evaluation() tf.logging.info("Reset eval func to {}".format(evaluation_fn.__name__)) result = evaluation_fn(dset, enc, np.random.RandomState(0)) ut.log(result) if metric == "factor": evals[metric] = result["eval_accuracy"] elif metric == "mig": evals[metric] = result["discrete_mig"] elif metric == "beta": evals[metric] = result["eval_accuracy"] elif metric == "dci": evals[metric] = result["disentanglement"] elif metric == "modularity": evals[metric] = result["modularity_score"] elif metric == "sap": evals[metric] = result["SAP_score"] # Clean up: resetting gin configs to original bindings gin.parse_config_files_and_bindings([original_file], original_bindings, finalize_config=False) return evals
def __init__(self, x_shape, y_dim, width=1, share_dense=False, uncond_bias=False, cond_bias=False, mask_type="match"): super().__init__() self.y_dim = y_dim self.mask_type = mask_type self.body = ts.Sequential( conv(32 * width, 4, 2, "same"), ts.LeakyReLU(), conv(32 * width, 4, 2, "same"), ts.LeakyReLU(), conv(64 * width, 4, 2, "same"), ts.LeakyReLU(), conv(64 * width, 4, 2, "same"), ts.LeakyReLU(), ts.Flatten(), ) if share_dense: self.body.append(dense(128 * width), ts.LeakyReLU()) if mask_type == "match": self.neck = ts.Sequential( dense(128 * width), ts.LeakyReLU(), dense(128 * width), ts.LeakyReLU(), ) self.head_uncond = dense(1, bias=uncond_bias) self.head_cond = dense(128 * width, bias=cond_bias) for m in (self.body, self.neck, self.head_uncond): m.apply(ts.SpectralNorm.add, targets=ts.Affine) add_wn(self.head_cond) x_shape, y_shape = [1] + x_shape, ((1, ), tf.int32) elif mask_type == "rank": self.body.append(dense(128 * width), ts.LeakyReLU(), dense(128 * width), ts.LeakyReLU(), dense(1 + y_dim, bias=uncond_bias)) self.body.apply(ts.SpectralNorm.add, targets=ts.Affine) x_shape, y_shape = [1] + x_shape, (1, y_dim) ut.log("Building {} discriminator...".format(mask_type)) self.build(x_shape, x_shape, y_shape) self.apply(ut.reset_parameters)
def get_dlib_data(task): ut.log("Loading {}".format(task)) if task == "dsprites": # 5 factors return dsprites.DSprites(list(range(1, 6))) elif task == "shapes3d": # 6 factors return shapes3d.Shapes3D() elif task == "norb": # 4 factors + 1 nuisance (which we'll handle via n_dim=2) return norb.SmallNORB() elif task == "cars3d": # 3 factors return cars3d.Cars3D() elif task == "scream": # 5 factors + 2 nuisance (handled as n_dim=2) return dsprites.ScreamDSprites(list(range(1, 6)))
def __init__(self, x_shape, z_dim, width=1, spectral_norm=True): super().__init__() self.net = ts.Sequential( conv(32 * width, 4, 2, "same"), ts.LeakyReLU(), conv(32 * width, 4, 2, "same"), ts.LeakyReLU(), conv(64 * width, 4, 2, "same"), ts.LeakyReLU(), conv(64 * width, 4, 2, "same"), ts.LeakyReLU(), ts.Flatten(), dense(128 * width), ts.LeakyReLU(), dense(2 * z_dim) ) if spectral_norm: self.net.apply(ts.SpectralNorm.add, targets=ts.Affine) ut.log("Building encoder...") self.build([1] + x_shape) self.apply(ut.reset_parameters)
def main(_): if FLAGS.debug: FLAGS.gin_bindings += ["log.debug = True"] gin.parse_config_files_and_bindings([FLAGS.gin_file], FLAGS.gin_bindings, finalize_config=False) ut.log("\n" + "*" * 80 + "\nBegin program\n" + "*" * 80) ut.log("In main") train() ut.log("\n" + "*" * 80 + "\nEnd program\n" + "*" * 80)
def evaluate_enc(enc, dset, s_dim, original_file, original_bindings, pida_sample_size=10000, dlib_metrics=True): """Evaluates an encoder on multiple disentanglement metrics. Given an encoder and an oracle generator, this function computes encoder- based metrics on consistency, restrictiveness, ranking accuracy, and six additional disentanglement metrics used in disentanglement_lib. The disentanglement_lib metrics are set by modifying the global gin-config. We require the user provide the original gin-file and gin-bindings so that they can be re-established at the end of this call. Args: enc: an encoding function that takes in and returns numpy arrays. dset: a disentanglement_lib dataset. s_dim: number of non-nuisance dimensions of the latent space. original_file: path to original gin file original_bindings: list of original gin bindings pida_sample_size: number of samples for monte carlo estimate pida metrics. dlib_metrics: flag for using disentanglement_lib metrics Returns: A dictionary of scores. """ # enc takes in and outputs numpy arrays # Consistency/Restrictiveness Metrics itypes = [ "{}={}".format(t, i) for t, i in itertools.product(("s", "c", "r"), range(s_dim)) ] evals = {} for it in itypes: scores = negative_pida(it, s_dim, enc, dset, sample_size=pida_sample_size, random_state=np.random.RandomState(0)) evals.update(scores) for k in scores: ut.log(k, ":", scores[k]) evals["s_mean"] = np.mean([evals[k] for k in evals if "s=" == k[:2]]) evals["c_mean"] = np.mean([evals[k] for k in evals if "c=" == k[:2]]) if not dlib_metrics: return evals # Disentanglement Lib Metrics eval_bindings_list = get_eval_bindings_list() metrics = ("factor", "mig", "beta", "dci", "modularity", "sap") for metric, eval_bindings in zip(metrics, eval_bindings_list): gin.parse_config_files_and_bindings([], eval_bindings, finalize_config=False) evaluation_fn = get_evaluation() tf.logging.info("Reset eval func to {}".format(evaluation_fn.__name__)) result = evaluation_fn(dset, enc, np.random.RandomState(0)) ut.log(result) if metric == "factor": evals[metric] = result["eval_accuracy"] elif metric == "mig": evals[metric] = result["discrete_mig"] elif metric == "beta": evals[metric] = result["eval_accuracy"] elif metric == "dci": evals[metric] = result["disentanglement"] elif metric == "modularity": evals[metric] = result["modularity_score"] elif metric == "sap": evals[metric] = result["SAP_score"] # Clean up: resetting gin configs to original bindings gin.parse_config_files_and_bindings([original_file], original_bindings, finalize_config=False) return evals
def train(dset_name, s_dim, n_dim, factors, batch_size, dec_lr, enc_lr_mul, iterations, model_type="gen"): ut.log("In train") masks = datasets.make_masks(factors, s_dim) z_dim = s_dim + n_dim enc_lr = enc_lr_mul * dec_lr # Load data dset = datasets.get_dlib_data(dset_name) if dset is None: x_shape = [64, 64, 1] else: x_shape = dset.observation_shape targets_real = tf.ones((batch_size, 1)) targets_fake = tf.zeros((batch_size, 1)) targets = tf.concat((targets_real, targets_fake), axis=0) # Networks if model_type == "gen": assert factors.split("=")[0] in {"c", "s", "cs", "r"} y_dim = len(masks) dis = networks.Discriminator(x_shape, y_dim) gen = networks.Generator(x_shape, z_dim) enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param ut.log(dis.read(dis.WITH_VARS)) ut.log(gen.read(gen.WITH_VARS)) ut.log(enc.read(enc.WITH_VARS)) elif model_type == "enc": assert factors.split("=")[0] in {"r"} enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param ut.log(enc.read(enc.WITH_VARS)) elif model_type == "van": assert factors.split("=")[0] in {"l"} dis = networks.LabelDiscriminator(x_shape, s_dim) # Uses s_dim gen = networks.Generator(x_shape, z_dim) enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param ut.log(dis.read(dis.WITH_VARS)) ut.log(gen.read(gen.WITH_VARS)) ut.log(enc.read(enc.WITH_VARS)) # Create optimizers if model_type in {"gen", "van"}: gen_opt = tfk.optimizers.Adam(learning_rate=dec_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) dis_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) elif model_type == "enc": enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) @tf.function def train_gen_step(x1_real, x2_real, y_real): gen.train() dis.train() enc.train() # Alternate discriminator step and generator step with tf.GradientTape(persistent=True) as tape: # Generate z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks) x1_fake = tf.stop_gradient(gen(z1)) x2_fake = tf.stop_gradient(gen(z2)) # Discriminate x1 = tf.concat((x1_real, x1_fake), 0) x2 = tf.concat((x2_real, x2_fake), 0) y = tf.concat((y_real, y_fake), 0) logits = dis(x1, x2, y) # Encode p_z = enc(x1_fake) dis_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=targets)) # Encoder ignores nuisance parameters (if they exist) enc_loss = -tf.reduce_mean(p_z.log_prob(z1[:, :s_dim])) dis_grads = tape.gradient(dis_loss, dis.trainable_variables) enc_grads = tape.gradient(enc_loss, enc.trainable_variables) dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables)) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) with tf.GradientTape(persistent=False) as tape: # Generate z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks) x1_fake = gen(z1) x2_fake = gen(z2) # Discriminate logits_fake = dis(x1_fake, x2_fake, y_fake) gen_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=targets_real)) gen_grads = tape.gradient(gen_loss, gen.trainable_variables) gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables)) return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss) @tf.function def train_van_step(x_real, y_real): gen.train() dis.train() enc.train() if n_dim > 0: padding = tf.zeros((y_real.shape[0], n_dim)) y_real_pad = tf.concat((y_real, padding), axis=-1) else: y_real_pad = y_real # Alternate discriminator step and generator step with tf.GradientTape(persistent=False) as tape: # Generate z_fake = datasets.paired_randn(batch_size, z_dim, masks) z_fake = z_fake + y_real_pad x_fake = gen(z_fake) # Discriminate logits_fake = dis(x_fake, y_real) gen_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=targets_real)) gen_grads = tape.gradient(gen_loss, gen.trainable_variables) gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables)) with tf.GradientTape(persistent=True) as tape: # Generate z_fake = datasets.paired_randn(batch_size, z_dim, masks) z_fake = z_fake + y_real_pad x_fake = tf.stop_gradient(gen(z_fake)) # Discriminate x = tf.concat((x_real, x_fake), 0) y = tf.concat((y_real, y_real), 0) logits = dis(x, y) # Encode p_z = enc(x_fake) dis_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=targets)) # Encoder ignores nuisance parameters (if they exist) enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim])) dis_grads = tape.gradient(dis_loss, dis.trainable_variables) enc_grads = tape.gradient(enc_loss, enc.trainable_variables) dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables)) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss) @tf.function def train_enc_step(x1_real, x2_real, y_real): with tf.GradientTape() as tape: z1 = enc(x1_real).mean() z2 = enc(x2_real).mean() logits = tf.gather(z1 - z2, masks, axis=-1) loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=y_real)) enc_grads = tape.gradient(loss, enc.trainable_variables) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) return dict(gen_loss=0, dis_loss=0, enc_loss=loss) @tf.function def gen_eval(z): gen.eval() return gen(z) @tf.function def enc_eval(x): enc.eval() return enc(x).mean() enc_np = lambda x: enc_eval(x).numpy() # Initial preparation if FLAGS.debug: iter_log = 100 iter_save = 2000 train_range = range(iterations) basedir = FLAGS.basedir vizdir = FLAGS.basedir ckptdir = FLAGS.basedir new_run = True else: iter_log = 5000 iter_save = 50000 iter_metric = iter_save * 5 # Make sure this is a factor of 500k basedir = os.path.join(FLAGS.basedir, "exp") ckptdir = os.path.join(basedir, "ckptdir") vizdir = os.path.join(basedir, "vizdir") gfile.MakeDirs(basedir) gfile.MakeDirs(ckptdir) gfile.MakeDirs(vizdir) # train_range will be specified below ckpt_prefix = os.path.join(ckptdir, "model") if model_type in {"gen", "van"}: ckpt_root = tf.train.Checkpoint(dis=dis, dis_opt=dis_opt, gen=gen, gen_opt=gen_opt, enc=enc, enc_opt=enc_opt) elif model_type == "enc": ckpt_root = tf.train.Checkpoint(enc=enc, enc_opt=enc_opt) # Check if we're resuming training if not in debugging mode if not FLAGS.debug: latest_ckpt = tf.train.latest_checkpoint(ckptdir) if latest_ckpt is None: new_run = True ut.log("Starting a completely new model") train_range = range(iterations) else: new_run = False ut.log("Restarting from {}".format(latest_ckpt)) ckpt_root.restore(latest_ckpt) resuming_iteration = iter_save * (int(ckpt_root.save_counter) - 1) train_range = range(resuming_iteration, iterations) # Training if dset is None: ut.log("Dataset {} is not available".format(dset_name)) ut.log("Ending program having checked that the networks can be built.") return batches = datasets.paired_data_generator( dset, masks).repeat().batch(batch_size).prefetch(1000) batches = iter(batches) start_time = time.time() train_time = 0 if FLAGS.debug: train_range = tqdm(train_range) for global_step in train_range: stopwatch = time.time() if model_type == "gen": x1, x2, y = next(batches) vals = train_gen_step(x1, x2, y) elif model_type == "enc": x1, x2, y = next(batches) vals = train_enc_step(x1, x2, y) elif model_type == "van": x, y = next(batches) vals = train_van_step(x, y) train_time += time.time() - stopwatch # Generic bookkeeping if (global_step + 1) % iter_log == 0 or global_step == 0: elapsed_time = time.time() - start_time string = ", ".join(( "Iter: {:07d}, Elapsed: {:.3e}, (Elapsed) Iter/s: {:.3e}, (Train Step) Iter/s: {:.3e}" .format(global_step, elapsed_time, global_step / elapsed_time, global_step / train_time), "Gen: {gen_loss:.4f}, Dis: {dis_loss:.4f}, Enc: {enc_loss:.4f}" .format(**vals))) + "." ut.log(string) # Log visualizations and evaluations if (global_step + 1) % iter_save == 0 or global_step == 0: if model_type == "gen": viz.ablation_visualization(x1, x2, gen_eval, z_dim, vizdir, global_step + 1) elif model_type == "van": viz.ablation_visualization(x, x, gen_eval, z_dim, vizdir, global_step + 1) if FLAGS.debug: evaluate.evaluate_enc(enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=1000, dlib_metrics=FLAGS.debug_dlib_metrics) else: dlib_metrics = (global_step + 1) % iter_metric == 0 evaluate.evaluate_enc(enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=10000, dlib_metrics=dlib_metrics) # Save model if (global_step + 1) % iter_save == 0 or (global_step == 0 and new_run): # Save model only after ensuring all measurements are taken. # This ensures that restarts always computes the evals ut.log("Saved to", ckpt_root.save(ckpt_prefix))
def train(dset_name, s_dim, n_dim, factors, s_I_dim, batch_size, clas_lr, iterations): ut.log("In train classifier") # Load data dset = datasets.get_dlib_data(dset_name) if dset is None: x_shape = [64, 64, 1] else: x_shape = dset.observation_shape targets_real = tf.ones((batch_size, 1)) targets_fake = tf.zeros((batch_size, 1)) targets = tf.concat((targets_real, targets_fake), axis=0) # Networks clas = networks.Classifier(x_shape, s_I_dim) ut.log(clas.read(clas.WITH_VARS)) # Create optimizers clas_opt = tfk.optimizers.Adam(learning_rate=clas_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) @tf.function def train_clas_step(x_real, y_real): clas.train() with tf.GradientTape(persistent=True) as tape: # Generate p_s = clas(x_real) clas_loss = -tf.reduce_mean(p_s.log_prob(y_real[:, :s_I_dim])) clas_grads = tape.gradient(clas_loss, clas.trainable_variables) clas_opt.apply_gradients(zip(clas_grads, clas.trainable_variables)) return dict(clas_loss=clas_loss) # Initial preparation if FLAGS.debug: iter_log = 100 iter_save = 2000 train_range = range(iterations) basedir = FLAGS.basedir vizdir = FLAGS.basedir ckptdir = FLAGS.basedir new_run = True else: iter_log = 5000 iter_save = 5000 iter_metric = iter_save * 5 # Make sure this is a factor of 500k basedir = os.path.join(FLAGS.basedir, "clas") ckptdir = os.path.join(basedir, "ckptdir") vizdir = os.path.join(basedir, "vizdir") gfile.MakeDirs(basedir) gfile.MakeDirs(ckptdir) gfile.MakeDirs(vizdir) # train_range will be specified below ckpt_prefix = os.path.join(ckptdir, "model") ckpt_root = tf.train.Checkpoint(clas=clas, clas_opt=clas_opt) # Check if we're resuming training if not in debugging mode if not FLAGS.debug: latest_ckpt = tf.train.latest_checkpoint(ckptdir) if latest_ckpt is None: new_run = True ut.log("Starting a completely new model") train_range = range(iterations) else: new_run = False ut.log("Restarting from {}".format(latest_ckpt)) ckpt_root.restore(latest_ckpt) resuming_iteration = iter_save * (int(ckpt_root.save_counter) - 1) train_range = range(resuming_iteration, iterations) # Training if dset is None: ut.log("Dataset {} is not available".format(dset_name)) ut.log("Ending program having checked that the networks can be built.") return batches = datasets.unmasked_label_data_generator( dset, s_dim).repeat().batch(batch_size).prefetch(1000) batches = iter(batches) start_time = time.time() train_time = 0 if FLAGS.debug: train_range = tqdm(train_range) for global_step in train_range: stopwatch = time.time() x, y = next(batches) vals = train_clas_step(x, y) train_time += time.time() - stopwatch # Generic bookkeeping if (global_step + 1) % iter_log == 0 or global_step == 0: elapsed_time = time.time() - start_time string = ", ".join(( "Iter: {:07d}, Elapsed: {:.3e}, (Elapsed) Iter/s: {:.3e}, (Train Step) Iter/s: {:.3e}" .format(global_step, elapsed_time, global_step / elapsed_time, global_step / train_time), "Clas: {clas_loss:.4f}".format(**vals))) + "." ut.log(string) # Log visualizations and evaluations if (global_step + 1) % iter_save == 0 or global_step == 0: ut.log("Beginning evaluation.") sample_size = 10000 random_state = np.random.RandomState(1) factors = dset.sample_factors(sample_size, random_state) obs = dset.sample_observations_from_factors(factors, random_state) eval_y = tf.convert_to_tensor(factors, dtype=tf.float32) eval_x = tf.convert_to_tensor(obs, dtype=tf.float32) p_s_eval = clas(eval_x) eval_loss = -tf.reduce_mean(p_s_eval.log_prob(eval_y[:, :s_I_dim])) ut.log("Eval loss: {}".format(eval_loss)) # Save model if (global_step + 1) % iter_save == 0 or (global_step == 0 and new_run): # Save model only after ensuring all measurements are taken. # This ensures that restarts always computes the evals ut.log("Saved to", ckpt_root.save(ckpt_prefix))
def train(dset_name, s_dim, n_dim, factors, z_transform, batch_size, dec_lr, enc_lr_mul, iterations, model_type="gen"): ut.log("In train") masks = datasets.make_masks(factors, s_dim) z_dim = s_dim + n_dim enc_lr = enc_lr_mul * dec_lr z_trans = datasets.get_z_transform(z_transform) # Load data dset = datasets.get_dlib_data(dset_name) # if FLAGS.evaluate: # dset = None # else: # dset = datasets.get_dlib_data(dset_name) if dset is None: x_shape = [64, 64, 1] else: x_shape = dset.observation_shape targets_real = tf.ones((batch_size, 1)) targets_fake = tf.zeros((batch_size, 1)) targets = tf.concat((targets_real, targets_fake), axis=0) # Networks if model_type == "gen": assert factors.split("=")[0] in {"c", "s", "cs", "r"} y_dim = len(masks) dis = networks.Discriminator(x_shape, y_dim) gen = networks.Generator(x_shape, z_dim) enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param ut.log(dis.read(dis.WITH_VARS)) ut.log(gen.read(gen.WITH_VARS)) ut.log(enc.read(enc.WITH_VARS)) elif model_type == "van": assert factors.split("=")[0] in {"l"} dis = networks.LabelDiscriminator(x_shape, s_dim) # Uses s_dim gen = networks.Generator(x_shape, z_dim) enc = networks.CovEncoder(x_shape, s_dim) # Encoder ignores nuisance param trans_enc = networks.CovEncoder(x_shape, s_dim) clas_path = os.path.join(FLAGS.basedir, "clas") clas = networks.Classifier(x_shape, s_dim) ckpt_root = tf.train.Checkpoint(clas=clas) latest_ckpt = tf.train.latest_checkpoint(clas_path) ckpt_root.restore(latest_ckpt) ut.log(dis.read(dis.WITH_VARS)) ut.log(gen.read(gen.WITH_VARS)) ut.log(enc.read(enc.WITH_VARS)) ut.log(clas.read(clas.WITH_VARS)) # Create optimizers if model_type in {"gen", "van"}: gen_opt = tfk.optimizers.Adam(learning_rate=dec_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) dis_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) trans_enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) @tf.function def train_gen_step(x1_real, x2_real, y_real): gen.train() dis.train() enc.train() # Alternate discriminator step and generator step with tf.GradientTape(persistent=True) as tape: # Generate z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks) x1_fake = tf.stop_gradient(gen(z1)) x2_fake = tf.stop_gradient(gen(z2)) # Discriminate x1 = tf.concat((x1_real, x1_fake), 0) x2 = tf.concat((x2_real, x2_fake), 0) y = tf.concat((y_real, y_fake), 0) logits = dis(x1, x2, y) # Encode p_z = enc(x1_fake) dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=logits, labels=targets)) # Encoder ignores nuisance parameters (if they exist) enc_loss = -tf.reduce_mean(p_z.log_prob(z1[:, :s_dim])) dis_grads = tape.gradient(dis_loss, dis.trainable_variables) enc_grads = tape.gradient(enc_loss, enc.trainable_variables) dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables)) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) with tf.GradientTape(persistent=False) as tape: # Generate z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks) x1_fake = gen(z1) x2_fake = gen(z2) # Discriminate logits_fake = dis(x1_fake, x2_fake, y_fake) gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=logits_fake, labels=targets_real)) gen_grads = tape.gradient(gen_loss, gen.trainable_variables) gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables)) return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss) @tf.function def train_van_step(x_real, y_real, entangle=False): gen.train() dis.train() enc.train() trans_enc.train() if n_dim > 0: padding = tf.zeros((y_real.shape[0], n_dim)) y_real_pad = tf.concat((y_real, padding), axis=-1) else: y_real_pad = y_real if entangle: # Alternate discriminator step and generator step with tf.GradientTape(persistent=False) as tape: # Generate dummy_mask = tf.zeros_like(masks) z_fake = datasets.paired_randn(batch_size, z_dim, dummy_mask) x_fake = gen(z_fake) # Discriminate logits_fake = dis(x_fake, y_real) gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=logits_fake, labels=targets_real)) gen_grads = tape.gradient(gen_loss, gen.trainable_variables) gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables)) with tf.GradientTape(persistent=True) as tape: # Generate dummy_mask = tf.zeros_like(masks) z_fake = datasets.paired_randn(batch_size, z_dim, dummy_mask) x_fake = tf.stop_gradient(gen(z_fake)) trans_z_fake = z_trans(z_fake) # Discriminate x = tf.concat((x_real, x_fake), 0) y = tf.concat((y_real, y_real), 0) logits = dis(x, y) # Encode p_z = enc(x_fake) p_z_trans = trans_enc(x_fake) dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=logits, labels=targets)) # Encoder ignores nuisance parameters (if they exist) enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim])) trans_enc_loss = -tf.reduce_mean(p_z_trans.log_prob( trans_z_fake[:, :s_dim])) dis_grads = tape.gradient(dis_loss, dis.trainable_variables) enc_grads = tape.gradient(enc_loss, enc.trainable_variables) trans_enc_grads = tape.gradient(trans_enc_loss, trans_enc.trainable_variables) dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables)) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) trans_enc_opt.apply_gradients(zip(trans_enc_grads, trans_enc.trainable_variables)) else: if n_dim > 0: padding = tf.zeros((y_real.shape[0], n_dim)) y_real_pad = tf.concat((y_real, padding), axis=-1) else: y_real_pad = y_real # Alternate discriminator step and generator step with tf.GradientTape(persistent=False) as tape: # Generate z_fake = datasets.paired_randn(batch_size, z_dim, masks) z_fake = z_fake + y_real_pad x_fake = gen(z_fake) # Discriminate logits_fake = dis(x_fake, y_real) gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=logits_fake, labels=targets_real)) gen_grads = tape.gradient(gen_loss, gen.trainable_variables) gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables)) with tf.GradientTape(persistent=True) as tape: # Generate z_fake = datasets.paired_randn(batch_size, z_dim, masks) z_fake = z_fake + y_real_pad x_fake = tf.stop_gradient(gen(z_fake)) trans_z_fake = z_trans(z_fake) # Discriminate x = tf.concat((x_real, x_fake), 0) y = tf.concat((y_real, y_real), 0) logits = dis(x, y) # Encode p_z = enc(x_fake) p_z_trans = trans_enc(x_fake) dis_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=logits, labels=targets)) # Encoder ignores nuisance parameters (if they exist) enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim])) trans_enc_loss = -tf.reduce_mean(p_z_trans.log_prob( trans_z_fake[:, :s_dim])) dis_grads = tape.gradient(dis_loss, dis.trainable_variables) enc_grads = tape.gradient(enc_loss, enc.trainable_variables) trans_enc_grads = tape.gradient(trans_enc_loss, trans_enc.trainable_variables) dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables)) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) trans_enc_opt.apply_gradients(zip(trans_enc_grads, trans_enc.trainable_variables)) return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss, trans_enc_loss=trans_enc_loss) @tf.function def gen_eval(z): gen.eval() return gen(z) @tf.function def enc_eval(x): enc.eval() return enc(x).mean() enc_np = lambda x: enc_eval(x).numpy() @tf.function def trans_enc_eval(x): trans_enc.eval() return trans_enc(x).mean() trans_enc_np = lambda x: trans_enc_eval(x).numpy() # Initial preparation if FLAGS.debug: iter_log = 100 iter_save = 2000 train_range = range(iterations) basedir = FLAGS.basedir vizdir = FLAGS.basedir ckptdir = FLAGS.basedir new_run = True else: iter_log = 5000 iter_save = 5000 iter_metric = iter_save * 5 # Make sure this is a factor of 500k basedir = os.path.join(FLAGS.basedir, "exp") ckptdir = os.path.join(basedir, "ckptdir") vizdir = os.path.join(basedir, "vizdir") gfile.MakeDirs(basedir) gfile.MakeDirs(ckptdir) gfile.MakeDirs(vizdir) # train_range will be specified below ckpt_prefix = os.path.join(ckptdir, "model") if model_type in {"gen", "van"}: ckpt_root = tf.train.Checkpoint(dis=dis, dis_opt=dis_opt, gen=gen, gen_opt=gen_opt, enc=enc, enc_opt=enc_opt, trans_enc=trans_enc, trans_enc_opt=trans_enc_opt) # Check if we're resuming training if not in debugging mode if not FLAGS.debug: latest_ckpt = tf.train.latest_checkpoint(ckptdir) if latest_ckpt is None: new_run = True ut.log("Starting a completely new model") train_range = range(iterations) else: new_run = False ut.log("Restarting from {}".format(latest_ckpt)) ckpt_root.restore(latest_ckpt) resuming_iteration = iter_save * (int(ckpt_root.save_counter) - 1) train_range = range(resuming_iteration, iterations) samples = FLAGS.val_samples if FLAGS.evaluate: masks = np.zeros([samples, z_dim]) masks[:, 0] = 1 masks = tf.convert_to_tensor(masks, dtype=tf.float32) transformed_prior = datasets.transformed_prior(z_trans) mi, mi_trans, mi_joint, mi_joint_trans = [], [], [], [] for i in range(FLAGS.mi_averages): mi.append(new_metrics.mi_difference(z_dim, gen, clas, masks, samples)) mi_trans.append(new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = transformed_prior)) mi_joint.append(new_metrics.mi_difference(z_dim, gen, clas, masks, samples, draw_from_joint=True)) mi_joint_trans.append(new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = transformed_prior, draw_from_joint=True)) mi = np.mean(np.stack(mi), axis=0) mi_trans = np.mean(np.stack(mi_trans), axis=0) mi_joint = np.mean(mi_joint) mi_joint_trans = np.mean(mi_joint_trans) ut.log("MI - Normal: {}, {} Trans: {}, {}".format(mi[0], mi[1], mi_trans[0], mi_trans[1])) # mi = new_metrics.mi_difference(z_dim, gen, clas, masks, samples) # unmixed_prior = datasets.unmixed_prior(FLAGS.shift, FLAGS.scale) # mi_unmixed = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = unmixed_prior) # mi_mixed = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = datasets.mixed_prior) # ut.log("MI - Normal: {}, {} Unmixed: {}, {} Mixed: {}, {}".format(mi[0], mi[1], mi_unmixed[0], mi_unmixed[1], mi_mixed[0], mi_mixed[1])) # mi_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, draw_from_joint=True) # mi_unmixed_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = unmixed_prior, draw_from_joint=True) # mi_mixed_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = datasets.mixed_prior, draw_from_joint=True) ut.log("MI Joint - Normal: {} Trans: {}".format(mi_joint, mi_joint_trans)) ut.log("Encoder Metrics") evaluate.evaluate_enc(enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=1000, dlib_metrics=FLAGS.debug_dlib_metrics) ut.log("Transformed Encoder Metrics") evaluate.evaluate_enc(trans_enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=1000, dlib_metrics=FLAGS.debug_dlib_metrics) ut.log("Completed Evaluation") return # Training if dset is None: ut.log("Dataset {} is not available".format(dset_name)) ut.log("Ending program having checked that the networks can be built.") return batches = datasets.paired_data_generator(dset, masks).repeat().batch(batch_size).prefetch(1000) batches = iter(batches) start_time = time.time() train_time = 0 if FLAGS.debug: train_range = tqdm(train_range) if FLAGS.visualize: train_range = range(iterations+1) for global_step in train_range: stopwatch = time.time() if model_type == "gen": x1, x2, y = next(batches) vals = train_gen_step(x1, x2, y) elif model_type == "van": x, y = next(batches) vals = train_van_step(x, y, FLAGS.entangle) train_time += time.time() - stopwatch # Generic bookkeeping if (global_step + 1) % iter_log == 0 or global_step == 0: elapsed_time = time.time() - start_time string = ", ".join(( "Iter: {:07d}, Elapsed: {:.3e}, (Elapsed) Iter/s: {:.3e}, (Train Step) Iter/s: {:.3e}".format( global_step, elapsed_time, global_step / elapsed_time, global_step / train_time), "Gen: {gen_loss:.4f}, Dis: {dis_loss:.4f}, Enc: {enc_loss:.4f}, Trans_Enc: {trans_enc_loss:.4f}".format( **vals) )) + "." ut.log(string) # Log visualizations and evaluations if (global_step + 1) % iter_save == 0 or global_step == 0: if model_type == "gen": viz.ablation_visualization(x1, x2, gen_eval, z_dim, vizdir, global_step + 1) elif model_type == "van": viz.ablation_visualization(x, x, gen_eval, z_dim, vizdir, global_step + 1) # num_s_I = 100 # k = 150 # y_real = tf.convert_to_tensor(dset.sample_factors(num_s_I, np.random.RandomState(1)), dtype=tf.float32) # masks = np.zeros([samples, z_dim]) # masks[:, 0] = 1 # masks = tf.convert_to_tensor(masks, dtype=tf.float32) # y_real = y_real * masks # mi = metrics.mi_estimate(y_real, gen, enc, masks, k, num_s_I, z_dim, s_dim) # mi_trans = metrics.mi_estimate(y_real, gen, trans_enc, masks, k, num_s_I, z_dim, s_dim, z_trans) # ut.log("Encoder MI: {} Transformed Encoder MI: {}".format(mi, mi_trans)) # mi = new_metrics.mi_difference(z_dim, gen, clas, masks, samples) # mi_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, draw_from_joint=True) # ut.log("MI:{} MI_Joint:{}".format(mi, mi_joint)) # mi = new_metrics.mi_difference(z_dim, gen, clas, masks, samples) # unmixed_prior = datasets.unmixed_prior(FLAGS.shift, FLAGS.scale) # mi_unmixed = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = unmixed_prior) # mi_mixed = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = datasets.mixed_prior) # ut.log("MI - Normal: {}, {} Unmixed: {}, {} Mixed: {}, {}".format(mi[0], mi[1], mi_unmixed[0], mi_unmixed[1], mi_mixed[0], mi_mixed[1])) # # mi_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, draw_from_joint=True) # mi_unmixed_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = unmixed_prior, draw_from_joint=True) # mi_mixed_joint = new_metrics.mi_difference(z_dim, gen, clas, masks, samples, z_prior = datasets.mixed_prior, draw_from_joint=True) # ut.log("MI Joint - Normal: {} Unmixed: {} Mixed: {}".format(mi_joint, mi_unmixed_joint, mi_mixed_joint)) if FLAGS.debug: ut.log("Encoder Metrics") evaluate.evaluate_enc(enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=1000, dlib_metrics=FLAGS.debug_dlib_metrics) ut.log("Transformed Encoder Metrics") evaluate.evaluate_enc(trans_enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=1000, dlib_metrics=FLAGS.debug_dlib_metrics) else: dlib_metrics = (global_step + 1) % iter_metric == 0 ut.log("Encoder Metrics") evaluate.evaluate_enc(enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=10000, dlib_metrics=dlib_metrics) ut.log("Transformed Encoder Metrics") evaluate.evaluate_enc(trans_enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=10000, dlib_metrics=dlib_metrics) # Save model if (global_step + 1) % iter_save == 0 or (global_step == 0 and new_run): # Save model only after ensuring all measurements are taken. # This ensures that restarts always computes the evals ut.log("Saved to", ckpt_root.save(ckpt_prefix))