Example #1
0
  def test_load_and_remap_invalid_remapping(self):
    """Tests that errors are raised when an ID maps to multiple new IDs.

    (This should usually not happen when using public APIs).
    """
    invalid_remapping = [1, 0, 0, 0, 1, 2]

    # Invalid row remapping.
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=invalid_remapping,
        col_remapping=[],
        initializing_values=[],
        num_rows=len(invalid_remapping),
        num_cols=self.old_num_cols)
    with self.cached_session(), self.assertRaises(errors.UnimplementedError):
      self.evaluate(remapped_matrix)

    # Invalid column remapping.
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=list(range(self.old_num_rows)),
        col_remapping=invalid_remapping,
        initializing_values=[],
        num_rows=self.old_num_rows,
        num_cols=len(invalid_remapping))
    with self.cached_session(), self.assertRaises(errors.UnimplementedError):
      self.evaluate(remapped_matrix)
Example #2
0
  def test_load_and_remap_incorrect_initializing_values(self):
    """Tests that errors are raised with incorrect number of init values."""
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=[2, -1, 0],
        col_remapping=[1, -1],
        # Too few initializing values - there should be 4. For some reason,
        # initializing_values must contain no element (instead of 3 or fewer) to
        # ensure that a seg fault would reliably occur if the check raising the
        # InvalidArgumentError were not present.
        initializing_values=[],
        num_rows=3,
        num_cols=2)
    with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(remapped_matrix)

    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=[2, -1, 0],
        col_remapping=[1, -1],
        # Too many initializing values - there should be 4.
        initializing_values=[0] * 5,
        num_rows=3,
        num_cols=2)
    with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(remapped_matrix)
Example #3
0
  def test_load_and_remap_incorrect_initializing_values(self):
    """Tests that errors are raised with incorrect number of init values."""
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=[2, -1, 0],
        col_remapping=[1, -1],
        # Too few initializing values - there should be 4. For some reason,
        # initializing_values must contain no element (instead of 3 or fewer) to
        # ensure that a seg fault would reliably occur if the check raising the
        # InvalidArgumentError were not present.
        initializing_values=[],
        num_rows=3,
        num_cols=2)
    with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
      remapped_matrix.eval()

    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=[2, -1, 0],
        col_remapping=[1, -1],
        # Too many initializing values - there should be 4.
        initializing_values=[0] * 5,
        num_rows=3,
        num_cols=2)
    with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
      remapped_matrix.eval()
Example #4
0
  def test_load_and_remap_invalid_remapping(self):
    """Tests that errors are raised when an ID maps to multiple new IDs.

    (This should usually not happen when using public APIs).
    """
    invalid_remapping = [1, 0, 0, 0, 1, 2]

    # Invalid row remapping.
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=invalid_remapping,
        col_remapping=[],
        initializing_values=[],
        num_rows=len(invalid_remapping),
        num_cols=self.old_num_cols)
    with self.cached_session(), self.assertRaises(errors.UnimplementedError):
      remapped_matrix.eval()

    # Invalid column remapping.
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=list(range(self.old_num_rows)),
        col_remapping=invalid_remapping,
        initializing_values=[],
        num_rows=self.old_num_rows,
        num_cols=len(invalid_remapping))
    with self.cached_session(), self.assertRaises(errors.UnimplementedError):
      remapped_matrix.eval()
