def main(): if len(sys.argv) != 3: print(""" usage: {} dirty-0.fits,dirty-1.fits,dirty-2.fits psf-0.fits,psf-1.fits,psf2.fits note: names don't matter, order does. only supports fits files of {}x{} will write output the current folder. """.format(sys.argv[0], CROP_SIZE, CROP_SIZE)) sys.exit(1) dirties = [os.path.realpath(i) for i in sys.argv[1].split(',')] psfs = [os.path.realpath(i) for i in sys.argv[2].split(',')] assert len(dirties) == len(psfs) batch, count = load_data(dirties, psfs) steps_per_epoch = count iter = batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty = iter.get_next() scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 input_ = tf.concat([scaled_dirty, scaled_psf], axis=3) # set up the network with tf.variable_scope("generator"): outputs = create_generator(input_, 1, a.ngf, a.separable_conv) deprocessed_output = deprocess(outputs, min_flux, max_flux) with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved with tf.name_scope("encode_fitss"): fits_fetches = { "indexs": index, "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"), "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"), } with tf.Session() as sess: checkpoint = tf.train.latest_checkpoint(a.checkpoint) tf.train.Saver().restore(sess, checkpoint) for step in range(steps_per_epoch): results = sess.run(fits_fetches) filesets = save_images(results, subfolder=None, extention="fits", output_dir=a.output_dir) for f in filesets: print("wrote " + f['name'])
def init(checkpoint): min_flux = tf.placeholder(tf.float32, shape=(1, )) max_flux = tf.placeholder(tf.float32, shape=(1, )) input_ = tf.placeholder(tf.float32, shape=(1, SIZE, SIZE, 2)) # set up the network with tf.variable_scope("generator"): outputs = create_generator(input_, 1, NGF, SEPERABLE_CONV) deprocessed_output = deprocess(outputs, min_flux, max_flux) sess = tf.Session() logger.info("restoring data from checkpoint " + checkpoint) checkpoint = tf.train.latest_checkpoint(checkpoint) tf.train.Saver().restore(sess, checkpoint) return Model(session=sess, output=deprocessed_output, input=input_, max_flux=max_flux, min_flux=min_flux)
def main(): prepare() train_batch = generative_model(a.psf_glob, flux_scale_min=a.flux_scale_min, flux_scale_max=a.flux_scale_max) iterator = train_batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty, skymodel = iterator.get_next() with tf.name_scope("scaling_flux"): scaled_skymodel = preprocess(skymodel, min_flux, max_flux) scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 # inputs and targets are [batch_size, height, width, channels] model = create_model(scaled_dirty, scaled_skymodel, EPS, a.separable_conv, beta1=a.beta1, gan_weight=a.gan_weight, l1_weight=a.l1_weight, lr=a.lr, ndf=a.ndf, ngf=a.ngf, psf=psf, min_flux=min_flux, max_flux=max_flux, res_weight=a.res_weight, disable_psf=a.disable_psf) deprocessed_output = deprocess(model.outputs, min_flux, max_flux) with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved # reverse any processing on images so they can be written to disk or displayed to user with tf.name_scope("convert_images"): converted_inputs = tf.image.convert_image_dtype( visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True) converted_targets = tf.image.convert_image_dtype( visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True) converted_outputs = tf.image.convert_image_dtype(visual_scaling( model.outputs), dtype=tf.uint8, saturate=True) converted_psfs = tf.image.convert_image_dtype( visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True) converted_residuals = tf.image.convert_image_dtype( visual_scaling(residuals), dtype=tf.uint8, saturate=True) with tf.name_scope("encode_images"): display_fetches = { "indexs": index, "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), "psfs": tf.map_fn(tf.image.encode_png, converted_psfs, dtype=tf.string, name="psf_pngs"), "residuals": tf.map_fn(tf.image.encode_png, converted_residuals, dtype=tf.string, name="residual_pngs"), } # summaries with tf.name_scope("combined_summary"): tf.summary.image("inputs", converted_inputs) tf.summary.image("outputs", converted_outputs) tf.summary.image("targets", converted_targets) tf.summary.image("residuals", converted_residuals) with tf.name_scope("psfs_summary"): tf.summary.image("psfss", converted_psfs) with tf.name_scope("predict_real_summary"): tf.summary.image( "predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) with tf.name_scope("predict_fake_summary"): tf.summary.image( "predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) tf.summary.scalar("discriminator_loss", model.discrim_loss) tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) tf.summary.scalar("generator_loss_RES", model.gen_loss_RES) with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None train_summary_op = tf.summary.merge_all() train_summary_writer = tf.summary.FileWriter(logdir=logdir + '/train') saver = tf.train.Saver(max_to_keep=100) sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=saver, summary_writer=None, summary_op=None) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6) with sv.managed_session(config=tf.ConfigProto( gpu_options=gpu_options)) as sess: print("parameter_count =", sess.run(parameter_count)) start = time.time() for step in range(a.max_steps): def should(freq): return freq > 0 and ((step + 1) % freq == 0 or step == a.max_steps - 1) options = None run_metadata = None if should(a.trace_freq): print("preparing") options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() fetches = { "train": model.train, "global_step": sv.global_step, } if should(a.progress_freq): print("progress step") fetches["discrim_loss"] = model.discrim_loss fetches["gen_loss_GAN"] = model.gen_loss_GAN fetches["gen_loss_L1"] = model.gen_loss_L1 fetches["gen_loss_RES"] = model.gen_loss_RES if should(a.summary_freq): print("preparing summary") fetches["summary"] = train_summary_op if should(a.display_freq): print("display step step") fetches["display"] = display_fetches results = sess.run(fetches, options=options, run_metadata=run_metadata) if should(a.summary_freq): print("recording summary") train_summary_writer.add_summary(results["summary"], results["global_step"]) if should(a.display_freq): print("saving display images") save_images(results["display"], step=results["global_step"], output_dir=a.output_dir) if should(a.trace_freq): print("recording trace") train_summary_writer.add_run_metadata( run_metadata, "step_%d" % results["global_step"]) if should(a.progress_freq): rate = (step + 1) * a.batch_size / (time.time() - start) remaining = (a.max_steps - step) * a.batch_size / rate print("progress step %d image/sec %0.1f remaining %dm" % (results["global_step"], rate, remaining / 60)) print("discrim_loss", results["discrim_loss"]) print("gen_loss_GAN", results["gen_loss_GAN"]) print("gen_loss_L1", results["gen_loss_L1"]) print("gen_loss_RES", results["gen_loss_RES"]) if should(a.save_freq): print("saving model") saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step) if sv.should_stop(): print("supervisor things we should stop!") break print("done! bye")
def main(): prepare() train_batch, train_count = load_data(a.input_dir, CROP_SIZE, a.flip, a.scale_size, a.max_epochs, a.batch_size, start=a.train_start, end=a.train_end, loop=True) print("train count = %d" % train_count) steps_per_epoch = int(math.ceil(train_count / a.batch_size)) training_iterator = train_batch.make_one_shot_iterator() validate_batch, validate_count = load_data(a.input_dir, CROP_SIZE, a.flip, a.scale_size, a.max_epochs, a.batch_size, start=a.validate_start, end=a.validate_end, loop=True) print("validate count = %d" % validate_count) validation_iterator = validate_batch.make_one_shot_iterator() handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle(handle, train_batch.output_types, train_batch.output_shapes) index, min_flux, max_flux, psf, dirty, skymodel = iterator.get_next() with tf.name_scope("scaling_flux"): scaled_skymodel = preprocess(skymodel, min_flux, max_flux) scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 ## i'll just leave this here in case we decide to do something with visibilities again (unlikely) # vis = tf.fft2d(tf.complex(scaled_dirty, tf.zeros(shape=(a.batch_size, CROP_SIZE, CROP_SIZE, 1)))) # real = tf.real(vis) # imag = tf.imag(vis) if a.disable_psf: input_ = scaled_dirty else: input_ = tf.concat([scaled_dirty, scaled_psf], axis=3) # inputs and targets are [batch_size, height, width, channels] model = create_model(input_, scaled_skymodel, EPS, a.separable_conv, beta1=a.beta1, gan_weight=a.gan_weight, l1_weight=a.l1_weight, lr=a.lr, ndf=a.ndf, ngf=a.ngf, psf=psf, min_flux=min_flux, max_flux=max_flux, res_weight=a.res_weight) deprocessed_output = deprocess(model.outputs, min_flux, max_flux) with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved # reverse any processing on images so they can be written to disk or displayed to user with tf.name_scope("convert_images"): converted_inputs = tf.image.convert_image_dtype( visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True) converted_targets = tf.image.convert_image_dtype( visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True) converted_outputs = tf.image.convert_image_dtype(visual_scaling( model.outputs), dtype=tf.uint8, saturate=True) converted_psfs = tf.image.convert_image_dtype( visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True) converted_residuals = tf.image.convert_image_dtype( visual_scaling(residuals), dtype=tf.uint8, saturate=True) with tf.name_scope("encode_images"): display_fetches = { "indexs": index, "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), "psfs": tf.map_fn(tf.image.encode_png, converted_psfs, dtype=tf.string, name="psf_pngs"), "residuals": tf.map_fn(tf.image.encode_png, converted_residuals, dtype=tf.string, name="residual_pngs"), } # summaries with tf.name_scope("combined_summary"): tf.summary.image("inputs", converted_inputs) tf.summary.image("outputs", converted_outputs) tf.summary.image("targets", converted_targets) tf.summary.image("residuals", converted_residuals) with tf.name_scope("psfs_summary"): tf.summary.image("psfss", converted_psfs) with tf.name_scope("predict_real_summary"): tf.summary.image( "predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) with tf.name_scope("predict_fake_summary"): tf.summary.image( "predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) tf.summary.scalar("discriminator_loss", model.discrim_loss) tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) tf.summary.scalar("generator_loss_RES", model.gen_loss_RES) if a.validation_freq: tf.summary.scalar("Validation generator_loss_GAN", model.gen_loss_GAN) tf.summary.scalar("Validation generator_loss_L1", model.gen_loss_L1) """ for var in tf.trainable_variables(): tf.summary.histogram(var.op.name + "/values", var) for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars: tf.summary.histogram(var.op.name + "/gradients", grad) """ with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None validation_summary_op = tf.summary.merge_all() validation_summary_writer = tf.summary.FileWriter(logdir=logdir + '/validation') train_summary_op = tf.summary.merge_all() train_summary_writer = tf.summary.FileWriter(logdir=logdir + '/train') saver = tf.train.Saver(max_to_keep=100) sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=saver, summary_writer=None, summary_op=None) with sv.managed_session() as sess: print("parameter_count =", sess.run(parameter_count)) # The `Iterator.string_handle()` method returns a tensor that can be evaluated # and used to feed the `handle` placeholder. training_handle = sess.run(training_iterator.string_handle()) validation_handle = sess.run(validation_iterator.string_handle()) max_steps = 2**32 if a.max_epochs is not None: max_steps = steps_per_epoch * a.max_epochs if a.max_steps is not None: max_steps = a.max_steps start = time.time() for step in range(max_steps): def should(freq): return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1) options = None run_metadata = None if should(a.trace_freq): print("preparing") options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() fetches = { "train": model.train, "global_step": sv.global_step, } if should(a.progress_freq): print("progress step") fetches["discrim_loss"] = model.discrim_loss fetches["gen_loss_GAN"] = model.gen_loss_GAN fetches["gen_loss_L1"] = model.gen_loss_L1 fetches["gen_loss_RES"] = model.gen_loss_RES if should(a.summary_freq): print("preparing summary") fetches["summary"] = train_summary_op if should(a.display_freq): print("display step step") fetches["display"] = display_fetches results = sess.run(fetches, options=options, run_metadata=run_metadata, feed_dict={handle: training_handle}) if should(a.summary_freq): print("recording summary") train_summary_writer.add_summary(results["summary"], results["global_step"]) if should(a.display_freq): print("saving display images") save_images(results["display"], step=results["global_step"], output_dir=a.output_dir) if should(a.trace_freq): print("recording trace") train_summary_writer.add_run_metadata( run_metadata, "step_%d" % results["global_step"]) if should(a.progress_freq): # global_step will have the correct step count if we resume from a checkpoint train_epoch = math.ceil(results["global_step"] / steps_per_epoch) train_step = (results["global_step"] - 1) % steps_per_epoch + 1 rate = (step + 1) * a.batch_size / (time.time() - start) remaining = (max_steps - step) * a.batch_size / rate print( "progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60)) print("discrim_loss", results["discrim_loss"]) print("gen_loss_GAN", results["gen_loss_GAN"]) print("gen_loss_L1", results["gen_loss_L1"]) print("gen_loss_RES", results["gen_loss_RES"]) if should(a.save_freq): print("saving model") saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step) if a.validation_freq and should(a.validation_freq): print("validation step") validation_fetches = { "validation_gen_loss_GAN": model.gen_loss_GAN, "validation_gen_loss_L1": model.gen_loss_L1, "summary": validation_summary_op, } validation_results = sess.run( validation_fetches, feed_dict={handle: validation_handle}) print("validation gen_loss_GAN", validation_results["validation_gen_loss_GAN"]) print("validation gen_loss_L1", validation_results["validation_gen_loss_L1"]) validation_summary_writer.add_summary( validation_results["summary"], results["global_step"]) if sv.should_stop(): print("supervisor things we should stop!") break print("done! bye")
def main(): dirty_path = os.path.realpath(a.dirty) psf_path = os.path.realpath(a.psf) big_fits = fits.open(str(dirty_path))[0] big_data = big_fits.data.squeeze()[:, :, np.newaxis] big_psf_fits = fits.open(str(psf_path))[0] assert (big_psf_fits.data.shape == big_fits.data.shape) # we need a smaller PSF to give as a channel to the dirty tiles big_psf_data = big_psf_fits.data.squeeze() big_psf_data = big_psf_data / big_psf_data.max() psf_small = big_psf_data[big_psf_data.shape[0] // 2 - SIZE // 2 + 1:big_psf_data.shape[0] // 2 + SIZE // 2 + 1, big_psf_data.shape[1] // 2 - SIZE // 2 + 1:big_psf_data.shape[1] // 2 + SIZE // 2 + 1] logger.debug(psf_small.shape) logger.debug((big_psf_data.shape[0] // 2 - SIZE // 2 + 1, big_psf_data.shape[0] // 2 + SIZE // 2 + 1, big_psf_data.shape[1] // 2 - SIZE // 2 + 1, big_psf_data.shape[1] // 2 + SIZE // 2 + 1)) psf_small = psf_small[:, :, np.newaxis] n_r = int(big_data.shape[0] / stride) n_c = int(big_data.shape[1] / stride) # set up the data loading batch, count = load_data(big_data, psf_small, n_r, n_c) steps_per_epoch = count iterator = batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty = iterator.get_next() scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 input_ = tf.concat([scaled_dirty, scaled_psf], axis=3) # set up the network with tf.variable_scope("generator"): outputs = create_generator(input_, 1, NGF, SEPERABLE_CONV) deprocessed_output = deprocess(outputs, min_flux, max_flux) # run all data through the network queue_ = IterableQueue() with tf.Session() as sess: logger.info("restoring data from checkpoint " + a.checkpoint) checkpoint = tf.train.latest_checkpoint(a.checkpoint) tf.train.Saver().restore(sess, checkpoint) for step in range(steps_per_epoch): n = sess.run(deprocessed_output) queue_.put(n) # reconstruct the data big_model = restore(big_data.squeeze().shape, iter(queue_), n_r, n_c) p = big_psf_data.shape[0] #r = slice(p // 2, -p // 2 + 1) # uneven PSF needs +2, even psf +1 r = slice(p // 2 + 1, -p // 2 + 2) convolved = fftconvolve(big_model, big_psf_data, mode="full")[r, r] residual = big_fits.data.squeeze() - convolved # write the data hdu = fits.PrimaryHDU(big_model.squeeze()) hdu.header = big_fits.header hdul = fits.HDUList([hdu]) hdul.writeto("vacuum-model.fits", overwrite=True) hdu = fits.PrimaryHDU(residual.squeeze()) hdu.header = big_fits.header hdul = fits.HDUList([hdu]) hdul.writeto("vacuum-residual.fits", overwrite=True) logger.info("done!")
def create_model(inputs, targets, EPS, separable_conv, ngf, ndf, gan_weight, l1_weight, res_weight, lr, beta1, psf, min_flux, max_flux): # type: (tf.Tensor, tf.Tensor, float, bool, int, int, float, float, float, float, float, tf.Tensor, float, float) -> Model with tf.variable_scope("generator"): out_channels = 1 outputs = create_generator(inputs, out_channels, ngf, separable_conv) # create two copies of discriminator, one for real pairs and one for fake pairs # they share the same underlying variables with tf.name_scope("real_discriminator"): with tf.variable_scope("discriminator"): # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_real = create_discriminator(inputs, targets, ndf) with tf.name_scope("fake_discriminator"): with tf.variable_scope("discriminator", reuse=True): # 2x [batch, height, width, channels] => [batch, 30, 30, 1] predict_fake = create_discriminator(inputs, outputs, ndf) with tf.name_scope("discriminator_loss"): # minimizing -tf.log will try to get inputs to 1 # predict_real => 1 # predict_fake => 0 discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) with tf.name_scope("discriminator_residuals"): deprocessed_output = deprocess(outputs, min_flux, max_flux) shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = targets - convolved with tf.name_scope("generator_loss"): # predict_fake => 1 # abs(targets - outputs) => 0 gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) gen_loss_RES = tf.reduce_mean( tf.abs(residuals - tf.reduce_mean(residuals))) gen_loss = gen_loss_GAN * gan_weight + gen_loss_L1 * l1_weight + gen_loss_RES * res_weight with tf.name_scope("discriminator_train"): discrim_tvars = [ var for var in tf.trainable_variables() if var.name.startswith("discriminator") ] discrim_optim = tf.train.AdamOptimizer(lr, beta1) discrim_grads_and_vars = discrim_optim.compute_gradients( discrim_loss, var_list=discrim_tvars) discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) with tf.name_scope("generator_train"): with tf.control_dependencies([discrim_train]): gen_tvars = [ var for var in tf.trainable_variables() if var.name.startswith("generator") ] gen_optim = tf.train.AdamOptimizer(lr, beta1) gen_grads_and_vars = gen_optim.compute_gradients( gen_loss, var_list=gen_tvars) gen_train = gen_optim.apply_gradients(gen_grads_and_vars) ema = tf.train.ExponentialMovingAverage(decay=0.99) update_losses = ema.apply( [discrim_loss, gen_loss_GAN, gen_loss_L1, gen_loss_RES]) global_step = tf.train.get_or_create_global_step() incr_global_step = tf.assign(global_step, global_step + 1) return Model( predict_real=predict_real, predict_fake=predict_fake, discrim_loss=ema.average(discrim_loss), discrim_grads_and_vars=discrim_grads_and_vars, gen_loss_GAN=ema.average(gen_loss_GAN), gen_loss_L1=ema.average(gen_loss_L1), gen_loss_RES=ema.average(gen_loss_RES), gen_grads_and_vars=gen_grads_and_vars, outputs=outputs, train=tf.group(update_losses, incr_global_step, gen_train), )
def test( input_dir, output_dir, checkpoint, batch_size=1, test_start=0, test_end=999, disable_psf=False, ngf=64, separable_conv=False, write_residuals=False, write_input=False, ): batch, count = load_data(path=input_dir, flip=False, crop_size=CROP_SIZE, scale_size=CROP_SIZE, max_epochs=1, batch_size=batch_size, start=test_start, end=test_end) steps_per_epoch = int(math.ceil(count / batch_size)) iter_ = batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next() print("train count = %d" % count) with tf.name_scope("scaling_flux"): scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 if disable_psf: input_ = scaled_dirty else: input_ = tf.concat([dirty, psf[:, 128:-128, 128:-128, :]], axis=3) with tf.variable_scope("generator"): generator = create_generator(input_, 1, ngf=ngf, separable_conv=separable_conv) deprocessed_output = deprocess(generator, min_flux, max_flux) if write_residuals: with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved with tf.name_scope("encode_fitss"): work = { "indexs": index, "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"), } if write_residuals: work["residuals"] = tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits") if write_input: work["inputs"] = tf.map_fn(fits_encode, dirty, dtype=tf.string, name="input_fits") sv = tf.train.Supervisor(logdir=None) with sv.managed_session() as sess: sv.saver.restore(sess, checkpoint) for step in range(steps_per_epoch): results = sess.run(work) filesets = save_images(results, subfolder="fits", extention="fits", output_dir=output_dir) for f in filesets: print("wrote " + f['name'])
def main(): prepare() batch, count = load_data(path=a.input_dir, flip=False, crop_size=CROP_SIZE, scale_size=CROP_SIZE, max_epochs=1, batch_size=a.batch_size, start=a.test_start, end=a.test_end) steps_per_epoch = int(math.ceil(count / a.batch_size)) iter_ = batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next() print("train count = %d" % count) with tf.name_scope("scaling_flux"): scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 if a.disable_psf: input_ = scaled_dirty else: input_ = tf.concat([scaled_dirty, scaled_psf], axis=3) with tf.variable_scope("generator"): generator = create_generator(input_, 1, ngf=a.ngf, separable_conv=a.separable_conv) deprocessed_output = deprocess(generator, min_flux, max_flux) with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved with tf.name_scope("encode_fitss"): fits_fetches = { "indexs": index, "inputs": tf.map_fn(fits_encode, dirty, dtype=tf.string, name="input_fits"), "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"), "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"), } with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) saver = tf.train.Saver(max_to_keep=100) sv = tf.train.Supervisor(logdir=None, save_summaries_secs=0, saver=None) with sv.managed_session() as sess: print("parameter_count =", sess.run(parameter_count)) if a.checkpoint is not None: print("loading model from checkpoint") checkpoint = tf.train.latest_checkpoint(a.checkpoint) print("loaded {}".format(checkpoint)) saver.restore(sess, checkpoint) max_steps = 2 ** 32 if a.max_epochs is not None: max_steps = steps_per_epoch * a.max_epochs if a.max_steps is not None: max_steps = a.max_steps # at most, process the test data once max_steps = min(steps_per_epoch, max_steps) # repeat the same for fits arrays for step in range(max_steps): results = sess.run(fits_fetches) filesets = save_images(results, subfolder="fits", extention="fits", output_dir=a.output_dir) for f in filesets: print("wrote " + f['name'])
def main(): parser = argparse.ArgumentParser() parser.add_argument("--output_dir", required=True, help="where to put output files") parser.add_argument("--checkpoint", required=True, help="directory with checkpoint to resume training from or use for testing") parser.add_argument("--separable_conv", action="store_true", help="use separable convolutions in the generator") parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer") parser.add_argument('--disable_psf', action='store_true', help="disable the concatenation of the PSF as a channel") a = parser.parse_args() def load_data(dirty_path, psf_path): # type: (str, str) -> tf.data.Dataset def dataset_generator(): psf = fits_open(psf_path)[:, :, np.newaxis] dirty = fits_open(dirty_path)[:, :, np.newaxis] min_flux = dirty.min() max_flux = dirty.max() yield min_flux, max_flux, psf, dirty ds = tf.data.Dataset.from_generator(dataset_generator, output_shapes=((), ()) + ((256, 256, 1),) * 2, output_types=(tf.float32, tf.float32) + (tf.float32,) * 2 ) ds = ds.batch(1) return ds dirty_path = tf.placeholder(tf.string, shape=[1]) psf_path = tf.placeholder(tf.string, shape=[1]) batch = load_data(dirty_path, psf_path) iter = batch.make_one_shot_iterator() min_flux, max_flux, psf, dirty = iter.get_next() scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = preprocess(psf, min_flux, max_flux) if a.disable_psf: input_ = scaled_dirty else: input_ = tf.concat([scaled_dirty, scaled_psf], axis=3) with tf.variable_scope("generator"): generator = create_generator(input_, 1, ngf=a.ngf, separable_conv=a.separable_conv) batch_output = deprocess(generator, min_flux, max_flux) output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0] # lets just assume png for now output_data = tf.image.encode_png(output_image) output = tf.convert_to_tensor([tf.encode_base64(output_data)]) key = tf.placeholder(tf.string, shape=[1]) inputs = { "key": key.name, "input": dirty.name } tf.add_to_collection("inputs", json.dumps(inputs)) outputs = { "key": tf.identity(key).name, "output": output.name, } tf.add_to_collection("outputs", json.dumps(outputs)) init_op = tf.global_variables_initializer() restore_saver = tf.train.Saver() export_saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) print("loading model from checkpoint") checkpoint = tf.train.latest_checkpoint(a.checkpoint) restore_saver.restore(sess, checkpoint) print("exporting model") #export_saver.export_meta_graph(filename=os.path.join(a.output_dir, "export.meta")) export_saver.save(sess, os.path.join(a.output_dir, "export"), write_meta_graph=True) #, save_relative_paths=True)