Example #1
0
    def _list_dir_helper_oss(self, root):
        # oss returns a file multiple times, e.g. listdir('root') returns
        #   ['folder', 'file1.txt', 'folder/file2.txt']
        # and then listdir('root/folder') returns
        #   ['file2.txt']
        filenames = set(path.join(root, i) for i in gfile.ListDirectory(root))
        res = []
        for fname in filenames:
            succ = path.join(path.dirname(fname), '_SUCCESS')
            if succ in filenames or not gfile.IsDirectory(fname):
                res.append(fname)

        return res
Example #2
0
 def generate_leader_raw_data(self):
     dbm = data_block_manager.DataBlockManager(self.data_source_l, 0)
     raw_data_dir = os.path.join(self.data_source_l.raw_data_dir,
                                 common.partition_repr(0))
     if gfile.Exists(raw_data_dir):
         gfile.DeleteRecursively(raw_data_dir)
     gfile.MakeDirs(raw_data_dir)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source_l, 0)
     block_index = 0
     builder = data_block_manager.DataBlockBuilder(
         self.data_source_l.raw_data_dir,
         self.data_source_l.data_source_meta.name, 0, block_index, None)
     process_index = 0
     start_index = 0
     for i in range(0, self.leader_end_index + 3):
         if (i > 0 and i % 2048 == 0) or (i == self.leader_end_index + 2):
             meta = builder.finish_data_block()
             if meta is not None:
                 ofname = common.encode_data_block_fname(
                     self.data_source_l.data_source_meta.name, meta)
                 fpath = os.path.join(raw_data_dir, ofname)
                 self.manifest_manager.add_raw_data(0, [fpath], False)
                 process_index += 1
                 start_index += len(meta.example_ids)
             block_index += 1
             builder = data_block_manager.DataBlockBuilder(
                 self.data_source_l.raw_data_dir,
                 self.data_source_l.data_source_meta.name, 0, block_index,
                 None)
         feat = {}
         pt = i + 1 << 30
         if i % 3 == 0:
             pt = i // 3
         example_id = '{}'.format(pt).encode()
         feat['example_id'] = tf.train.Feature(
             bytes_list=tf.train.BytesList(value=[example_id]))
         event_time = 150000000 + pt
         feat['event_time'] = tf.train.Feature(
             int64_list=tf.train.Int64List(value=[event_time]))
         example = tf.train.Example(features=tf.train.Features(
             feature=feat))
         builder.append(example.SerializeToString(), example_id, event_time,
                        i, i)
     fpaths = [
         os.path.join(raw_data_dir, f)
         for f in gfile.ListDirectory(raw_data_dir)
         if not gfile.IsDirectory(os.path.join(raw_data_dir, f))
     ]
     for fpath in fpaths:
         if not fpath.endswith(common.DataBlockSuffix):
             gfile.Remove(fpath)
Example #3
0
 def _list_file_metas(self, partition_id):
     dumped_dir = os.path.join(self._options.output_dir,
                               common.partition_repr(partition_id))
     if not gfile.Exists(dumped_dir):
         gfile.MakeDirs(dumped_dir)
     assert gfile.IsDirectory(dumped_dir)
     fnames = [
         os.path.basename(f) for f in gfile.ListDirectory(dumped_dir)
         if f.endswith(RawDataPartitioner.FileSuffix)
     ]
     return [
         RawDataPartitioner.FileMeta.decode_meta_from_fname(f)
         for f in fnames
     ]
 def __init__(self, options, partition_id):
     self._options = options
     self._partition_id = partition_id
     self._process_index = 0
     self._writer = None
     self._dumped_item = 0
     self._output_fpaths = []
     self._output_dir = os.path.join(
             self._options.output_dir,
             common.partition_repr(self._partition_id)
         )
     if not gfile.Exists(self._output_dir):
         gfile.MakeDirs(self._output_dir)
     assert gfile.IsDirectory(self._output_dir)