Example #5
0
    def test_load_and_remap_no_missing(self):
        """Tests the op's load and remap where there are no missing entries."""

        # No column remapping, new weight matrix has second row, then first row.
        row_remapping = [1, 0]
        remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
            ckpt_path=[self.bundle_file],
            old_tensor_name=self.old_tensor_name,
            row_remapping=row_remapping,
            col_remapping=[],
            initializing_values=[],
            num_rows=2,
            num_cols=self.old_num_cols)
        with self.cached_session():
            self.assertAllClose(self.matrix_value[row_remapping],
                                self.evaluate(remapped_matrix))

        # No row remapping, new weight matrix has third col, then first col.
        row_remapping = list(range(self.old_num_rows))
        col_remapping = [2, 0]
        remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
            ckpt_path=[self.bundle_file],
            old_tensor_name=self.old_tensor_name,
            row_remapping=row_remapping,
            col_remapping=col_remapping,
            initializing_values=[],
            num_rows=len(row_remapping),
            num_cols=len(col_remapping))
        with self.cached_session():
            self.assertAllClose(
                self.matrix_value[row_remapping][:, col_remapping],
                self.evaluate(remapped_matrix))

        # Both row and column remappings.
        row_remapping = [1, 0, 4]
        col_remapping = [1, 15]
        remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
            ckpt_path=[self.bundle_file],
            old_tensor_name=self.old_tensor_name,
            row_remapping=row_remapping,
            col_remapping=col_remapping,
            initializing_values=[],
            num_rows=len(row_remapping),
            num_cols=len(col_remapping))
        with self.cached_session():
            self.assertAllClose(
                self.matrix_value[row_remapping][:, col_remapping],
                self.evaluate(remapped_matrix))
Example #6
0
  def test_load_and_remap_no_missing(self):
    """Tests the op's load and remap where there are no missing entries."""

    # No column remapping, new weight matrix has second row, then first row.
    row_remapping = [1, 0]
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=row_remapping,
        col_remapping=[],
        initializing_values=[],
        num_rows=2,
        num_cols=self.old_num_cols)
    with self.cached_session():
      self.assertAllClose(self.matrix_value[row_remapping],
                          remapped_matrix.eval())

    # No row remapping, new weight matrix has third col, then first col.
    row_remapping = list(range(self.old_num_rows))
    col_remapping = [2, 0]
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=row_remapping,
        col_remapping=col_remapping,
        initializing_values=[],
        num_rows=len(row_remapping),
        num_cols=len(col_remapping))
    with self.cached_session():
      self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
                          remapped_matrix.eval())

    # Both row and column remappings.
    row_remapping = [1, 0, 4]
    col_remapping = [1, 15]
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=row_remapping,
        col_remapping=col_remapping,
        initializing_values=[],
        num_rows=len(row_remapping),
        num_cols=len(col_remapping))
    with self.cached_session():
      self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
                          remapped_matrix.eval())
Example #7
0
 def test_load_and_remap_all_missing_rows(self):
   """Tests when all the rows are missing and need to be initialized."""
   num_rows = 7
   initializing_values = [42] * num_rows * self.old_num_cols
   remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
       ckpt_path=[self.bundle_file],
       old_tensor_name=self.old_tensor_name,
       row_remapping=[-1] * num_rows,
       col_remapping=[],
       initializing_values=initializing_values,
       num_rows=num_rows,
       num_cols=self.old_num_cols)
   with self.cached_session():
     self.assertAllClose(
         np.reshape(initializing_values, (num_rows, self.old_num_cols)),
         self.evaluate(remapped_matrix))
Example #8
0
 def test_load_and_remap_all_missing_rows(self):
   """Tests when all the rows are missing and need to be initialized."""
   num_rows = 7
   initializing_values = [42] * num_rows * self.old_num_cols
   remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
       ckpt_path=[self.bundle_file],
       old_tensor_name=self.old_tensor_name,
       row_remapping=[-1] * num_rows,
       col_remapping=[],
       initializing_values=initializing_values,
       num_rows=num_rows,
       num_cols=self.old_num_cols)
   with self.cached_session():
     self.assertAllClose(
         np.reshape(initializing_values, (num_rows, self.old_num_cols)),
         remapped_matrix.eval())
Example #9
0
  def test_load_and_remap_with_init(self):
    """Tests the op's load and remap where there are missing entries."""
    init_val = 42
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=[2, -1, 0],
        col_remapping=[1, -1],
        initializing_values=[init_val] * 4,
        num_rows=3,
        num_cols=2)

    expected_remapped_matrix = np.reshape(
        [33, init_val, init_val, init_val, 1, init_val], [3, 2])

    with self.cached_session():
      self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
