Exemple #1
0
def main():
    if len(sys.argv) != 3:
        print("""
usage: {}  dirty-0.fits,dirty-1.fits,dirty-2.fits  psf-0.fits,psf-1.fits,psf2.fits
        
 note: names don't matter, order does. only supports fits files of {}x{}
       will write output the current folder.
""".format(sys.argv[0], CROP_SIZE, CROP_SIZE))
        sys.exit(1)

    dirties = [os.path.realpath(i) for i in sys.argv[1].split(',')]
    psfs = [os.path.realpath(i) for i in sys.argv[2].split(',')]
    assert len(dirties) == len(psfs)
    batch, count = load_data(dirties, psfs)
    steps_per_epoch = count
    iter = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty = iter.get_next()

    scaled_dirty = preprocess(dirty, min_flux, max_flux)
    scaled_psf = (psf * 2) - 1

    input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    # set up the network
    with tf.variable_scope("generator"):
        outputs = create_generator(input_, 1, a.ngf, a.separable_conv)
        deprocessed_output = deprocess(outputs, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME")
        residuals = dirty - convolved

    with tf.name_scope("encode_fitss"):
        fits_fetches = {
            "indexs": index,
            "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"),
            "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"),
        }

    with tf.Session() as sess:
        checkpoint = tf.train.latest_checkpoint(a.checkpoint)
        tf.train.Saver().restore(sess, checkpoint)

        for step in range(steps_per_epoch):
            results = sess.run(fits_fetches)
            filesets = save_images(results, subfolder=None, extention="fits", output_dir=a.output_dir)
            for f in filesets:
                print("wrote " + f['name'])
Exemple #2
0
def init(checkpoint):

    min_flux = tf.placeholder(tf.float32, shape=(1, ))
    max_flux = tf.placeholder(tf.float32, shape=(1, ))
    input_ = tf.placeholder(tf.float32, shape=(1, SIZE, SIZE, 2))

    # set up the network
    with tf.variable_scope("generator"):
        outputs = create_generator(input_, 1, NGF, SEPERABLE_CONV)
        deprocessed_output = deprocess(outputs, min_flux, max_flux)

    sess = tf.Session()
    logger.info("restoring data from checkpoint " + checkpoint)
    checkpoint = tf.train.latest_checkpoint(checkpoint)
    tf.train.Saver().restore(sess, checkpoint)
    return Model(session=sess,
                 output=deprocessed_output,
                 input=input_,
                 max_flux=max_flux,
                 min_flux=min_flux)
