示例#1
0
def encode_examples(kappa: tf.Tensor, galaxies: tf.Tensor,
                    lensed_images: tf.Tensor, z_source: float, z_lens: float,
                    image_fov: float, kappa_fov: float, source_fov: float,
                    noise_rms: np.array, psf: tf.Tensor, fwhm: np.array):
    batch_size = galaxies.shape[0]
    source_pixels = galaxies.shape[1]
    kappa_pixels = kappa.shape[1]
    pixels = lensed_images.shape[1]
    psf_pixels = psf.shape[1]
    records = []
    for j in range(batch_size):
        features = {
            "kappa": _bytes_feature(kappa[j].numpy().tobytes()),
            "source": _bytes_feature(galaxies[j].numpy().tobytes()),
            "lens": _bytes_feature(lensed_images[j].numpy().tobytes()),
            "z source": _float_feature(z_source),
            "z lens": _float_feature(z_lens),
            "image fov": _float_feature(image_fov),  # arc seconds
            "kappa fov": _float_feature(kappa_fov),  # arc seconds
            "source fov": _float_feature(source_fov),  # arc seconds
            "src pixels": _int64_feature(source_pixels),
            "kappa pixels": _int64_feature(kappa_pixels),
            "pixels": _int64_feature(pixels),
            "noise rms": _float_feature(noise_rms[j]),
            "psf": _bytes_feature(psf[j].numpy().tobytes()),
            "psf pixels": _int64_feature(psf_pixels),
            "fwhm": _float_feature(fwhm[j])
        }
        serialized_output = tf.train.Example(features=tf.train.Features(
            feature=features))
        record = serialized_output.SerializeToString()
        records.append(record)
    return records
示例#2
0
def encode_examples(kappa: tf.Tensor, einstein_radius_init: list,
                    einstein_radius: list, rescalings: list, z_source: float,
                    z_lens: float, kappa_fov: float, sigma_crit: float,
                    kappa_ids: list):
    batch_size = kappa.shape[0]
    kappa_pixels = kappa.shape[1]
    records = []
    for j in range(batch_size):
        features = {
            "kappa":
            _bytes_feature(kappa[j].numpy().tobytes()),
            "kappa pixels":
            _int64_feature(kappa_pixels),
            "Einstein radius before rescaling":
            _float_feature(einstein_radius_init[j]),
            "Einstein radius":
            _float_feature(einstein_radius[j]),
            "rescaling factor":
            _float_feature(rescalings[j]),
            "z source":
            _float_feature(z_source),
            "z lens":
            _float_feature(z_lens),
            "kappa fov":
            _float_feature(kappa_fov),  # arc seconds
            "sigma crit":
            _float_feature(sigma_crit),  # 10^10 M_sun / Mpc^2
            "kappa id":
            _int64_feature(kappa_ids[j])
        }
        serialized_output = tf.train.Example(features=tf.train.Features(
            feature=features))
        record = serialized_output.SerializeToString()
        records.append(record)
    return records
示例#3
0
def main(args):
    options = tf.io.TFRecordOptions(compression_type=args.compression_type)
    with open(os.path.join(args.first_stage_model_id, "model_hparams.json"),
              "r") as f:
        vae_hparams = json.load(f)
    # load weights
    vae = VAE(**vae_hparams)
    ckpt1 = tf.train.Checkpoint(net=vae)
    checkpoint_manager1 = tf.train.CheckpointManager(ckpt1,
                                                     args.first_stage_model_id,
                                                     1)
    checkpoint_manager1.checkpoint.restore(
        checkpoint_manager1.latest_checkpoint).expect_partial()

    n_batch = args.total_items // args.batch_size
    batch_per_record = args.n_records // n_batch
    last_record_n_batch = batch_per_record + n_batch % args.n_records

    for record in range(args.n_records - 1):
        with tf.io.TFRecordWriter(
                os.path.join(args.output_dir, f"data_{record:02d}.tfrecords"),
                options) as writer:
            for batch in range(batch_per_record):
                z = tf.random.normal(shape=[args.batch_size, vae.latent_size])
                kappa_batch = vae.decode(z)
                for kappa in kappa_batch:
                    features = {
                        "kappa": _bytes_feature(kappa.numpy().tobytes()),
                        "kappa pixels": _int64_feature(kappa.shape[0]),
                    }

                    record = tf.train.Example(features=tf.train.Features(
                        feature=features)).SerializeToString()
                    writer.write(record)

    with tf.io.TFRecordWriter(
            os.path.join(args.output_dir,
                         f"data_{args.n_record-1:02d}.tfrecords"),
            options) as writer:
        for batch in range(last_record_n_batch):
            z = tf.random.normal(shape=[args.batch_size, vae.latent_size])
            kappa_batch = vae.decode(z)
            for kappa in kappa_batch:
                features = {
                    "kappa": _bytes_feature(kappa.numpy().tobytes()),
                    "kappa pixels": _int64_feature(kappa.shape[0]),
                }

                record = tf.train.Example(features=tf.train.Features(
                    feature=features)).SerializeToString()
                writer.write(record)
