Esempio n. 1
0
def main():
    if len(sys.argv) != 3:
        print("""
usage: {}  dirty-0.fits,dirty-1.fits,dirty-2.fits  psf-0.fits,psf-1.fits,psf2.fits
        
 note: names don't matter, order does. only supports fits files of {}x{}
       will write output the current folder.
""".format(sys.argv[0], CROP_SIZE, CROP_SIZE))
        sys.exit(1)

    dirties = [os.path.realpath(i) for i in sys.argv[1].split(',')]
    psfs = [os.path.realpath(i) for i in sys.argv[2].split(',')]
    assert len(dirties) == len(psfs)
    batch, count = load_data(dirties, psfs)
    steps_per_epoch = count
    iter = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty = iter.get_next()

    scaled_dirty = preprocess(dirty, min_flux, max_flux)
    scaled_psf = (psf * 2) - 1

    input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    # set up the network
    with tf.variable_scope("generator"):
        outputs = create_generator(input_, 1, a.ngf, a.separable_conv)
        deprocessed_output = deprocess(outputs, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME")
        residuals = dirty - convolved

    with tf.name_scope("encode_fitss"):
        fits_fetches = {
            "indexs": index,
            "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"),
            "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"),
        }

    with tf.Session() as sess:
        checkpoint = tf.train.latest_checkpoint(a.checkpoint)
        tf.train.Saver().restore(sess, checkpoint)

        for step in range(steps_per_epoch):
            results = sess.run(fits_fetches)
            filesets = save_images(results, subfolder=None, extention="fits", output_dir=a.output_dir)
            for f in filesets:
                print("wrote " + f['name'])
