def _list_dir_helper_oss(self, root): # oss returns a file multiple times, e.g. listdir('root') returns # ['folder', 'file1.txt', 'folder/file2.txt'] # and then listdir('root/folder') returns # ['file2.txt'] filenames = set(path.join(root, i) for i in gfile.ListDirectory(root)) res = [] for fname in filenames: succ = path.join(path.dirname(fname), '_SUCCESS') if succ in filenames or not gfile.IsDirectory(fname): res.append(fname) return res
def generate_leader_raw_data(self): dbm = data_block_manager.DataBlockManager(self.data_source_l, 0) raw_data_dir = os.path.join(self.data_source_l.raw_data_dir, common.partition_repr(0)) if gfile.Exists(raw_data_dir): gfile.DeleteRecursively(raw_data_dir) gfile.MakeDirs(raw_data_dir) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source_l, 0) block_index = 0 builder = data_block_manager.DataBlockBuilder( self.data_source_l.raw_data_dir, self.data_source_l.data_source_meta.name, 0, block_index, None) process_index = 0 start_index = 0 for i in range(0, self.leader_end_index + 3): if (i > 0 and i % 2048 == 0) or (i == self.leader_end_index + 2): meta = builder.finish_data_block() if meta is not None: ofname = common.encode_data_block_fname( self.data_source_l.data_source_meta.name, meta) fpath = os.path.join(raw_data_dir, ofname) self.manifest_manager.add_raw_data(0, [fpath], False) process_index += 1 start_index += len(meta.example_ids) block_index += 1 builder = data_block_manager.DataBlockBuilder( self.data_source_l.raw_data_dir, self.data_source_l.data_source_meta.name, 0, block_index, None) feat = {} pt = i + 1 << 30 if i % 3 == 0: pt = i // 3 example_id = '{}'.format(pt).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + pt feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append(example.SerializeToString(), example_id, event_time, i, i) fpaths = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in fpaths: if not fpath.endswith(common.DataBlockSuffix): gfile.Remove(fpath)
def _list_file_metas(self, partition_id): dumped_dir = os.path.join(self._options.output_dir, common.partition_repr(partition_id)) if not gfile.Exists(dumped_dir): gfile.MakeDirs(dumped_dir) assert gfile.IsDirectory(dumped_dir) fnames = [ os.path.basename(f) for f in gfile.ListDirectory(dumped_dir) if f.endswith(RawDataPartitioner.FileSuffix) ] return [ RawDataPartitioner.FileMeta.decode_meta_from_fname(f) for f in fnames ]
def __init__(self, options, partition_id): self._options = options self._partition_id = partition_id self._process_index = 0 self._writer = None self._dumped_item = 0 self._output_fpaths = [] self._output_dir = os.path.join( self._options.output_dir, common.partition_repr(self._partition_id) ) if not gfile.Exists(self._output_dir): gfile.MakeDirs(self._output_dir) assert gfile.IsDirectory(self._output_dir)
def parse_data_block_dir(self, data_block_dir, role="leader"): dir_path_list = [ path.join(data_block_dir, f) for f in gfile.ListDirectory(data_block_dir) if gfile.IsDirectory(path.join(data_block_dir, f)) ] for dir_path in dir_path_list: if role == "leader": self.leader_file_path_list += [ path.join(dir_path, f) for f in gfile.ListDirectory(dir_path) if f.split(".")[-1] == "data" and not gfile.IsDirectory(path.join(dir_path, f)) ] else: self.follower_file_path_list += [ path.join(dir_path, f) for f in gfile.ListDirectory(dir_path) if f.split(".")[-1] == "data" and not gfile.IsDirectory(path.join(dir_path, f)) ] self.leader_file_path_list.sort() self.follower_file_path_list.sort()
def _list_input_dir(self): all_inputs = [] wildcard = self._portal_manifest.input_file_wildcard dirs = [self._portal_manifest.input_base_dir] while len(dirs) > 0: fdir = dirs[0] dirs = dirs[1:] fnames = gfile.ListDirectory(fdir) for fname in fnames: fpath = path.join(fdir, fname) if gfile.IsDirectory(fpath): dirs.append(fpath) elif len(wildcard) == 0 or fnmatch(fname, wildcard): all_inputs.append(fpath) return all_inputs
def validate_holdout_selfplay(): """Validate on held-out selfplay data.""" holdout_dirs = ( os.path.join(fsdb.holdout_dir(), d) for d in reversed(gfile.ListDirectory(fsdb.holdout_dir())) if gfile.IsDirectory(os.path.join(fsdb.holdout_dir(), d)) for f in gfile.ListDirectory(os.path.join(fsdb.holdout_dir(), d))) # This is a roundabout way of computing how many hourly directories we need # to read in order to encompass 20,000 holdout games. holdout_dirs = set(itertools.islice(holdout_dirs), 20000) cmd = ['python3', 'validate.py'] + list(holdout_dirs) + [ '--use_tpu', '--tpu_name={}'.format(TPU_NAME), '--flagfile=rl_loop/distributed_flags', '--expand_validation_dirs' ] mask_flags.run(cmd)
def visualize_dataset(dataset_name, output_path, num_animations=5, num_frames=20, fps=10): """Visualizes the data set by saving images to output_path. For each latent factor, outputs 16 images where only that latent factor is varied while all others are kept constant. Args: dataset_name: String with name of dataset as defined in named_data.py. output_path: String with path in which to create the visualizations. num_animations: Integer with number of distinct animations to create. num_frames: Integer with number of frames in each animation. fps: Integer with frame rate for the animation. """ data = named_data.get_named_ground_truth_data(dataset_name) random_state = np.random.RandomState(0) # Create output folder if necessary. path = os.path.join(output_path, dataset_name) if not gfile.IsDirectory(path): gfile.MakeDirs(path) # Create still images. for i in range(data.num_factors): factors = data.sample_factors(16, random_state) indices = [j for j in range(data.num_factors) if i != j] factors[:, indices] = factors[0, indices] images = data.sample_observations_from_factors(factors, random_state) visualize_util.grid_save_images( images, os.path.join(path, "variations_of_factor%s.png" % i)) # Create animations. for i in range(num_animations): base_factor = data.sample_factors(1, random_state) images = [] for j, num_atoms in enumerate(data.factors_num_values): factors = np.repeat(base_factor, num_frames, axis=0) factors[:, j] = visualize_util.cycle_factor(base_factor[0, j], num_atoms, num_frames) images.append( data.sample_observations_from_factors(factors, random_state)) visualize_util.save_animation( np.array(images), os.path.join(path, "animation%d.gif" % i), fps)
def _list_input_dir(self): all_inputs = [] wildcard = self._portal_manifest.input_file_wildcard dirs = [self._portal_manifest.input_base_dir] num_dirs = 0 num_files = 0 num_target_files = 0 while len(dirs) > 0: fdir = dirs[0] dirs = dirs[1:] fnames = gfile.ListDirectory(fdir) for fname in fnames: fpath = path.join(fdir, fname) # OSS does not retain folder structure. # For example, if we have file oss://test/1001/a.txt # list(oss://test) returns 1001/a.txt instead of 1001 basename = path.basename(fpath) if basename == '_SUCCESS': continue if gfile.IsDirectory(fpath): dirs.append(fpath) num_dirs += 1 continue num_files += 1 if len(wildcard) == 0 or fnmatch(basename, wildcard): num_target_files += 1 if self._check_success_tag: has_succ = gfile.Exists( path.join(path.dirname(fpath), '_SUCCESS')) if not has_succ: logging.warning( 'File %s skipped because _SUCCESS file is ' 'missing under %s', fpath, fdir) continue all_inputs.append(fpath) rest_fpaths = [] for fpath in all_inputs: if fpath not in self._processed_fpath: rest_fpaths.append(fpath) logging.info( 'Listing %s: found %d dirs, %d files, %d files matching wildcard, ' '%d files with success tag, %d new files to process', self._portal_manifest.input_base_dir, num_dirs, num_files, num_target_files, len(all_inputs), len(rest_fpaths)) return rest_fpaths
def _list_dir_helper(self, root): filenames = list(gfile.ListDirectory(root)) # If _SUCCESS is present, we assume there are no subdirs if '_SUCCESS' in filenames: return [path.join(root, i) for i in filenames] res = [] for basename in filenames: fname = path.join(root, basename) if gfile.IsDirectory(fname): # 'ignore tmp dirs starting with _ if basename.startswith('_'): continue res += self._list_dir_helper(fname) else: res.append(fname) return res
def test_rename_dir(self): """Test rename dir. """ # Setup and check preconditions. src_dir_name = "igfs:///test_rename_dir/1" dst_dir_name = "igfs:///test_rename_dir/2" gfile.MkDir(src_dir_name) # Rename directory. gfile.Rename(src_dir_name, dst_dir_name) # Check that only new name of directory is available. self.assertFalse(gfile.Exists(src_dir_name)) self.assertTrue(gfile.Exists(dst_dir_name)) self.assertTrue(gfile.IsDirectory(dst_dir_name)) # Remove directory. gfile.Remove(dst_dir_name) # Check that directory was removed. self.assertFalse(gfile.Exists(dst_dir_name))
def __init__(self, data_source_name, data_block_fname, partition_id, dirpath, check_existed=True): assert data_block_fname.endswith(DataBlockSuffix), \ "data block fname {} should has suffix {}".format( data_block_fname, DataBlockSuffix ) block_id = data_block_fname[:-len(DataBlockSuffix)] segmap = decode_block_id(block_id) if segmap["data_source_name"] != data_source_name: raise ValueError("{} invalid. Data source name mismatch "\ "{} != {}".format(data_block_fname, segmap["data_source_name"], data_source_name)) self._data_source_name = data_source_name if segmap["partition_id"] != partition_id: raise ValueError("{} invalid. partition mismatch "\ "{} != {}".format(data_block_fname, segmap["partition_id"], partition_id)) self._partition_id = partition_id start_time, end_time = \ segmap["time_frame"][0], segmap["time_frame"][1] if start_time > end_time: raise ValueError("{} invalid. time frame error start_time {} > "\ "end_time {}".format(data_block_fname, start_time, end_time)) self._start_time, self._end_time = start_time, end_time self._data_block_index = segmap["data_block_index"] self._block_id = block_id meta_fname = encode_data_block_meta_fname(self._data_source_name, self._partition_id, self._data_block_index) meta_fpath = os.path.join(dirpath, meta_fname) if check_existed and (not gfile.Exists(meta_fpath) or \ gfile.IsDirectory(meta_fpath)): raise ValueError("{} invalid. the corresponding meta file "\ "is not existed".format(data_block_fname)) self._data_block_meta_fpath = meta_fpath self._data_block_meta = None self._data_block_fpath = os.path.join(dirpath, data_block_fname)
def _publish_raw_data(self, job_id): portal_manifest = self._sync_portal_manifest() output_dir = None if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: output_dir = common.portal_map_output_dir( portal_manifest.output_base_dir, job_id) else: output_dir = common.portal_reduce_output_dir( portal_manifest.output_base_dir, job_id) for partition_id in range(self._output_partition_num): dpath = path.join(output_dir, common.partition_repr(partition_id)) fnames = [] if gfile.Exists(dpath) and gfile.IsDirectory(dpath): fnames = [ f for f in gfile.ListDirectory(dpath) if f.endswith(common.RawDataFileSuffix) ] if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI: self._publish_psi_raw_data(partition_id, dpath, fnames) else: self._publish_streaming_raw_data(partition_id, dpath, fnames)
def __init__(self, potral_manifest, potral_options, date_time): assert isinstance(date_time, datetime) self._potral_manifest = potral_manifest self._potral_options = potral_options self._date_time = date_time hourly_dir = common.encode_portal_hourly_dir( self._potral_manifest.output_data_base_dir, date_time) if not gfile.Exists(hourly_dir): gfile.MakeDirs(hourly_dir) if not gfile.IsDirectory(hourly_dir): logging.fatal("%s must be a directory for mapper output", hourly_dir) os._exit(-1) # pylint: disable=protected-access self._writers = [] for partition_id in range(self.output_partition_num): fpath = common.encode_portal_hourly_fpath( self._potral_manifest.output_data_base_dir, date_time, partition_id) writer = PotralHourlyOutputMapper.OutputFileWriter( partition_id, fpath) self._writers.append(writer)
def generate_raw_data(self, begin_index, item_count): raw_data_dir = os.path.join(self.raw_data_dir, common.partition_repr(0)) if not gfile.Exists(raw_data_dir): gfile.MakeDirs(raw_data_dir) self.total_raw_data_count += item_count useless_index = 0 rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source, 0) fpaths = [] for block_index in range(0, item_count // 2048): builder = DataBlockBuilder( self.raw_data_dir, self.data_source.data_source_meta.name, 0, block_index, dj_pb.WriterOptions(output_writer='TF_RECORD'), None) cands = list( range(begin_index + block_index * 2048, begin_index + (block_index + 1) * 2048)) start_index = cands[0] for i in range(len(cands)): if random.randint(1, 4) > 2: continue a = random.randint(i - 32, i + 32) b = random.randint(i - 32, i + 32) if a < 0: a = 0 if a >= len(cands): a = len(cands) - 1 if b < 0: b = 0 if b >= len(cands): b = len(cands) - 1 if (abs(cands[a] - i - start_index) <= 32 and abs(cands[b] - i - start_index) <= 32): cands[a], cands[b] = cands[b], cands[a] for example_idx in cands: feat = {} example_id = '{}'.format(example_idx).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + example_idx feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) label = random.choice([1, 0]) if random.random() < 0.8: feat['label'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[label])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append_item(TfExampleItem(example.SerializeToString()), useless_index, useless_index) useless_index += 1 meta = builder.finish_data_block() fname = common.encode_data_block_fname( self.data_source.data_source_meta.name, meta) fpath = os.path.join(raw_data_dir, fname) fpaths.append( dj_pb.RawDataMeta( file_path=fpath, timestamp=timestamp_pb2.Timestamp(seconds=3))) self.g_data_block_index += 1 all_files = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in all_files: if not fpath.endswith(common.DataBlockSuffix): gfile.Remove(fpath) self.manifest_manager.add_raw_data(0, fpaths, False)
def generate_raw_data(self, etcd, rdp, data_source, partition_id, block_size, shuffle_win_size, feat_key_fmt, feat_val_fmt): dbm = data_block_manager.DataBlockManager(data_source, partition_id) raw_data_dir = os.path.join(data_source.raw_data_dir, common.partition_repr(partition_id)) if gfile.Exists(raw_data_dir): gfile.DeleteRecursively(raw_data_dir) gfile.MakeDirs(raw_data_dir) useless_index = 0 new_raw_data_fnames = [] for block_index in range(self.total_index // block_size): builder = DataBlockBuilder( data_source.raw_data_dir, data_source.data_source_meta.name, partition_id, block_index, dj_pb.WriterOptions(output_writer='TF_RECORD'), None) cands = list( range(block_index * block_size, (block_index + 1) * block_size)) start_index = cands[0] for i in range(len(cands)): if random.randint(1, 4) > 2: continue a = random.randint(i - shuffle_win_size, i + shuffle_win_size) b = random.randint(i - shuffle_win_size, i + shuffle_win_size) if a < 0: a = 0 if a >= len(cands): a = len(cands) - 1 if b < 0: b = 0 if b >= len(cands): b = len(cands) - 1 if (abs(cands[a] - i - start_index) <= shuffle_win_size and abs(cands[b] - i - start_index) <= shuffle_win_size): cands[a], cands[b] = cands[b], cands[a] for example_idx in cands: feat = {} example_id = '{}'.format(example_idx).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + example_idx feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) feat[feat_key_fmt.format(example_idx)] = tf.train.Feature( bytes_list=tf.train.BytesList( value=[feat_val_fmt.format(example_idx).encode()])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append_item(TfExampleItem(example.SerializeToString()), useless_index, useless_index) useless_index += 1 meta = builder.finish_data_block() fname = common.encode_data_block_fname( data_source.data_source_meta.name, meta) new_raw_data_fnames.append(os.path.join(raw_data_dir, fname)) fpaths = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in fpaths: if fpath.endswith(common.DataBlockMetaSuffix): gfile.Remove(fpath) rdp.publish_raw_data(partition_id, new_raw_data_fnames)
def _create_merged_dir_if_need(self): if not gfile.Exists(self._merged_dir): gfile.MakeDirs(self._merged_dir) assert gfile.IsDirectory(self._merged_dir)
args = parser.parse_args() master_channel = make_insecure_channel(args.master_addr, ChannelType.INTERNAL) master_cli = dj_grpc.DataJoinMasterServiceStub(master_channel) data_src = master_cli.GetDataSource(empty_pb2.Empty()) rdc = RawDataController(data_src, master_cli) if args.cmd == 'add': all_fpaths = [] if args.files is not None: for fp in args.files: all_fpaths.append(fp) if args.src_dir is not None: dir_fpaths = \ [path.join(args.src_dir, f) for f in gfile.ListDirectory(args.src_dir) if not gfile.IsDirectory(path.join(args.src_dir, f))] dir_fpaths.sort() all_fpaths += dir_fpaths if not all_fpaths: raise RuntimeError("no raw data files supply") status = rdc.add_raw_data(args.partition_id, all_fpaths, args.dedup) if status.code != 0: logging.error("Failed to add raw data for partition %d reason "\ "%s", args.partition_id, status.error_message) else: logging.info("Success add following %d raw data file for "\ "partition %d", len(all_fpaths), args.partition_id) for idx, fp in enumerate(all_fpaths): logging.info("%d. %s", idx, fp) else: assert args.cmd == 'finish'
def __init__(self, dir, write_graph=True): if not gfile.IsDirectory(dir): gfile.MakeDirs(dir) self.writer = tf.summary.FileWriter( dir, graph=tf.get_default_graph() if write_graph else None)
def _list_data_block(self, partition_id): dirpath = self._partition_data_block_dir(partition_id) if gfile.Exists(dirpath) and gfile.IsDirectory(dirpath): return [f for f in gfile.ListDirectory(dirpath) if f.endswith(DataBlockSuffix)] return []
def visualize(model_dir, output_dir, overwrite=False, num_animations=5, num_frames=20, fps=10, num_points_irs=10000): """Takes trained model from model_dir and visualizes it in output_dir. Args: model_dir: Path to directory where the trained model is saved. output_dir: Path to output directory. overwrite: Boolean indicating whether to overwrite output directory. num_animations: Integer with number of distinct animations to create. num_frames: Integer with number of frames in each animation. fps: Integer with frame rate for the animation. num_points_irs: Number of points to be used for the IRS plots. """ # Fix the random seed for reproducibility. random_state = np.random.RandomState(0) # Create the output directory if necessary. if tf.gfile.IsDirectory(output_dir): if overwrite: tf.gfile.DeleteRecursively(output_dir) else: raise ValueError( "Directory already exists and overwrite is False.") # Automatically set the proper data set if necessary. We replace the active # gin config as this will lead to a valid gin config file where the data set # is present. # Obtain the dataset name from the gin config of the previous step. gin_config_file = os.path.join(model_dir, "results", "gin", "train.gin") gin_dict = results.gin_dict(gin_config_file) gin.bind_parameter("dataset.name", gin_dict["dataset.name"].replace("'", "")) # Automatically infer the activation function from gin config. activation_str = gin_dict["reconstruction_loss.activation"] if activation_str == "'logits'": activation = sigmoid elif activation_str == "'tanh'": activation = tanh else: raise ValueError( "Activation function could not be infered from gin config.") dataset = named_data.get_named_ground_truth_data() num_pics = 64 module_path = os.path.join(model_dir, "tfhub") with hub.eval_function_for_module(module_path) as f: # Save reconstructions. real_pics = dataset.sample_observations(num_pics, random_state) raw_pics = f(dict(images=real_pics), signature="reconstructions", as_dict=True)["images"] pics = activation(raw_pics) paired_pics = np.concatenate((real_pics, pics), axis=2) paired_pics = [ paired_pics[i, :, :, :] for i in range(paired_pics.shape[0]) ] results_dir = os.path.join(output_dir, "reconstructions") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) visualize_util.grid_save_images( paired_pics, os.path.join(results_dir, "reconstructions.jpg")) # Save samples. def _decoder(latent_vectors): return f(dict(latent_vectors=latent_vectors), signature="decoder", as_dict=True)["images"] num_latent = int(gin_dict["encoder.num_latent"]) num_pics = 64 random_codes = random_state.normal(0, 1, [num_pics, num_latent]) pics = activation(_decoder(random_codes)) results_dir = os.path.join(output_dir, "sampled") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) visualize_util.grid_save_images( pics, os.path.join(results_dir, "samples.jpg")) # Save latent traversals. result = f( dict(images=dataset.sample_observations(num_pics, random_state)), signature="gaussian_encoder", as_dict=True) means = result["mean"] logvars = result["logvar"] results_dir = os.path.join(output_dir, "traversals") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) for i in range(means.shape[1]): pics = activation( latent_traversal_1d_multi_dim(_decoder, means[i, :], None)) file_name = os.path.join(results_dir, "traversals{}.jpg".format(i)) visualize_util.grid_save_images([pics], file_name) # Save the latent traversal animations. results_dir = os.path.join(output_dir, "animated_traversals") if not gfile.IsDirectory(results_dir): gfile.MakeDirs(results_dir) # Cycle through quantiles of a standard Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) code[:, j] = visualize_util.cycle_gaussian( base_code[j], num_frames) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "std_gaussian_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle through quantiles of a fitted Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) loc = np.mean(means[:, j]) total_variance = np.mean(np.exp(logvars[:, j])) + np.var( means[:, j]) code[:, j] = visualize_util.cycle_gaussian( base_code[j], num_frames, loc=loc, scale=np.sqrt(total_variance)) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "fitted_gaussian_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle through [-2, 2] interval. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) code[:, j] = visualize_util.cycle_interval( base_code[j], num_frames, -2., 2.) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "fixed_interval_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle linearly through +-2 std dev of a fitted Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) loc = np.mean(means[:, j]) total_variance = np.mean(np.exp(logvars[:, j])) + np.var( means[:, j]) scale = np.sqrt(total_variance) code[:, j] = visualize_util.cycle_interval( base_code[j], num_frames, loc - 2. * scale, loc + 2. * scale) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "conf_interval_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Cycle linearly through minmax of a fitted Gaussian. for i, base_code in enumerate(means[:num_animations]): images = [] for j in range(base_code.shape[0]): code = np.repeat(np.expand_dims(base_code, 0), num_frames, axis=0) code[:, j] = visualize_util.cycle_interval( base_code[j], num_frames, np.min(means[:, j]), np.max(means[:, j])) images.append(np.array(activation(_decoder(code)))) filename = os.path.join(results_dir, "minmax_interval_cycle%d.gif" % i) visualize_util.save_animation(np.array(images), filename, fps) # Interventional effects visualization. factors = dataset.sample_factors(num_points_irs, random_state) obs = dataset.sample_observations_from_factors(factors, random_state) latents = f(dict(images=obs), signature="gaussian_encoder", as_dict=True)["mean"] results_dir = os.path.join(output_dir, "interventional_effects") vis_all_interventional_effects(factors, latents, results_dir) # Finally, we clear the gin config that we have set. gin.clear_config()