示例#4
0
def encode_examples(kappa: tf.Tensor, alpha: tf.Tensor, rescalings: list,
                    kappa_ids: list, einstein_radius: list, image_fov: float,
                    kappa_fov: float):
    batch_size = kappa.shape[0]
    pixels = kappa.shape[1]
    records = []
    for j in range(batch_size):
        features = {
            "kappa": _bytes_feature(kappa[j].numpy().tobytes()),
            "pixels": _int64_feature(pixels),
            "alpha": _bytes_feature(alpha[j].numpy().tobytes()),
            "rescale": _float_feature(rescalings[j]),
            "kappa_id": _int64_feature(kappa_ids[j]),
            "Einstein radius": _float_feature(einstein_radius[j]),
            "image_fov": _float_feature(image_fov),
            "kappa_fov": _float_feature(kappa_fov)
        }

        serialized_output = tf.train.Example(features=tf.train.Features(
            feature=features))
        record = serialized_output.SerializeToString()
        records.append(record)
    return records
示例#5
0
def encode_examples(
        kappa: tf.Tensor,
        alpha: tf.Tensor,
        kappa_fov: float
):
    batch_size = kappa.shape[0]
    pixels = kappa.shape[1]
    records = []
    for j in range(batch_size):
        features = {
            "kappa": _bytes_feature(kappa[j].numpy().tobytes()),
            "pixels": _int64_feature(pixels),
            "alpha": _bytes_feature(alpha[j].numpy().tobytes()),
            "kappa_fov": _float_feature(kappa_fov)
        }

        serialized_output = tf.train.Example(features=tf.train.Features(feature=features))
        record = serialized_output.SerializeToString()
        records.append(record)
    return records
示例#6
0
def distributed_strategy(args):
    cosmos_files = glob.glob(os.path.join(args.cosmos_dir, "*.tfrecords"))
    cosmos = tf.data.TFRecordDataset(cosmos_files)
    cosmos = cosmos.map(decode).map(preprocess).batch(args.batch_size)

    max_shift = min(args.crop, args.max_shift)

    options = tf.io.TFRecordOptions(compression_type=args.compression_type)
    with tf.io.TFRecordWriter(os.path.join(args.output_dir, f"data_{THIS_WORKER:d}.tfrecords"), options) as writer:
        for galaxies in cosmos:
            for j in range(galaxies.shape[0]):
                angle = np.random.randint(low=0, high=3, size=1)[0]
                galaxy = tf.image.rot90(galaxies[j], k=angle).numpy()
                if args.crop > 0:
                    shift = np.random.randint(low=-max_shift, high=max_shift, size=2)
                    galaxy = galaxy[args.crop + shift[0]: -(args.crop - shift[0]), args.crop + shift[1]: -(args.crop - shift[1]), ...]

                features = {
                    "image": _bytes_feature(galaxy.tobytes()),
                    "height": _int64_feature(galaxy.shape[0]),
                }
                record = tf.train.Example(features=tf.train.Features(feature=features)).SerializeToString()
                writer.write(record)
    print(f"Finished work at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}")