Esempio n. 2
0
def main():
    prepare()

    train_batch = generative_model(a.psf_glob,
                                   flux_scale_min=a.flux_scale_min,
                                   flux_scale_max=a.flux_scale_max)
    iterator = train_batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty, skymodel = iterator.get_next()

    with tf.name_scope("scaling_flux"):
        scaled_skymodel = preprocess(skymodel, min_flux, max_flux)
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    # inputs and targets are [batch_size, height, width, channels]
    model = create_model(scaled_dirty,
                         scaled_skymodel,
                         EPS,
                         a.separable_conv,
                         beta1=a.beta1,
                         gan_weight=a.gan_weight,
                         l1_weight=a.l1_weight,
                         lr=a.lr,
                         ndf=a.ndf,
                         ngf=a.ngf,
                         psf=psf,
                         min_flux=min_flux,
                         max_flux=max_flux,
                         res_weight=a.res_weight,
                         disable_psf=a.disable_psf)

    deprocessed_output = deprocess(model.outputs, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1],
                                 "SAME")
        residuals = dirty - convolved

    # reverse any processing on images so they can be written to disk or displayed to user
    with tf.name_scope("convert_images"):
        converted_inputs = tf.image.convert_image_dtype(
            visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True)
        converted_targets = tf.image.convert_image_dtype(
            visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True)
        converted_outputs = tf.image.convert_image_dtype(visual_scaling(
            model.outputs),
                                                         dtype=tf.uint8,
                                                         saturate=True)
        converted_psfs = tf.image.convert_image_dtype(
            visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True)
        converted_residuals = tf.image.convert_image_dtype(
            visual_scaling(residuals), dtype=tf.uint8, saturate=True)

    with tf.name_scope("encode_images"):
        display_fetches = {
            "indexs":
            index,
            "inputs":
            tf.map_fn(tf.image.encode_png,
                      converted_inputs,
                      dtype=tf.string,
                      name="input_pngs"),
            "targets":
            tf.map_fn(tf.image.encode_png,
                      converted_targets,
                      dtype=tf.string,
                      name="target_pngs"),
            "outputs":
            tf.map_fn(tf.image.encode_png,
                      converted_outputs,
                      dtype=tf.string,
                      name="output_pngs"),
            "psfs":
            tf.map_fn(tf.image.encode_png,
                      converted_psfs,
                      dtype=tf.string,
                      name="psf_pngs"),
            "residuals":
            tf.map_fn(tf.image.encode_png,
                      converted_residuals,
                      dtype=tf.string,
                      name="residual_pngs"),
        }

    # summaries
    with tf.name_scope("combined_summary"):
        tf.summary.image("inputs", converted_inputs)
        tf.summary.image("outputs", converted_outputs)
        tf.summary.image("targets", converted_targets)
        tf.summary.image("residuals", converted_residuals)

    with tf.name_scope("psfs_summary"):
        tf.summary.image("psfss", converted_psfs)

    with tf.name_scope("predict_real_summary"):
        tf.summary.image(
            "predict_real",
            tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))

    with tf.name_scope("predict_fake_summary"):
        tf.summary.image(
            "predict_fake",
            tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))

    tf.summary.scalar("discriminator_loss", model.discrim_loss)
    tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
    tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
    tf.summary.scalar("generator_loss_RES", model.gen_loss_RES)

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None

    train_summary_op = tf.summary.merge_all()
    train_summary_writer = tf.summary.FileWriter(logdir=logdir + '/train')

    saver = tf.train.Saver(max_to_keep=100)
    sv = tf.train.Supervisor(logdir=logdir,
                             save_summaries_secs=0,
                             saver=saver,
                             summary_writer=None,
                             summary_op=None)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)

    with sv.managed_session(config=tf.ConfigProto(
            gpu_options=gpu_options)) as sess:
        print("parameter_count =", sess.run(parameter_count))

        start = time.time()

        for step in range(a.max_steps):

            def should(freq):
                return freq > 0 and ((step + 1) % freq == 0
                                     or step == a.max_steps - 1)

            options = None
            run_metadata = None
            if should(a.trace_freq):
                print("preparing")
                options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

            fetches = {
                "train": model.train,
                "global_step": sv.global_step,
            }

            if should(a.progress_freq):
                print("progress step")
                fetches["discrim_loss"] = model.discrim_loss
                fetches["gen_loss_GAN"] = model.gen_loss_GAN
                fetches["gen_loss_L1"] = model.gen_loss_L1
                fetches["gen_loss_RES"] = model.gen_loss_RES

            if should(a.summary_freq):
                print("preparing summary")
                fetches["summary"] = train_summary_op

            if should(a.display_freq):
                print("display step step")
                fetches["display"] = display_fetches

            results = sess.run(fetches,
                               options=options,
                               run_metadata=run_metadata)

            if should(a.summary_freq):
                print("recording summary")
                train_summary_writer.add_summary(results["summary"],
                                                 results["global_step"])

            if should(a.display_freq):
                print("saving display images")
                save_images(results["display"],
                            step=results["global_step"],
                            output_dir=a.output_dir)

            if should(a.trace_freq):
                print("recording trace")
                train_summary_writer.add_run_metadata(
                    run_metadata, "step_%d" % results["global_step"])

            if should(a.progress_freq):
                rate = (step + 1) * a.batch_size / (time.time() - start)
                remaining = (a.max_steps - step) * a.batch_size / rate
                print("progress  step %d  image/sec %0.1f  remaining %dm" %
                      (results["global_step"], rate, remaining / 60))
                print("discrim_loss", results["discrim_loss"])
                print("gen_loss_GAN", results["gen_loss_GAN"])
                print("gen_loss_L1", results["gen_loss_L1"])
                print("gen_loss_RES", results["gen_loss_RES"])

            if should(a.save_freq):
                print("saving model")
                saver.save(sess,
                           os.path.join(a.output_dir, "model"),
                           global_step=sv.global_step)

            if sv.should_stop():
                print("supervisor things we should stop!")
                break

        print("done! bye")
