Esempio n. 1
0
    def get_hit_rate_and_ndcg(self,
                              predicted_scores_by_user,
                              items_by_user,
                              top_k=rconst.TOP_K,
                              match_mlperf=False):
        rconst.TOP_K = top_k
        rconst.NUM_EVAL_NEGATIVES = predicted_scores_by_user.shape[1] - 1

        g = tf.Graph()
        with g.as_default():
            logits = tf.convert_to_tensor(
                predicted_scores_by_user.reshape((-1, 1)), tf.float32)
            softmax_logits = tf.concat(
                [tf.zeros(logits.shape, dtype=logits.dtype), logits], axis=1)
            duplicate_mask = tf.convert_to_tensor(
                stat_utils.mask_duplicates(items_by_user, axis=1), tf.float32)

            metric_ops = neumf_model.compute_eval_loss_and_metrics(
                logits=logits,
                softmax_logits=softmax_logits,
                duplicate_mask=duplicate_mask,
                num_training_neg=NUM_TRAIN_NEG,
                match_mlperf=match_mlperf).eval_metric_ops

            hr = metric_ops[rconst.HR_KEY]
            ndcg = metric_ops[rconst.NDCG_KEY]

            init = [
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ]

        with self.test_session(graph=g) as sess:
            sess.run(init)
            return sess.run([hr[1], ndcg[1]])
Esempio n. 2
0
  def get_hit_rate_and_ndcg(self, predicted_scores_by_user, items_by_user,
                            top_k=rconst.TOP_K, match_mlperf=False):
    rconst.TOP_K = top_k
    rconst.NUM_EVAL_NEGATIVES = predicted_scores_by_user.shape[1] - 1

    g = tf.Graph()
    with g.as_default():
      logits = tf.convert_to_tensor(
          predicted_scores_by_user.reshape((-1, 1)), tf.float32)
      softmax_logits = tf.concat([tf.zeros(logits.shape, dtype=logits.dtype),
                                  logits], axis=1)
      duplicate_mask = tf.convert_to_tensor(
          stat_utils.mask_duplicates(items_by_user, axis=1), tf.float32)

      metric_ops = neumf_model.compute_eval_loss_and_metrics(
          logits=logits, softmax_logits=softmax_logits,
          duplicate_mask=duplicate_mask, num_training_neg=NUM_TRAIN_NEG,
          match_mlperf=match_mlperf).eval_metric_ops

      hr = metric_ops[rconst.HR_KEY]
      ndcg = metric_ops[rconst.NDCG_KEY]

      init = [tf.global_variables_initializer(),
              tf.local_variables_initializer()]

    with self.test_session(graph=g) as sess:
      sess.run(init)
      return sess.run([hr[1], ndcg[1]])
def _construct_eval_record(cache_paths, eval_batch_size):
  """Convert Eval data to a single TFRecords file."""

  # Later logic assumes that all items for a given user are in the same batch.
  assert not eval_batch_size % (rconst.NUM_EVAL_NEGATIVES + 1)

  log_msg("Beginning construction of eval TFRecords file.")
  raw_fpath = cache_paths.eval_raw_file
  intermediate_fpath = cache_paths.eval_record_template_temp
  dest_fpath = cache_paths.eval_record_template.format(eval_batch_size)
  with tf.gfile.Open(raw_fpath, "rb") as f:
    eval_data = pickle.load(f)

  users = eval_data[0][movielens.USER_COLUMN]
  items = eval_data[0][movielens.ITEM_COLUMN]
  assert users.shape == items.shape
  # eval_data[1] is the labels, but during evaluation they are infered as they
  # have a set structure. They are included the the data artifact for debug
  # purposes.

  # This packaging assumes that the caller knows to drop the padded values.
  n_pts = users.shape[0]
  n_pad = eval_batch_size - (n_pts % eval_batch_size)
  assert not (n_pts + n_pad) % eval_batch_size

  users = np.concatenate([users, np.zeros(shape=(n_pad,), dtype=np.int32)])\
    .reshape((-1, eval_batch_size))
  items = np.concatenate([items, np.zeros(shape=(n_pad,), dtype=np.uint16)])\
    .reshape((-1, eval_batch_size))

  num_batches = users.shape[0]
  with tf.python_io.TFRecordWriter(intermediate_fpath) as writer:
    for i in range(num_batches):
      batch_users = users[i, :]
      batch_items = items[i, :]
      dupe_mask = stat_utils.mask_duplicates(
          batch_items.reshape(-1, rconst.NUM_EVAL_NEGATIVES + 1),
          axis=1).flatten().astype(np.int8)

      batch_bytes = _construct_record(
          users=batch_users,
          items=batch_items,
          dupe_mask=dupe_mask
      )
      writer.write(batch_bytes)
  tf.gfile.Rename(intermediate_fpath, dest_fpath)
  log_msg("Eval TFRecords file successfully constructed.")