def distributed_strategy(args):
    print(
        f"Started worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}"
    )
    options = tf.io.TFRecordOptions(compression_type=args.compression_type)

    catalog = galsim.COSMOSCatalog(sample=args.sample,
                                   dir=args.cosmos_dir,
                                   exclusion_level=args.exclusion_level,
                                   min_flux=args.min_flux)
    n_galaxies = catalog.getNObjects()
    cat_param = catalog.param_cat[catalog.orig_index]
    sparams = cat_param['sersicfit']
    cat_param = append_fields(cat_param, 'sersic_q', sparams[:, 3])
    cat_param = append_fields(cat_param, 'sersic_n', sparams[:, 2])

    with tf.io.TFRecordWriter(
            os.path.join(args.output_dir, f"data_{THIS_WORKER}.tfrecords"),
            options) as writer:
        for index in range((THIS_WORKER - 1), n_galaxies, N_WORKERS):
            gal = catalog.makeGalaxy(index,
                                     noise_pad_size=args.pixels *
                                     args.pixel_scale)
            psf = gal.original_psf

            # We save the corresponding attributes for this galaxy
            if hasattr(args, 'attributes'):
                params = cat_param[index]
                attributes = {k: params[k] for k in args.attributes}
            else:
                attributes = None

            # Apply the PSF
            gal = galsim.Convolve(gal, psf)

            # Compute sqrt of absolute noise power spectrum, at the resolution and stamp size of target image
            ps = gal.noise._get_update_rootps(
                (args.pixels, args.pixels),
                wcs=galsim.PixelScale(args.pixel_scale))

            # We draw the pixel image of the convolved image
            im = gal.drawImage(nx=args.pixels,
                               ny=args.pixels,
                               scale=args.pixel_scale,
                               method='no_pixel',
                               use_true_center=False).array.astype('float32')

            # preprocess image
            # For correlated noise, we estimate that the sqrt of the Energy Spectral Density of the noise at (f_x=f_y=0)
            # is a good estimate of the STD
            if args.preprocess:
                im = tf.nn.relu(im - ps[0, 0]).numpy(
                )  # subtract background, fold negative pixels to 0
                im /= im.max()  # normalize peak to 1
                signal_pixels = np.sum(
                    im > args.signal_threshold
                )  # how many pixels have a value above a certain threshold
                if signal_pixels < args.signal_pixels:  # argument used to select only examples that are more distinct galaxy features (it does however bias the dataset in redshift space)
                    continue

            # Draw a kimage of the galaxy, just to figure out what maxk is, there might
            # be more efficient ways to do this though...
            bounds = galsim.BoundsI(0, args.pixels // 2, -args.pixels // 2,
                                    args.pixels // 2 - 1)
            imG = gal.drawKImage(bounds=bounds,
                                 scale=2. * np.pi /
                                 (args.pixels * args.pixels),
                                 recenter=False)
            mask = ~(np.fft.fftshift(imG.array, axes=0) == 0)

            # Draw the Fourier domain image of the galaxy, using x1 zero padding,
            # and x2 subsampling
            interp_factor = 2
            padding_factor = 1
            Nk = args.pixels * interp_factor * padding_factor
            bounds = galsim.BoundsI(0, Nk // 2, -Nk // 2, Nk // 2 - 1)
            imCp = psf.drawKImage(bounds=bounds,
                                  scale=2. * np.pi /
                                  (Nk * args.pixel_scale / interp_factor),
                                  recenter=False)

            # Transform the psf array into proper format, remove the phase
            im_psf = np.abs(np.fft.fftshift(imCp.array,
                                            axes=0)).astype('float32')

            # The following comes from correlatednoise.py
            rt2 = np.sqrt(2.)
            shape = (args.pixels, args.pixels)
            ps[0, 0] = rt2 * ps[0, 0]
            # Then make the changes necessary for even sized arrays
            if shape[1] % 2 == 0:  # x dimension even
                ps[0, shape[1] // 2] = rt2 * ps[0, shape[1] // 2]
            if shape[0] % 2 == 0:  # y dimension even
                ps[shape[0] // 2, 0] = rt2 * ps[shape[0] // 2, 0]
                # Both dimensions even
                if shape[1] % 2 == 0:
                    ps[shape[0] // 2,
                       shape[1] // 2] = rt2 * ps[shape[0] // 2, shape[1] // 2]

            # Apply mask to power spectrum so that it is very large outside maxk
            ps = np.where(mask, np.log(ps**2), 10).astype('float32')
            features = {
                "image": _bytes_feature(im.tobytes()),
                "height": _int64_feature(im.shape[0]),
                "width": _int64_feature(im.shape[1]),
                "psf": _bytes_feature(im_psf.tobytes()),
                "ps":
                _bytes_feature(ps.tobytes()),  # power spectrum of the noise
            }

            # Adding the parameters provided
            if attributes is not None:
                for k in attributes:
                    features['attrs_' + k] = _float_feature(attributes[k])

            record = tf.train.Example(features=tf.train.Features(
                feature=features)).SerializeToString()
            writer.write(record)
    print(
        f"Finished worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}"
    )
示例#8
0
def distributed_strategy(args):
    files = [glob.glob(os.path.join(args.dataset, "*.tfrecords"))]
    # Read concurrently from multiple records
    files = tf.data.Dataset.from_tensor_slices(files).shuffle(len(files), reshuffle_each_iteration=False)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type),
                               block_length=1, num_parallel_calls=tf.data.AUTOTUNE)
    total_items = int(np.sum(np.loadtxt(os.path.join(args.dataset, "shard_size.txt")), axis=0)[1])
    train_items = math.floor(args.train_split * total_items)

    dataset = dataset.shuffle(args.buffer_size, reshuffle_each_iteration=False).map(decode_all)
    train_dataset = dataset.take(train_items)
    val_dataset = dataset.skip(train_items)

    if THIS_WORKER > 1:
        time.sleep(3)
    train_dir = args.dataset + "_train"
    if not os.path.isdir(train_dir):
        os.mkdir(train_dir)
    val_dir = args.dataset + "_val"
    if not os.path.isdir(val_dir):
        os.mkdir(val_dir)
    if THIS_WORKER <= 1:
        with open(os.path.join(train_dir, "dataset_size.txt"), "w") as f:
            f.write(f"{train_items:d}")
        with open(os.path.join(val_dir, "dataset_size.txt"), "w") as f:
            f.write(f"{total_items-train_items:d}")
    options = tf.io.TFRecordOptions(compression_type=args.compression_type)
    train_shards = train_items // args.examples_per_shard + 1 * (train_items % args.examples_per_shard > 0)
    val_shards = (total_items - train_items) // args.examples_per_shard + 1 * ((total_items - train_items) % args.examples_per_shard > 0)

    for shard in range((THIS_WORKER - 1), train_shards, N_WORKERS):
        data = train_dataset.skip(shard * args.examples_per_shard).take(args.examples_per_shard)
        with tf.io.TFRecordWriter(os.path.join(train_dir, f"data_{shard:02d}.tfrecords"), options=options) as writer:
            for example in data:
                features = {
                    "kappa": _bytes_feature(example["kappa"].numpy().tobytes()),
                    "source": _bytes_feature(example["source"].numpy().tobytes()),
                    "lens": _bytes_feature(example["lens"].numpy().tobytes()),
                    "z source": _float_feature(example["z source"].numpy()),
                    "z lens": _float_feature(example["z lens"].numpy()),
                    "image fov": _float_feature(example["image fov"].numpy()),  # arc seconds
                    "kappa fov": _float_feature(example["kappa fov"].numpy()),  # arc seconds
                    "source fov": _float_feature(example["source fov"].numpy()),  # arc seconds
                    "src pixels": _int64_feature(example["source"].shape[0]),
                    "kappa pixels": _int64_feature(example["kappa"].shape[0]),
                    "pixels": _int64_feature(example["lens"].shape[0]),
                    "noise rms": _float_feature(example["noise rms"].numpy()),
                    "psf": _bytes_feature(example["psf"].numpy().tobytes()),
                    "psf pixels": _int64_feature(example["psf"].shape[0]),
                    "fwhm": _float_feature(example["fwhm"].numpy())
                }
                serialized_output = tf.train.Example(features=tf.train.Features(feature=features))
                record = serialized_output.SerializeToString()
                writer.write(record)
    for shard in range((THIS_WORKER - 1), val_shards, N_WORKERS):
        data = val_dataset.skip(shard * args.examples_per_shard).take(args.examples_per_shard)
        with tf.io.TFRecordWriter(os.path.join(val_dir, f"data_{shard:02d}.tfrecords"), options=options) as writer:
            for example in data:
                features = {
                    "kappa": _bytes_feature(example["kappa"].numpy().tobytes()),
                    "source": _bytes_feature(example["source"].numpy().tobytes()),
                    "lens": _bytes_feature(example["lens"].numpy().tobytes()),
                    "z source": _float_feature(example["z source"].numpy()),
                    "z lens": _float_feature(example["z lens"].numpy()),
                    "image fov": _float_feature(example["image fov"].numpy()),  # arc seconds
                    "kappa fov": _float_feature(example["kappa fov"].numpy()),  # arc seconds
                    "source fov": _float_feature(example["source fov"].numpy()),  # arc seconds
                    "src pixels": _int64_feature(example["source"].shape[0]),
                    "kappa pixels": _int64_feature(example["kappa"].shape[0]),
                    "pixels": _int64_feature(example["lens"].shape[0]),
                    "noise rms": _float_feature(example["noise rms"].numpy()),
                    "psf": _bytes_feature(example["psf"].numpy().tobytes()),
                    "psf pixels": _int64_feature(example["psf"].shape[0]),
                    "fwhm": _float_feature(example["fwhm"].numpy())
                }
                serialized_output = tf.train.Example(features=tf.train.Features(feature=features))
                record = serialized_output.SerializeToString()
                writer.write(record)
示例#9
0
def distributed_strategy(args):
    if THIS_WORKER > 1:
        time.sleep(5)
    output_dir = args.dataset + "_validated"
    if THIS_WORKER == 1 and os.path.exists(
            os.path.join(output_dir, "shard_size.txt")):
        os.remove(os.path.join(output_dir, "shard_size.txt"))
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)
    files = glob.glob(os.path.join(args.dataset, "*.tfrecords"))
    # Read concurrently from multiple records
    files = tf.data.Dataset.from_tensor_slices(files)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type),
                               block_length=args.block_length,
                               num_parallel_calls=tf.data.AUTOTUNE)
    for example in dataset.map(decode_physical_model_info):
        lens_pixels = example["pixels"].numpy()
        break
    options = tf.io.TFRecordOptions(compression_type=args.compression_type)
    kept = 0
    current_dataset = dataset.skip(
        (THIS_WORKER - 1) * args.example_per_worker).take(
            args.example_per_worker)

    # setup mask for edge detection
    x = tf.range(lens_pixels,
                 dtype=DTYPE) - lens_pixels // 2 + 0.5 * lens_pixels % 2
    x, y = tf.meshgrid(x, x)
    edge = lens_pixels // 2 - args.edge
    mask = (x > edge) | (x < -edge) | (y > edge) | (y < -edge)
    mask = tf.cast(mask[..., None], DTYPE)
    with tf.io.TFRecordWriter(
            os.path.join(output_dir, f"data_{THIS_WORKER:02d}.tfrecords"),
            options) as writer:
        for example in current_dataset.map(decode_all):
            im_area = tf.reduce_sum(
                tf.cast(example["lens"] > args.signal_threshold,
                        tf.float32)) * (example["image fov"] /
                                        tf.cast(example["pixels"], DTYPE))**2
            src_area = tf.reduce_sum(
                tf.cast(example["source"] > args.signal_threshold,
                        tf.float32)) * (example["source fov"] / tf.cast(
                            example["src pixels"], DTYPE))**2
            magnification = im_area / src_area
            if magnification < args.min_magnification:
                continue
            if tf.reduce_max(
                    example["lens"] * mask) > args.edge_signal_tolerance:
                continue
            if tf.reduce_sum(
                    tf.cast(example["source"] > args.source_signal_threshold,
                            tf.float32)) < args.min_source_signal_pixels:
                continue
            if tf.reduce_max(
                    example["kappa"]
            ) < 1.:  # this filters out some of the bad VAE samples.
                continue
            kept += 1
            features = {
                "kappa": _bytes_feature(example["kappa"].numpy().tobytes()),
                "source": _bytes_feature(example["source"].numpy().tobytes()),
                "lens": _bytes_feature(example["lens"].numpy().tobytes()),
                "z source": _float_feature(example["z source"].numpy()),
                "z lens": _float_feature(example["z lens"].numpy()),
                "image fov":
                _float_feature(example["image fov"].numpy()),  # arc seconds
                "kappa fov":
                _float_feature(example["kappa fov"].numpy()),  # arc seconds
                "source fov":
                _float_feature(example["source fov"].numpy()),  # arc seconds
                "src pixels": _int64_feature(example["source"].shape[0]),
                "kappa pixels": _int64_feature(example["kappa"].shape[0]),
                "pixels": _int64_feature(example["lens"].shape[0]),
                "noise rms": _float_feature(example["noise rms"].numpy()),
                "psf": _bytes_feature(example["psf"].numpy().tobytes()),
                "psf pixels": _int64_feature(example["psf"].shape[0]),
                "fwhm": _float_feature(example["fwhm"].numpy())
            }
            serialized_output = tf.train.Example(features=tf.train.Features(
                feature=features))
            record = serialized_output.SerializeToString()
            writer.write(record)
    print(
        f"Finished worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}, kept {kept:d} examples"
    )

    with open(os.path.join(output_dir, "shard_size.txt"), "a") as f:
        f.write(f"{THIS_WORKER} {kept:d}\n")
