def encode_examples(kappa: tf.Tensor, galaxies: tf.Tensor, lensed_images: tf.Tensor, z_source: float, z_lens: float, image_fov: float, kappa_fov: float, source_fov: float, noise_rms: np.array, psf: tf.Tensor, fwhm: np.array): batch_size = galaxies.shape[0] source_pixels = galaxies.shape[1] kappa_pixels = kappa.shape[1] pixels = lensed_images.shape[1] psf_pixels = psf.shape[1] records = [] for j in range(batch_size): features = { "kappa": _bytes_feature(kappa[j].numpy().tobytes()), "source": _bytes_feature(galaxies[j].numpy().tobytes()), "lens": _bytes_feature(lensed_images[j].numpy().tobytes()), "z source": _float_feature(z_source), "z lens": _float_feature(z_lens), "image fov": _float_feature(image_fov), # arc seconds "kappa fov": _float_feature(kappa_fov), # arc seconds "source fov": _float_feature(source_fov), # arc seconds "src pixels": _int64_feature(source_pixels), "kappa pixels": _int64_feature(kappa_pixels), "pixels": _int64_feature(pixels), "noise rms": _float_feature(noise_rms[j]), "psf": _bytes_feature(psf[j].numpy().tobytes()), "psf pixels": _int64_feature(psf_pixels), "fwhm": _float_feature(fwhm[j]) } serialized_output = tf.train.Example(features=tf.train.Features( feature=features)) record = serialized_output.SerializeToString() records.append(record) return records
def main(args): options = tf.io.TFRecordOptions(compression_type=args.compression_type) with open(os.path.join(args.first_stage_model_id, "model_hparams.json"), "r") as f: vae_hparams = json.load(f) # load weights vae = VAE(**vae_hparams) ckpt1 = tf.train.Checkpoint(net=vae) checkpoint_manager1 = tf.train.CheckpointManager(ckpt1, args.first_stage_model_id, 1) checkpoint_manager1.checkpoint.restore( checkpoint_manager1.latest_checkpoint).expect_partial() n_batch = args.total_items // args.batch_size batch_per_record = args.n_records // n_batch last_record_n_batch = batch_per_record + n_batch % args.n_records for record in range(args.n_records - 1): with tf.io.TFRecordWriter( os.path.join(args.output_dir, f"data_{record:02d}.tfrecords"), options) as writer: for batch in range(batch_per_record): z = tf.random.normal(shape=[args.batch_size, vae.latent_size]) kappa_batch = vae.decode(z) for kappa in kappa_batch: features = { "kappa": _bytes_feature(kappa.numpy().tobytes()), "kappa pixels": _int64_feature(kappa.shape[0]), } record = tf.train.Example(features=tf.train.Features( feature=features)).SerializeToString() writer.write(record) with tf.io.TFRecordWriter( os.path.join(args.output_dir, f"data_{args.n_record-1:02d}.tfrecords"), options) as writer: for batch in range(last_record_n_batch): z = tf.random.normal(shape=[args.batch_size, vae.latent_size]) kappa_batch = vae.decode(z) for kappa in kappa_batch: features = { "kappa": _bytes_feature(kappa.numpy().tobytes()), "kappa pixels": _int64_feature(kappa.shape[0]), } record = tf.train.Example(features=tf.train.Features( feature=features)).SerializeToString() writer.write(record)
def encode_examples(kappa: tf.Tensor, einstein_radius_init: list, einstein_radius: list, rescalings: list, z_source: float, z_lens: float, kappa_fov: float, sigma_crit: float, kappa_ids: list): batch_size = kappa.shape[0] kappa_pixels = kappa.shape[1] records = [] for j in range(batch_size): features = { "kappa": _bytes_feature(kappa[j].numpy().tobytes()), "kappa pixels": _int64_feature(kappa_pixels), "Einstein radius before rescaling": _float_feature(einstein_radius_init[j]), "Einstein radius": _float_feature(einstein_radius[j]), "rescaling factor": _float_feature(rescalings[j]), "z source": _float_feature(z_source), "z lens": _float_feature(z_lens), "kappa fov": _float_feature(kappa_fov), # arc seconds "sigma crit": _float_feature(sigma_crit), # 10^10 M_sun / Mpc^2 "kappa id": _int64_feature(kappa_ids[j]) } serialized_output = tf.train.Example(features=tf.train.Features( feature=features)) record = serialized_output.SerializeToString() records.append(record) return records
def encode_examples( kappa: tf.Tensor, alpha: tf.Tensor, kappa_fov: float ): batch_size = kappa.shape[0] pixels = kappa.shape[1] records = [] for j in range(batch_size): features = { "kappa": _bytes_feature(kappa[j].numpy().tobytes()), "pixels": _int64_feature(pixels), "alpha": _bytes_feature(alpha[j].numpy().tobytes()), "kappa_fov": _float_feature(kappa_fov) } serialized_output = tf.train.Example(features=tf.train.Features(feature=features)) record = serialized_output.SerializeToString() records.append(record) return records
def encode_examples(kappa: tf.Tensor, alpha: tf.Tensor, rescalings: list, kappa_ids: list, einstein_radius: list, image_fov: float, kappa_fov: float): batch_size = kappa.shape[0] pixels = kappa.shape[1] records = [] for j in range(batch_size): features = { "kappa": _bytes_feature(kappa[j].numpy().tobytes()), "pixels": _int64_feature(pixels), "alpha": _bytes_feature(alpha[j].numpy().tobytes()), "rescale": _float_feature(rescalings[j]), "kappa_id": _int64_feature(kappa_ids[j]), "Einstein radius": _float_feature(einstein_radius[j]), "image_fov": _float_feature(image_fov), "kappa_fov": _float_feature(kappa_fov) } serialized_output = tf.train.Example(features=tf.train.Features( feature=features)) record = serialized_output.SerializeToString() records.append(record) return records
def distributed_strategy(args): cosmos_files = glob.glob(os.path.join(args.cosmos_dir, "*.tfrecords")) cosmos = tf.data.TFRecordDataset(cosmos_files) cosmos = cosmos.map(decode).map(preprocess).batch(args.batch_size) max_shift = min(args.crop, args.max_shift) options = tf.io.TFRecordOptions(compression_type=args.compression_type) with tf.io.TFRecordWriter(os.path.join(args.output_dir, f"data_{THIS_WORKER:d}.tfrecords"), options) as writer: for galaxies in cosmos: for j in range(galaxies.shape[0]): angle = np.random.randint(low=0, high=3, size=1)[0] galaxy = tf.image.rot90(galaxies[j], k=angle).numpy() if args.crop > 0: shift = np.random.randint(low=-max_shift, high=max_shift, size=2) galaxy = galaxy[args.crop + shift[0]: -(args.crop - shift[0]), args.crop + shift[1]: -(args.crop - shift[1]), ...] features = { "image": _bytes_feature(galaxy.tobytes()), "height": _int64_feature(galaxy.shape[0]), } record = tf.train.Example(features=tf.train.Features(feature=features)).SerializeToString() writer.write(record) print(f"Finished work at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}")
def distributed_strategy(args): print( f"Started worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}" ) options = tf.io.TFRecordOptions(compression_type=args.compression_type) catalog = galsim.COSMOSCatalog(sample=args.sample, dir=args.cosmos_dir, exclusion_level=args.exclusion_level, min_flux=args.min_flux) n_galaxies = catalog.getNObjects() cat_param = catalog.param_cat[catalog.orig_index] sparams = cat_param['sersicfit'] cat_param = append_fields(cat_param, 'sersic_q', sparams[:, 3]) cat_param = append_fields(cat_param, 'sersic_n', sparams[:, 2]) with tf.io.TFRecordWriter( os.path.join(args.output_dir, f"data_{THIS_WORKER}.tfrecords"), options) as writer: for index in range((THIS_WORKER - 1), n_galaxies, N_WORKERS): gal = catalog.makeGalaxy(index, noise_pad_size=args.pixels * args.pixel_scale) psf = gal.original_psf # We save the corresponding attributes for this galaxy if hasattr(args, 'attributes'): params = cat_param[index] attributes = {k: params[k] for k in args.attributes} else: attributes = None # Apply the PSF gal = galsim.Convolve(gal, psf) # Compute sqrt of absolute noise power spectrum, at the resolution and stamp size of target image ps = gal.noise._get_update_rootps( (args.pixels, args.pixels), wcs=galsim.PixelScale(args.pixel_scale)) # We draw the pixel image of the convolved image im = gal.drawImage(nx=args.pixels, ny=args.pixels, scale=args.pixel_scale, method='no_pixel', use_true_center=False).array.astype('float32') # preprocess image # For correlated noise, we estimate that the sqrt of the Energy Spectral Density of the noise at (f_x=f_y=0) # is a good estimate of the STD if args.preprocess: im = tf.nn.relu(im - ps[0, 0]).numpy( ) # subtract background, fold negative pixels to 0 im /= im.max() # normalize peak to 1 signal_pixels = np.sum( im > args.signal_threshold ) # how many pixels have a value above a certain threshold if signal_pixels < args.signal_pixels: # argument used to select only examples that are more distinct galaxy features (it does however bias the dataset in redshift space) continue # Draw a kimage of the galaxy, just to figure out what maxk is, there might # be more efficient ways to do this though... bounds = galsim.BoundsI(0, args.pixels // 2, -args.pixels // 2, args.pixels // 2 - 1) imG = gal.drawKImage(bounds=bounds, scale=2. * np.pi / (args.pixels * args.pixels), recenter=False) mask = ~(np.fft.fftshift(imG.array, axes=0) == 0) # Draw the Fourier domain image of the galaxy, using x1 zero padding, # and x2 subsampling interp_factor = 2 padding_factor = 1 Nk = args.pixels * interp_factor * padding_factor bounds = galsim.BoundsI(0, Nk // 2, -Nk // 2, Nk // 2 - 1) imCp = psf.drawKImage(bounds=bounds, scale=2. * np.pi / (Nk * args.pixel_scale / interp_factor), recenter=False) # Transform the psf array into proper format, remove the phase im_psf = np.abs(np.fft.fftshift(imCp.array, axes=0)).astype('float32') # The following comes from correlatednoise.py rt2 = np.sqrt(2.) shape = (args.pixels, args.pixels) ps[0, 0] = rt2 * ps[0, 0] # Then make the changes necessary for even sized arrays if shape[1] % 2 == 0: # x dimension even ps[0, shape[1] // 2] = rt2 * ps[0, shape[1] // 2] if shape[0] % 2 == 0: # y dimension even ps[shape[0] // 2, 0] = rt2 * ps[shape[0] // 2, 0] # Both dimensions even if shape[1] % 2 == 0: ps[shape[0] // 2, shape[1] // 2] = rt2 * ps[shape[0] // 2, shape[1] // 2] # Apply mask to power spectrum so that it is very large outside maxk ps = np.where(mask, np.log(ps**2), 10).astype('float32') features = { "image": _bytes_feature(im.tobytes()), "height": _int64_feature(im.shape[0]), "width": _int64_feature(im.shape[1]), "psf": _bytes_feature(im_psf.tobytes()), "ps": _bytes_feature(ps.tobytes()), # power spectrum of the noise } # Adding the parameters provided if attributes is not None: for k in attributes: features['attrs_' + k] = _float_feature(attributes[k]) record = tf.train.Example(features=tf.train.Features( feature=features)).SerializeToString() writer.write(record) print( f"Finished worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}" )
def distributed_strategy(args): files = [glob.glob(os.path.join(args.dataset, "*.tfrecords"))] # Read concurrently from multiple records files = tf.data.Dataset.from_tensor_slices(files).shuffle(len(files), reshuffle_each_iteration=False) dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=args.compression_type), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) total_items = int(np.sum(np.loadtxt(os.path.join(args.dataset, "shard_size.txt")), axis=0)[1]) train_items = math.floor(args.train_split * total_items) dataset = dataset.shuffle(args.buffer_size, reshuffle_each_iteration=False).map(decode_all) train_dataset = dataset.take(train_items) val_dataset = dataset.skip(train_items) if THIS_WORKER > 1: time.sleep(3) train_dir = args.dataset + "_train" if not os.path.isdir(train_dir): os.mkdir(train_dir) val_dir = args.dataset + "_val" if not os.path.isdir(val_dir): os.mkdir(val_dir) if THIS_WORKER <= 1: with open(os.path.join(train_dir, "dataset_size.txt"), "w") as f: f.write(f"{train_items:d}") with open(os.path.join(val_dir, "dataset_size.txt"), "w") as f: f.write(f"{total_items-train_items:d}") options = tf.io.TFRecordOptions(compression_type=args.compression_type) train_shards = train_items // args.examples_per_shard + 1 * (train_items % args.examples_per_shard > 0) val_shards = (total_items - train_items) // args.examples_per_shard + 1 * ((total_items - train_items) % args.examples_per_shard > 0) for shard in range((THIS_WORKER - 1), train_shards, N_WORKERS): data = train_dataset.skip(shard * args.examples_per_shard).take(args.examples_per_shard) with tf.io.TFRecordWriter(os.path.join(train_dir, f"data_{shard:02d}.tfrecords"), options=options) as writer: for example in data: features = { "kappa": _bytes_feature(example["kappa"].numpy().tobytes()), "source": _bytes_feature(example["source"].numpy().tobytes()), "lens": _bytes_feature(example["lens"].numpy().tobytes()), "z source": _float_feature(example["z source"].numpy()), "z lens": _float_feature(example["z lens"].numpy()), "image fov": _float_feature(example["image fov"].numpy()), # arc seconds "kappa fov": _float_feature(example["kappa fov"].numpy()), # arc seconds "source fov": _float_feature(example["source fov"].numpy()), # arc seconds "src pixels": _int64_feature(example["source"].shape[0]), "kappa pixels": _int64_feature(example["kappa"].shape[0]), "pixels": _int64_feature(example["lens"].shape[0]), "noise rms": _float_feature(example["noise rms"].numpy()), "psf": _bytes_feature(example["psf"].numpy().tobytes()), "psf pixels": _int64_feature(example["psf"].shape[0]), "fwhm": _float_feature(example["fwhm"].numpy()) } serialized_output = tf.train.Example(features=tf.train.Features(feature=features)) record = serialized_output.SerializeToString() writer.write(record) for shard in range((THIS_WORKER - 1), val_shards, N_WORKERS): data = val_dataset.skip(shard * args.examples_per_shard).take(args.examples_per_shard) with tf.io.TFRecordWriter(os.path.join(val_dir, f"data_{shard:02d}.tfrecords"), options=options) as writer: for example in data: features = { "kappa": _bytes_feature(example["kappa"].numpy().tobytes()), "source": _bytes_feature(example["source"].numpy().tobytes()), "lens": _bytes_feature(example["lens"].numpy().tobytes()), "z source": _float_feature(example["z source"].numpy()), "z lens": _float_feature(example["z lens"].numpy()), "image fov": _float_feature(example["image fov"].numpy()), # arc seconds "kappa fov": _float_feature(example["kappa fov"].numpy()), # arc seconds "source fov": _float_feature(example["source fov"].numpy()), # arc seconds "src pixels": _int64_feature(example["source"].shape[0]), "kappa pixels": _int64_feature(example["kappa"].shape[0]), "pixels": _int64_feature(example["lens"].shape[0]), "noise rms": _float_feature(example["noise rms"].numpy()), "psf": _bytes_feature(example["psf"].numpy().tobytes()), "psf pixels": _int64_feature(example["psf"].shape[0]), "fwhm": _float_feature(example["fwhm"].numpy()) } serialized_output = tf.train.Example(features=tf.train.Features(feature=features)) record = serialized_output.SerializeToString() writer.write(record)
def distributed_strategy(args): if THIS_WORKER > 1: time.sleep(5) output_dir = args.dataset + "_validated" if THIS_WORKER == 1 and os.path.exists( os.path.join(output_dir, "shard_size.txt")): os.remove(os.path.join(output_dir, "shard_size.txt")) if not os.path.isdir(output_dir): os.mkdir(output_dir) files = glob.glob(os.path.join(args.dataset, "*.tfrecords")) # Read concurrently from multiple records files = tf.data.Dataset.from_tensor_slices(files) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=args.block_length, num_parallel_calls=tf.data.AUTOTUNE) for example in dataset.map(decode_physical_model_info): lens_pixels = example["pixels"].numpy() break options = tf.io.TFRecordOptions(compression_type=args.compression_type) kept = 0 current_dataset = dataset.skip( (THIS_WORKER - 1) * args.example_per_worker).take( args.example_per_worker) # setup mask for edge detection x = tf.range(lens_pixels, dtype=DTYPE) - lens_pixels // 2 + 0.5 * lens_pixels % 2 x, y = tf.meshgrid(x, x) edge = lens_pixels // 2 - args.edge mask = (x > edge) | (x < -edge) | (y > edge) | (y < -edge) mask = tf.cast(mask[..., None], DTYPE) with tf.io.TFRecordWriter( os.path.join(output_dir, f"data_{THIS_WORKER:02d}.tfrecords"), options) as writer: for example in current_dataset.map(decode_all): im_area = tf.reduce_sum( tf.cast(example["lens"] > args.signal_threshold, tf.float32)) * (example["image fov"] / tf.cast(example["pixels"], DTYPE))**2 src_area = tf.reduce_sum( tf.cast(example["source"] > args.signal_threshold, tf.float32)) * (example["source fov"] / tf.cast( example["src pixels"], DTYPE))**2 magnification = im_area / src_area if magnification < args.min_magnification: continue if tf.reduce_max( example["lens"] * mask) > args.edge_signal_tolerance: continue if tf.reduce_sum( tf.cast(example["source"] > args.source_signal_threshold, tf.float32)) < args.min_source_signal_pixels: continue if tf.reduce_max( example["kappa"] ) < 1.: # this filters out some of the bad VAE samples. continue kept += 1 features = { "kappa": _bytes_feature(example["kappa"].numpy().tobytes()), "source": _bytes_feature(example["source"].numpy().tobytes()), "lens": _bytes_feature(example["lens"].numpy().tobytes()), "z source": _float_feature(example["z source"].numpy()), "z lens": _float_feature(example["z lens"].numpy()), "image fov": _float_feature(example["image fov"].numpy()), # arc seconds "kappa fov": _float_feature(example["kappa fov"].numpy()), # arc seconds "source fov": _float_feature(example["source fov"].numpy()), # arc seconds "src pixels": _int64_feature(example["source"].shape[0]), "kappa pixels": _int64_feature(example["kappa"].shape[0]), "pixels": _int64_feature(example["lens"].shape[0]), "noise rms": _float_feature(example["noise rms"].numpy()), "psf": _bytes_feature(example["psf"].numpy().tobytes()), "psf pixels": _int64_feature(example["psf"].shape[0]), "fwhm": _float_feature(example["fwhm"].numpy()) } serialized_output = tf.train.Example(features=tf.train.Features( feature=features)) record = serialized_output.SerializeToString() writer.write(record) print( f"Finished worker {THIS_WORKER} at {datetime.now().strftime('%y-%m-%d_%H-%M-%S')}, kept {kept:d} examples" ) with open(os.path.join(output_dir, "shard_size.txt"), "a") as f: f.write(f"{THIS_WORKER} {kept:d}\n")
def draw_and_encode_stamp(gal, psf, stamp_size, pixel_scale, attributes=None): """ Draws the galaxy, psf and noise power spectrum on a postage stamp and encodes it to be exported in a TFRecord. Taken from galaxy2galaxy by François Lanusse https://github.com/ml4astro/galaxy2galaxy Modified by Alexandre Adam May 29, 2021 """ # Apply the PSF gal = galsim.Convolve(gal, psf) # Draw a kimage of the galaxy, just to figure out what maxk is, there might # be more efficient ways to do this though... bounds = galsim.BoundsI(0, stamp_size//2, -stamp_size//2, stamp_size//2-1) imG = gal.drawKImage(bounds=bounds, scale=2.*np.pi/(stamp_size * pixel_scale), recenter=False) mask = ~(np.fft.fftshift(imG.array, axes=0) == 0) # We draw the pixel image of the convolved image im = gal.drawImage(nx=stamp_size, ny=stamp_size, scale=pixel_scale, method='no_pixel', use_true_center=False).array.astype('float32') # Draw the Fourier domain image of the galaxy, using x1 zero padding, # and x2 subsampling interp_factor = 2 padding_factor = 1 Nk = stamp_size*interp_factor*padding_factor bounds = galsim.BoundsI(0, Nk//2, -Nk//2, Nk//2-1) imCp = psf.drawKImage(bounds=bounds, scale=2.*np.pi/(Nk * pixel_scale / interp_factor), recenter=False) # Transform the psf array into proper format, remove the phase im_psf = np.abs(np.fft.fftshift(imCp.array, axes=0)).astype('float32') # Compute noise power spectrum, at the resolution and stamp size of target # image ps = gal.noise._get_update_rootps((stamp_size, stamp_size), wcs=galsim.PixelScale(pixel_scale)) # The following comes from correlatednoise.py rt2 = np.sqrt(2.) shape = (stamp_size, stamp_size) ps[0, 0] = rt2 * ps[0, 0] # Then make the changes necessary for even sized arrays if shape[1] % 2 == 0: # x dimension even ps[0, shape[1] // 2] = rt2 * ps[0, shape[1] // 2] if shape[0] % 2 == 0: # y dimension even ps[shape[0] // 2, 0] = rt2 * ps[shape[0] // 2, 0] # Both dimensions even if shape[1] % 2 == 0: ps[shape[0] // 2, shape[1] // 2] = rt2 * \ ps[shape[0] // 2, shape[1] // 2] # Apply mask to power spectrum so that it is very large outside maxk ps = np.where(mask, np.log(ps**2), 10).astype('float32') features = { "image": _bytes_feature(im.tobytes()), "height": _int64_feature(im.shape[0]), "width": _int64_feature(im.shape[1]), "psf": _bytes_feature(im_psf.tobytes()), "ps": _bytes_feature(ps.tobytes()), # power spectrum of the noise # "ps_height": _int64_feature(ps.shape[0]), # "ps_width": _int64_feature(ps.shape[1]) } # Adding the parameters provided if attributes is not None: for k in attributes: features['attrs_'+k] = _float_feature(attributes[k]) serialized_output = tf.train.Example(features=tf.train.Features(feature=features)) return serialized_output.SerializeToString()
def main(args): files = [glob.glob(os.path.join(args.dataset, "*.tfrecords"))] # Read concurrently from multiple records files = tf.data.Dataset.from_tensor_slices(files).shuffle( len(files), reshuffle_each_iteration=False) dataset = files.interleave(lambda x: tf.data.TFRecordDataset( x, compression_type=args.compression_type), block_length=1, num_parallel_calls=tf.data.AUTOTUNE) total_items = 0 for _ in dataset: total_items += 1 train_items = math.floor(args.train_split * total_items) dataset = dataset.shuffle(args.buffer_size, reshuffle_each_iteration=False) train_dataset = dataset.take(train_items) test_dataset = dataset.skip(train_items) train_dir = args.dataset + "_train" if not os.path.isdir(train_dir): os.mkdir(train_dir) test_dir = args.dataset + "_test" if not os.path.isdir(test_dir): os.mkdir(test_dir) options = tf.io.TFRecordOptions(compression_type=args.compression_type) train_shards = train_items // args.examples_per_shard + 1 * ( train_items % args.examples_per_shard > 0) for shard in range(train_shards): data = train_dataset.skip(shard * args.examples_per_shard).take( args.examples_per_shard) with tf.io.TFRecordWriter(os.path.join(train_dir, f"data_{shard:02d}.tfrecords"), options=options) as writer: for image in data.map(decode_image): features = { "image": _bytes_feature(image.numpy().tobytes()), "height": _int64_feature(image.shape[0]), } record = tf.train.Example(features=tf.train.Features( feature=features)).SerializeToString() writer.write(record) test_shards = ( total_items - train_items) // args.examples_per_shard + 1 * ( (total_items - train_items) % args.examples_per_shard > 0) for shard in range(test_shards): data = test_dataset.skip(shard * args.examples_per_shard).take( args.examples_per_shard) with tf.io.TFRecordWriter(os.path.join(test_dir, f"data_{shard:02d}.tfrecords"), options=options) as writer: for image in data.map(decode_image): features = { "image": _bytes_feature(image.numpy().tobytes()), "height": _int64_feature(image.shape[0]), } record = tf.train.Example(features=tf.train.Features( feature=features)).SerializeToString() writer.write(record) with open(os.path.join(train_dir, "dataset_size.txt"), "w") as f: f.write(f"{train_items:d}") with open(os.path.join(test_dir, "dataset_size.txt"), "w") as f: f.write(f"{total_items-train_items:d}")