Exemple #3
0
def main():
    dirty_path = os.path.realpath(a.dirty)
    psf_path = os.path.realpath(a.psf)
    big_fits = fits.open(str(dirty_path))[0]
    big_data = big_fits.data.squeeze()[:, :, np.newaxis]
    big_psf_fits = fits.open(str(psf_path))[0]
    assert (big_psf_fits.data.shape == big_fits.data.shape)

    # we need a smaller PSF to give as a channel to the dirty tiles
    big_psf_data = big_psf_fits.data.squeeze()
    big_psf_data = big_psf_data / big_psf_data.max()
    psf_small = big_psf_data[big_psf_data.shape[0] // 2 - SIZE // 2 +
                             1:big_psf_data.shape[0] // 2 + SIZE // 2 + 1,
                             big_psf_data.shape[1] // 2 - SIZE // 2 +
                             1:big_psf_data.shape[1] // 2 + SIZE // 2 + 1]

    logger.debug(psf_small.shape)
    logger.debug((big_psf_data.shape[0] // 2 - SIZE // 2 + 1,
                  big_psf_data.shape[0] // 2 + SIZE // 2 + 1,
                  big_psf_data.shape[1] // 2 - SIZE // 2 + 1,
                  big_psf_data.shape[1] // 2 + SIZE // 2 + 1))

    psf_small = psf_small[:, :, np.newaxis]

    n_r = int(big_data.shape[0] / stride)
    n_c = int(big_data.shape[1] / stride)

    # set up the data loading
    batch, count = load_data(big_data, psf_small, n_r, n_c)
    steps_per_epoch = count
    iterator = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty = iterator.get_next()
    scaled_dirty = preprocess(dirty, min_flux, max_flux)
    scaled_psf = (psf * 2) - 1
    input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    # set up the network
    with tf.variable_scope("generator"):
        outputs = create_generator(input_, 1, NGF, SEPERABLE_CONV)
        deprocessed_output = deprocess(outputs, min_flux, max_flux)

    # run all data through the network
    queue_ = IterableQueue()
    with tf.Session() as sess:
        logger.info("restoring data from checkpoint " + a.checkpoint)
        checkpoint = tf.train.latest_checkpoint(a.checkpoint)
        tf.train.Saver().restore(sess, checkpoint)

        for step in range(steps_per_epoch):
            n = sess.run(deprocessed_output)
            queue_.put(n)

    # reconstruct the data
    big_model = restore(big_data.squeeze().shape, iter(queue_), n_r, n_c)
    p = big_psf_data.shape[0]
    #r = slice(p // 2, -p // 2 + 1)  # uneven PSF needs +2, even psf +1
    r = slice(p // 2 + 1, -p // 2 + 2)
    convolved = fftconvolve(big_model, big_psf_data, mode="full")[r, r]
    residual = big_fits.data.squeeze() - convolved

    # write the data
    hdu = fits.PrimaryHDU(big_model.squeeze())
    hdu.header = big_fits.header
    hdul = fits.HDUList([hdu])
    hdul.writeto("vacuum-model.fits", overwrite=True)

    hdu = fits.PrimaryHDU(residual.squeeze())
    hdu.header = big_fits.header
    hdul = fits.HDUList([hdu])
    hdul.writeto("vacuum-residual.fits", overwrite=True)

    logger.info("done!")
Exemple #4
0
def test(
    input_dir,
    output_dir,
    checkpoint,
    batch_size=1,
    test_start=0,
    test_end=999,
    disable_psf=False,
    ngf=64,
    separable_conv=False,
    write_residuals=False,
    write_input=False,
):
    batch, count = load_data(path=input_dir,
                             flip=False,
                             crop_size=CROP_SIZE,
                             scale_size=CROP_SIZE,
                             max_epochs=1,
                             batch_size=batch_size,
                             start=test_start,
                             end=test_end)
    steps_per_epoch = int(math.ceil(count / batch_size))
    iter_ = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next()
    print("train count = %d" % count)

    with tf.name_scope("scaling_flux"):
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    if disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([dirty, psf[:, 128:-128, 128:-128, :]], axis=3)

    with tf.variable_scope("generator"):
        generator = create_generator(input_,
                                     1,
                                     ngf=ngf,
                                     separable_conv=separable_conv)
        deprocessed_output = deprocess(generator, min_flux, max_flux)

    if write_residuals:
        with tf.name_scope("calculate_residuals"):
            shifted = shift(psf, y=-1, x=-1)
            filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
            convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1],
                                     "SAME")
            residuals = dirty - convolved

    with tf.name_scope("encode_fitss"):
        work = {
            "indexs":
            index,
            "outputs":
            tf.map_fn(fits_encode,
                      deprocessed_output,
                      dtype=tf.string,
                      name="output_fits"),
        }
        if write_residuals:
            work["residuals"] = tf.map_fn(fits_encode,
                                          residuals,
                                          dtype=tf.string,
                                          name="residuals_fits")
        if write_input:
            work["inputs"] = tf.map_fn(fits_encode,
                                       dirty,
                                       dtype=tf.string,
                                       name="input_fits")

    sv = tf.train.Supervisor(logdir=None)
    with sv.managed_session() as sess:
        sv.saver.restore(sess, checkpoint)

        for step in range(steps_per_epoch):
            results = sess.run(work)
            filesets = save_images(results,
                                   subfolder="fits",
                                   extention="fits",
                                   output_dir=output_dir)
            for f in filesets:
                print("wrote " + f['name'])
Exemple #5
0
def main():
    prepare()

    batch, count = load_data(path=a.input_dir, flip=False, crop_size=CROP_SIZE, scale_size=CROP_SIZE, max_epochs=1,
                             batch_size=a.batch_size, start=a.test_start, end=a.test_end)
    steps_per_epoch = int(math.ceil(count / a.batch_size))
    iter_ = batch.make_one_shot_iterator()
    index, min_flux, max_flux, psf, dirty, skymodel = iter_.get_next()
    print("train count = %d" % count)

    with tf.name_scope("scaling_flux"):
        scaled_dirty = preprocess(dirty, min_flux, max_flux)
        scaled_psf = (psf * 2) - 1

    if a.disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    with tf.variable_scope("generator"):
        generator = create_generator(input_, 1, ngf=a.ngf, separable_conv=a.separable_conv)
        deprocessed_output = deprocess(generator, min_flux, max_flux)

    with tf.name_scope("calculate_residuals"):
        shifted = shift(psf, y=-1, x=-1)
        filter_ = tf.expand_dims(tf.expand_dims(tf.squeeze(shifted), 2), 3)
        convolved = tf.nn.conv2d(deprocessed_output, filter_, [1, 1, 1, 1], "SAME")
        residuals = dirty - convolved

    with tf.name_scope("encode_fitss"):
        fits_fetches = {
            "indexs": index,
            "inputs": tf.map_fn(fits_encode, dirty, dtype=tf.string, name="input_fits"),
            "outputs": tf.map_fn(fits_encode, deprocessed_output, dtype=tf.string, name="output_fits"),
            "residuals": tf.map_fn(fits_encode, residuals, dtype=tf.string, name="residuals_fits"),
        }

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver(max_to_keep=100)

    sv = tf.train.Supervisor(logdir=None, save_summaries_secs=0, saver=None)
    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        if a.checkpoint is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
            print("loaded {}".format(checkpoint))
            saver.restore(sess, checkpoint)

        max_steps = 2 ** 32
        if a.max_epochs is not None:
            max_steps = steps_per_epoch * a.max_epochs
        if a.max_steps is not None:
            max_steps = a.max_steps

        # at most, process the test data once
        max_steps = min(steps_per_epoch, max_steps)

        # repeat the same for fits arrays
        for step in range(max_steps):
            results = sess.run(fits_fetches)
            filesets = save_images(results, subfolder="fits", extention="fits", output_dir=a.output_dir)
            for f in filesets:
                print("wrote " + f['name'])
Exemple #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", required=True, help="where to put output files")
    parser.add_argument("--checkpoint", required=True, help="directory with checkpoint to resume training from or use for testing")
    parser.add_argument("--separable_conv", action="store_true", help="use separable convolutions in the generator")
    parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
    parser.add_argument('--disable_psf', action='store_true', help="disable the concatenation of the PSF as a channel")

    a = parser.parse_args()

    def load_data(dirty_path, psf_path):
        # type: (str, str) -> tf.data.Dataset
        def dataset_generator():
            psf = fits_open(psf_path)[:, :, np.newaxis]
            dirty = fits_open(dirty_path)[:, :, np.newaxis]
            min_flux = dirty.min()
            max_flux = dirty.max()
            yield min_flux, max_flux, psf, dirty

        ds = tf.data.Dataset.from_generator(dataset_generator,
                                            output_shapes=((), ()) + ((256, 256, 1),) * 2,
                                            output_types=(tf.float32, tf.float32) + (tf.float32,) * 2
                                            )
        ds = ds.batch(1)
        return ds

    dirty_path = tf.placeholder(tf.string, shape=[1])
    psf_path = tf.placeholder(tf.string, shape=[1])
    batch = load_data(dirty_path, psf_path)

    iter = batch.make_one_shot_iterator()
    min_flux, max_flux, psf, dirty = iter.get_next()

    scaled_dirty = preprocess(dirty, min_flux, max_flux)
    scaled_psf = preprocess(psf, min_flux, max_flux)

    if a.disable_psf:
        input_ = scaled_dirty
    else:
        input_ = tf.concat([scaled_dirty, scaled_psf], axis=3)

    with tf.variable_scope("generator"):
        generator =  create_generator(input_, 1, ngf=a.ngf, separable_conv=a.separable_conv)
        batch_output = deprocess(generator, min_flux, max_flux)

    output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0]

    # lets just assume png for now
    output_data = tf.image.encode_png(output_image)
    output = tf.convert_to_tensor([tf.encode_base64(output_data)])

    key = tf.placeholder(tf.string, shape=[1])
    inputs = {
        "key": key.name,
        "input": dirty.name
    }
    tf.add_to_collection("inputs", json.dumps(inputs))
    outputs = {
        "key": tf.identity(key).name,
        "output": output.name,
    }
    tf.add_to_collection("outputs", json.dumps(outputs))

    init_op = tf.global_variables_initializer()
    restore_saver = tf.train.Saver()
    export_saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init_op)
        print("loading model from checkpoint")
        checkpoint = tf.train.latest_checkpoint(a.checkpoint)
        restore_saver.restore(sess, checkpoint)
        print("exporting model")
        #export_saver.export_meta_graph(filename=os.path.join(a.output_dir, "export.meta"))
        export_saver.save(sess, os.path.join(a.output_dir, "export"), write_meta_graph=True) #, save_relative_paths=True)