def main(args): del args dataset_name = FLAGS.dataset or os.path.basename( os.path.dirname(os.path.dirname(FLAGS.ckpt))) try: dataset = data.get_dataset(dataset_name) except KeyError: dataset = data.get_dataset('lsun_' + dataset_name) ops = get_ops(dataset) images = load_hires(dataset, get_samples_indexes(FLAGS.samples)) image_grid = get_candidates(ops, images) img = utils.images_to_grid(image_grid) output_file = os.path.abspath(FLAGS.save_to) os.makedirs(os.path.dirname(output_file), exist_ok=True) open(output_file, 'wb').write(utils.to_png(img)) print('Saved', output_file)
def main(argv): del argv # Unused. dataset = data.get_dataset(FLAGS.dataset) schedule = TrainSchedule(2, FLAGS.scale, FLAGS.transition_kimg, FLAGS.training_kimg, FLAGS.total_kimg) if FLAGS.memtest: schedule.schedule = schedule.schedule[-2:] model = LAG( os.path.join(FLAGS.train_dir, dataset.name), lr=FLAGS.lr, batch=FLAGS.batch, lod_min=1, scale=FLAGS.scale, downscaler=FLAGS.downscaler, blocks=FLAGS.blocks, filters=FLAGS.filters, filters_min=FLAGS.filters_min, mse_weight=FLAGS.mse_weight, noise_dim=FLAGS.noise_dim, transition_kimg=FLAGS.transition_kimg, training_kimg=FLAGS.training_kimg, ttur=FLAGS.ttur, wass_target=FLAGS.wass_target, weight_avg=FLAGS.weight_avg) if FLAGS.reset: model.reset_files() model.train(dataset, schedule)
def main(argv): del argv nbatch = FLAGS.samples // FLAGS.batch dataset = data.get_dataset(FLAGS.dataset) train_data = dataset.train.batch(FLAGS.batch) train_data = train_data.prefetch(32) train_data = train_data.make_one_shot_iterator().get_next() with tf.train.MonitoredSession() as sess: for _ in trange(nbatch, leave=True): sess.run(train_data)
def main(argv): del argv # Unused. dataset = data.get_dataset(FLAGS.dataset) model = EDSR(os.path.join(FLAGS.train_dir, dataset.name), lr=FLAGS.lr, batch=FLAGS.batch, scale=FLAGS.scale, downscaler=FLAGS.downscaler, filters=FLAGS.filters, repeat=FLAGS.repeat) model.train(dataset)
def main(argv): del argv # Unused. dataset = data.get_dataset(FLAGS.dataset) decay_start = (FLAGS.total_kimg << 9) // FLAGS.batch decay_stop = (FLAGS.total_kimg << 10) // FLAGS.batch model = cGAN(os.path.join(FLAGS.train_dir, dataset.name), scale=FLAGS.scale, downscaler=FLAGS.downscaler, blocks=FLAGS.blocks, filters=FLAGS.filters, noise=FLAGS.noise, decay_start=decay_start, decay_stop=decay_stop, lr_decay=FLAGS.lr_decay) if FLAGS.reset: model.reset_files() model.train(dataset)
def main(argv): del argv # Unused. dataset = data.get_dataset(FLAGS.dataset) decay_start = (FLAGS.total_kimg << 9) // FLAGS.batch decay_stop = (FLAGS.total_kimg << 10) // FLAGS.batch model = SRGAN(os.path.join(FLAGS.train_dir, dataset.name), scale=FLAGS.scale, downscaler=FLAGS.downscaler, filters=FLAGS.filters, blocks=FLAGS.blocks, decay_start=decay_start, decay_stop=decay_stop, lr_decay=FLAGS.lr_decay, adv_weight=FLAGS.adv_weight, pcp_weight=FLAGS.pcp_weight, layer_name=FLAGS.layer_name) if FLAGS.reset: model.reset_files() model.train(dataset)