Esempio n. 4
0
    def _assemble_eval_batch(users, positive_items, negative_items,
                             users_per_batch):
        """Construct duplicate_mask and structure data accordingly.

    The positive items should be last so that they lose ties. However, they
    should not be masked out if the true eval positive happens to be
    selected as a negative. So instead, the positive is placed in the first
    position, and then switched with the last element after the duplicate
    mask has been computed.

    Args:
      users: An array of users in a batch. (should be identical along axis 1)
      positive_items: An array (batch_size x 1) of positive item indices.
      negative_items: An array of negative item indices.
      users_per_batch: How many users should be in the batch. This is passed
        as an argument so that ncf_test.py can use this method.

    Returns:
      User, item, and duplicate_mask arrays.
    """
        items = np.concatenate([positive_items, negative_items], axis=1)

        # We pad the users and items here so that the duplicate mask calculation
        # will include padding. The metric function relies on all padded elements
        # except the positive being marked as duplicate to mask out padded points.
        if users.shape[0] < users_per_batch:
            pad_rows = users_per_batch - users.shape[0]
            padding = np.zeros(shape=(pad_rows, users.shape[1]),
                               dtype=np.int32)
            users = np.concatenate([users, padding.astype(users.dtype)],
                                   axis=0)
            items = np.concatenate([items, padding.astype(items.dtype)],
                                   axis=0)

        duplicate_mask = stat_utils.mask_duplicates(items,
                                                    axis=1).astype(np.bool)

        items[:, (0, -1)] = items[:, (-1, 0)]
        duplicate_mask[:, (0, -1)] = duplicate_mask[:, (-1, 0)]

        assert users.shape == items.shape == duplicate_mask.shape
        return users, items, duplicate_mask
Esempio n. 5
0
  def _assemble_eval_batch(users, positive_items, negative_items,
                           users_per_batch):
    """Construct duplicate_mask and structure data accordingly.

    The positive items should be last so that they lose ties. However, they
    should not be masked out if the true eval positive happens to be
    selected as a negative. So instead, the positive is placed in the first
    position, and then switched with the last element after the duplicate
    mask has been computed.

    Args:
      users: An array of users in a batch. (should be identical along axis 1)
      positive_items: An array (batch_size x 1) of positive item indices.
      negative_items: An array of negative item indices.
      users_per_batch: How many users should be in the batch. This is passed
        as an argument so that ncf_test.py can use this method.

    Returns:
      User, item, and duplicate_mask arrays.
    """
    items = np.concatenate([positive_items, negative_items], axis=1)

    # We pad the users and items here so that the duplicate mask calculation
    # will include padding. The metric function relies on all padded elements
    # except the positive being marked as duplicate to mask out padded points.
    if users.shape[0] < users_per_batch:
      pad_rows = users_per_batch - users.shape[0]
      padding = np.zeros(shape=(pad_rows, users.shape[1]), dtype=np.int32)
      users = np.concatenate([users, padding.astype(users.dtype)], axis=0)
      items = np.concatenate([items, padding.astype(items.dtype)], axis=0)

    duplicate_mask = stat_utils.mask_duplicates(items, axis=1).astype(np.bool)

    items[:, (0, -1)] = items[:, (-1, 0)]
    duplicate_mask[:, (0, -1)] = duplicate_mask[:, (-1, 0)]

    assert users.shape == items.shape == duplicate_mask.shape
    return users, items, duplicate_mask