Esempio n. 3
0
def main():
    prepare()

    train_batch, train_count = load_data(a.input_dir,
                                         CROP_SIZE,
                                         a.flip,
                                         a.scale_size,
                                         a.max_epochs,
                                         a.batch_size,
                                         start=a.train_start,
                                         end=a.train_end,
                                         loop=True)
    print("train count = %d" % train_count)
    steps_per_epoch = int(math.ceil(train_count / a.batch_size))
    training_iterator = train_batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty, skymodel = training_iterator.get_next(
    )

    with tf.name_scope("scaling_flux"):
        scaled_skymodel = preprocess(skymodel, min_flux, max_flux)
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    # inputs and targets are [batch_size, height, width, channels]
    model = create_model(scaled_dirty,
                         scaled_skymodel,
                         EPS,
                         a.separable_conv,
                         beta1=a.beta1,
                         gan_weight=a.gan_weight,
                         l1_weight=a.l1_weight,
                         lr=a.lr,
                         ndf=a.ndf,
                         ngf=a.ngf,
                         psf=scaled_psf,
                         min_flux=min_flux,
                         max_flux=max_flux,
                         res_weight=a.res_weight)

    # reverse any processing on images so they can be written to disk or displayed to user
    with tf.name_scope("convert_images"):
        converted_inputs = tf.image.convert_image_dtype(
            visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True)
        converted_targets = tf.image.convert_image_dtype(
            visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True)
        converted_outputs = tf.image.convert_image_dtype(visual_scaling(
            model.outputs),
                                                         dtype=tf.uint8,
                                                         saturate=True)
        converted_psfs = tf.image.convert_image_dtype(
            visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True)
        converted_residuals = tf.image.convert_image_dtype(visual_scaling(
            model.residuals),
                                                           dtype=tf.uint8,
                                                           saturate=True)
        converted_likelihood = tf.image.convert_image_dtype(visual_scaling(
            model.likelihood),
                                                            dtype=tf.uint8,
                                                            saturate=True)

    with tf.name_scope("encode_images"):
        display_fetches = {
            "indexs":
            index,
            "inputs":
            tf.map_fn(tf.image.encode_png,
                      converted_inputs,
                      dtype=tf.string,
                      name="input_pngs"),
            "targets":
            tf.map_fn(tf.image.encode_png,
                      converted_targets,
                      dtype=tf.string,
                      name="target_pngs"),
            "outputs":
            tf.map_fn(tf.image.encode_png,
                      converted_outputs,
                      dtype=tf.string,
                      name="output_pngs"),
            "psfs":
            tf.map_fn(tf.image.encode_png,
                      converted_psfs,
                      dtype=tf.string,
                      name="psf_pngs"),
            "residuals":
            tf.map_fn(tf.image.encode_png,
                      converted_residuals,
                      dtype=tf.string,
                      name="residual_pngs"),
            "likelihood":
            tf.map_fn(tf.image.encode_png,
                      converted_likelihood,
                      dtype=tf.string,
                      name="likelihood_pngs"),
        }

    # summaries
    with tf.name_scope("combined_summary"):
        tf.summary.image("inputs", converted_inputs)
        tf.summary.image("outputs", converted_outputs)
        tf.summary.image("targets", converted_targets)
        tf.summary.image("residuals", converted_residuals)
        tf.summary.image("likelihood", converted_likelihood)
        tf.summary.image("psfss", converted_psfs)

    with tf.name_scope("predict_summary"):
        tf.summary.image(
            "predict_real",
            tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))
        tf.summary.image(
            "predict_fake",
            tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))

    with tf.name_scope("generator_scalars"):
        tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
        tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
        tf.summary.scalar("generator_loss_RES", model.gen_loss_RES)

    with tf.name_scope("discriminator_scalars"):
        tf.summary.scalar("discriminator_loss", model.discrim_loss)

    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name + "/values", var)

    for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
        tf.summary.histogram(var.op.name + "/gradients", grad)

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
    saver = tf.train.Saver(max_to_keep=0)
    sv = tf.train.Supervisor(logdir=logdir,
                             save_summaries_secs=0,
                             save_model_secs=0,
                             saver=saver)

    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        max_steps = 2**32
        if a.max_epochs is not None:
            max_steps = steps_per_epoch * a.max_epochs
        if a.max_steps is not None:
            max_steps = a.max_steps

        start = time.time()

        for step in range(max_steps):

            def should(freq):
                return freq > 0 and ((step + 1) % freq == 0
                                     or step == max_steps - 1)

            options = None
            run_metadata = None
            if should(a.trace_freq):
                print("preparing")
                options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

            fetches = {
                "train": model.train,
                "global_step": sv.global_step,
            }

            if should(a.progress_freq):
                print("progress step")
                fetches["discrim_loss"] = model.discrim_loss
                fetches["gen_loss_GAN"] = model.gen_loss_GAN
                fetches["gen_loss_L1"] = model.gen_loss_L1
                fetches["gen_loss_RES"] = model.gen_loss_RES

            if should(a.summary_freq):
                print("preparing summary")
                fetches["summary"] = sv.summary_op

            if should(a.display_freq):
                print("display step step")
                fetches["display"] = display_fetches

            results = sess.run(fetches,
                               options=options,
                               run_metadata=run_metadata)

            if should(a.summary_freq):
                print("recording summary")
                sv.summary_writer.add_summary(results["summary"],
                                              results["global_step"])

            if should(a.display_freq):
                print("saving display images")
                save_images(results["display"],
                            step=results["global_step"],
                            output_dir=a.output_dir)

            if should(a.trace_freq):
                print("recording trace")
                sv.summary_writer.add_run_metadata(
                    run_metadata, "step_%d" % results["global_step"])

            if should(a.progress_freq):
                # global_step will have the correct step count if we resume from a checkpoint
                train_epoch = math.ceil(results["global_step"] /
                                        steps_per_epoch)
                train_step = (results["global_step"] - 1) % steps_per_epoch + 1
                rate = (step + 1) * a.batch_size / (time.time() - start)
                remaining = (max_steps - step) * a.batch_size / rate
                print(
                    "progress  epoch %d  step %d  image/sec %0.1f  remaining %dm"
                    % (train_epoch, train_step, rate, remaining / 60))
                print("discrim_loss", results["discrim_loss"])
                print("gen_loss_GAN", results["gen_loss_GAN"])
                print("gen_loss_L1", results["gen_loss_L1"])
                print("gen_loss_RES", results["gen_loss_RES"])

            if should(a.save_freq):
                print("saving model")
                sv.saver.save(sess,
                              os.path.join(a.output_dir, "model"),
                              global_step=sv.global_step)

            if sv.should_stop():
                print("supervisor things we should stop!")
                break

        print("done! bye")
