def test_load_and_remap_invalid_remapping(self): """Tests that errors are raised when an ID maps to multiple new IDs. (This should usually not happen when using public APIs). """ invalid_remapping = [1, 0, 0, 0, 1, 2] # Invalid row remapping. remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=invalid_remapping, col_remapping=[], initializing_values=[], num_rows=len(invalid_remapping), num_cols=self.old_num_cols) with self.test_session(), self.assertRaises(errors.UnimplementedError): remapped_matrix.eval() # Invalid column remapping. remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=list(range(self.old_num_rows)), col_remapping=invalid_remapping, initializing_values=[], num_rows=self.old_num_rows, num_cols=len(invalid_remapping)) with self.test_session(), self.assertRaises(errors.UnimplementedError): remapped_matrix.eval()
def test_load_and_remap_incorrect_initializing_values(self): """Tests that errors are raised with incorrect number of init values.""" remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], col_remapping=[1, -1], # Too few initializing values - there should be 4. For some reason, # initializing_values must contain no element (instead of 3 or fewer) to # ensure that a seg fault would reliably occur if the check raising the # InvalidArgumentError were not present. initializing_values=[], num_rows=3, num_cols=2) with self.test_session(), self.assertRaises(errors.InvalidArgumentError): remapped_matrix.eval() remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], col_remapping=[1, -1], # Too many initializing values - there should be 4. initializing_values=[0] * 5, num_rows=3, num_cols=2) with self.test_session(), self.assertRaises(errors.InvalidArgumentError): remapped_matrix.eval()
def test_load_and_remap_incorrect_initializing_values(self): """Tests that errors are raised with incorrect number of init values.""" remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], col_remapping=[1, -1], # Too few initializing values - there should be 4. For some reason, # initializing_values must contain no element (instead of 3 or fewer) to # ensure that a seg fault would reliably occur if the check raising the # InvalidArgumentError were not present. initializing_values=[], num_rows=3, num_cols=2) with self.test_session(), self.assertRaises( errors.InvalidArgumentError): remapped_matrix.eval() remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], col_remapping=[1, -1], # Too many initializing values - there should be 4. initializing_values=[0] * 5, num_rows=3, num_cols=2) with self.test_session(), self.assertRaises( errors.InvalidArgumentError): remapped_matrix.eval()
def test_load_and_remap_invalid_remapping(self): """Tests that errors are raised when an ID maps to multiple new IDs. (This should usually not happen when using public APIs). """ invalid_remapping = [1, 0, 0, 0, 1, 2] # Invalid row remapping. remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=invalid_remapping, col_remapping=[], initializing_values=[], num_rows=len(invalid_remapping), num_cols=self.old_num_cols) with self.test_session(), self.assertRaises(errors.UnimplementedError): remapped_matrix.eval() # Invalid column remapping. remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=list(range(self.old_num_rows)), col_remapping=invalid_remapping, initializing_values=[], num_rows=self.old_num_rows, num_cols=len(invalid_remapping)) with self.test_session(), self.assertRaises(errors.UnimplementedError): remapped_matrix.eval()
def test_load_and_remap_no_missing(self): """Tests the op's load and remap where there are no missing entries.""" # No column remapping, new weight matrix has second row, then first row. row_remapping = [1, 0] remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, col_remapping=[], initializing_values=[], num_rows=2, num_cols=self.old_num_cols) with self.test_session(): self.assertAllClose(self.matrix_value[row_remapping], remapped_weight_matrix.eval()) # No row remapping, new weight matrix has third col, then first col. row_remapping = list(range(self.old_num_rows)) col_remapping = [2, 0] remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, col_remapping=col_remapping, initializing_values=[], num_rows=len(row_remapping), num_cols=len(col_remapping)) with self.test_session(): self.assertAllClose( self.matrix_value[row_remapping][:, col_remapping], remapped_weight_matrix.eval()) # Both row and column remappings. row_remapping = [1, 0, 4] col_remapping = [1, 15] remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, col_remapping=col_remapping, initializing_values=[], num_rows=len(row_remapping), num_cols=len(col_remapping)) with self.test_session(): self.assertAllClose( self.matrix_value[row_remapping][:, col_remapping], remapped_weight_matrix.eval())
def test_load_and_remap_no_missing(self): """Tests the op's load and remap where there are no missing entries.""" # No column remapping, new weight matrix has second row, then first row. row_remapping = [1, 0] remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, col_remapping=[], initializing_values=[], num_rows=2, num_cols=self.old_num_cols) with self.test_session(): self.assertAllClose(self.matrix_value[row_remapping], remapped_matrix.eval()) # No row remapping, new weight matrix has third col, then first col. row_remapping = list(range(self.old_num_rows)) col_remapping = [2, 0] remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, col_remapping=col_remapping, initializing_values=[], num_rows=len(row_remapping), num_cols=len(col_remapping)) with self.test_session(): self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping], remapped_matrix.eval()) # Both row and column remappings. row_remapping = [1, 0, 4] col_remapping = [1, 15] remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, col_remapping=col_remapping, initializing_values=[], num_rows=len(row_remapping), num_cols=len(col_remapping)) with self.test_session(): self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping], remapped_matrix.eval())
def test_load_and_remap_all_missing_rows(self): """Tests when all the rows are missing and need to be initialized.""" num_rows = 7 initializing_values = [42] * num_rows * self.old_num_cols remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[-1] * num_rows, col_remapping=[], initializing_values=initializing_values, num_rows=num_rows, num_cols=self.old_num_cols) with self.test_session(): self.assertAllClose( np.reshape(initializing_values, (num_rows, self.old_num_cols)), remapped_weight_matrix.eval())
def test_load_and_remap_all_missing_rows(self): """Tests when all the rows are missing and need to be initialized.""" num_rows = 7 initializing_values = [42] * num_rows * self.old_num_cols remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[-1] * num_rows, col_remapping=[], initializing_values=initializing_values, num_rows=num_rows, num_cols=self.old_num_cols) with self.test_session(): self.assertAllClose( np.reshape(initializing_values, (num_rows, self.old_num_cols)), remapped_matrix.eval())
def test_load_and_remap_duplicate_row_remapping(self): """Tests when an old row maps to multiple new rows. (This should usually not happen when using public APIs). """ row_remapping = [1, 0, 0, 0, 1, 2] remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, col_remapping=[], initializing_values=[], num_rows=len(row_remapping), num_cols=self.old_num_cols) with self.test_session(): self.assertAllClose(self.matrix_value[row_remapping], remapped_weight_matrix.eval())
def test_load_and_remap_with_init(self): """Tests the op's load and remap where there are missing entries.""" init_val = 42 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], col_remapping=[1, -1], initializing_values=[init_val] * 4, num_rows=3, num_cols=2) expected_remapped_matrix = np.reshape( [33, init_val, init_val, init_val, 1, init_val], [3, 2]) with self.test_session(): self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_duplicate_row_remapping(self): """Tests when an old row maps to multiple new rows. (This should usually not happen when using public APIs). """ row_remapping = [1, 0, 0, 0, 1, 2] remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=row_remapping, col_remapping=[], initializing_values=[], num_rows=len(row_remapping), num_cols=self.old_num_cols) with self.test_session(): self.assertAllClose(self.matrix_value[row_remapping], remapped_weight_matrix.eval())
def test_load_and_remap_with_init(self): """Tests the op's load and remap where there are missing entries.""" init_val = 42 remapped_weight_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=[self.bundle_file], old_tensor_name=self.old_tensor_name, row_remapping=[2, -1, 0], col_remapping=[1, -1], initializing_values=[init_val] * 4, num_rows=3, num_cols=2) expected_remapped_weight_matrix = np.reshape( [33, init_val, init_val, init_val, 1, init_val], [3, 2]) with self.test_session(): self.assertAllClose(expected_remapped_weight_matrix, remapped_weight_matrix.eval())
def _load_and_remap_matrix(ckpt_path, old_tensor_name, new_row_vocab_offset, num_rows_to_load, new_col_vocab_size, initializer, old_row_vocab_file=None, new_row_vocab_file=None, old_col_vocab_file=None, new_col_vocab_file=None, num_row_oov_buckets=0, num_col_oov_buckets=0): """Loads a 2-D (matrix) `Tensor` from checkpoint. Generates 1D-remappings for rows and columns using the `GenerateVocabRemapping` op, and initializes any anticipated values with the provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a matrix that loads existing values from the checkpoint, while filling out "missing" values with the newly initialized values. See contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped functionality (LoadAndRemapMatrix). This wrapper can be used to perform only row remapping or only col remapping. If only row remapping is desired, {new,old}_col_vocab_file should be `None`, and vice versa for column remapping. NOTE: This only supports div-partitioning the vocabulary on the 1st dimension (row axis) via `new_row_vocab_offset`. Args: ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) from which the old matrix `Tensor` will be loaded. old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. new_row_vocab_offset: A 0-indexed integer representing what line to start reading at in the new row vocabulary. Used for partitioned variables. num_rows_to_load: Number of rows to load for the new vocabulary (note: to support variable partitioning and partial loading, this does not need to be the same as the number of entries in `new_row_vocab_file`). new_col_vocab_size: Number of columns to load - should be the same as the number of entries in `new_col_vocab_file`, since we don't support partitioning along the column axis. initializer: Callable initializer function that accepts a 1-D tensor as the arg to specify the shape of the returned tensor. Used to initialize missing values. old_row_vocab_file: A scalar `Tensor` of type `string` containing the path to the old row vocabulary file. Can be None, which represents no remapping on the row axis. new_row_vocab_file: A scalar `Tensor` of type `string` containing the path to the new row vocabulary file. Can be None, which represents no remapping on the row axis - in which case, `new_row_vocab_offset` and `num_rows_to_load` work under the assumption that the new row vocab is the same as the old row vocab. old_col_vocab_file: A scalar `Tensor` of type `string` containing the path to the old column vocabulary file. Can be None, which represents no remapping on the column axis. new_col_vocab_file: A scalar `Tensor` of type `string` containing the path to the new column vocabulary file. Can be None, which represents no remapping on the column axis - in which case, `new_col_vocab_size` works under the assumption that the new col vocab is the same as the old col vocab. num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows to append. Must be >= 0. num_col_oov_buckets: `int` specifying the number of out-of-vocabulary columns to append. Must be >= 0. Returns: A Tensor of shape `[num_rows_to_load + num_row_oov_buckets, new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the specified tensor in the checkpoint, and any missing or OOV values initialized with the given `initializer`. Raises: ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0. ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is provided, while the other is not. Same for `old_col_vocab_file` and `new_col_vocab_file`. ValueError: If neither row vocabs or col vocabs are provided. """ if num_row_oov_buckets < 0: raise ValueError("num_row_oov_buckets must be >= 0, but received %d" % num_row_oov_buckets) if num_col_oov_buckets < 0: raise ValueError("num_col_oov_buckets must be >= 0, but received %d" % num_col_oov_buckets) if bool(old_row_vocab_file) != bool(new_row_vocab_file): raise ValueError( "old_row_vocab_file and new_row_vocab_file must both be specified or " "left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'". format(old_row_vocab_file, new_row_vocab_file)) if bool(old_col_vocab_file) != bool(new_col_vocab_file): raise ValueError( "old_col_vocab_file and new_col_vocab_file must both be specified or " "left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'". format(old_col_vocab_file, new_col_vocab_file)) remap_rows = new_row_vocab_file and old_row_vocab_file remap_cols = new_col_vocab_file and old_col_vocab_file if not (remap_rows or remap_cols): raise ValueError( "Must provide either row or column vocab files. If no remapping is " "necessary, consider using `tf.contrib.framework.init_from_checkpoint` " "instead.") num_rows_present = num_rows_to_load if remap_rows: row_remapping, num_rows_present = ( gen_checkpoint_ops.generate_vocab_remapping( new_vocab_file=new_row_vocab_file, old_vocab_file=old_row_vocab_file, new_vocab_offset=new_row_vocab_offset, num_new_vocab=num_rows_to_load)) else: # Even when the rows are not being reordered, we still need to generate a # remapping to account for initializing partitioned Variables (when # new_row_vocab_offset is non-zero). row_remapping = math_ops.range( new_row_vocab_offset, new_row_vocab_offset + num_rows_to_load, dtype=dtypes.int64) col_remapping = [] num_cols_present = new_col_vocab_size if remap_cols: col_remapping, num_cols_present = ( gen_checkpoint_ops.generate_vocab_remapping( new_vocab_file=new_col_vocab_file, old_vocab_file=old_col_vocab_file, new_vocab_offset=0, # Offset is unused for cols (no partitioning). num_new_vocab=new_col_vocab_size)) init_vals = initializer([ num_rows_to_load * new_col_vocab_size - num_rows_present * num_cols_present, 1 ]) return_tensor = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, row_remapping=row_remapping, col_remapping=col_remapping, initializing_values=init_vals, num_rows=num_rows_to_load, num_cols=new_col_vocab_size) # Add OOV row(s) and column(s). if num_row_oov_buckets > 0: init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size]) init_row_oov_val = ops.convert_to_tensor(init_row_oov_val) return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0) if num_col_oov_buckets > 0: # We need to add any row OOV to the new column shape. init_col_oov_val = initializer( [num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets]) init_col_oov_val = ops.convert_to_tensor(init_col_oov_val) return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1) return return_tensor
def _test_loading_variable_with_max_rows(self, np_value, partitioner, max_rows_in_memory): """Helper function for various tests using max_rows_in_memory.""" ops.reset_default_graph() old_tensor_name = 'matrix_to_load_and_remap' matrix = variable_scope.get_variable( old_tensor_name, dtype=dtypes.float32, initializer=constant_op.constant(np_value, dtype=dtypes.float32), partitioner=partitioner) with self.test_session() as sess: ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt') save = saver.Saver([matrix]) variables.global_variables_initializer().run() save.save(sess, ckpt_path) num_rows, num_cols = np_value.shape # Tests loading the entire tensor (except reversed). remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, # Simply reverses the rows of the matrix. row_remapping=list(range(num_rows - 1, -1, -1)), col_remapping=[], initializing_values=[], num_rows=num_rows, num_cols=num_cols, max_rows_in_memory=max_rows_in_memory) self.assertAllClose(np_value[::-1], remapped_matrix.eval()) # Tests loading the tensor (except for the first and last rows), with # uninitialized values. Requires num_rows to be at least 3 since we're # skipping the first and last rows. self.assertGreater(num_rows, 2) prefix_rows = 2 suffix_rows = 3 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, # Reverses the rows of the matrix, then prepends and appends # uninitialized rows. row_remapping=([-1] * prefix_rows + list(range(1, num_rows - 1)) + [-1] * suffix_rows), col_remapping=[], initializing_values=[42] * (prefix_rows + suffix_rows) * num_cols, num_rows=num_rows - 2 + prefix_rows + suffix_rows, num_cols=num_cols, max_rows_in_memory=max_rows_in_memory) self.assertAllClose( np.vstack([ np.tile(42, [prefix_rows, num_cols]), np_value[1:-1], np.tile(42, [suffix_rows, num_cols]) ]), remapped_matrix.eval()) # Tests when everything is taken from initializing_values. new_rows = 7 initializing_values = [42] * new_rows * num_cols remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, # Nothing is loaded from the old tensor. row_remapping=[-1] * new_rows, col_remapping=[], initializing_values=initializing_values, num_rows=new_rows, num_cols=num_cols, max_rows_in_memory=max_rows_in_memory) self.assertAllClose( np.reshape(initializing_values, (new_rows, num_cols)), remapped_matrix.eval())
def _test_loading_variable_with_max_rows(self, np_value, partitioner, max_rows_in_memory): """Helper function for various tests using max_rows_in_memory.""" ops.reset_default_graph() old_tensor_name = 'matrix_to_load_and_remap' matrix = variable_scope.get_variable(old_tensor_name, dtype=dtypes.float32, initializer=constant_op.constant( np_value, dtype=dtypes.float32), partitioner=partitioner) with self.test_session() as sess: ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt') save = saver.Saver([matrix]) variables.global_variables_initializer().run() save.save(sess, ckpt_path) num_rows, num_cols = np_value.shape # Tests loading the entire tensor (except reversed). remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, # Simply reverses the rows of the matrix. row_remapping=list(range(num_rows - 1, -1, -1)), col_remapping=[], initializing_values=[], num_rows=num_rows, num_cols=num_cols, max_rows_in_memory=max_rows_in_memory) self.assertAllClose(np_value[::-1], remapped_matrix.eval()) # Tests loading the tensor (except for the first and last rows), with # uninitialized values. Requires num_rows to be at least 3 since we're # skipping the first and last rows. self.assertGreater(num_rows, 2) prefix_rows = 2 suffix_rows = 3 remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, # Reverses the rows of the matrix, then prepends and appends # uninitialized rows. row_remapping=([-1] * prefix_rows + list(range(1, num_rows - 1)) + [-1] * suffix_rows), col_remapping=[], initializing_values=[42] * (prefix_rows + suffix_rows) * num_cols, num_rows=num_rows - 2 + prefix_rows + suffix_rows, num_cols=num_cols, max_rows_in_memory=max_rows_in_memory) self.assertAllClose( np.vstack([ np.tile(42, [prefix_rows, num_cols]), np_value[1:-1], np.tile(42, [suffix_rows, num_cols]) ]), remapped_matrix.eval()) # Tests when everything is taken from initializing_values. new_rows = 7 initializing_values = [42] * new_rows * num_cols remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix( ckpt_path=ckpt_path, old_tensor_name=old_tensor_name, # Nothing is loaded from the old tensor. row_remapping=[-1] * new_rows, col_remapping=[], initializing_values=initializing_values, num_rows=new_rows, num_cols=num_cols, max_rows_in_memory=max_rows_in_memory) self.assertAllClose( np.reshape(initializing_values, (new_rows, num_cols)), remapped_matrix.eval())