Example #5
0
    def parse_data_block_dir(self, data_block_dir, role="leader"):
        dir_path_list = [
            path.join(data_block_dir, f)
            for f in gfile.ListDirectory(data_block_dir)
            if gfile.IsDirectory(path.join(data_block_dir, f))
        ]
        for dir_path in dir_path_list:
            if role == "leader":
                self.leader_file_path_list += [
                    path.join(dir_path, f)
                    for f in gfile.ListDirectory(dir_path)
                    if f.split(".")[-1] == "data"
                    and not gfile.IsDirectory(path.join(dir_path, f))
                ]
            else:
                self.follower_file_path_list += [
                    path.join(dir_path, f)
                    for f in gfile.ListDirectory(dir_path)
                    if f.split(".")[-1] == "data"
                    and not gfile.IsDirectory(path.join(dir_path, f))
                ]

        self.leader_file_path_list.sort()
        self.follower_file_path_list.sort()
Example #6
0
 def _list_input_dir(self):
     all_inputs = []
     wildcard = self._portal_manifest.input_file_wildcard
     dirs = [self._portal_manifest.input_base_dir]
     while len(dirs) > 0:
         fdir = dirs[0]
         dirs = dirs[1:]
         fnames = gfile.ListDirectory(fdir)
         for fname in fnames:
             fpath = path.join(fdir, fname)
             if gfile.IsDirectory(fpath):
                 dirs.append(fpath)
             elif len(wildcard) == 0 or fnmatch(fname, wildcard):
                 all_inputs.append(fpath)
     return all_inputs
Example #7
0
def validate_holdout_selfplay():
    """Validate on held-out selfplay data."""
    holdout_dirs = (
        os.path.join(fsdb.holdout_dir(), d)
        for d in reversed(gfile.ListDirectory(fsdb.holdout_dir()))
        if gfile.IsDirectory(os.path.join(fsdb.holdout_dir(), d))
        for f in gfile.ListDirectory(os.path.join(fsdb.holdout_dir(), d)))

    # This is a roundabout way of computing how many hourly directories we need
    # to read in order to encompass 20,000 holdout games.
    holdout_dirs = set(itertools.islice(holdout_dirs), 20000)
    cmd = ['python3', 'validate.py'] + list(holdout_dirs) + [
        '--use_tpu', '--tpu_name={}'.format(TPU_NAME),
        '--flagfile=rl_loop/distributed_flags', '--expand_validation_dirs'
    ]
    mask_flags.run(cmd)
