Exemple #1
0
def main():
    if len(sys.argv) != 3:
        print("""
usage: {}  dirty-0.fits,dirty-1.fits,dirty-2.fits  psf-0.fits,psf-1.fits,psf2.fits
        
 note: names don't matter, order does. only supports fits files of {}x{}
       will write output the current folder.
""".format(sys.argv[0], CROP_SIZE, CROP_SIZE))
        sys.exit(1)

    dirties = [os.path.realpath(i) for i in sys.argv[1].split(',')]
    psfs = [os.path.realpath(i) for i in sys.argv[2].split(',')]
    assert len(dirties) == len(psfs)
    batch, count = load_data(dirties, psfs)
    steps_per_epoch = count
    iter = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty = iter.get_next()

    scaled_dirty = preprocess(dirty, min_flux, max_flux)
    scaled_psf = (psf * 2) - 1

    input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    # set up the network
    with tf.variable_scope("generator"):
        outputs = create_generator(input_, 1, a.ngf, a.separable_conv)
        deprocessed_output = deprocess(outputs, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME")
        residuals = dirty - convolved

    with tf.name_scope("encode_fitss"):
        fits_fetches = {
            "indexs": index,
            "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"),
            "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"),
        }

    with tf.Session() as sess:
        checkpoint = tf.train.latest_checkpoint(a.checkpoint)
        tf.train.Saver().restore(sess, checkpoint)

        for step in range(steps_per_epoch):
            results = sess.run(fits_fetches)
            filesets = save_images(results, subfolder=None, extention="fits", output_dir=a.output_dir)
            for f in filesets:
                print("wrote " + f['name'])
Exemple #2
0
def init(checkpoint):

    min_flux = tf.placeholder(tf.float32, shape=(1, ))
    max_flux = tf.placeholder(tf.float32, shape=(1, ))
    input_ = tf.placeholder(tf.float32, shape=(1, SIZE, SIZE, 2))

    # set up the network
    with tf.variable_scope("generator"):
        outputs = create_generator(input_, 1, NGF, SEPERABLE_CONV)
        deprocessed_output = deprocess(outputs, min_flux, max_flux)

    sess = tf.Session()
    logger.info("restoring data from checkpoint " + checkpoint)
    checkpoint = tf.train.latest_checkpoint(checkpoint)
    tf.train.Saver().restore(sess, checkpoint)
    return Model(session=sess,
                 output=deprocessed_output,
                 input=input_,
                 max_flux=max_flux,
                 min_flux=min_flux)
