def test_state_with_new_model_weights(self): trainable = [np.array([1.0, 2.0]), np.array([[1.0]])] non_trainable = [np.array(1)] state = anonymous_tuple.from_container( optimizer_utils.ServerState( model=model_utils.ModelWeights( trainable=trainable, non_trainable=non_trainable), optimizer_state=[], delta_aggregate_state=tf.constant(0), model_broadcast_state=tf.constant(0)), recursive=True) new_state = optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([[3.0]])], non_trainable_weights=[np.array(3)]) self.assertAllClose( new_state.model.trainable, [np.array([3.0, 3.0]), np.array([[3.0]])]) self.assertAllClose(new_state.model.non_trainable, [3]) with self.assertRaisesRegex(TypeError, 'tensor type'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([[3]])], non_trainable_weights=[np.array(3.0)]) with self.assertRaisesRegex(TypeError, 'tensor type'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([3.0])], non_trainable_weights=[np.array(3)]) with self.assertRaisesRegex(TypeError, 'different lengths'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0])], non_trainable_weights=[np.array(3)]) with self.assertRaisesRegex(TypeError, 'cannot be handled'): optimizer_utils.state_with_new_model_weights( state, trainable_weights={'a': np.array([3.0, 3.0])}, non_trainable_weights=[np.array(3)])
def test_state_with_new_model_weights_failure(self, new_trainable, new_non_trainable, expected_err_msg): trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)] non_trainable = [np.array(1), b'bytes type', 5, 2.0] state = optimizer_utils.ServerState( model=model_utils.ModelWeights(trainable=trainable, non_trainable=non_trainable), optimizer_state=[], delta_aggregate_state=tf.constant(0), model_broadcast_state=tf.constant(0)) new_trainable = trainable if new_trainable is None else new_trainable non_trainable = non_trainable if new_non_trainable is None else non_trainable with self.assertRaisesRegex(TypeError, expected_err_msg): optimizer_utils.state_with_new_model_weights( state, trainable_weights=new_trainable, non_trainable_weights=new_non_trainable)
def test_state_with_model_weights_success(self): trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)] non_trainable = [np.array(1), b'bytes type', 5, 2.0] new_trainable = [np.array([3.0, 3.0]), np.array([[3.0]]), np.int64(4)] new_non_trainable = [np.array(3), b'bytes check', 6, 3.0] state = optimizer_utils.ServerState( model=model_utils.ModelWeights(trainable=trainable, non_trainable=non_trainable), optimizer_state=[], delta_aggregate_state=tf.constant(0), model_broadcast_state=tf.constant(0)) new_state = optimizer_utils.state_with_new_model_weights( state, trainable_weights=new_trainable, non_trainable_weights=new_non_trainable) self.assertAllClose(new_state.model.trainable, new_trainable) self.assertEqual(new_state.model.non_trainable, new_non_trainable)
def test_state_with_new_model_weights(self): trainable = [('b', np.array([1.0, 2.0])), ('a', np.array([[1.0]]))] non_trainable = [('c', np.array(1))] state = anonymous_tuple.from_container( optimizer_utils.ServerState( model=model_utils.ModelWeights( trainable=collections.OrderedDict(trainable), non_trainable=collections.OrderedDict(non_trainable)), optimizer_state=[], delta_aggregate_state=tf.constant(0), model_broadcast_state=tf.constant(0)), recursive=True) new_state = optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([[3.0]])], non_trainable_weights=[np.array(3)]) self.assertEqual(list(new_state.model.trainable.keys()), ['b', 'a']) self.assertEqual(list(new_state.model.non_trainable.keys()), ['c']) self.assertAllClose(new_state.model.trainable['b'], [3.0, 3.0]) self.assertAllClose(new_state.model.trainable['a'], [[3.0]]) self.assertAllClose(new_state.model.non_trainable['c'], 3) with self.assertRaisesRegexp(ValueError, 'dtype'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([[3]])], non_trainable_weights=[np.array(3.0)]) with self.assertRaisesRegexp(ValueError, 'shape'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0]), np.array([3.0])], non_trainable_weights=[np.array(3)]) with self.assertRaisesRegexp(ValueError, 'Lengths differ'): optimizer_utils.state_with_new_model_weights( state, trainable_weights=[np.array([3.0, 3.0])], non_trainable_weights=[np.array(3)])