Esempio n. 4
0
def main():
    prepare()

    train_batch, train_count = load_data(a.input_dir,
                                         CROP_SIZE,
                                         a.flip,
                                         a.scale_size,
                                         a.max_epochs,
                                         a.batch_size,
                                         start=a.train_start,
                                         end=a.train_end,
                                         loop=True)
    print("train count = %d" % train_count)
    steps_per_epoch = int(math.ceil(train_count / a.batch_size))
    training_iterator = train_batch.make_one_shot_iterator()

    validate_batch, validate_count = load_data(a.input_dir,
                                               CROP_SIZE,
                                               a.flip,
                                               a.scale_size,
                                               a.max_epochs,
                                               a.batch_size,
                                               start=a.validate_start,
                                               end=a.validate_end,
                                               loop=True)
    print("validate count = %d" % validate_count)
    validation_iterator = validate_batch.make_one_shot_iterator()

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   train_batch.output_types,
                                                   train_batch.output_shapes)
    index, min_flux, max_flux, psf, dirty, skymodel = iterator.get_next()

    with tf.name_scope("scaling_flux"):
        scaled_skymodel = preprocess(skymodel, min_flux, max_flux)
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    ## i'll just leave this here in case we decide to do something with visibilities again (unlikely)
    # vis = tf.fft2d(tf.complex(scaled_dirty, tf.zeros(shape=(a.batch_size, CROP_SIZE, CROP_SIZE, 1))))
    # real = tf.real(vis)
    # imag = tf.imag(vis)

    if a.disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    # inputs and targets are [batch_size, height, width, channels]
    model = create_model(input_,
                         scaled_skymodel,
                         EPS,
                         a.separable_conv,
                         beta1=a.beta1,
                         gan_weight=a.gan_weight,
                         l1_weight=a.l1_weight,
                         lr=a.lr,
                         ndf=a.ndf,
                         ngf=a.ngf,
                         psf=psf,
                         min_flux=min_flux,
                         max_flux=max_flux,
                         res_weight=a.res_weight)

    deprocessed_output = deprocess(model.outputs, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1],
                                 "SAME")
        residuals = dirty - convolved

    # reverse any processing on images so they can be written to disk or displayed to user
    with tf.name_scope("convert_images"):
        converted_inputs = tf.image.convert_image_dtype(
            visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True)
        converted_targets = tf.image.convert_image_dtype(
            visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True)
        converted_outputs = tf.image.convert_image_dtype(visual_scaling(
            model.outputs),
                                                         dtype=tf.uint8,
                                                         saturate=True)
        converted_psfs = tf.image.convert_image_dtype(
            visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True)
        converted_residuals = tf.image.convert_image_dtype(
            visual_scaling(residuals), dtype=tf.uint8, saturate=True)

    with tf.name_scope("encode_images"):
        display_fetches = {
            "indexs":
            index,
            "inputs":
            tf.map_fn(tf.image.encode_png,
                      converted_inputs,
                      dtype=tf.string,
                      name="input_pngs"),
            "targets":
            tf.map_fn(tf.image.encode_png,
                      converted_targets,
                      dtype=tf.string,
                      name="target_pngs"),
            "outputs":
            tf.map_fn(tf.image.encode_png,
                      converted_outputs,
                      dtype=tf.string,
                      name="output_pngs"),
            "psfs":
            tf.map_fn(tf.image.encode_png,
                      converted_psfs,
                      dtype=tf.string,
                      name="psf_pngs"),
            "residuals":
            tf.map_fn(tf.image.encode_png,
                      converted_residuals,
                      dtype=tf.string,
                      name="residual_pngs"),
        }

    # summaries
    with tf.name_scope("combined_summary"):
        tf.summary.image("inputs", converted_inputs)
        tf.summary.image("outputs", converted_outputs)
        tf.summary.image("targets", converted_targets)
        tf.summary.image("residuals", converted_residuals)

    with tf.name_scope("psfs_summary"):
        tf.summary.image("psfss", converted_psfs)

    with tf.name_scope("predict_real_summary"):
        tf.summary.image(
            "predict_real",
            tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))

    with tf.name_scope("predict_fake_summary"):
        tf.summary.image(
            "predict_fake",
            tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))

    tf.summary.scalar("discriminator_loss", model.discrim_loss)
    tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
    tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
    tf.summary.scalar("generator_loss_RES", model.gen_loss_RES)

    if a.validation_freq:
        tf.summary.scalar("Validation generator_loss_GAN", model.gen_loss_GAN)
        tf.summary.scalar("Validation generator_loss_L1", model.gen_loss_L1)
    """
    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name + "/values", var)

    for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
        tf.summary.histogram(var.op.name + "/gradients", grad)
    """

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None

    validation_summary_op = tf.summary.merge_all()
    validation_summary_writer = tf.summary.FileWriter(logdir=logdir +
                                                      '/validation')

    train_summary_op = tf.summary.merge_all()
    train_summary_writer = tf.summary.FileWriter(logdir=logdir + '/train')

    saver = tf.train.Saver(max_to_keep=100)
    sv = tf.train.Supervisor(logdir=logdir,
                             save_summaries_secs=0,
                             saver=saver,
                             summary_writer=None,
                             summary_op=None)

    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        # The `Iterator.string_handle()` method returns a tensor that can be evaluated
        # and used to feed the `handle` placeholder.
        training_handle = sess.run(training_iterator.string_handle())

        validation_handle = sess.run(validation_iterator.string_handle())

        max_steps = 2**32
        if a.max_epochs is not None:
            max_steps = steps_per_epoch * a.max_epochs
        if a.max_steps is not None:
            max_steps = a.max_steps

        start = time.time()

        for step in range(max_steps):

            def should(freq):
                return freq > 0 and ((step + 1) % freq == 0
                                     or step == max_steps - 1)

            options = None
            run_metadata = None
            if should(a.trace_freq):
                print("preparing")
                options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

            fetches = {
                "train": model.train,
                "global_step": sv.global_step,
            }

            if should(a.progress_freq):
                print("progress step")
                fetches["discrim_loss"] = model.discrim_loss
                fetches["gen_loss_GAN"] = model.gen_loss_GAN
                fetches["gen_loss_L1"] = model.gen_loss_L1
                fetches["gen_loss_RES"] = model.gen_loss_RES

            if should(a.summary_freq):
                print("preparing summary")
                fetches["summary"] = train_summary_op

            if should(a.display_freq):
                print("display step step")
                fetches["display"] = display_fetches

            results = sess.run(fetches,
                               options=options,
                               run_metadata=run_metadata,
                               feed_dict={handle: training_handle})

            if should(a.summary_freq):
                print("recording summary")
                train_summary_writer.add_summary(results["summary"],
                                                 results["global_step"])

            if should(a.display_freq):
                print("saving display images")
                save_images(results["display"],
                            step=results["global_step"],
                            output_dir=a.output_dir)

            if should(a.trace_freq):
                print("recording trace")
                train_summary_writer.add_run_metadata(
                    run_metadata, "step_%d" % results["global_step"])

            if should(a.progress_freq):
                # global_step will have the correct step count if we resume from a checkpoint
                train_epoch = math.ceil(results["global_step"] /
                                        steps_per_epoch)
                train_step = (results["global_step"] - 1) % steps_per_epoch + 1
                rate = (step + 1) * a.batch_size / (time.time() - start)
                remaining = (max_steps - step) * a.batch_size / rate
                print(
                    "progress  epoch %d  step %d  image/sec %0.1f  remaining %dm"
                    % (train_epoch, train_step, rate, remaining / 60))
                print("discrim_loss", results["discrim_loss"])
                print("gen_loss_GAN", results["gen_loss_GAN"])
                print("gen_loss_L1", results["gen_loss_L1"])
                print("gen_loss_RES", results["gen_loss_RES"])

            if should(a.save_freq):
                print("saving model")
                saver.save(sess,
                           os.path.join(a.output_dir, "model"),
                           global_step=sv.global_step)

            if a.validation_freq and should(a.validation_freq):
                print("validation step")
                validation_fetches = {
                    "validation_gen_loss_GAN": model.gen_loss_GAN,
                    "validation_gen_loss_L1": model.gen_loss_L1,
                    "summary": validation_summary_op,
                }

                validation_results = sess.run(
                    validation_fetches, feed_dict={handle: validation_handle})
                print("validation gen_loss_GAN",
                      validation_results["validation_gen_loss_GAN"])
                print("validation gen_loss_L1",
                      validation_results["validation_gen_loss_L1"])
                validation_summary_writer.add_summary(
                    validation_results["summary"], results["global_step"])

            if sv.should_stop():
                print("supervisor things we should stop!")
                break

        print("done! bye")