def _construct_records(
        is_training,  # type: bool
        train_cycle,  # type: typing.Optional[int]
        num_workers,  # type: int
        cache_paths,  # type: rconst.Paths
        num_readers,  # type: int
        num_neg,  # type: int
        num_positives,  # type: int
        num_items,  # type: int
        epochs_per_cycle,  # type: int
        batch_size,  # type: int
        training_shards,  # type: typing.List[str]
        deterministic=False,  # type: bool
        match_mlperf=False  # type: bool
):
    """Generate false negatives and write TFRecords files.

  Args:
    is_training: Are training records (True) or eval records (False) created.
    train_cycle: Integer of which cycle the generated data is for.
    num_workers: Number of multiprocessing workers to use for negative
      generation.
    cache_paths: Paths object with information of where to write files.
    num_readers: The number of reader datasets in the input_fn. This number is
      approximate; fewer shards will be created if not all shards are assigned
      batches. This can occur due to discretization in the assignment process.
    num_neg: The number of false negatives per positive example.
    num_positives: The number of positive examples. This value is used
      to pre-allocate arrays while the imap is still running. (NumPy does not
      allow dynamic arrays.)
    num_items: The cardinality of the item set.
    epochs_per_cycle: The number of epochs worth of data to construct.
    batch_size: The expected batch size used during training. This is used
      to properly batch data when writing TFRecords.
    training_shards: The picked positive examples from which to generate
      negatives.
  """
    st = timeit.default_timer()

    if is_training:
        mlperf_helper.ncf_print(
            key=mlperf_helper.TAGS.INPUT_STEP_TRAIN_NEG_GEN)
        mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_HP_NUM_NEG,
                                value=num_neg)

        # set inside _process_shard()
        mlperf_helper.ncf_print(
            key=mlperf_helper.TAGS.INPUT_HP_SAMPLE_TRAIN_REPLACEMENT,
            value=True)

    else:
        # Later logic assumes that all items for a given user are in the same batch.
        assert not batch_size % (rconst.NUM_EVAL_NEGATIVES + 1)
        assert num_neg == rconst.NUM_EVAL_NEGATIVES

        mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_STEP_EVAL_NEG_GEN)

        mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_HP_NUM_USERS,
                                value=num_positives)

    assert epochs_per_cycle == 1 or is_training
    num_workers = min([num_workers, len(training_shards) * epochs_per_cycle])

    num_pts = num_positives * (1 + num_neg)

    # Equivalent to `int(ceil(num_pts / batch_size)) * batch_size`, but without
    # precision concerns
    num_pts_with_padding = (num_pts + batch_size -
                            1) // batch_size * batch_size
    num_padding = num_pts_with_padding - num_pts

    # We choose a different random seed for each process, so that the processes
    # will not all choose the same random numbers.
    process_seeds = [
        stat_utils.random_int32() for _ in training_shards * epochs_per_cycle
    ]
    map_args = [(shard, num_items, num_neg, process_seeds[i], is_training,
                 match_mlperf)
                for i, shard in enumerate(training_shards * epochs_per_cycle)]

    with popen_helper.get_pool(num_workers, init_worker) as pool:
        map_fn = pool.imap if deterministic else pool.imap_unordered  # pylint: disable=no-member
        data_generator = map_fn(_process_shard, map_args)
        data = [
            np.zeros(shape=(num_pts_with_padding, ), dtype=np.int32) - 1,
            np.zeros(shape=(num_pts_with_padding, ), dtype=np.uint16),
            np.zeros(shape=(num_pts_with_padding, ), dtype=np.int8),
        ]

        # Training data is shuffled. Evaluation data MUST not be shuffled.
        # Downstream processing depends on the fact that evaluation data for a given
        # user is grouped within a batch.
        if is_training:
            index_destinations = np.random.permutation(num_pts)
            mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
        else:
            index_destinations = np.arange(num_pts)

        start_ind = 0
        for data_segment in data_generator:
            n_in_segment = data_segment[0].shape[0]
            dest = index_destinations[start_ind:start_ind + n_in_segment]
            start_ind += n_in_segment
            for i in range(3):
                data[i][dest] = data_segment[i]

    assert np.sum(data[0] == -1) == num_padding

    if is_training:
        if num_padding:
            # In order to have a full batch, randomly include points from earlier in
            # the batch.

            mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
            pad_sample_indices = np.random.randint(low=0,
                                                   high=num_pts,
                                                   size=(num_padding, ))
            dest = np.arange(start=start_ind, stop=start_ind + num_padding)
            start_ind += num_padding
            for i in range(3):
                data[i][dest] = data[i][pad_sample_indices]
    else:
        # For Evaluation, padding is all zeros. The evaluation input_fn knows how
        # to interpret and discard the zero padded entries.
        data[0][num_pts:] = 0

    # Check that no points were overlooked.
    assert not np.sum(data[0] == -1)

    if is_training:
        # The number of points is slightly larger than num_pts due to padding.
        mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_SIZE,
                                value=int(data[0].shape[0]))
        mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_BATCH_SIZE,
                                value=batch_size)
    else:
        # num_pts is logged instead of int(data[0].shape[0]), because the size
        # of the data vector includes zero pads which are ignored.
        mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_SIZE,
                                value=num_pts)

    batches_per_file = np.ceil(num_pts_with_padding / batch_size / num_readers)
    current_file_id = -1
    current_batch_id = -1
    batches_by_file = [[] for _ in range(num_readers)]

    while True:
        current_batch_id += 1
        if (current_batch_id % batches_per_file) == 0:
            current_file_id += 1

        start_ind = current_batch_id * batch_size
        end_ind = start_ind + batch_size
        if end_ind > num_pts_with_padding:
            if start_ind != num_pts_with_padding:
                raise ValueError("Batch padding does not line up")
            break
        batches_by_file[current_file_id].append(current_batch_id)

    # Drop shards which were not assigned batches
    batches_by_file = [i for i in batches_by_file if i]
    num_readers = len(batches_by_file)

    if is_training:
        # Empirically it is observed that placing the batch with repeated values at
        # the start rather than the end improves convergence.
        mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
        batches_by_file[0][0], batches_by_file[-1][-1] = \
          batches_by_file[-1][-1], batches_by_file[0][0]

    if is_training:
        template = rconst.TRAIN_RECORD_TEMPLATE
        record_dir = os.path.join(cache_paths.train_epoch_dir,
                                  get_cycle_folder_name(train_cycle))
        tf.gfile.MakeDirs(record_dir)
    else:
        template = rconst.EVAL_RECORD_TEMPLATE
        record_dir = cache_paths.eval_data_subdir

    batch_count = 0
    for i in range(num_readers):
        fpath = os.path.join(record_dir, template.format(i))
        log_msg("Writing {}".format(fpath))
        with tf.python_io.TFRecordWriter(fpath) as writer:
            for j in batches_by_file[i]:
                start_ind = j * batch_size
                end_ind = start_ind + batch_size
                record_kwargs = dict(
                    users=data[0][start_ind:end_ind],
                    items=data[1][start_ind:end_ind],
                )

                if is_training:
                    record_kwargs["labels"] = data[2][start_ind:end_ind]
                else:
                    record_kwargs["dupe_mask"] = stat_utils.mask_duplicates(
                        record_kwargs["items"].reshape(-1, num_neg + 1),
                        axis=1).flatten().astype(np.int8)

                batch_bytes = _construct_record(**record_kwargs)

                writer.write(batch_bytes)
                batch_count += 1

    # We write to a temp file then atomically rename it to the final file, because
    # writing directly to the final file can cause the main process to read a
    # partially written JSON file.
    ready_file_temp = os.path.join(record_dir, rconst.READY_FILE_TEMP)
    with tf.gfile.Open(ready_file_temp, "w") as f:
        json.dump({
            "batch_size": batch_size,
            "batch_count": batch_count,
        }, f)
    ready_file = os.path.join(record_dir, rconst.READY_FILE)
    tf.gfile.Rename(ready_file_temp, ready_file)

    if is_training:
        log_msg("Cycle {} complete. Total time: {:.1f} seconds".format(
            train_cycle,
            timeit.default_timer() - st))
    else:
        log_msg(
            "Eval construction complete. Total time: {:.1f} seconds".format(
                timeit.default_timer() - st))