def visualize_dataset(dataset_name,
                      output_path,
                      num_animations=5,
                      num_frames=20,
                      fps=10):
    """Visualizes the data set by saving images to output_path.

  For each latent factor, outputs 16 images where only that latent factor is
  varied while all others are kept constant.

  Args:
    dataset_name: String with name of dataset as defined in named_data.py.
    output_path: String with path in which to create the visualizations.
    num_animations: Integer with number of distinct animations to create.
    num_frames: Integer with number of frames in each animation.
    fps: Integer with frame rate for the animation.
  """
    data = named_data.get_named_ground_truth_data(dataset_name)
    random_state = np.random.RandomState(0)

    # Create output folder if necessary.
    path = os.path.join(output_path, dataset_name)
    if not gfile.IsDirectory(path):
        gfile.MakeDirs(path)

    # Create still images.
    for i in range(data.num_factors):
        factors = data.sample_factors(16, random_state)
        indices = [j for j in range(data.num_factors) if i != j]
        factors[:, indices] = factors[0, indices]
        images = data.sample_observations_from_factors(factors, random_state)
        visualize_util.grid_save_images(
            images, os.path.join(path, "variations_of_factor%s.png" % i))

    # Create animations.
    for i in range(num_animations):
        base_factor = data.sample_factors(1, random_state)
        images = []
        for j, num_atoms in enumerate(data.factors_num_values):
            factors = np.repeat(base_factor, num_frames, axis=0)
            factors[:,
                    j] = visualize_util.cycle_factor(base_factor[0, j],
                                                     num_atoms, num_frames)
            images.append(
                data.sample_observations_from_factors(factors, random_state))
        visualize_util.save_animation(
            np.array(images), os.path.join(path, "animation%d.gif" % i), fps)
    def _list_input_dir(self):
        all_inputs = []
        wildcard = self._portal_manifest.input_file_wildcard
        dirs = [self._portal_manifest.input_base_dir]

        num_dirs = 0
        num_files = 0
        num_target_files = 0
        while len(dirs) > 0:
            fdir = dirs[0]
            dirs = dirs[1:]
            fnames = gfile.ListDirectory(fdir)
            for fname in fnames:
                fpath = path.join(fdir, fname)
                # OSS does not retain folder structure.
                # For example, if we have file oss://test/1001/a.txt
                # list(oss://test) returns 1001/a.txt instead of 1001
                basename = path.basename(fpath)
                if basename == '_SUCCESS':
                    continue
                if gfile.IsDirectory(fpath):
                    dirs.append(fpath)
                    num_dirs += 1
                    continue
                num_files += 1
                if len(wildcard) == 0 or fnmatch(basename, wildcard):
                    num_target_files += 1
                    if self._check_success_tag:
                        has_succ = gfile.Exists(
                            path.join(path.dirname(fpath), '_SUCCESS'))
                        if not has_succ:
                            logging.warning(
                                'File %s skipped because _SUCCESS file is '
                                'missing under %s', fpath, fdir)
                            continue
                    all_inputs.append(fpath)

        rest_fpaths = []
        for fpath in all_inputs:
            if fpath not in self._processed_fpath:
                rest_fpaths.append(fpath)
        logging.info(
            'Listing %s: found %d dirs, %d files, %d files matching wildcard, '
            '%d files with success tag, %d new files to process',
            self._portal_manifest.input_base_dir, num_dirs, num_files,
            num_target_files, len(all_inputs), len(rest_fpaths))
        return rest_fpaths
    def _list_dir_helper(self, root):
        filenames = list(gfile.ListDirectory(root))
        # If _SUCCESS is present, we assume there are no subdirs
        if '_SUCCESS' in filenames:
            return [path.join(root, i) for i in filenames]

        res = []
        for basename in filenames:
            fname = path.join(root, basename)
            if gfile.IsDirectory(fname):
                # 'ignore tmp dirs starting with _
                if basename.startswith('_'):
                    continue
                res += self._list_dir_helper(fname)
            else:
                res.append(fname)
        return res
Example #11
0
    def test_rename_dir(self):
        """Test rename dir.

    """
        # Setup and check preconditions.
        src_dir_name = "igfs:///test_rename_dir/1"
        dst_dir_name = "igfs:///test_rename_dir/2"
        gfile.MkDir(src_dir_name)
        # Rename directory.
        gfile.Rename(src_dir_name, dst_dir_name)
        # Check that only new name of directory is available.
        self.assertFalse(gfile.Exists(src_dir_name))
        self.assertTrue(gfile.Exists(dst_dir_name))
        self.assertTrue(gfile.IsDirectory(dst_dir_name))
        # Remove directory.
        gfile.Remove(dst_dir_name)
        # Check that directory was removed.
        self.assertFalse(gfile.Exists(dst_dir_name))