Esempio n. 5
0
def test(
    input_dir,
    output_dir,
    checkpoint,
    batch_size=1,
    test_start=0,
    test_end=999,
    disable_psf=False,
    ngf=64,
    separable_conv=False,
    write_residuals=False,
    write_input=False,
):
    batch, count = load_data(path=input_dir,
                             flip=False,
                             crop_size=CROP_SIZE,
                             scale_size=CROP_SIZE,
                             max_epochs=1,
                             batch_size=batch_size,
                             start=test_start,
                             end=test_end)
    steps_per_epoch = int(math.ceil(count / batch_size))
    iter_ = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next()
    print("train count = %d" % count)

    with tf.name_scope("scaling_flux"):
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    if disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([dirty, psf[:, 128:-128, 128:-128, :]], axis=3)

    with tf.variable_scope("generator"):
        generator = create_generator(input_,
                                     1,
                                     ngf=ngf,
                                     separable_conv=separable_conv)
        deprocessed_output = deprocess(generator, min_flux, max_flux)

    if write_residuals:
        with tf.name_scope("calculate_residuals"):
            shifted = shift(psf, y=-1, x=-1)
            filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
            convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1],
                                     "SAME")
            residuals = dirty - convolved

    with tf.name_scope("encode_fitss"):
        work = {
            "indexs":
            index,
            "outputs":
            tf.map_fn(fits_encode,
                      deprocessed_output,
                      dtype=tf.string,
                      name="output_fits"),
        }
        if write_residuals:
            work["residuals"] = tf.map_fn(fits_encode,
                                          residuals,
                                          dtype=tf.string,
                                          name="residuals_fits")
        if write_input:
            work["inputs"] = tf.map_fn(fits_encode,
                                       dirty,
                                       dtype=tf.string,
                                       name="input_fits")

    sv = tf.train.Supervisor(logdir=None)
    with sv.managed_session() as sess:
        sv.saver.restore(sess, checkpoint)

        for step in range(steps_per_epoch):
            results = sess.run(work)
            filesets = save_images(results,
                                   subfolder="fits",
                                   extention="fits",
                                   output_dir=output_dir)
            for f in filesets:
                print("wrote " + f['name'])