示例#10
0
def draw_and_encode_stamp(gal, psf, stamp_size, pixel_scale, attributes=None):
    """
    Draws the galaxy, psf and noise power spectrum on a postage stamp and
    encodes it to be exported in a TFRecord.

    Taken from galaxy2galaxy by François Lanusse https://github.com/ml4astro/galaxy2galaxy
    Modified by Alexandre Adam May 29, 2021
    """

    # Apply the PSF
    gal = galsim.Convolve(gal, psf)

    # Draw a kimage of the galaxy, just to figure out what maxk is, there might
    # be more efficient ways to do this though...
    bounds = galsim.BoundsI(0, stamp_size//2, -stamp_size//2, stamp_size//2-1)
    imG = gal.drawKImage(bounds=bounds,
                         scale=2.*np.pi/(stamp_size * pixel_scale),
                         recenter=False)
    mask = ~(np.fft.fftshift(imG.array, axes=0) == 0)

    # We draw the pixel image of the convolved image
    im = gal.drawImage(nx=stamp_size, ny=stamp_size, scale=pixel_scale,
                       method='no_pixel', use_true_center=False).array.astype('float32')

    # Draw the Fourier domain image of the galaxy, using x1 zero padding,
    # and x2 subsampling
    interp_factor = 2
    padding_factor = 1
    Nk = stamp_size*interp_factor*padding_factor
    bounds = galsim.BoundsI(0, Nk//2, -Nk//2, Nk//2-1)
    imCp = psf.drawKImage(bounds=bounds,
                          scale=2.*np.pi/(Nk * pixel_scale / interp_factor),
                          recenter=False)

    # Transform the psf array into proper format, remove the phase
    im_psf = np.abs(np.fft.fftshift(imCp.array, axes=0)).astype('float32')

    # Compute noise power spectrum, at the resolution and stamp size of target
    # image
    ps = gal.noise._get_update_rootps((stamp_size, stamp_size), wcs=galsim.PixelScale(pixel_scale))

    # The following comes from correlatednoise.py
    rt2 = np.sqrt(2.)
    shape = (stamp_size, stamp_size)
    ps[0, 0] = rt2 * ps[0, 0]
    # Then make the changes necessary for even sized arrays
    if shape[1] % 2 == 0:  # x dimension even
        ps[0, shape[1] // 2] = rt2 * ps[0, shape[1] // 2]
    if shape[0] % 2 == 0:  # y dimension even
        ps[shape[0] // 2, 0] = rt2 * ps[shape[0] // 2, 0]
        # Both dimensions even
        if shape[1] % 2 == 0:
            ps[shape[0] // 2, shape[1] // 2] = rt2 * \
                ps[shape[0] // 2, shape[1] // 2]

    # Apply mask to power spectrum so that it is very large outside maxk
    ps = np.where(mask, np.log(ps**2), 10).astype('float32')
    features = {
        "image": _bytes_feature(im.tobytes()),
        "height": _int64_feature(im.shape[0]),
        "width": _int64_feature(im.shape[1]),
        "psf": _bytes_feature(im_psf.tobytes()),
        "ps": _bytes_feature(ps.tobytes()),  # power spectrum of the noise
        # "ps_height": _int64_feature(ps.shape[0]),
        # "ps_width": _int64_feature(ps.shape[1])
    }

    # Adding the parameters provided
    if attributes is not None:
        for k in attributes:
            features['attrs_'+k] = _float_feature(attributes[k])

    serialized_output = tf.train.Example(features=tf.train.Features(feature=features))
    return serialized_output.SerializeToString()
示例#11
0
def main(args):
    files = [glob.glob(os.path.join(args.dataset, "*.tfrecords"))]
    # Read concurrently from multiple records
    files = tf.data.Dataset.from_tensor_slices(files).shuffle(
        len(files), reshuffle_each_iteration=False)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(
        x, compression_type=args.compression_type),
                               block_length=1,
                               num_parallel_calls=tf.data.AUTOTUNE)
    total_items = 0
    for _ in dataset:
        total_items += 1

    train_items = math.floor(args.train_split * total_items)

    dataset = dataset.shuffle(args.buffer_size, reshuffle_each_iteration=False)
    train_dataset = dataset.take(train_items)
    test_dataset = dataset.skip(train_items)

    train_dir = args.dataset + "_train"
    if not os.path.isdir(train_dir):
        os.mkdir(train_dir)
    test_dir = args.dataset + "_test"
    if not os.path.isdir(test_dir):
        os.mkdir(test_dir)
    options = tf.io.TFRecordOptions(compression_type=args.compression_type)
    train_shards = train_items // args.examples_per_shard + 1 * (
        train_items % args.examples_per_shard > 0)
    for shard in range(train_shards):
        data = train_dataset.skip(shard * args.examples_per_shard).take(
            args.examples_per_shard)
        with tf.io.TFRecordWriter(os.path.join(train_dir,
                                               f"data_{shard:02d}.tfrecords"),
                                  options=options) as writer:
            for image in data.map(decode_image):
                features = {
                    "image": _bytes_feature(image.numpy().tobytes()),
                    "height": _int64_feature(image.shape[0]),
                }
                record = tf.train.Example(features=tf.train.Features(
                    feature=features)).SerializeToString()
                writer.write(record)
    test_shards = (
        total_items - train_items) // args.examples_per_shard + 1 * (
            (total_items - train_items) % args.examples_per_shard > 0)
    for shard in range(test_shards):
        data = test_dataset.skip(shard * args.examples_per_shard).take(
            args.examples_per_shard)
        with tf.io.TFRecordWriter(os.path.join(test_dir,
                                               f"data_{shard:02d}.tfrecords"),
                                  options=options) as writer:
            for image in data.map(decode_image):
                features = {
                    "image": _bytes_feature(image.numpy().tobytes()),
                    "height": _int64_feature(image.shape[0]),
                }
                record = tf.train.Example(features=tf.train.Features(
                    feature=features)).SerializeToString()
                writer.write(record)

    with open(os.path.join(train_dir, "dataset_size.txt"), "w") as f:
        f.write(f"{train_items:d}")

    with open(os.path.join(test_dir, "dataset_size.txt"), "w") as f:
        f.write(f"{total_items-train_items:d}")