Exemple #3
0
def main():
    prepare()

    train_batch = generative_model(a.psf_glob,
                                   flux_scale_min=a.flux_scale_min,
                                   flux_scale_max=a.flux_scale_max)
    iterator = train_batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty, skymodel = iterator.get_next()

    with tf.name_scope("scaling_flux"):
        scaled_skymodel = preprocess(skymodel, min_flux, max_flux)
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    # inputs and targets are [batch_size, height, width, channels]
    model = create_model(scaled_dirty,
                         scaled_skymodel,
                         EPS,
                         a.separable_conv,
                         beta1=a.beta1,
                         gan_weight=a.gan_weight,
                         l1_weight=a.l1_weight,
                         lr=a.lr,
                         ndf=a.ndf,
                         ngf=a.ngf,
                         psf=psf,
                         min_flux=min_flux,
                         max_flux=max_flux,
                         res_weight=a.res_weight,
                         disable_psf=a.disable_psf)

    deprocessed_output = deprocess(model.outputs, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1],
                                 "SAME")
        residuals = dirty - convolved

    # reverse any processing on images so they can be written to disk or displayed to user
    with tf.name_scope("convert_images"):
        converted_inputs = tf.image.convert_image_dtype(
            visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True)
        converted_targets = tf.image.convert_image_dtype(
            visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True)
        converted_outputs = tf.image.convert_image_dtype(visual_scaling(
            model.outputs),
                                                         dtype=tf.uint8,
                                                         saturate=True)
        converted_psfs = tf.image.convert_image_dtype(
            visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True)
        converted_residuals = tf.image.convert_image_dtype(
            visual_scaling(residuals), dtype=tf.uint8, saturate=True)

    with tf.name_scope("encode_images"):
        display_fetches = {
            "indexs":
            index,
            "inputs":
            tf.map_fn(tf.image.encode_png,
                      converted_inputs,
                      dtype=tf.string,
                      name="input_pngs"),
            "targets":
            tf.map_fn(tf.image.encode_png,
                      converted_targets,
                      dtype=tf.string,
                      name="target_pngs"),
            "outputs":
            tf.map_fn(tf.image.encode_png,
                      converted_outputs,
                      dtype=tf.string,
                      name="output_pngs"),
            "psfs":
            tf.map_fn(tf.image.encode_png,
                      converted_psfs,
                      dtype=tf.string,
                      name="psf_pngs"),
            "residuals":
            tf.map_fn(tf.image.encode_png,
                      converted_residuals,
                      dtype=tf.string,
                      name="residual_pngs"),
        }

    # summaries
    with tf.name_scope("combined_summary"):
        tf.summary.image("inputs", converted_inputs)
        tf.summary.image("outputs", converted_outputs)
        tf.summary.image("targets", converted_targets)
        tf.summary.image("residuals", converted_residuals)

    with tf.name_scope("psfs_summary"):
        tf.summary.image("psfss", converted_psfs)

    with tf.name_scope("predict_real_summary"):
        tf.summary.image(
            "predict_real",
            tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))

    with tf.name_scope("predict_fake_summary"):
        tf.summary.image(
            "predict_fake",
            tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))

    tf.summary.scalar("discriminator_loss", model.discrim_loss)
    tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
    tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
    tf.summary.scalar("generator_loss_RES", model.gen_loss_RES)

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None

    train_summary_op = tf.summary.merge_all()
    train_summary_writer = tf.summary.FileWriter(logdir=logdir + '/train')

    saver = tf.train.Saver(max_to_keep=100)
    sv = tf.train.Supervisor(logdir=logdir,
                             save_summaries_secs=0,
                             saver=saver,
                             summary_writer=None,
                             summary_op=None)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)

    with sv.managed_session(config=tf.ConfigProto(
            gpu_options=gpu_options)) as sess:
        print("parameter_count =", sess.run(parameter_count))

        start = time.time()

        for step in range(a.max_steps):

            def should(freq):
                return freq > 0 and ((step + 1) % freq == 0
                                     or step == a.max_steps - 1)

            options = None
            run_metadata = None
            if should(a.trace_freq):
                print("preparing")
                options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

            fetches = {
                "train": model.train,
                "global_step": sv.global_step,
            }

            if should(a.progress_freq):
                print("progress step")
                fetches["discrim_loss"] = model.discrim_loss
                fetches["gen_loss_GAN"] = model.gen_loss_GAN
                fetches["gen_loss_L1"] = model.gen_loss_L1
                fetches["gen_loss_RES"] = model.gen_loss_RES

            if should(a.summary_freq):
                print("preparing summary")
                fetches["summary"] = train_summary_op

            if should(a.display_freq):
                print("display step step")
                fetches["display"] = display_fetches

            results = sess.run(fetches,
                               options=options,
                               run_metadata=run_metadata)

            if should(a.summary_freq):
                print("recording summary")
                train_summary_writer.add_summary(results["summary"],
                                                 results["global_step"])

            if should(a.display_freq):
                print("saving display images")
                save_images(results["display"],
                            step=results["global_step"],
                            output_dir=a.output_dir)

            if should(a.trace_freq):
                print("recording trace")
                train_summary_writer.add_run_metadata(
                    run_metadata, "step_%d" % results["global_step"])

            if should(a.progress_freq):
                rate = (step + 1) * a.batch_size / (time.time() - start)
                remaining = (a.max_steps - step) * a.batch_size / rate
                print("progress  step %d  image/sec %0.1f  remaining %dm" %
                      (results["global_step"], rate, remaining / 60))
                print("discrim_loss", results["discrim_loss"])
                print("gen_loss_GAN", results["gen_loss_GAN"])
                print("gen_loss_L1", results["gen_loss_L1"])
                print("gen_loss_RES", results["gen_loss_RES"])

            if should(a.save_freq):
                print("saving model")
                saver.save(sess,
                           os.path.join(a.output_dir, "model"),
                           global_step=sv.global_step)

            if sv.should_stop():
                print("supervisor things we should stop!")
                break

        print("done! bye")
