Beispiel #1
0
    def test_update_raises_value_error_if_variable_struct_not_match(
            self) -> None:
        # Prepare some data in the Reverb server.
        self._push_nested_data()

        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        with self.assertRaises(ValueError):
            variable_container.update(tf.Variable(1))
Beispiel #2
0
    def test_push(self) -> None:
        # Prepare nested variables to push into the server.
        variables = _create_nested_variable()

        # Push the input to the server.
        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        variable_container.push(variables)

        # Check the content of the server.
        self._assert_nested_variable_in_server()
Beispiel #3
0
    def test_push_under_distribute_strategy(
            self, strategy: tf.distribute.Strategy) -> None:
        # Prepare nested variables under strategy scope to push into the server.
        with strategy.scope():
            variables = _create_nested_variable()
        logging.info('Variables: %s', variables)

        # Push the input to the server.
        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        variable_container.push(variables)

        # Check the content of the server.
        self._assert_nested_variable_in_server()
Beispiel #4
0
 def test_push_raises_error_if_variable_type_is_wrong(self) -> None:
     variable_container = reverb_variable_container.ReverbVariableContainer(
         self._server_address)
     # The first element has a type `tf.int64` in the signature, but here we
     # declare `tf.int32`.
     variables_with_wrong_type = (tf.Variable(-1, dtype=tf.int32, shape=()),
                                  {
                                      'var1': (tf.Variable([0, 0],
                                                           dtype=tf.float64,
                                                           shape=(2, )), ),
                                      'var2':
                                      tf.Variable([[0], [0]],
                                                  dtype=tf.int32,
                                                  shape=(2, 1))
                                  })
     with self.assertRaises(tf.errors.InvalidArgumentError):
         variable_container.push(variables_with_wrong_type)
Beispiel #5
0
    def test_update(self) -> None:
        # Prepare some data in the Reverb server.
        self._push_nested_data()

        # Get the values from the server.
        variables = (tf.Variable(-1, dtype=tf.int64, shape=()), {
            'var1': (tf.Variable([0, 0], dtype=tf.float64, shape=(2, )), ),
            'var2':
            tf.Variable([[0], [0]], dtype=tf.int32, shape=(2, 1))
        })

        # Update variables based on value pulled from the server.
        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        variable_container.update(variables)

        # Check the values of the `variables`.
        self._assert_nested_variable_updated(variables)
Beispiel #6
0
    def test_push_with_not_exact_sequence_type_matching(self) -> None:
        # The second element (i.e the value of `var1`) was in a tuple in the
        # original signature, here we place it into a list.
        variables = (tf.Variable(0, dtype=tf.int64, shape=()), {
            'var1': [tf.Variable([1, 1], dtype=tf.float64, shape=(2, ))],
            'var2':
            tf.Variable([[2], [3]], dtype=tf.int32, shape=(2, 1))
        })

        # Sequence type check is turned off by default allowing sequence type
        # differences in the signature. This is required to be able work with
        # policies loaded from file which often change tuple to e.g. `ListWrapper`.
        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        variable_container.push(variables)

        # Check the content of the server.
        self._assert_nested_variable_in_server()
Beispiel #7
0
    def test_update_raises_value_error_if_variable_type_is_wrong(self) -> None:
        # Prepare some data in the Reverb server.
        self._push_nested_data()

        variable_container = reverb_variable_container.ReverbVariableContainer(
            self._server_address)
        # The first element has a type `tf.int64` in the signature, but here we
        # declare `tf.int32`.
        variables_with_wrong_type = (tf.Variable(-1, dtype=tf.int32, shape=()),
                                     {
                                         'var1': (tf.Variable([0, 0],
                                                              dtype=tf.float64,
                                                              shape=(2, )), ),
                                         'var2':
                                         tf.Variable([[0], [0]],
                                                     dtype=tf.int32,
                                                     shape=(2, 1))
                                     })
        with self.assertRaises(ValueError):
            variable_container.update(variables_with_wrong_type)
Beispiel #8
0
 def test_init_raises_value_error_if_max_size_is_different_than_one(self):
     server, server_address = _create_server(max_size=2)
     with self.assertRaises(ValueError):
         reverb_variable_container.ReverbVariableContainer(server_address)
     server.stop()
Beispiel #9
0
 def test_init_raises_type_error_if_no_signature_of_a_table(self):
     server, server_address = _create_server(signature=None)
     with self.assertRaises(TypeError):
         reverb_variable_container.ReverbVariableContainer(server_address)
     server.stop()
Beispiel #10
0
 def test_init_raises_key_error_if_undefined_table_passed(self):
     server, server_address = _create_server(table='no_variables_table')
     with self.assertRaises(KeyError):
         reverb_variable_container.ReverbVariableContainer(server_address)
     server.stop()
Beispiel #11
0
 def test_pull_raises_key_error_on_unknown_table(self) -> None:
     variable_container = reverb_variable_container.ReverbVariableContainer(
         self._server_address)
     with self.assertRaises(KeyError):
         variable_container.pull('unknown_table')
Beispiel #12
0
 def test_push_raises_error_if_variable_struct_not_match(self) -> None:
     variable_container = reverb_variable_container.ReverbVariableContainer(
         self._server_address)
     with self.assertRaises(tf.errors.InvalidArgumentError):
         variable_container.push(tf.Variable(1))