Esempio n. 6
0
File: test.py Progetto: richarms/gvc
def main():
    prepare()

    batch, count = load_data(path=a.input_dir, flip=False, crop_size=CROP_SIZE, scale_size=CROP_SIZE, max_epochs=1,
                             batch_size=a.batch_size, start=a.test_start, end=a.test_end)
    steps_per_epoch = int(math.ceil(count / a.batch_size))
    iter_ = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next()
    print("train count = %d" % count)

    with tf.name_scope("scaling_flux"):
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    if a.disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    with tf.variable_scope("generator"):
        generator = create_generator(input_, 1, ngf=a.ngf, separable_conv=a.separable_conv)
        deprocessed_output = deprocess(generator, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME")
        residuals = dirty - convolved

    with tf.name_scope("encode_fitss"):
        fits_fetches = {
            "indexs": index,
            "inputs": tf.map_fn(fits_encode, dirty, dtype=tf.string, name="input_fits"),
            "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"),
            "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"),
        }

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver(max_to_keep=100)

    sv = tf.train.Supervisor(logdir=None, save_summaries_secs=0, saver=None)
    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        if a.checkpoint is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
            print("loaded {}".format(checkpoint))
            saver.restore(sess, checkpoint)

        max_steps = 2 ** 32
        if a.max_epochs is not None:
            max_steps = steps_per_epoch * a.max_epochs
        if a.max_steps is not None:
            max_steps = a.max_steps

        # at most, process the test data once
        max_steps = min(steps_per_epoch, max_steps)

        # repeat the same for fits arrays
        for step in range(max_steps):
            results = sess.run(fits_fetches)
            filesets = save_images(results, subfolder="fits", extention="fits", output_dir=a.output_dir)
            for f in filesets:
                print("wrote " + f['name'])