Exemple #4
0
def main():
    prepare()

    train_batch, train_count = load_data(a.input_dir,
                                         CROP_SIZE,
                                         a.flip,
                                         a.scale_size,
                                         a.max_epochs,
                                         a.batch_size,
                                         start=a.train_start,
                                         end=a.train_end,
                                         loop=True)
    print("train count = %d" % train_count)
    steps_per_epoch = int(math.ceil(train_count / a.batch_size))
    training_iterator = train_batch.make_one_shot_iterator()

    validate_batch, validate_count = load_data(a.input_dir,
                                               CROP_SIZE,
                                               a.flip,
                                               a.scale_size,
                                               a.max_epochs,
                                               a.batch_size,
                                               start=a.validate_start,
                                               end=a.validate_end,
                                               loop=True)
    print("validate count = %d" % validate_count)
    validation_iterator = validate_batch.make_one_shot_iterator()

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   train_batch.output_types,
                                                   train_batch.output_shapes)
    index, min_flux, max_flux, psf, dirty, skymodel = iterator.get_next()

    with tf.name_scope("scaling_flux"):
        scaled_skymodel = preprocess(skymodel, min_flux, max_flux)
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    ## i'll just leave this here in case we decide to do something with visibilities again (unlikely)
    # vis = tf.fft2d(tf.complex(scaled_dirty, tf.zeros(shape=(a.batch_size, CROP_SIZE, CROP_SIZE, 1))))
    # real = tf.real(vis)
    # imag = tf.imag(vis)

    if a.disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    # inputs and targets are [batch_size, height, width, channels]
    model = create_model(input_,
                         scaled_skymodel,
                         EPS,
                         a.separable_conv,
                         beta1=a.beta1,
                         gan_weight=a.gan_weight,
                         l1_weight=a.l1_weight,
                         lr=a.lr,
                         ndf=a.ndf,
                         ngf=a.ngf,
                         psf=psf,
                         min_flux=min_flux,
                         max_flux=max_flux,
                         res_weight=a.res_weight)

    deprocessed_output = deprocess(model.outputs, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1],
                                 "SAME")
        residuals = dirty - convolved

    # reverse any processing on images so they can be written to disk or displayed to user
    with tf.name_scope("convert_images"):
        converted_inputs = tf.image.convert_image_dtype(
            visual_scaling(scaled_dirty), dtype=tf.uint8, saturate=True)
        converted_targets = tf.image.convert_image_dtype(
            visual_scaling(scaled_skymodel), dtype=tf.uint8, saturate=True)
        converted_outputs = tf.image.convert_image_dtype(visual_scaling(
            model.outputs),
                                                         dtype=tf.uint8,
                                                         saturate=True)
        converted_psfs = tf.image.convert_image_dtype(
            visual_scaling(scaled_psf), dtype=tf.uint8, saturate=True)
        converted_residuals = tf.image.convert_image_dtype(
            visual_scaling(residuals), dtype=tf.uint8, saturate=True)

    with tf.name_scope("encode_images"):
        display_fetches = {
            "indexs":
            index,
            "inputs":
            tf.map_fn(tf.image.encode_png,
                      converted_inputs,
                      dtype=tf.string,
                      name="input_pngs"),
            "targets":
            tf.map_fn(tf.image.encode_png,
                      converted_targets,
                      dtype=tf.string,
                      name="target_pngs"),
            "outputs":
            tf.map_fn(tf.image.encode_png,
                      converted_outputs,
                      dtype=tf.string,
                      name="output_pngs"),
            "psfs":
            tf.map_fn(tf.image.encode_png,
                      converted_psfs,
                      dtype=tf.string,
                      name="psf_pngs"),
            "residuals":
            tf.map_fn(tf.image.encode_png,
                      converted_residuals,
                      dtype=tf.string,
                      name="residual_pngs"),
        }

    # summaries
    with tf.name_scope("combined_summary"):
        tf.summary.image("inputs", converted_inputs)
        tf.summary.image("outputs", converted_outputs)
        tf.summary.image("targets", converted_targets)
        tf.summary.image("residuals", converted_residuals)

    with tf.name_scope("psfs_summary"):
        tf.summary.image("psfss", converted_psfs)

    with tf.name_scope("predict_real_summary"):
        tf.summary.image(
            "predict_real",
            tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))

    with tf.name_scope("predict_fake_summary"):
        tf.summary.image(
            "predict_fake",
            tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))

    tf.summary.scalar("discriminator_loss", model.discrim_loss)
    tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
    tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
    tf.summary.scalar("generator_loss_RES", model.gen_loss_RES)

    if a.validation_freq:
        tf.summary.scalar("Validation generator_loss_GAN", model.gen_loss_GAN)
        tf.summary.scalar("Validation generator_loss_L1", model.gen_loss_L1)
    """
    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name + "/values", var)

    for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
        tf.summary.histogram(var.op.name + "/gradients", grad)
    """

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None

    validation_summary_op = tf.summary.merge_all()
    validation_summary_writer = tf.summary.FileWriter(logdir=logdir +
                                                      '/validation')

    train_summary_op = tf.summary.merge_all()
    train_summary_writer = tf.summary.FileWriter(logdir=logdir + '/train')

    saver = tf.train.Saver(max_to_keep=100)
    sv = tf.train.Supervisor(logdir=logdir,
                             save_summaries_secs=0,
                             saver=saver,
                             summary_writer=None,
                             summary_op=None)

    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        # The `Iterator.string_handle()` method returns a tensor that can be evaluated
        # and used to feed the `handle` placeholder.
        training_handle = sess.run(training_iterator.string_handle())

        validation_handle = sess.run(validation_iterator.string_handle())

        max_steps = 2**32
        if a.max_epochs is not None:
            max_steps = steps_per_epoch * a.max_epochs
        if a.max_steps is not None:
            max_steps = a.max_steps

        start = time.time()

        for step in range(max_steps):

            def should(freq):
                return freq > 0 and ((step + 1) % freq == 0
                                     or step == max_steps - 1)

            options = None
            run_metadata = None
            if should(a.trace_freq):
                print("preparing")
                options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

            fetches = {
                "train": model.train,
                "global_step": sv.global_step,
            }

            if should(a.progress_freq):
                print("progress step")
                fetches["discrim_loss"] = model.discrim_loss
                fetches["gen_loss_GAN"] = model.gen_loss_GAN
                fetches["gen_loss_L1"] = model.gen_loss_L1
                fetches["gen_loss_RES"] = model.gen_loss_RES

            if should(a.summary_freq):
                print("preparing summary")
                fetches["summary"] = train_summary_op

            if should(a.display_freq):
                print("display step step")
                fetches["display"] = display_fetches

            results = sess.run(fetches,
                               options=options,
                               run_metadata=run_metadata,
                               feed_dict={handle: training_handle})

            if should(a.summary_freq):
                print("recording summary")
                train_summary_writer.add_summary(results["summary"],
                                                 results["global_step"])

            if should(a.display_freq):
                print("saving display images")
                save_images(results["display"],
                            step=results["global_step"],
                            output_dir=a.output_dir)

            if should(a.trace_freq):
                print("recording trace")
                train_summary_writer.add_run_metadata(
                    run_metadata, "step_%d" % results["global_step"])

            if should(a.progress_freq):
                # global_step will have the correct step count if we resume from a checkpoint
                train_epoch = math.ceil(results["global_step"] /
                                        steps_per_epoch)
                train_step = (results["global_step"] - 1) % steps_per_epoch + 1
                rate = (step + 1) * a.batch_size / (time.time() - start)
                remaining = (max_steps - step) * a.batch_size / rate
                print(
                    "progress  epoch %d  step %d  image/sec %0.1f  remaining %dm"
                    % (train_epoch, train_step, rate, remaining / 60))
                print("discrim_loss", results["discrim_loss"])
                print("gen_loss_GAN", results["gen_loss_GAN"])
                print("gen_loss_L1", results["gen_loss_L1"])
                print("gen_loss_RES", results["gen_loss_RES"])

            if should(a.save_freq):
                print("saving model")
                saver.save(sess,
                           os.path.join(a.output_dir, "model"),
                           global_step=sv.global_step)

            if a.validation_freq and should(a.validation_freq):
                print("validation step")
                validation_fetches = {
                    "validation_gen_loss_GAN": model.gen_loss_GAN,
                    "validation_gen_loss_L1": model.gen_loss_L1,
                    "summary": validation_summary_op,
                }

                validation_results = sess.run(
                    validation_fetches, feed_dict={handle: validation_handle})
                print("validation gen_loss_GAN",
                      validation_results["validation_gen_loss_GAN"])
                print("validation gen_loss_L1",
                      validation_results["validation_gen_loss_L1"])
                validation_summary_writer.add_summary(
                    validation_results["summary"], results["global_step"])

            if sv.should_stop():
                print("supervisor things we should stop!")
                break

        print("done! bye")