Example #10
0
  def test_load_and_remap_with_init(self):
    """Tests the op's load and remap where there are missing entries."""
    init_val = 42
    remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=[self.bundle_file],
        old_tensor_name=self.old_tensor_name,
        row_remapping=[2, -1, 0],
        col_remapping=[1, -1],
        initializing_values=[init_val] * 4,
        num_rows=3,
        num_cols=2)

    expected_remapped_matrix = np.reshape(
        [33, init_val, init_val, init_val, 1, init_val], [3, 2])

    with self.cached_session():
      self.assertAllClose(expected_remapped_matrix,
                          self.evaluate(remapped_matrix))
Example #11
0
  def _test_loading_variable_with_max_rows(self, np_value, partitioner,
                                           max_rows_in_memory):
    """Helper function for various tests using max_rows_in_memory."""
    ops.reset_default_graph()
    old_tensor_name = 'matrix_to_load_and_remap'
    matrix = variable_scope.get_variable(
        old_tensor_name,
        dtype=dtypes.float32,
        initializer=constant_op.constant(np_value, dtype=dtypes.float32),
        partitioner=partitioner)

    with self.cached_session() as sess:
      ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt')
      save = saver.Saver([matrix])
      self.evaluate(variables.global_variables_initializer())
      save.save(sess, ckpt_path)
      num_rows, num_cols = np_value.shape

      # Tests loading the entire tensor (except reversed).
      remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
          ckpt_path=ckpt_path,
          old_tensor_name=old_tensor_name,
          # Simply reverses the rows of the matrix.
          row_remapping=list(range(num_rows - 1, -1, -1)),
          col_remapping=[],
          initializing_values=[],
          num_rows=num_rows,
          num_cols=num_cols,
          max_rows_in_memory=max_rows_in_memory)
      self.assertAllClose(np_value[::-1], self.evaluate(remapped_matrix))

      # Tests loading the tensor (except for the first and last rows), with
      # uninitialized values. Requires num_rows to be at least 3 since we're
      # skipping the first and last rows.
      self.assertGreater(num_rows, 2)
      prefix_rows = 2
      suffix_rows = 3
      remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
          ckpt_path=ckpt_path,
          old_tensor_name=old_tensor_name,
          # Reverses the rows of the matrix, then prepends and appends
          # uninitialized rows.
          row_remapping=([-1] * prefix_rows + list(range(1, num_rows - 1)) +
                         [-1] * suffix_rows),
          col_remapping=[],
          initializing_values=[42] * (prefix_rows + suffix_rows) * num_cols,
          num_rows=num_rows - 2 + prefix_rows + suffix_rows,
          num_cols=num_cols,
          max_rows_in_memory=max_rows_in_memory)
      self.assertAllClose(
          np.vstack([
              np.tile(42, [prefix_rows, num_cols]), np_value[1:-1],
              np.tile(42, [suffix_rows, num_cols])
          ]), self.evaluate(remapped_matrix))

      # Tests when everything is taken from initializing_values.
      new_rows = 7
      initializing_values = [42] * new_rows * num_cols
      remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
          ckpt_path=ckpt_path,
          old_tensor_name=old_tensor_name,
          # Nothing is loaded from the old tensor.
          row_remapping=[-1] * new_rows,
          col_remapping=[],
          initializing_values=initializing_values,
          num_rows=new_rows,
          num_cols=num_cols,
          max_rows_in_memory=max_rows_in_memory)
      self.assertAllClose(
          np.reshape(initializing_values, (new_rows, num_cols)),
          self.evaluate(remapped_matrix))
