def test_update_raises_value_error_if_variable_struct_not_match( self) -> None: # Prepare some data in the Reverb server. self._push_nested_data() variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) with self.assertRaises(ValueError): variable_container.update(tf.Variable(1))
def test_push(self) -> None: # Prepare nested variables to push into the server. variables = _create_nested_variable() # Push the input to the server. variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) variable_container.push(variables) # Check the content of the server. self._assert_nested_variable_in_server()
def test_push_under_distribute_strategy( self, strategy: tf.distribute.Strategy) -> None: # Prepare nested variables under strategy scope to push into the server. with strategy.scope(): variables = _create_nested_variable() logging.info('Variables: %s', variables) # Push the input to the server. variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) variable_container.push(variables) # Check the content of the server. self._assert_nested_variable_in_server()
def test_push_raises_error_if_variable_type_is_wrong(self) -> None: variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) # The first element has a type `tf.int64` in the signature, but here we # declare `tf.int32`. variables_with_wrong_type = (tf.Variable(-1, dtype=tf.int32, shape=()), { 'var1': (tf.Variable([0, 0], dtype=tf.float64, shape=(2, )), ), 'var2': tf.Variable([[0], [0]], dtype=tf.int32, shape=(2, 1)) }) with self.assertRaises(tf.errors.InvalidArgumentError): variable_container.push(variables_with_wrong_type)
def test_update(self) -> None: # Prepare some data in the Reverb server. self._push_nested_data() # Get the values from the server. variables = (tf.Variable(-1, dtype=tf.int64, shape=()), { 'var1': (tf.Variable([0, 0], dtype=tf.float64, shape=(2, )), ), 'var2': tf.Variable([[0], [0]], dtype=tf.int32, shape=(2, 1)) }) # Update variables based on value pulled from the server. variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) variable_container.update(variables) # Check the values of the `variables`. self._assert_nested_variable_updated(variables)
def test_push_with_not_exact_sequence_type_matching(self) -> None: # The second element (i.e the value of `var1`) was in a tuple in the # original signature, here we place it into a list. variables = (tf.Variable(0, dtype=tf.int64, shape=()), { 'var1': [tf.Variable([1, 1], dtype=tf.float64, shape=(2, ))], 'var2': tf.Variable([[2], [3]], dtype=tf.int32, shape=(2, 1)) }) # Sequence type check is turned off by default allowing sequence type # differences in the signature. This is required to be able work with # policies loaded from file which often change tuple to e.g. `ListWrapper`. variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) variable_container.push(variables) # Check the content of the server. self._assert_nested_variable_in_server()
def test_update_raises_value_error_if_variable_type_is_wrong(self) -> None: # Prepare some data in the Reverb server. self._push_nested_data() variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) # The first element has a type `tf.int64` in the signature, but here we # declare `tf.int32`. variables_with_wrong_type = (tf.Variable(-1, dtype=tf.int32, shape=()), { 'var1': (tf.Variable([0, 0], dtype=tf.float64, shape=(2, )), ), 'var2': tf.Variable([[0], [0]], dtype=tf.int32, shape=(2, 1)) }) with self.assertRaises(ValueError): variable_container.update(variables_with_wrong_type)
def test_init_raises_value_error_if_max_size_is_different_than_one(self): server, server_address = _create_server(max_size=2) with self.assertRaises(ValueError): reverb_variable_container.ReverbVariableContainer(server_address) server.stop()
def test_init_raises_type_error_if_no_signature_of_a_table(self): server, server_address = _create_server(signature=None) with self.assertRaises(TypeError): reverb_variable_container.ReverbVariableContainer(server_address) server.stop()
def test_init_raises_key_error_if_undefined_table_passed(self): server, server_address = _create_server(table='no_variables_table') with self.assertRaises(KeyError): reverb_variable_container.ReverbVariableContainer(server_address) server.stop()
def test_pull_raises_key_error_on_unknown_table(self) -> None: variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) with self.assertRaises(KeyError): variable_container.pull('unknown_table')
def test_push_raises_error_if_variable_struct_not_match(self) -> None: variable_container = reverb_variable_container.ReverbVariableContainer( self._server_address) with self.assertRaises(tf.errors.InvalidArgumentError): variable_container.push(tf.Variable(1))