Exemple #5
0
def main():
    dirty_path = os.path.realpath(a.dirty)
    psf_path = os.path.realpath(a.psf)
    big_fits = fits.open(str(dirty_path))[0]
    big_data = big_fits.data.squeeze()[:, :, np.newaxis]
    big_psf_fits = fits.open(str(psf_path))[0]
    assert (big_psf_fits.data.shape == big_fits.data.shape)

    # we need a smaller PSF to give as a channel to the dirty tiles
    big_psf_data = big_psf_fits.data.squeeze()
    big_psf_data = big_psf_data / big_psf_data.max()
    psf_small = big_psf_data[big_psf_data.shape[0] // 2 - SIZE // 2 +
                             1:big_psf_data.shape[0] // 2 + SIZE // 2 + 1,
                             big_psf_data.shape[1] // 2 - SIZE // 2 +
                             1:big_psf_data.shape[1] // 2 + SIZE // 2 + 1]

    logger.debug(psf_small.shape)
    logger.debug((big_psf_data.shape[0] // 2 - SIZE // 2 + 1,
                  big_psf_data.shape[0] // 2 + SIZE // 2 + 1,
                  big_psf_data.shape[1] // 2 - SIZE // 2 + 1,
                  big_psf_data.shape[1] // 2 + SIZE // 2 + 1))

    psf_small = psf_small[:, :, np.newaxis]

    n_r = int(big_data.shape[0] / stride)
    n_c = int(big_data.shape[1] / stride)

    # set up the data loading
    batch, count = load_data(big_data, psf_small, n_r, n_c)
    steps_per_epoch = count
    iterator = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty = iterator.get_next()
    scaled_dirty = preprocess(dirty, min_flux, max_flux)
    scaled_psf = (psf * 2) - 1
    input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    # set up the network
    with tf.variable_scope("generator"):
        outputs = create_generator(input_, 1, NGF, SEPERABLE_CONV)
        deprocessed_output = deprocess(outputs, min_flux, max_flux)

    # run all data through the network
    queue_ = IterableQueue()
    with tf.Session() as sess:
        logger.info("restoring data from checkpoint " + a.checkpoint)
        checkpoint = tf.train.latest_checkpoint(a.checkpoint)
        tf.train.Saver().restore(sess, checkpoint)

        for step in range(steps_per_epoch):
            n = sess.run(deprocessed_output)
            queue_.put(n)

    # reconstruct the data
    big_model = restore(big_data.squeeze().shape, iter(queue_), n_r, n_c)
    p = big_psf_data.shape[0]
    #r = slice(p // 2, -p // 2 + 1)  # uneven PSF needs +2, even psf +1
    r = slice(p // 2 + 1, -p // 2 + 2)
    convolved = fftconvolve(big_model, big_psf_data, mode="full")[r, r]
    residual = big_fits.data.squeeze() - convolved

    # write the data
    hdu = fits.PrimaryHDU(big_model.squeeze())
    hdu.header = big_fits.header
    hdul = fits.HDUList([hdu])
    hdul.writeto("vacuum-model.fits", overwrite=True)

    hdu = fits.PrimaryHDU(residual.squeeze())
    hdu.header = big_fits.header
    hdul = fits.HDUList([hdu])
    hdul.writeto("vacuum-residual.fits", overwrite=True)

    logger.info("done!")
Exemple #6
0
def create_model(inputs, targets, EPS, separable_conv, ngf, ndf, gan_weight,
                 l1_weight, res_weight, lr, beta1, psf, min_flux, max_flux):
    # type: (tf.Tensor, tf.Tensor, float, bool, int, int, float, float, float, float, float, tf.Tensor, float, float) -> Model

    with tf.variable_scope("generator"):
        out_channels = 1
        outputs = create_generator(inputs, out_channels, ngf, separable_conv)

    # create two copies of discriminator, one for real pairs and one for fake pairs
    # they share the same underlying variables
    with tf.name_scope("real_discriminator"):
        with tf.variable_scope("discriminator"):
            # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
            predict_real = create_discriminator(inputs, targets, ndf)

    with tf.name_scope("fake_discriminator"):
        with tf.variable_scope("discriminator", reuse=True):
            # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
            predict_fake = create_discriminator(inputs, outputs, ndf)

    with tf.name_scope("discriminator_loss"):
        # minimizing -tf.log will try to get inputs to 1
        # predict_real => 1
        # predict_fake => 0
        discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) +
                                        tf.log(1 - predict_fake + EPS)))

    with tf.name_scope("discriminator_residuals"):
        deprocessed_output = deprocess(outputs, min_flux, max_flux)
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1],
                                 "SAME")
        residuals = targets - convolved

    with tf.name_scope("generator_loss"):
        # predict_fake => 1
        # abs(targets - outputs) => 0
        gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
        gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
        gen_loss_RES = tf.reduce_mean(
            tf.abs(residuals - tf.reduce_mean(residuals)))
        gen_loss = gen_loss_GAN * gan_weight + gen_loss_L1 * l1_weight + gen_loss_RES * res_weight

    with tf.name_scope("discriminator_train"):
        discrim_tvars = [
            var for var in tf.trainable_variables()
            if var.name.startswith("discriminator")
        ]
        discrim_optim = tf.train.AdamOptimizer(lr, beta1)
        discrim_grads_and_vars = discrim_optim.compute_gradients(
            discrim_loss, var_list=discrim_tvars)
        discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)

    with tf.name_scope("generator_train"):
        with tf.control_dependencies([discrim_train]):
            gen_tvars = [
                var for var in tf.trainable_variables()
                if var.name.startswith("generator")
            ]
            gen_optim = tf.train.AdamOptimizer(lr, beta1)
            gen_grads_and_vars = gen_optim.compute_gradients(
                gen_loss, var_list=gen_tvars)
            gen_train = gen_optim.apply_gradients(gen_grads_and_vars)

    ema = tf.train.ExponentialMovingAverage(decay=0.99)
    update_losses = ema.apply(
        [discrim_loss, gen_loss_GAN, gen_loss_L1, gen_loss_RES])

    global_step = tf.train.get_or_create_global_step()
    incr_global_step = tf.assign(global_step, global_step + 1)

    return Model(
        predict_real=predict_real,
        predict_fake=predict_fake,
        discrim_loss=ema.average(discrim_loss),
        discrim_grads_and_vars=discrim_grads_and_vars,
        gen_loss_GAN=ema.average(gen_loss_GAN),
        gen_loss_L1=ema.average(gen_loss_L1),
        gen_loss_RES=ema.average(gen_loss_RES),
        gen_grads_and_vars=gen_grads_and_vars,
        outputs=outputs,
        train=tf.group(update_losses, incr_global_step, gen_train),
    )