Example #12
0
 def __init__(self,
              data_source_name,
              data_block_fname,
              partition_id,
              dirpath,
              check_existed=True):
     assert data_block_fname.endswith(DataBlockSuffix), \
         "data block fname {} should has suffix {}".format(
             data_block_fname, DataBlockSuffix
         )
     block_id = data_block_fname[:-len(DataBlockSuffix)]
     segmap = decode_block_id(block_id)
     if segmap["data_source_name"] != data_source_name:
         raise ValueError("{} invalid. Data source name mismatch "\
                          "{} != {}".format(data_block_fname,
                              segmap["data_source_name"], data_source_name))
     self._data_source_name = data_source_name
     if segmap["partition_id"] != partition_id:
         raise ValueError("{} invalid. partition mismatch "\
                          "{} != {}".format(data_block_fname,
                              segmap["partition_id"], partition_id))
     self._partition_id = partition_id
     start_time, end_time = \
             segmap["time_frame"][0], segmap["time_frame"][1]
     if start_time > end_time:
         raise ValueError("{} invalid. time frame error start_time {} > "\
                          "end_time {}".format(data_block_fname,
                                               start_time, end_time))
     self._start_time, self._end_time = start_time, end_time
     self._data_block_index = segmap["data_block_index"]
     self._block_id = block_id
     meta_fname = encode_data_block_meta_fname(self._data_source_name,
                                               self._partition_id,
                                               self._data_block_index)
     meta_fpath = os.path.join(dirpath, meta_fname)
     if check_existed and (not gfile.Exists(meta_fpath) or \
                           gfile.IsDirectory(meta_fpath)):
         raise ValueError("{} invalid. the corresponding meta file "\
                          "is not existed".format(data_block_fname))
     self._data_block_meta_fpath = meta_fpath
     self._data_block_meta = None
     self._data_block_fpath = os.path.join(dirpath, data_block_fname)
 def _publish_raw_data(self, job_id):
     portal_manifest = self._sync_portal_manifest()
     output_dir = None
     if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
         output_dir = common.portal_map_output_dir(
             portal_manifest.output_base_dir, job_id)
     else:
         output_dir = common.portal_reduce_output_dir(
             portal_manifest.output_base_dir, job_id)
     for partition_id in range(self._output_partition_num):
         dpath = path.join(output_dir, common.partition_repr(partition_id))
         fnames = []
         if gfile.Exists(dpath) and gfile.IsDirectory(dpath):
             fnames = [
                 f for f in gfile.ListDirectory(dpath)
                 if f.endswith(common.RawDataFileSuffix)
             ]
         if portal_manifest.data_portal_type == dp_pb.DataPortalType.PSI:
             self._publish_psi_raw_data(partition_id, dpath, fnames)
         else:
             self._publish_streaming_raw_data(partition_id, dpath, fnames)
 def __init__(self, potral_manifest, potral_options, date_time):
     assert isinstance(date_time, datetime)
     self._potral_manifest = potral_manifest
     self._potral_options = potral_options
     self._date_time = date_time
     hourly_dir = common.encode_portal_hourly_dir(
         self._potral_manifest.output_data_base_dir, date_time)
     if not gfile.Exists(hourly_dir):
         gfile.MakeDirs(hourly_dir)
     if not gfile.IsDirectory(hourly_dir):
         logging.fatal("%s must be a directory for mapper output",
                       hourly_dir)
         os._exit(-1)  # pylint: disable=protected-access
     self._writers = []
     for partition_id in range(self.output_partition_num):
         fpath = common.encode_portal_hourly_fpath(
             self._potral_manifest.output_data_base_dir, date_time,
             partition_id)
         writer = PotralHourlyOutputMapper.OutputFileWriter(
             partition_id, fpath)
         self._writers.append(writer)