Example #12
0
  def _test_loading_variable_with_max_rows(self, np_value, partitioner,
                                           max_rows_in_memory):
    """Helper function for various tests using max_rows_in_memory."""
    ops.reset_default_graph()
    old_tensor_name = 'matrix_to_load_and_remap'
    matrix = variable_scope.get_variable(
        old_tensor_name,
        dtype=dtypes.float32,
        initializer=constant_op.constant(np_value, dtype=dtypes.float32),
        partitioner=partitioner)

    with self.cached_session() as sess:
      ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt')
      save = saver.Saver([matrix])
      variables.global_variables_initializer().run()
      save.save(sess, ckpt_path)
      num_rows, num_cols = np_value.shape

      # Tests loading the entire tensor (except reversed).
      remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
          ckpt_path=ckpt_path,
          old_tensor_name=old_tensor_name,
          # Simply reverses the rows of the matrix.
          row_remapping=list(range(num_rows - 1, -1, -1)),
          col_remapping=[],
          initializing_values=[],
          num_rows=num_rows,
          num_cols=num_cols,
          max_rows_in_memory=max_rows_in_memory)
      self.assertAllClose(np_value[::-1], remapped_matrix.eval())

      # Tests loading the tensor (except for the first and last rows), with
      # uninitialized values. Requires num_rows to be at least 3 since we're
      # skipping the first and last rows.
      self.assertGreater(num_rows, 2)
      prefix_rows = 2
      suffix_rows = 3
      remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
          ckpt_path=ckpt_path,
          old_tensor_name=old_tensor_name,
          # Reverses the rows of the matrix, then prepends and appends
          # uninitialized rows.
          row_remapping=([-1] * prefix_rows + list(range(1, num_rows - 1)) +
                         [-1] * suffix_rows),
          col_remapping=[],
          initializing_values=[42] * (prefix_rows + suffix_rows) * num_cols,
          num_rows=num_rows - 2 + prefix_rows + suffix_rows,
          num_cols=num_cols,
          max_rows_in_memory=max_rows_in_memory)
      self.assertAllClose(
          np.vstack([
              np.tile(42, [prefix_rows, num_cols]), np_value[1:-1],
              np.tile(42, [suffix_rows, num_cols])
          ]), remapped_matrix.eval())

      # Tests when everything is taken from initializing_values.
      new_rows = 7
      initializing_values = [42] * new_rows * num_cols
      remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
          ckpt_path=ckpt_path,
          old_tensor_name=old_tensor_name,
          # Nothing is loaded from the old tensor.
          row_remapping=[-1] * new_rows,
          col_remapping=[],
          initializing_values=initializing_values,
          num_rows=new_rows,
          num_cols=num_cols,
          max_rows_in_memory=max_rows_in_memory)
      self.assertAllClose(
          np.reshape(initializing_values, (new_rows, num_cols)),
          remapped_matrix.eval())