Exemple #7
0
def test(
    input_dir,
    output_dir,
    checkpoint,
    batch_size=1,
    test_start=0,
    test_end=999,
    disable_psf=False,
    ngf=64,
    separable_conv=False,
    write_residuals=False,
    write_input=False,
):
    batch, count = load_data(path=input_dir,
                             flip=False,
                             crop_size=CROP_SIZE,
                             scale_size=CROP_SIZE,
                             max_epochs=1,
                             batch_size=batch_size,
                             start=test_start,
                             end=test_end)
    steps_per_epoch = int(math.ceil(count / batch_size))
    iter_ = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next()
    print("train count = %d" % count)

    with tf.name_scope("scaling_flux"):
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    if disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([dirty, psf[:, 128:-128, 128:-128, :]], axis=3)

    with tf.variable_scope("generator"):
        generator = create_generator(input_,
                                     1,
                                     ngf=ngf,
                                     separable_conv=separable_conv)
        deprocessed_output = deprocess(generator, min_flux, max_flux)

    if write_residuals:
        with tf.name_scope("calculate_residuals"):
            shifted = shift(psf, y=-1, x=-1)
            filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
            convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1],
                                     "SAME")
            residuals = dirty - convolved

    with tf.name_scope("encode_fitss"):
        work = {
            "indexs":
            index,
            "outputs":
            tf.map_fn(fits_encode,
                      deprocessed_output,
                      dtype=tf.string,
                      name="output_fits"),
        }
        if write_residuals:
            work["residuals"] = tf.map_fn(fits_encode,
                                          residuals,
                                          dtype=tf.string,
                                          name="residuals_fits")
        if write_input:
            work["inputs"] = tf.map_fn(fits_encode,
                                       dirty,
                                       dtype=tf.string,
                                       name="input_fits")

    sv = tf.train.Supervisor(logdir=None)
    with sv.managed_session() as sess:
        sv.saver.restore(sess, checkpoint)

        for step in range(steps_per_epoch):
            results = sess.run(work)
            filesets = save_images(results,
                                   subfolder="fits",
                                   extention="fits",
                                   output_dir=output_dir)
            for f in filesets:
                print("wrote " + f['name'])
