def main(_): # Data, model, logs and samples are all going to be saved in the folder dependent on the name of the data (i.e. mnist, celebA, lsun etc) [data_path, model_path, log_path, sample_path] = append_to_paths(flags.FLAGS.data, flags.FLAGS.data_path, flags.FLAGS.model_path, flags.FLAGS.log_path, flags.FLAGS.sample_path) if flags.FLAGS.data == "mnist": data = Data.Mnist(data_path) elif flags.FLAGS.data == "banners": data = Data.Banners(data_path) elif flags.FLAGS.data == "artists": data = Data.Artists(data_path) elif flags.FLAGS.data == "celebA": data = Data.CelebA(data_path, 5000) elif flags.FLAGS.data == "wines": data = Data.Wines(data_path) else: raise ValueError("Data %s is not supported" % (flags.FLAGS.data)) if flags.FLAGS.operation == "train": cgan = ConditionalGAN.ConditionalGAN(data, flags.FLAGS.batch_size, flags.FLAGS.z_dim, True) make_paths(model_path, log_path, sample_path) cgan.train(model_path, log_path, sample_path, flags.FLAGS.training_steps, flags.FLAGS.learn_rate, flags.FLAGS.save_frequency, flags.FLAGS.generator_advantage) elif flags.FLAGS.operation == "test": cgan = ConditionalGAN.ConditionalGAN(data, flags.FLAGS.batch_size, flags.FLAGS.z_dim) make_paths(sample_path) if flags.FLAGS.samples_spec == "random": labels = data.get_random_labels( (flags.FLAGS.samples, data.get_number_of_labels())) else: labels = data.get_labels_by_spec( (flags.FLAGS.samples, data.get_number_of_labels()), flags.FLAGS.samples_spec) print("\n".join(data.describe_labels(labels))) cgan.test(model_path, sample_path, None, labels) else: print("Unknown operation %s" % (flags.FLAGS.operation))