def test_state_with_new_model_weights(self):
    trainable = [np.array([1.0, 2.0]), np.array([[1.0]])]
    non_trainable = [np.array(1)]
    state = anonymous_tuple.from_container(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=trainable, non_trainable=non_trainable),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0)),
        recursive=True)

    new_state = optimizer_utils.state_with_new_model_weights(
        state,
        trainable_weights=[np.array([3.0, 3.0]),
                           np.array([[3.0]])],
        non_trainable_weights=[np.array(3)])
    self.assertAllClose(
        new_state.model.trainable,
        [np.array([3.0, 3.0]), np.array([[3.0]])])
    self.assertAllClose(new_state.model.non_trainable, [3])

    with self.assertRaisesRegex(TypeError, 'tensor type'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([[3]])],
          non_trainable_weights=[np.array(3.0)])

    with self.assertRaisesRegex(TypeError, 'tensor type'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([3.0])],
          non_trainable_weights=[np.array(3)])

    with self.assertRaisesRegex(TypeError, 'different lengths'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0])],
          non_trainable_weights=[np.array(3)])

    with self.assertRaisesRegex(TypeError, 'cannot be handled'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights={'a': np.array([3.0, 3.0])},
          non_trainable_weights=[np.array(3)])
Beispiel #2
0
    def test_state_with_new_model_weights_failure(self, new_trainable,
                                                  new_non_trainable,
                                                  expected_err_msg):
        trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)]
        non_trainable = [np.array(1), b'bytes type', 5, 2.0]
        state = optimizer_utils.ServerState(
            model=model_utils.ModelWeights(trainable=trainable,
                                           non_trainable=non_trainable),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0))

        new_trainable = trainable if new_trainable is None else new_trainable
        non_trainable = non_trainable if new_non_trainable is None else non_trainable

        with self.assertRaisesRegex(TypeError, expected_err_msg):
            optimizer_utils.state_with_new_model_weights(
                state,
                trainable_weights=new_trainable,
                non_trainable_weights=new_non_trainable)
Beispiel #3
0
    def test_state_with_model_weights_success(self):
        trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)]
        non_trainable = [np.array(1), b'bytes type', 5, 2.0]

        new_trainable = [np.array([3.0, 3.0]), np.array([[3.0]]), np.int64(4)]
        new_non_trainable = [np.array(3), b'bytes check', 6, 3.0]

        state = optimizer_utils.ServerState(
            model=model_utils.ModelWeights(trainable=trainable,
                                           non_trainable=non_trainable),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0))

        new_state = optimizer_utils.state_with_new_model_weights(
            state,
            trainable_weights=new_trainable,
            non_trainable_weights=new_non_trainable)
        self.assertAllClose(new_state.model.trainable, new_trainable)
        self.assertEqual(new_state.model.non_trainable, new_non_trainable)
  def test_state_with_new_model_weights(self):
    trainable = [('b', np.array([1.0, 2.0])), ('a', np.array([[1.0]]))]
    non_trainable = [('c', np.array(1))]
    state = anonymous_tuple.from_container(
        optimizer_utils.ServerState(
            model=model_utils.ModelWeights(
                trainable=collections.OrderedDict(trainable),
                non_trainable=collections.OrderedDict(non_trainable)),
            optimizer_state=[],
            delta_aggregate_state=tf.constant(0),
            model_broadcast_state=tf.constant(0)),
        recursive=True)

    new_state = optimizer_utils.state_with_new_model_weights(
        state,
        trainable_weights=[np.array([3.0, 3.0]),
                           np.array([[3.0]])],
        non_trainable_weights=[np.array(3)])
    self.assertEqual(list(new_state.model.trainable.keys()), ['b', 'a'])
    self.assertEqual(list(new_state.model.non_trainable.keys()), ['c'])
    self.assertAllClose(new_state.model.trainable['b'], [3.0, 3.0])
    self.assertAllClose(new_state.model.trainable['a'], [[3.0]])
    self.assertAllClose(new_state.model.non_trainable['c'], 3)

    with self.assertRaisesRegexp(ValueError, 'dtype'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([[3]])],
          non_trainable_weights=[np.array(3.0)])

    with self.assertRaisesRegexp(ValueError, 'shape'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0]),
                             np.array([3.0])],
          non_trainable_weights=[np.array(3)])

    with self.assertRaisesRegexp(ValueError, 'Lengths differ'):
      optimizer_utils.state_with_new_model_weights(
          state,
          trainable_weights=[np.array([3.0, 3.0])],
          non_trainable_weights=[np.array(3)])