Example #13
0
def _load_and_remap_matrix(ckpt_path,
                           old_tensor_name,
                           new_row_vocab_offset,
                           num_rows_to_load,
                           new_col_vocab_size,
                           initializer,
                           old_row_vocab_size=-1,
                           old_row_vocab_file=None,
                           new_row_vocab_file=None,
                           old_col_vocab_file=None,
                           new_col_vocab_file=None,
                           num_row_oov_buckets=0,
                           num_col_oov_buckets=0,
                           max_rows_in_memory=-1):
  """Loads a 2-D (matrix) `Tensor` from checkpoint.

  Generates 1D-remappings for rows and columns using the
  `GenerateVocabRemapping` op, and initializes any anticipated values with the
  provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a
  matrix that loads existing values from the checkpoint, while filling out
  "missing" values with the newly initialized values. See
  contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped
  functionality (LoadAndRemapMatrix). This wrapper can be used to perform only
  row remapping or only col remapping. If only row remapping is desired,
  {new,old}_col_vocab_file should be `None`, and vice versa for column
  remapping.

  NOTE: This only supports div-partitioning the vocabulary on the 1st dimension
  (row axis) via `new_row_vocab_offset`.

  Args:
    ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
      from which the old matrix `Tensor` will be loaded.
    old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
    new_row_vocab_offset: A 0-indexed integer representing what line to
      start reading at in the new row vocabulary. Used for partitioned
      variables.
    num_rows_to_load: Number of rows to load for the new vocabulary (note: to
      support variable partitioning and partial loading, this does not need to
      be the same as the number of entries in `new_row_vocab_file`).
    new_col_vocab_size: Number of columns to load - should be the same as the
      number of entries in `new_col_vocab_file`, since we don't support
      partitioning along the column axis.
    initializer: Callable initializer function that accepts a 1-D tensor as the
      arg to specify the shape of the returned tensor. Used to initialize
      missing values.
    old_row_vocab_size: The number of entries to consider in the old vocabulary.
      With the default value of -1, the entire old row vocabulary file will be
      used.  Otherwise, only the first `old_row_vocab_size` entries will be
      considered for remapping.Must be smaller than the length of
      `old_row_vocab_file`.  NOTE: we do not provide an equivalent
      `old_col_vocab_size` for classes.
    old_row_vocab_file: A scalar `Tensor` of type `string` containing the
      path to the old row vocabulary file. Can be None, which represents no
      remapping on the row axis.
    new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
      to the new row vocabulary file. Can be None, which represents no remapping
      on the row axis - in which case, `new_row_vocab_offset` and
      `num_rows_to_load` work under the assumption that the new row vocab is the
      same as the old row vocab.
    old_col_vocab_file: A scalar `Tensor` of type `string` containing the
      path to the old column vocabulary file. Can be None, which represents no
      remapping on the column axis.
    new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
      to the new column vocabulary file. Can be None, which represents no
      remapping on the column axis - in which case, `new_col_vocab_size` works
      under the assumption that the new col vocab is the same as the old col
      vocab.
    num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
      to append. Must be >= 0.
    num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
      columns to append. Must be >= 0.
    max_rows_in_memory: `int` specifying the maximum number of rows to load from
      the checkpoint at once. If less than or equal to 0, the entire matrix will
      be loaded into memory. Setting this arg trades increased disk reads for
      lower memory usage.

  Returns:
    A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
    new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the
    specified tensor in the checkpoint, and any missing or OOV values
    initialized with the given `initializer`.

  Raises:
    ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0.
    ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is
      provided, while the other is not. Same for `old_col_vocab_file` and
      `new_col_vocab_file`.
    ValueError: If neither row vocabs or col vocabs are provided.
  """
  if num_row_oov_buckets < 0:
    raise ValueError("num_row_oov_buckets must be >= 0, but received %d" %
                     num_row_oov_buckets)
  if num_col_oov_buckets < 0:
    raise ValueError("num_col_oov_buckets must be >= 0, but received %d" %
                     num_col_oov_buckets)

  if bool(old_row_vocab_file) != bool(new_row_vocab_file):
    raise ValueError(
        "old_row_vocab_file and new_row_vocab_file must both be specified or "
        "left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'".
        format(old_row_vocab_file, new_row_vocab_file))
  if bool(old_col_vocab_file) != bool(new_col_vocab_file):
    raise ValueError(
        "old_col_vocab_file and new_col_vocab_file must both be specified or "
        "left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'".
        format(old_col_vocab_file, new_col_vocab_file))

  remap_rows = new_row_vocab_file and old_row_vocab_file
  remap_cols = new_col_vocab_file and old_col_vocab_file
  if not (remap_rows or remap_cols):
    raise ValueError(
        "Must provide either row or column vocab files. If no remapping is "
        "necessary, consider using `tf.contrib.framework.init_from_checkpoint` "
        "instead.")

  num_rows_present = num_rows_to_load
  if remap_rows:
    row_remapping, num_rows_present = (
        gen_checkpoint_ops.generate_vocab_remapping(
            new_vocab_file=new_row_vocab_file,
            old_vocab_file=old_row_vocab_file,
            new_vocab_offset=new_row_vocab_offset,
            num_new_vocab=num_rows_to_load,
            old_vocab_size=old_row_vocab_size))
  else:
    # Even when the rows are not being reordered, we still need to generate a
    # remapping to account for initializing partitioned Variables (when
    # new_row_vocab_offset is non-zero).
    row_remapping = math_ops.range(
        new_row_vocab_offset,
        new_row_vocab_offset + num_rows_to_load,
        dtype=dtypes.int64)

  col_remapping = []
  num_cols_present = new_col_vocab_size
  if remap_cols:
    col_remapping, num_cols_present = (
        gen_checkpoint_ops.generate_vocab_remapping(
            new_vocab_file=new_col_vocab_file,
            old_vocab_file=old_col_vocab_file,
            new_vocab_offset=0,  # Offset is unused for cols (no partitioning).
            num_new_vocab=new_col_vocab_size))

  init_vals = initializer([
      num_rows_to_load * new_col_vocab_size -
      num_rows_present * num_cols_present, 1
  ])
  return_tensor = gen_checkpoint_ops.load_and_remap_matrix(
      ckpt_path=ckpt_path,
      old_tensor_name=old_tensor_name,
      row_remapping=row_remapping,
      col_remapping=col_remapping,
      initializing_values=init_vals,
      num_rows=num_rows_to_load,
      num_cols=new_col_vocab_size,
      max_rows_in_memory=max_rows_in_memory)

  # Add OOV row(s) and column(s).
  if num_row_oov_buckets > 0:
    init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size])
    init_row_oov_val = ops.convert_to_tensor(init_row_oov_val)
    return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0)
  if num_col_oov_buckets > 0:
    # We need to add any row OOV to the new column shape.
    init_col_oov_val = initializer(
        [num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets])
    init_col_oov_val = ops.convert_to_tensor(init_col_oov_val)
    return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1)

  return return_tensor