Exemple #8
0
def main():
    prepare()

    batch, count = load_data(path=a.input_dir, flip=False, crop_size=CROP_SIZE, scale_size=CROP_SIZE, max_epochs=1,
                             batch_size=a.batch_size, start=a.test_start, end=a.test_end)
    steps_per_epoch = int(math.ceil(count / a.batch_size))
    iter_ = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next()
    print("train count = %d" % count)

    with tf.name_scope("scaling_flux"):
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    if a.disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    with tf.variable_scope("generator"):
        generator = create_generator(input_, 1, ngf=a.ngf, separable_conv=a.separable_conv)
        deprocessed_output = deprocess(generator, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME")
        residuals = dirty - convolved

    with tf.name_scope("encode_fitss"):
        fits_fetches = {
            "indexs": index,
            "inputs": tf.map_fn(fits_encode, dirty, dtype=tf.string, name="input_fits"),
            "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"),
            "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"),
        }

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver(max_to_keep=100)

    sv = tf.train.Supervisor(logdir=None, save_summaries_secs=0, saver=None)
    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        if a.checkpoint is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
            print("loaded {}".format(checkpoint))
            saver.restore(sess, checkpoint)

        max_steps = 2 ** 32
        if a.max_epochs is not None:
            max_steps = steps_per_epoch * a.max_epochs
        if a.max_steps is not None:
            max_steps = a.max_steps

        # at most, process the test data once
        max_steps = min(steps_per_epoch, max_steps)

        # repeat the same for fits arrays
        for step in range(max_steps):
            results = sess.run(fits_fetches)
            filesets = save_images(results, subfolder="fits", extention="fits", output_dir=a.output_dir)
            for f in filesets:
                print("wrote " + f['name'])