Example #15
0
 def generate_raw_data(self, begin_index, item_count):
     raw_data_dir = os.path.join(self.raw_data_dir,
                                 common.partition_repr(0))
     if not gfile.Exists(raw_data_dir):
         gfile.MakeDirs(raw_data_dir)
     self.total_raw_data_count += item_count
     useless_index = 0
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,
                                           0)
     fpaths = []
     for block_index in range(0, item_count // 2048):
         builder = DataBlockBuilder(
             self.raw_data_dir,
             self.data_source.data_source_meta.name, 0, block_index,
             dj_pb.WriterOptions(output_writer='TF_RECORD'), None)
         cands = list(
             range(begin_index + block_index * 2048,
                   begin_index + (block_index + 1) * 2048))
         start_index = cands[0]
         for i in range(len(cands)):
             if random.randint(1, 4) > 2:
                 continue
             a = random.randint(i - 32, i + 32)
             b = random.randint(i - 32, i + 32)
             if a < 0:
                 a = 0
             if a >= len(cands):
                 a = len(cands) - 1
             if b < 0:
                 b = 0
             if b >= len(cands):
                 b = len(cands) - 1
             if (abs(cands[a] - i - start_index) <= 32
                     and abs(cands[b] - i - start_index) <= 32):
                 cands[a], cands[b] = cands[b], cands[a]
         for example_idx in cands:
             feat = {}
             example_id = '{}'.format(example_idx).encode()
             feat['example_id'] = tf.train.Feature(
                 bytes_list=tf.train.BytesList(value=[example_id]))
             event_time = 150000000 + example_idx
             feat['event_time'] = tf.train.Feature(
                 int64_list=tf.train.Int64List(value=[event_time]))
             label = random.choice([1, 0])
             if random.random() < 0.8:
                 feat['label'] = tf.train.Feature(
                     int64_list=tf.train.Int64List(value=[label]))
             example = tf.train.Example(features=tf.train.Features(
                 feature=feat))
             builder.append_item(TfExampleItem(example.SerializeToString()),
                                 useless_index, useless_index)
             useless_index += 1
         meta = builder.finish_data_block()
         fname = common.encode_data_block_fname(
             self.data_source.data_source_meta.name, meta)
         fpath = os.path.join(raw_data_dir, fname)
         fpaths.append(
             dj_pb.RawDataMeta(
                 file_path=fpath,
                 timestamp=timestamp_pb2.Timestamp(seconds=3)))
         self.g_data_block_index += 1
     all_files = [
         os.path.join(raw_data_dir, f)
         for f in gfile.ListDirectory(raw_data_dir)
         if not gfile.IsDirectory(os.path.join(raw_data_dir, f))
     ]
     for fpath in all_files:
         if not fpath.endswith(common.DataBlockSuffix):
             gfile.Remove(fpath)
     self.manifest_manager.add_raw_data(0, fpaths, False)
 def generate_raw_data(self, etcd, rdp, data_source, partition_id,
                       block_size, shuffle_win_size, feat_key_fmt,
                       feat_val_fmt):
     dbm = data_block_manager.DataBlockManager(data_source, partition_id)
     raw_data_dir = os.path.join(data_source.raw_data_dir,
                                 common.partition_repr(partition_id))
     if gfile.Exists(raw_data_dir):
         gfile.DeleteRecursively(raw_data_dir)
     gfile.MakeDirs(raw_data_dir)
     useless_index = 0
     new_raw_data_fnames = []
     for block_index in range(self.total_index // block_size):
         builder = DataBlockBuilder(
             data_source.raw_data_dir, data_source.data_source_meta.name,
             partition_id, block_index,
             dj_pb.WriterOptions(output_writer='TF_RECORD'), None)
         cands = list(
             range(block_index * block_size,
                   (block_index + 1) * block_size))
         start_index = cands[0]
         for i in range(len(cands)):
             if random.randint(1, 4) > 2:
                 continue
             a = random.randint(i - shuffle_win_size, i + shuffle_win_size)
             b = random.randint(i - shuffle_win_size, i + shuffle_win_size)
             if a < 0:
                 a = 0
             if a >= len(cands):
                 a = len(cands) - 1
             if b < 0:
                 b = 0
             if b >= len(cands):
                 b = len(cands) - 1
             if (abs(cands[a] - i - start_index) <= shuffle_win_size and
                     abs(cands[b] - i - start_index) <= shuffle_win_size):
                 cands[a], cands[b] = cands[b], cands[a]
         for example_idx in cands:
             feat = {}
             example_id = '{}'.format(example_idx).encode()
             feat['example_id'] = tf.train.Feature(
                 bytes_list=tf.train.BytesList(value=[example_id]))
             event_time = 150000000 + example_idx
             feat['event_time'] = tf.train.Feature(
                 int64_list=tf.train.Int64List(value=[event_time]))
             feat[feat_key_fmt.format(example_idx)] = tf.train.Feature(
                 bytes_list=tf.train.BytesList(
                     value=[feat_val_fmt.format(example_idx).encode()]))
             example = tf.train.Example(features=tf.train.Features(
                 feature=feat))
             builder.append_item(TfExampleItem(example.SerializeToString()),
                                 useless_index, useless_index)
             useless_index += 1
         meta = builder.finish_data_block()
         fname = common.encode_data_block_fname(
             data_source.data_source_meta.name, meta)
         new_raw_data_fnames.append(os.path.join(raw_data_dir, fname))
     fpaths = [
         os.path.join(raw_data_dir, f)
         for f in gfile.ListDirectory(raw_data_dir)
         if not gfile.IsDirectory(os.path.join(raw_data_dir, f))
     ]
     for fpath in fpaths:
         if fpath.endswith(common.DataBlockMetaSuffix):
             gfile.Remove(fpath)
     rdp.publish_raw_data(partition_id, new_raw_data_fnames)
Example #17
0
 def _create_merged_dir_if_need(self):
     if not gfile.Exists(self._merged_dir):
         gfile.MakeDirs(self._merged_dir)
     assert gfile.IsDirectory(self._merged_dir)
 args = parser.parse_args()
 master_channel = make_insecure_channel(args.master_addr,
                                        ChannelType.INTERNAL)
 master_cli = dj_grpc.DataJoinMasterServiceStub(master_channel)
 data_src = master_cli.GetDataSource(empty_pb2.Empty())
 rdc = RawDataController(data_src, master_cli)
 if args.cmd == 'add':
     all_fpaths = []
     if args.files is not None:
         for fp in args.files:
             all_fpaths.append(fp)
     if args.src_dir is not None:
         dir_fpaths = \
                 [path.join(args.src_dir, f)
                  for f in gfile.ListDirectory(args.src_dir)
                  if not gfile.IsDirectory(path.join(args.src_dir, f))]
         dir_fpaths.sort()
         all_fpaths += dir_fpaths
     if not all_fpaths:
         raise RuntimeError("no raw data files supply")
     status = rdc.add_raw_data(args.partition_id, all_fpaths, args.dedup)
     if status.code != 0:
         logging.error("Failed to add raw data for partition %d reason "\
                       "%s", args.partition_id, status.error_message)
     else:
         logging.info("Success add following %d raw data file for "\
                      "partition %d", len(all_fpaths), args.partition_id)
         for idx, fp in enumerate(all_fpaths):
             logging.info("%d. %s", idx, fp)
 else:
     assert args.cmd == 'finish'
 def __init__(self, dir, write_graph=True):
     if not gfile.IsDirectory(dir):
         gfile.MakeDirs(dir)
     self.writer = tf.summary.FileWriter(
         dir, graph=tf.get_default_graph() if write_graph else None)
Example #20
0
 def _list_data_block(self, partition_id):
     dirpath = self._partition_data_block_dir(partition_id)
     if gfile.Exists(dirpath) and gfile.IsDirectory(dirpath):
         return [f for f in gfile.ListDirectory(dirpath)
                 if f.endswith(DataBlockSuffix)]
     return []
Example #21
0
def visualize(model_dir,
              output_dir,
              overwrite=False,
              num_animations=5,
              num_frames=20,
              fps=10,
              num_points_irs=10000):
    """Takes trained model from model_dir and visualizes it in output_dir.

  Args:
    model_dir: Path to directory where the trained model is saved.
    output_dir: Path to output directory.
    overwrite: Boolean indicating whether to overwrite output directory.
    num_animations: Integer with number of distinct animations to create.
    num_frames: Integer with number of frames in each animation.
    fps: Integer with frame rate for the animation.
    num_points_irs: Number of points to be used for the IRS plots.
  """
    # Fix the random seed for reproducibility.
    random_state = np.random.RandomState(0)

    # Create the output directory if necessary.
    if tf.gfile.IsDirectory(output_dir):
        if overwrite:
            tf.gfile.DeleteRecursively(output_dir)
        else:
            raise ValueError(
                "Directory already exists and overwrite is False.")

    # Automatically set the proper data set if necessary. We replace the active
    # gin config as this will lead to a valid gin config file where the data set
    # is present.
    # Obtain the dataset name from the gin config of the previous step.
    gin_config_file = os.path.join(model_dir, "results", "gin", "train.gin")
    gin_dict = results.gin_dict(gin_config_file)
    gin.bind_parameter("dataset.name",
                       gin_dict["dataset.name"].replace("'", ""))

    # Automatically infer the activation function from gin config.
    activation_str = gin_dict["reconstruction_loss.activation"]
    if activation_str == "'logits'":
        activation = sigmoid
    elif activation_str == "'tanh'":
        activation = tanh
    else:
        raise ValueError(
            "Activation function  could not be infered from gin config.")

    dataset = named_data.get_named_ground_truth_data()
    num_pics = 64
    module_path = os.path.join(model_dir, "tfhub")

    with hub.eval_function_for_module(module_path) as f:
        # Save reconstructions.
        real_pics = dataset.sample_observations(num_pics, random_state)
        raw_pics = f(dict(images=real_pics),
                     signature="reconstructions",
                     as_dict=True)["images"]
        pics = activation(raw_pics)
        paired_pics = np.concatenate((real_pics, pics), axis=2)
        paired_pics = [
            paired_pics[i, :, :, :] for i in range(paired_pics.shape[0])
        ]
        results_dir = os.path.join(output_dir, "reconstructions")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            paired_pics, os.path.join(results_dir, "reconstructions.jpg"))

        # Save samples.
        def _decoder(latent_vectors):
            return f(dict(latent_vectors=latent_vectors),
                     signature="decoder",
                     as_dict=True)["images"]

        num_latent = int(gin_dict["encoder.num_latent"])
        num_pics = 64
        random_codes = random_state.normal(0, 1, [num_pics, num_latent])
        pics = activation(_decoder(random_codes))
        results_dir = os.path.join(output_dir, "sampled")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        visualize_util.grid_save_images(
            pics, os.path.join(results_dir, "samples.jpg"))

        # Save latent traversals.
        result = f(
            dict(images=dataset.sample_observations(num_pics, random_state)),
            signature="gaussian_encoder",
            as_dict=True)
        means = result["mean"]
        logvars = result["logvar"]
        results_dir = os.path.join(output_dir, "traversals")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)
        for i in range(means.shape[1]):
            pics = activation(
                latent_traversal_1d_multi_dim(_decoder, means[i, :], None))
            file_name = os.path.join(results_dir, "traversals{}.jpg".format(i))
            visualize_util.grid_save_images([pics], file_name)

        # Save the latent traversal animations.
        results_dir = os.path.join(output_dir, "animated_traversals")
        if not gfile.IsDirectory(results_dir):
            gfile.MakeDirs(results_dir)

        # Cycle through quantiles of a standard Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_gaussian(
                    base_code[j], num_frames)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "std_gaussian_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle through quantiles of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                loc = np.mean(means[:, j])
                total_variance = np.mean(np.exp(logvars[:, j])) + np.var(
                    means[:, j])
                code[:, j] = visualize_util.cycle_gaussian(
                    base_code[j],
                    num_frames,
                    loc=loc,
                    scale=np.sqrt(total_variance))
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "fitted_gaussian_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle through [-2, 2] interval.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, -2., 2.)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "fixed_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle linearly through +-2 std dev of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                loc = np.mean(means[:, j])
                total_variance = np.mean(np.exp(logvars[:, j])) + np.var(
                    means[:, j])
                scale = np.sqrt(total_variance)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, loc - 2. * scale,
                    loc + 2. * scale)
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "conf_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Cycle linearly through minmax of a fitted Gaussian.
        for i, base_code in enumerate(means[:num_animations]):
            images = []
            for j in range(base_code.shape[0]):
                code = np.repeat(np.expand_dims(base_code, 0),
                                 num_frames,
                                 axis=0)
                code[:, j] = visualize_util.cycle_interval(
                    base_code[j], num_frames, np.min(means[:, j]),
                    np.max(means[:, j]))
                images.append(np.array(activation(_decoder(code))))
            filename = os.path.join(results_dir,
                                    "minmax_interval_cycle%d.gif" % i)
            visualize_util.save_animation(np.array(images), filename, fps)

        # Interventional effects visualization.
        factors = dataset.sample_factors(num_points_irs, random_state)
        obs = dataset.sample_observations_from_factors(factors, random_state)
        latents = f(dict(images=obs),
                    signature="gaussian_encoder",
                    as_dict=True)["mean"]
        results_dir = os.path.join(output_dir, "interventional_effects")
        vis_all_interventional_effects(factors, latents, results_dir)

    # Finally, we clear the gin config that we have set.
    gin.clear_config()