Example #14
0
def _load_and_remap_matrix(ckpt_path,
                           old_tensor_name,
                           new_row_vocab_offset,
                           num_rows_to_load,
                           new_col_vocab_size,
                           initializer,
                           old_row_vocab_size=-1,
                           old_row_vocab_file=None,
                           new_row_vocab_file=None,
                           old_col_vocab_file=None,
                           new_col_vocab_file=None,
                           num_row_oov_buckets=0,
                           num_col_oov_buckets=0,
                           max_rows_in_memory=-1):
    """Loads a 2-D (matrix) `Tensor` from checkpoint.

  Generates 1D-remappings for rows and columns using the
  `GenerateVocabRemapping` op, and initializes any anticipated values with the
  provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a
  matrix that loads existing values from the checkpoint, while filling out
  "missing" values with the newly initialized values. See
  contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped
  functionality (LoadAndRemapMatrix). This wrapper can be used to perform only
  row remapping or only col remapping. If only row remapping is desired,
  {new,old}_col_vocab_file should be `None`, and vice versa for column
  remapping.

  NOTE: This only supports div-partitioning the vocabulary on the 1st dimension
  (row axis) via `new_row_vocab_offset`.

  Args:
    ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
      from which the old matrix `Tensor` will be loaded.
    old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
    new_row_vocab_offset: A 0-indexed integer representing what line to
      start reading at in the new row vocabulary. Used for partitioned
      variables.
    num_rows_to_load: Number of rows to load for the new vocabulary (note: to
      support variable partitioning and partial loading, this does not need to
      be the same as the number of entries in `new_row_vocab_file`).
    new_col_vocab_size: Number of columns to load - should be the same as the
      number of entries in `new_col_vocab_file`, since we don't support
      partitioning along the column axis.
    initializer: Callable initializer function that accepts a 1-D tensor as the
      arg to specify the shape of the returned tensor. Used to initialize
      missing values.
    old_row_vocab_size: The number of entries to consider in the old vocabulary.
      With the default value of -1, the entire old row vocabulary file will be
      used.  Otherwise, only the first `old_row_vocab_size` entries will be
      considered for remapping.Must be smaller than the length of
      `old_row_vocab_file`.  NOTE: we do not provide an equivalent
      `old_col_vocab_size` for classes.
    old_row_vocab_file: A scalar `Tensor` of type `string` containing the
      path to the old row vocabulary file. Can be None, which represents no
      remapping on the row axis.
    new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
      to the new row vocabulary file. Can be None, which represents no remapping
      on the row axis - in which case, `new_row_vocab_offset` and
      `num_rows_to_load` work under the assumption that the new row vocab is the
      same as the old row vocab.
    old_col_vocab_file: A scalar `Tensor` of type `string` containing the
      path to the old column vocabulary file. Can be None, which represents no
      remapping on the column axis.
    new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
      to the new column vocabulary file. Can be None, which represents no
      remapping on the column axis - in which case, `new_col_vocab_size` works
      under the assumption that the new col vocab is the same as the old col
      vocab.
    num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
      to append. Must be >= 0.
    num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
      columns to append. Must be >= 0.
    max_rows_in_memory: `int` specifying the maximum number of rows to load from
      the checkpoint at once. If less than or equal to 0, the entire matrix will
      be loaded into memory. Setting this arg trades increased disk reads for
      lower memory usage.

  Returns:
    A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
    new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the
    specified tensor in the checkpoint, and any missing or OOV values
    initialized with the given `initializer`.

  Raises:
    ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0.
    ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is
      provided, while the other is not. Same for `old_col_vocab_file` and
      `new_col_vocab_file`.
    ValueError: If neither row vocabs or col vocabs are provided.
  """
    if num_row_oov_buckets < 0:
        raise ValueError("num_row_oov_buckets must be >= 0, but received %d" %
                         num_row_oov_buckets)
    if num_col_oov_buckets < 0:
        raise ValueError("num_col_oov_buckets must be >= 0, but received %d" %
                         num_col_oov_buckets)

    if bool(old_row_vocab_file) != bool(new_row_vocab_file):
        raise ValueError(
            "old_row_vocab_file and new_row_vocab_file must both be specified or "
            "left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'"
            .format(old_row_vocab_file, new_row_vocab_file))
    if bool(old_col_vocab_file) != bool(new_col_vocab_file):
        raise ValueError(
            "old_col_vocab_file and new_col_vocab_file must both be specified or "
            "left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'"
            .format(old_col_vocab_file, new_col_vocab_file))

    remap_rows = new_row_vocab_file and old_row_vocab_file
    remap_cols = new_col_vocab_file and old_col_vocab_file
    if not (remap_rows or remap_cols):
        raise ValueError(
            "Must provide either row or column vocab files. If no remapping is "
            "necessary, consider using `tf.contrib.framework.init_from_checkpoint` "
            "instead.")

    num_rows_present = num_rows_to_load
    if remap_rows:
        row_remapping, num_rows_present = (
            gen_checkpoint_ops.generate_vocab_remapping(
                new_vocab_file=new_row_vocab_file,
                old_vocab_file=old_row_vocab_file,
                new_vocab_offset=new_row_vocab_offset,
                num_new_vocab=num_rows_to_load,
                old_vocab_size=old_row_vocab_size))
    else:
        # Even when the rows are not being reordered, we still need to generate a
        # remapping to account for initializing partitioned Variables (when
        # new_row_vocab_offset is non-zero).
        row_remapping = math_ops.range(new_row_vocab_offset,
                                       new_row_vocab_offset + num_rows_to_load,
                                       dtype=dtypes.int64)

    col_remapping = []
    num_cols_present = new_col_vocab_size
    if remap_cols:
        col_remapping, num_cols_present = (
            gen_checkpoint_ops.generate_vocab_remapping(
                new_vocab_file=new_col_vocab_file,
                old_vocab_file=old_col_vocab_file,
                new_vocab_offset=
                0,  # Offset is unused for cols (no partitioning).
                num_new_vocab=new_col_vocab_size))

    init_vals = initializer([
        num_rows_to_load * new_col_vocab_size -
        num_rows_present * num_cols_present, 1
    ])
    return_tensor = gen_checkpoint_ops.load_and_remap_matrix(
        ckpt_path=ckpt_path,
        old_tensor_name=old_tensor_name,
        row_remapping=row_remapping,
        col_remapping=col_remapping,
        initializing_values=init_vals,
        num_rows=num_rows_to_load,
        num_cols=new_col_vocab_size,
        max_rows_in_memory=max_rows_in_memory)

    # Add OOV row(s) and column(s).
    if num_row_oov_buckets > 0:
        init_row_oov_val = initializer(
            [num_row_oov_buckets, new_col_vocab_size])
        init_row_oov_val = ops.convert_to_tensor(init_row_oov_val)
        return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0)
    if num_col_oov_buckets > 0:
        # We need to add any row OOV to the new column shape.
        init_col_oov_val = initializer(
            [num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets])
        init_col_oov_val = ops.convert_to_tensor(init_col_oov_val)
        return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1)

    return return_tensor