Esempio n. 7
0
def _construct_records(
    is_training,          # type: bool
    train_cycle,          # type: typing.Optional[int]
    num_workers,          # type: int
    cache_paths,          # type: rconst.Paths
    num_readers,          # type: int
    num_neg,              # type: int
    num_positives,        # type: int
    num_items,            # type: int
    epochs_per_cycle,     # type: int
    batch_size,           # type: int
    training_shards,      # type: typing.List[str]
    deterministic=False,  # type: bool
    match_mlperf=False    # type: bool
    ):
  """Generate false negatives and write TFRecords files.

  Args:
    is_training: Are training records (True) or eval records (False) created.
    train_cycle: Integer of which cycle the generated data is for.
    num_workers: Number of multiprocessing workers to use for negative
      generation.
    cache_paths: Paths object with information of where to write files.
    num_readers: The number of reader datasets in the input_fn. This number is
      approximate; fewer shards will be created if not all shards are assigned
      batches. This can occur due to discretization in the assignment process.
    num_neg: The number of false negatives per positive example.
    num_positives: The number of positive examples. This value is used
      to pre-allocate arrays while the imap is still running. (NumPy does not
      allow dynamic arrays.)
    num_items: The cardinality of the item set.
    epochs_per_cycle: The number of epochs worth of data to construct.
    batch_size: The expected batch size used during training. This is used
      to properly batch data when writing TFRecords.
    training_shards: The picked positive examples from which to generate
      negatives.
  """
  st = timeit.default_timer()

  if is_training:
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_STEP_TRAIN_NEG_GEN)
    mlperf_helper.ncf_print(
        key=mlperf_helper.TAGS.INPUT_HP_NUM_NEG, value=num_neg)

    # set inside _process_shard()
    mlperf_helper.ncf_print(
        key=mlperf_helper.TAGS.INPUT_HP_SAMPLE_TRAIN_REPLACEMENT, value=True)

  else:
    # Later logic assumes that all items for a given user are in the same batch.
    assert not batch_size % (rconst.NUM_EVAL_NEGATIVES + 1)
    assert num_neg == rconst.NUM_EVAL_NEGATIVES

    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_STEP_EVAL_NEG_GEN)

    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_HP_NUM_USERS,
                            value=num_positives)

  assert epochs_per_cycle == 1 or is_training
  num_workers = min([num_workers, len(training_shards) * epochs_per_cycle])

  num_pts = num_positives * (1 + num_neg)

  # Equivalent to `int(ceil(num_pts / batch_size)) * batch_size`, but without
  # precision concerns
  num_pts_with_padding = (num_pts + batch_size - 1) // batch_size * batch_size
  num_padding = num_pts_with_padding - num_pts

  # We choose a different random seed for each process, so that the processes
  # will not all choose the same random numbers.
  process_seeds = [stat_utils.random_int32()
                   for _ in training_shards * epochs_per_cycle]
  map_args = [
      (shard, num_items, num_neg, process_seeds[i], is_training, match_mlperf)
      for i, shard in enumerate(training_shards * epochs_per_cycle)]

  with popen_helper.get_pool(num_workers, init_worker) as pool:
    map_fn = pool.imap if deterministic else pool.imap_unordered  # pylint: disable=no-member
    data_generator = map_fn(_process_shard, map_args)
    data = [
        np.zeros(shape=(num_pts_with_padding,), dtype=np.int32) - 1,
        np.zeros(shape=(num_pts_with_padding,), dtype=np.uint16),
        np.zeros(shape=(num_pts_with_padding,), dtype=np.int8),
    ]

    # Training data is shuffled. Evaluation data MUST not be shuffled.
    # Downstream processing depends on the fact that evaluation data for a given
    # user is grouped within a batch.
    if is_training:
      index_destinations = np.random.permutation(num_pts)
      mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
    else:
      index_destinations = np.arange(num_pts)

    start_ind = 0
    for data_segment in data_generator:
      n_in_segment = data_segment[0].shape[0]
      dest = index_destinations[start_ind:start_ind + n_in_segment]
      start_ind += n_in_segment
      for i in range(3):
        data[i][dest] = data_segment[i]

  assert np.sum(data[0] == -1) == num_padding

  if is_training:
    if num_padding:
      # In order to have a full batch, randomly include points from earlier in
      # the batch.

      mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
      pad_sample_indices = np.random.randint(
          low=0, high=num_pts, size=(num_padding,))
      dest = np.arange(start=start_ind, stop=start_ind + num_padding)
      start_ind += num_padding
      for i in range(3):
        data[i][dest] = data[i][pad_sample_indices]
  else:
    # For Evaluation, padding is all zeros. The evaluation input_fn knows how
    # to interpret and discard the zero padded entries.
    data[0][num_pts:] = 0

  # Check that no points were overlooked.
  assert not np.sum(data[0] == -1)

  if is_training:
    # The number of points is slightly larger than num_pts due to padding.
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_SIZE,
                            value=int(data[0].shape[0]))
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_BATCH_SIZE,
                            value=batch_size)
  else:
    # num_pts is logged instead of int(data[0].shape[0]), because the size
    # of the data vector includes zero pads which are ignored.
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.EVAL_SIZE, value=num_pts)

  batches_per_file = np.ceil(num_pts_with_padding / batch_size / num_readers)
  current_file_id = -1
  current_batch_id = -1
  batches_by_file = [[] for _ in range(num_readers)]

  while True:
    current_batch_id += 1
    if (current_batch_id % batches_per_file) == 0:
      current_file_id += 1

    start_ind = current_batch_id * batch_size
    end_ind = start_ind + batch_size
    if end_ind > num_pts_with_padding:
      if start_ind != num_pts_with_padding:
        raise ValueError("Batch padding does not line up")
      break
    batches_by_file[current_file_id].append(current_batch_id)

  # Drop shards which were not assigned batches
  batches_by_file = [i for i in batches_by_file if i]
  num_readers = len(batches_by_file)

  if is_training:
    # Empirically it is observed that placing the batch with repeated values at
    # the start rather than the end improves convergence.
    mlperf_helper.ncf_print(key=mlperf_helper.TAGS.INPUT_ORDER)
    batches_by_file[0][0], batches_by_file[-1][-1] = \
      batches_by_file[-1][-1], batches_by_file[0][0]

  if is_training:
    template = rconst.TRAIN_RECORD_TEMPLATE
    record_dir = os.path.join(cache_paths.train_epoch_dir,
                              get_cycle_folder_name(train_cycle))
    tf.gfile.MakeDirs(record_dir)
  else:
    template = rconst.EVAL_RECORD_TEMPLATE
    record_dir = cache_paths.eval_data_subdir

  batch_count = 0
  for i in range(num_readers):
    fpath = os.path.join(record_dir, template.format(i))
    log_msg("Writing {}".format(fpath))
    with tf.python_io.TFRecordWriter(fpath) as writer:
      for j in batches_by_file[i]:
        start_ind = j * batch_size
        end_ind = start_ind + batch_size
        record_kwargs = dict(
            users=data[0][start_ind:end_ind],
            items=data[1][start_ind:end_ind],
        )

        if is_training:
          record_kwargs["labels"] = data[2][start_ind:end_ind]
        else:
          record_kwargs["dupe_mask"] = stat_utils.mask_duplicates(
              record_kwargs["items"].reshape(-1, num_neg + 1),
              axis=1).flatten().astype(np.int8)

        batch_bytes = _construct_record(**record_kwargs)

        writer.write(batch_bytes)
        batch_count += 1

  # We write to a temp file then atomically rename it to the final file, because
  # writing directly to the final file can cause the main process to read a
  # partially written JSON file.
  ready_file_temp = os.path.join(record_dir, rconst.READY_FILE_TEMP)
  with tf.gfile.Open(ready_file_temp, "w") as f:
    json.dump({
        "batch_size": batch_size,
        "batch_count": batch_count,
    }, f)
  ready_file = os.path.join(record_dir, rconst.READY_FILE)
  tf.gfile.Rename(ready_file_temp, ready_file)

  if is_training:
    log_msg("Cycle {} complete. Total time: {:.1f} seconds"
            .format(train_cycle, timeit.default_timer() - st))
  else:
    log_msg("Eval construction complete. Total time: {:.1f} seconds"
            .format(timeit.default_timer() - st))