def main(): if len(sys.argv) != 3: print(""" usage: {} dirty-0.fits,dirty-1.fits,dirty-2.fits psf-0.fits,psf-1.fits,psf2.fits note: names don't matter, order does. only supports fits files of {}x{} will write output the current folder. """.format(sys.argv[0], CROP_SIZE, CROP_SIZE)) sys.exit(1) dirties = [os.path.realpath(i) for i in sys.argv[1].split(',')] psfs = [os.path.realpath(i) for i in sys.argv[2].split(',')] assert len(dirties) == len(psfs) batch, count = load_data(dirties, psfs) steps_per_epoch = count iter = batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty = iter.get_next() scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 input_ = tf.concat([scaled_dirty, scaled_psf], axis=3) # set up the network with tf.variable_scope("generator"): outputs = create_generator(input_, 1, a.ngf, a.separable_conv) deprocessed_output = deprocess(outputs, min_flux, max_flux) with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved with tf.name_scope("encode_fitss"): fits_fetches = { "indexs": index, "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"), "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"), } with tf.Session() as sess: checkpoint = tf.train.latest_checkpoint(a.checkpoint) tf.train.Saver().restore(sess, checkpoint) for step in range(steps_per_epoch): results = sess.run(fits_fetches) filesets = save_images(results, subfolder=None, extention="fits", output_dir=a.output_dir) for f in filesets: print("wrote " + f['name'])
def main(): prepare() train_batch = generative_model(a.psf_glob, flux_scale_min=a.flux_scale_min, flux_scale_max=a.flux_scale_max) iterator = train_batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty, skymodel = iterator.get_next() with tf.name_scope("scaling_flux"): scaled_skymodel = preprocess(skymodel, min_flux, max_flux) scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 # inputs and targets are [batch_size, height, width, channels] model = create_model(scaled_dirty, scaled_skymodel, EPS, a.separable_conv, beta1=a.beta1, gan_weight=a.gan_weight, l1_weight=a.l1_weight, lr=a.lr, ndf=a.ndf, ngf=a.ngf, psf=psf, min_flux=min_flux, max_flux=max_flux, res_weight=a.res_weight, disable_psf=a.disable_psf) deprocessed_output = deprocess(model.outputs, min_flux, max_flux) with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved # reverse any processing on images so they can be written to disk or displayed to user with tf.name_scope("convert_images"): converted_inputs = tf.image.convert_image_dtype( visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True) converted_targets = tf.image.convert_image_dtype( visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True) converted_outputs = tf.image.convert_image_dtype(visual_scaling( model.outputs), dtype=tf.uint8, saturate=True) converted_psfs = tf.image.convert_image_dtype( visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True) converted_residuals = tf.image.convert_image_dtype( visual_scaling(residuals), dtype=tf.uint8, saturate=True) with tf.name_scope("encode_images"): display_fetches = { "indexs": index, "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), "psfs": tf.map_fn(tf.image.encode_png, converted_psfs, dtype=tf.string, name="psf_pngs"), "residuals": tf.map_fn(tf.image.encode_png, converted_residuals, dtype=tf.string, name="residual_pngs"), } # summaries with tf.name_scope("combined_summary"): tf.summary.image("inputs", converted_inputs) tf.summary.image("outputs", converted_outputs) tf.summary.image("targets", converted_targets) tf.summary.image("residuals", converted_residuals) with tf.name_scope("psfs_summary"): tf.summary.image("psfss", converted_psfs) with tf.name_scope("predict_real_summary"): tf.summary.image( "predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) with tf.name_scope("predict_fake_summary"): tf.summary.image( "predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) tf.summary.scalar("discriminator_loss", model.discrim_loss) tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) tf.summary.scalar("generator_loss_RES", model.gen_loss_RES) with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None train_summary_op = tf.summary.merge_all() train_summary_writer = tf.summary.FileWriter(logdir=logdir + '/train') saver = tf.train.Saver(max_to_keep=100) sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=saver, summary_writer=None, summary_op=None) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6) with sv.managed_session(config=tf.ConfigProto( gpu_options=gpu_options)) as sess: print("parameter_count =", sess.run(parameter_count)) start = time.time() for step in range(a.max_steps): def should(freq): return freq > 0 and ((step + 1) % freq == 0 or step == a.max_steps - 1) options = None run_metadata = None if should(a.trace_freq): print("preparing") options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() fetches = { "train": model.train, "global_step": sv.global_step, } if should(a.progress_freq): print("progress step") fetches["discrim_loss"] = model.discrim_loss fetches["gen_loss_GAN"] = model.gen_loss_GAN fetches["gen_loss_L1"] = model.gen_loss_L1 fetches["gen_loss_RES"] = model.gen_loss_RES if should(a.summary_freq): print("preparing summary") fetches["summary"] = train_summary_op if should(a.display_freq): print("display step step") fetches["display"] = display_fetches results = sess.run(fetches, options=options, run_metadata=run_metadata) if should(a.summary_freq): print("recording summary") train_summary_writer.add_summary(results["summary"], results["global_step"]) if should(a.display_freq): print("saving display images") save_images(results["display"], step=results["global_step"], output_dir=a.output_dir) if should(a.trace_freq): print("recording trace") train_summary_writer.add_run_metadata( run_metadata, "step_%d" % results["global_step"]) if should(a.progress_freq): rate = (step + 1) * a.batch_size / (time.time() - start) remaining = (a.max_steps - step) * a.batch_size / rate print("progress step %d image/sec %0.1f remaining %dm" % (results["global_step"], rate, remaining / 60)) print("discrim_loss", results["discrim_loss"]) print("gen_loss_GAN", results["gen_loss_GAN"]) print("gen_loss_L1", results["gen_loss_L1"]) print("gen_loss_RES", results["gen_loss_RES"]) if should(a.save_freq): print("saving model") saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step) if sv.should_stop(): print("supervisor things we should stop!") break print("done! bye")
def main(): prepare() train_batch, train_count = load_data(a.input_dir, CROP_SIZE, a.flip, a.scale_size, a.max_epochs, a.batch_size, start=a.train_start, end=a.train_end, loop=True) print("train count = %d" % train_count) steps_per_epoch = int(math.ceil(train_count / a.batch_size)) training_iterator = train_batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty, skymodel = training_iterator.get_next( ) with tf.name_scope("scaling_flux"): scaled_skymodel = preprocess(skymodel, min_flux, max_flux) scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 # inputs and targets are [batch_size, height, width, channels] model = create_model(scaled_dirty, scaled_skymodel, EPS, a.separable_conv, beta1=a.beta1, gan_weight=a.gan_weight, l1_weight=a.l1_weight, lr=a.lr, ndf=a.ndf, ngf=a.ngf, psf=scaled_psf, min_flux=min_flux, max_flux=max_flux, res_weight=a.res_weight) # reverse any processing on images so they can be written to disk or displayed to user with tf.name_scope("convert_images"): converted_inputs = tf.image.convert_image_dtype( visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True) converted_targets = tf.image.convert_image_dtype( visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True) converted_outputs = tf.image.convert_image_dtype(visual_scaling( model.outputs), dtype=tf.uint8, saturate=True) converted_psfs = tf.image.convert_image_dtype( visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True) converted_residuals = tf.image.convert_image_dtype(visual_scaling( model.residuals), dtype=tf.uint8, saturate=True) converted_likelihood = tf.image.convert_image_dtype(visual_scaling( model.likelihood), dtype=tf.uint8, saturate=True) with tf.name_scope("encode_images"): display_fetches = { "indexs": index, "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), "psfs": tf.map_fn(tf.image.encode_png, converted_psfs, dtype=tf.string, name="psf_pngs"), "residuals": tf.map_fn(tf.image.encode_png, converted_residuals, dtype=tf.string, name="residual_pngs"), "likelihood": tf.map_fn(tf.image.encode_png, converted_likelihood, dtype=tf.string, name="likelihood_pngs"), } # summaries with tf.name_scope("combined_summary"): tf.summary.image("inputs", converted_inputs) tf.summary.image("outputs", converted_outputs) tf.summary.image("targets", converted_targets) tf.summary.image("residuals", converted_residuals) tf.summary.image("likelihood", converted_likelihood) tf.summary.image("psfss", converted_psfs) with tf.name_scope("predict_summary"): tf.summary.image( "predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) tf.summary.image( "predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) with tf.name_scope("generator_scalars"): tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) tf.summary.scalar("generator_loss_RES", model.gen_loss_RES) with tf.name_scope("discriminator_scalars"): tf.summary.scalar("discriminator_loss", model.discrim_loss) for var in tf.trainable_variables(): tf.summary.histogram(var.op.name + "/values", var) for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars: tf.summary.histogram(var.op.name + "/gradients", grad) with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None saver = tf.train.Saver(max_to_keep=0) sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, save_model_secs=0, saver=saver) with sv.managed_session() as sess: print("parameter_count =", sess.run(parameter_count)) max_steps = 2**32 if a.max_epochs is not None: max_steps = steps_per_epoch * a.max_epochs if a.max_steps is not None: max_steps = a.max_steps start = time.time() for step in range(max_steps): def should(freq): return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1) options = None run_metadata = None if should(a.trace_freq): print("preparing") options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() fetches = { "train": model.train, "global_step": sv.global_step, } if should(a.progress_freq): print("progress step") fetches["discrim_loss"] = model.discrim_loss fetches["gen_loss_GAN"] = model.gen_loss_GAN fetches["gen_loss_L1"] = model.gen_loss_L1 fetches["gen_loss_RES"] = model.gen_loss_RES if should(a.summary_freq): print("preparing summary") fetches["summary"] = sv.summary_op if should(a.display_freq): print("display step step") fetches["display"] = display_fetches results = sess.run(fetches, options=options, run_metadata=run_metadata) if should(a.summary_freq): print("recording summary") sv.summary_writer.add_summary(results["summary"], results["global_step"]) if should(a.display_freq): print("saving display images") save_images(results["display"], step=results["global_step"], output_dir=a.output_dir) if should(a.trace_freq): print("recording trace") sv.summary_writer.add_run_metadata( run_metadata, "step_%d" % results["global_step"]) if should(a.progress_freq): # global_step will have the correct step count if we resume from a checkpoint train_epoch = math.ceil(results["global_step"] / steps_per_epoch) train_step = (results["global_step"] - 1) % steps_per_epoch + 1 rate = (step + 1) * a.batch_size / (time.time() - start) remaining = (max_steps - step) * a.batch_size / rate print( "progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60)) print("discrim_loss", results["discrim_loss"]) print("gen_loss_GAN", results["gen_loss_GAN"]) print("gen_loss_L1", results["gen_loss_L1"]) print("gen_loss_RES", results["gen_loss_RES"]) if should(a.save_freq): print("saving model") sv.saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step) if sv.should_stop(): print("supervisor things we should stop!") break print("done! bye")
def main(): prepare() train_batch, train_count = load_data(a.input_dir, CROP_SIZE, a.flip, a.scale_size, a.max_epochs, a.batch_size, start=a.train_start, end=a.train_end, loop=True) print("train count = %d" % train_count) steps_per_epoch = int(math.ceil(train_count / a.batch_size)) training_iterator = train_batch.make_one_shot_iterator() validate_batch, validate_count = load_data(a.input_dir, CROP_SIZE, a.flip, a.scale_size, a.max_epochs, a.batch_size, start=a.validate_start, end=a.validate_end, loop=True) print("validate count = %d" % validate_count) validation_iterator = validate_batch.make_one_shot_iterator() handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle(handle, train_batch.output_types, train_batch.output_shapes) index, min_flux, max_flux, psf, dirty, skymodel = iterator.get_next() with tf.name_scope("scaling_flux"): scaled_skymodel = preprocess(skymodel, min_flux, max_flux) scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 ## i'll just leave this here in case we decide to do something with visibilities again (unlikely) # vis = tf.fft2d(tf.complex(scaled_dirty, tf.zeros(shape=(a.batch_size, CROP_SIZE, CROP_SIZE, 1)))) # real = tf.real(vis) # imag = tf.imag(vis) if a.disable_psf: input_ = scaled_dirty else: input_ = tf.concat([scaled_dirty, scaled_psf], axis=3) # inputs and targets are [batch_size, height, width, channels] model = create_model(input_, scaled_skymodel, EPS, a.separable_conv, beta1=a.beta1, gan_weight=a.gan_weight, l1_weight=a.l1_weight, lr=a.lr, ndf=a.ndf, ngf=a.ngf, psf=psf, min_flux=min_flux, max_flux=max_flux, res_weight=a.res_weight) deprocessed_output = deprocess(model.outputs, min_flux, max_flux) with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved # reverse any processing on images so they can be written to disk or displayed to user with tf.name_scope("convert_images"): converted_inputs = tf.image.convert_image_dtype( visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True) converted_targets = tf.image.convert_image_dtype( visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True) converted_outputs = tf.image.convert_image_dtype(visual_scaling( model.outputs), dtype=tf.uint8, saturate=True) converted_psfs = tf.image.convert_image_dtype( visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True) converted_residuals = tf.image.convert_image_dtype( visual_scaling(residuals), dtype=tf.uint8, saturate=True) with tf.name_scope("encode_images"): display_fetches = { "indexs": index, "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"), "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"), "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"), "psfs": tf.map_fn(tf.image.encode_png, converted_psfs, dtype=tf.string, name="psf_pngs"), "residuals": tf.map_fn(tf.image.encode_png, converted_residuals, dtype=tf.string, name="residual_pngs"), } # summaries with tf.name_scope("combined_summary"): tf.summary.image("inputs", converted_inputs) tf.summary.image("outputs", converted_outputs) tf.summary.image("targets", converted_targets) tf.summary.image("residuals", converted_residuals) with tf.name_scope("psfs_summary"): tf.summary.image("psfss", converted_psfs) with tf.name_scope("predict_real_summary"): tf.summary.image( "predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) with tf.name_scope("predict_fake_summary"): tf.summary.image( "predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) tf.summary.scalar("discriminator_loss", model.discrim_loss) tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) tf.summary.scalar("generator_loss_RES", model.gen_loss_RES) if a.validation_freq: tf.summary.scalar("Validation generator_loss_GAN", model.gen_loss_GAN) tf.summary.scalar("Validation generator_loss_L1", model.gen_loss_L1) """ for var in tf.trainable_variables(): tf.summary.histogram(var.op.name + "/values", var) for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars: tf.summary.histogram(var.op.name + "/gradients", grad) """ with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum( [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None validation_summary_op = tf.summary.merge_all() validation_summary_writer = tf.summary.FileWriter(logdir=logdir + '/validation') train_summary_op = tf.summary.merge_all() train_summary_writer = tf.summary.FileWriter(logdir=logdir + '/train') saver = tf.train.Saver(max_to_keep=100) sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=saver, summary_writer=None, summary_op=None) with sv.managed_session() as sess: print("parameter_count =", sess.run(parameter_count)) # The `Iterator.string_handle()` method returns a tensor that can be evaluated # and used to feed the `handle` placeholder. training_handle = sess.run(training_iterator.string_handle()) validation_handle = sess.run(validation_iterator.string_handle()) max_steps = 2**32 if a.max_epochs is not None: max_steps = steps_per_epoch * a.max_epochs if a.max_steps is not None: max_steps = a.max_steps start = time.time() for step in range(max_steps): def should(freq): return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1) options = None run_metadata = None if should(a.trace_freq): print("preparing") options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() fetches = { "train": model.train, "global_step": sv.global_step, } if should(a.progress_freq): print("progress step") fetches["discrim_loss"] = model.discrim_loss fetches["gen_loss_GAN"] = model.gen_loss_GAN fetches["gen_loss_L1"] = model.gen_loss_L1 fetches["gen_loss_RES"] = model.gen_loss_RES if should(a.summary_freq): print("preparing summary") fetches["summary"] = train_summary_op if should(a.display_freq): print("display step step") fetches["display"] = display_fetches results = sess.run(fetches, options=options, run_metadata=run_metadata, feed_dict={handle: training_handle}) if should(a.summary_freq): print("recording summary") train_summary_writer.add_summary(results["summary"], results["global_step"]) if should(a.display_freq): print("saving display images") save_images(results["display"], step=results["global_step"], output_dir=a.output_dir) if should(a.trace_freq): print("recording trace") train_summary_writer.add_run_metadata( run_metadata, "step_%d" % results["global_step"]) if should(a.progress_freq): # global_step will have the correct step count if we resume from a checkpoint train_epoch = math.ceil(results["global_step"] / steps_per_epoch) train_step = (results["global_step"] - 1) % steps_per_epoch + 1 rate = (step + 1) * a.batch_size / (time.time() - start) remaining = (max_steps - step) * a.batch_size / rate print( "progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60)) print("discrim_loss", results["discrim_loss"]) print("gen_loss_GAN", results["gen_loss_GAN"]) print("gen_loss_L1", results["gen_loss_L1"]) print("gen_loss_RES", results["gen_loss_RES"]) if should(a.save_freq): print("saving model") saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step) if a.validation_freq and should(a.validation_freq): print("validation step") validation_fetches = { "validation_gen_loss_GAN": model.gen_loss_GAN, "validation_gen_loss_L1": model.gen_loss_L1, "summary": validation_summary_op, } validation_results = sess.run( validation_fetches, feed_dict={handle: validation_handle}) print("validation gen_loss_GAN", validation_results["validation_gen_loss_GAN"]) print("validation gen_loss_L1", validation_results["validation_gen_loss_L1"]) validation_summary_writer.add_summary( validation_results["summary"], results["global_step"]) if sv.should_stop(): print("supervisor things we should stop!") break print("done! bye")
def test( input_dir, output_dir, checkpoint, batch_size=1, test_start=0, test_end=999, disable_psf=False, ngf=64, separable_conv=False, write_residuals=False, write_input=False, ): batch, count = load_data(path=input_dir, flip=False, crop_size=CROP_SIZE, scale_size=CROP_SIZE, max_epochs=1, batch_size=batch_size, start=test_start, end=test_end) steps_per_epoch = int(math.ceil(count / batch_size)) iter_ = batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next() print("train count = %d" % count) with tf.name_scope("scaling_flux"): scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 if disable_psf: input_ = scaled_dirty else: input_ = tf.concat([dirty, psf[:, 128:-128, 128:-128, :]], axis=3) with tf.variable_scope("generator"): generator = create_generator(input_, 1, ngf=ngf, separable_conv=separable_conv) deprocessed_output = deprocess(generator, min_flux, max_flux) if write_residuals: with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved with tf.name_scope("encode_fitss"): work = { "indexs": index, "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"), } if write_residuals: work["residuals"] = tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits") if write_input: work["inputs"] = tf.map_fn(fits_encode, dirty, dtype=tf.string, name="input_fits") sv = tf.train.Supervisor(logdir=None) with sv.managed_session() as sess: sv.saver.restore(sess, checkpoint) for step in range(steps_per_epoch): results = sess.run(work) filesets = save_images(results, subfolder="fits", extention="fits", output_dir=output_dir) for f in filesets: print("wrote " + f['name'])
def main(): prepare() batch, count = load_data(path=a.input_dir, flip=False, crop_size=CROP_SIZE, scale_size=CROP_SIZE, max_epochs=1, batch_size=a.batch_size, start=a.test_start, end=a.test_end) steps_per_epoch = int(math.ceil(count / a.batch_size)) iter_ = batch.make_one_shot_iterator() index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next() print("train count = %d" % count) with tf.name_scope("scaling_flux"): scaled_dirty = preprocess(dirty, min_flux, max_flux) scaled_psf = (psf * 2) - 1 if a.disable_psf: input_ = scaled_dirty else: input_ = tf.concat([scaled_dirty, scaled_psf], axis=3) with tf.variable_scope("generator"): generator = create_generator(input_, 1, ngf=a.ngf, separable_conv=a.separable_conv) deprocessed_output = deprocess(generator, min_flux, max_flux) with tf.name_scope("calculate_residuals"): shifted = shift(psf, y=-1, x=-1) filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3) convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME") residuals = dirty - convolved with tf.name_scope("encode_fitss"): fits_fetches = { "indexs": index, "inputs": tf.map_fn(fits_encode, dirty, dtype=tf.string, name="input_fits"), "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"), "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"), } with tf.name_scope("parameter_count"): parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) saver = tf.train.Saver(max_to_keep=100) sv = tf.train.Supervisor(logdir=None, save_summaries_secs=0, saver=None) with sv.managed_session() as sess: print("parameter_count =", sess.run(parameter_count)) if a.checkpoint is not None: print("loading model from checkpoint") checkpoint = tf.train.latest_checkpoint(a.checkpoint) print("loaded {}".format(checkpoint)) saver.restore(sess, checkpoint) max_steps = 2 ** 32 if a.max_epochs is not None: max_steps = steps_per_epoch * a.max_epochs if a.max_steps is not None: max_steps = a.max_steps # at most, process the test data once max_steps = min(steps_per_epoch, max_steps) # repeat the same for fits arrays for step in range(max_steps): results = sess.run(fits_fetches) filesets = save_images(results, subfolder="fits", extention="fits", output_dir=a.output_dir) for f in filesets: print("wrote " + f['name'])