Exemple #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", required=True, help="where to put output files")
    parser.add_argument("--checkpoint", required=True, help="directory with checkpoint to resume training from or use for testing")
    parser.add_argument("--separable_conv", action="store_true", help="use separable convolutions in the generator")
    parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
    parser.add_argument('--disable_psf', action='store_true', help="disable the concatenation of the PSF as a channel")

    a = parser.parse_args()

    def load_data(dirty_path, psf_path):
        # type: (str, str) -> tf.data.Dataset
        def dataset_generator():
            psf = fits_open(psf_path)[:, :, np.newaxis]
            dirty = fits_open(dirty_path)[:, :, np.newaxis]
            min_flux = dirty.min()
            max_flux = dirty.max()
            yield min_flux, max_flux, psf, dirty

        ds = tf.data.Dataset.from_generator(dataset_generator,
                                            output_shapes=((), ()) + ((256, 256, 1),) * 2,
                                            output_types=(tf.float32, tf.float32) + (tf.float32,) * 2
                                            )
        ds = ds.batch(1)
        return ds

    dirty_path = tf.placeholder(tf.string, shape=[1])
    psf_path = tf.placeholder(tf.string, shape=[1])
    batch = load_data(dirty_path, psf_path)

    iter = batch.make_one_shot_iterator()
    min_flux, max_flux, psf, dirty = iter.get_next()

    scaled_dirty = preprocess(dirty, min_flux, max_flux)
    scaled_psf = preprocess(psf, min_flux, max_flux)

    if a.disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    with tf.variable_scope("generator"):
        generator =  create_generator(input_, 1, ngf=a.ngf, separable_conv=a.separable_conv)
        batch_output = deprocess(generator, min_flux, max_flux)

    output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0]

    # lets just assume png for now
    output_data = tf.image.encode_png(output_image)
    output = tf.convert_to_tensor([tf.encode_base64(output_data)])

    key = tf.placeholder(tf.string, shape=[1])
    inputs = {
        "key": key.name,
        "input": dirty.name
    }
    tf.add_to_collection("inputs", json.dumps(inputs))
    outputs = {
        "key": tf.identity(key).name,
        "output": output.name,
    }
    tf.add_to_collection("outputs", json.dumps(outputs))

    init_op = tf.global_variables_initializer()
    restore_saver = tf.train.Saver()
    export_saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init_op)
        print("loading model from checkpoint")
        checkpoint = tf.train.latest_checkpoint(a.checkpoint)
        restore_saver.restore(sess, checkpoint)
        print("exporting model")
        #export_saver.export_meta_graph(filename=os.path.join(a.output_dir, "export.meta"))
        export_saver.save(sess, os.path.join(a.output_dir, "export"), write_meta_graph=True) #, save_relative_paths=True)