def test_server_eager_mode(self, optimizer_fn, updated_val,
                               num_optimizer_vars):
        model_fn = lambda: model_examples.TrainableLinearRegression(feature_dim
                                                                    =2)

        server_state = optimizer_utils.server_init(model_fn, optimizer_fn, (),
                                                   ())
        train_vars = server_state.model.trainable
        self.assertAllClose(train_vars['a'].numpy(), np.array([[0.0], [0.0]]))
        self.assertEqual(train_vars['b'].numpy(), 0.0)
        self.assertEqual(server_state.model.non_trainable['c'].numpy(), 0.0)
        self.assertLen(server_state.optimizer_state, num_optimizer_vars)
        weights_delta = tensor_utils.to_odict({
            'a': tf.constant([[1.0], [0.0]]),
            'b': tf.constant(1.0)
        })
        server_state = optimizer_utils.server_update_model(
            server_state, weights_delta, model_fn, optimizer_fn)

        train_vars = server_state.model.trainable
        # For SGD: learning_Rate=0.1, update=[1.0, 0.0], initial model=[0.0, 0.0],
        # so updated_val=0.1
        self.assertAllClose(train_vars['a'].numpy(), [[updated_val], [0.0]])
        self.assertAllClose(train_vars['b'].numpy(), updated_val)
        self.assertEqual(server_state.model.non_trainable['c'].numpy(), 0.0)
    def test_server_eager_mode(self, optimizer_fn, updated_val,
                               num_optimizer_vars):
        model_fn = lambda: model_examples.TrainableLinearRegression(feature_dim
                                                                    =2)

        server_state = optimizer_utils.server_init(model_fn, optimizer_fn, (),
                                                   ())
        model_vars = self.evaluate(server_state.model)
        train_vars = model_vars.trainable
        self.assertLen(train_vars, 2)
        self.assertAllClose(train_vars['a'], [[0.0], [0.0]])
        self.assertEqual(train_vars['b'], 0.0)
        self.assertEqual(model_vars.non_trainable, {'c': 0.0})
        self.assertLen(server_state.optimizer_state, num_optimizer_vars)
        weights_delta = collections.OrderedDict([
            ('a', tf.constant([[1.0], [0.0]])),
            ('b', tf.constant(1.0)),
        ])
        server_state = optimizer_utils.server_update_model(
            server_state, weights_delta, model_fn, optimizer_fn)

        model_vars = self.evaluate(server_state.model)
        train_vars = model_vars.trainable
        # For SGD: learning_Rate=0.1, update=[1.0, 0.0], initial model=[0.0, 0.0],
        # so updated_val=0.1
        self.assertLen(train_vars, 2)
        self.assertAllClose(train_vars['a'], [[updated_val], [0.0]])
        self.assertAllClose(train_vars['b'], updated_val)
        self.assertEqual(model_vars.non_trainable, {'c': 0.0})
    def test_server_graph_mode(self):
        optimizer_fn = lambda: gradient_descent.SGD(learning_rate=0.1)
        model_fn = lambda: model_examples.TrainableLinearRegression(feature_dim
                                                                    =2)

        # Explicitly entering a graph as a default enables graph-mode.
        with tf.Graph().as_default() as g:
            server_state_op = optimizer_utils.server_init(
                model_fn, optimizer_fn, (), ())
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            g.finalize()
            with self.session() as sess:
                sess.run(init_op)
                server_state = sess.run(server_state_op)
        train_vars = server_state.model.trainable
        self.assertAllClose(train_vars['a'], [[0.0], [0.0]])
        self.assertEqual(train_vars['b'], 0.0)
        self.assertEqual(server_state.model.non_trainable['c'], 0.0)
        self.assertEqual(server_state.optimizer_state, [0.0])

        with tf.Graph().as_default() as g:
            # N.B. Must use a fresh graph so variable names are the same.
            weights_delta = tensor_utils.to_odict({
                'a':
                tf.constant([[1.0], [0.0]]),
                'b':
                tf.constant(2.0)
            })
            update_op = optimizer_utils.server_update_model(
                server_state, weights_delta, model_fn, optimizer_fn)
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            g.finalize()
            with self.session() as sess:
                sess.run(init_op)
                server_state = sess.run(update_op)
        train_vars = server_state.model.trainable
        # learning_Rate=0.1, update is [1.0, 0.0], initial model is [0.0, 0.0].
        self.assertAllClose(train_vars['a'], [[0.1], [0.0]])
        self.assertAllClose(train_vars['b'], 0.2)
        self.assertEqual(server_state.model.non_trainable['c'], 0.0)
Beispiel #4
0
    def test_server_eager_mode(self, optimizer_fn, updated_val,
                               num_optimizer_vars):
        model_fn = lambda: model_examples.LinearRegression(feature_dim=2)

        server_state = optimizer_utils.server_init(model_fn, optimizer_fn, (),
                                                   ())
        model_vars = self.evaluate(server_state.model)
        train_vars = model_vars.trainable
        self.assertLen(train_vars, 2)
        self.assertAllClose(train_vars, [np.zeros((2, 1)), 0.0])
        self.assertAllClose(model_vars.non_trainable, [0.0])
        self.assertLen(server_state.optimizer_state, num_optimizer_vars)
        weights_delta = [tf.constant([[1.0], [0.0]]), tf.constant(1.0)]
        server_state = optimizer_utils.server_update_model(
            server_state, weights_delta, model_fn, optimizer_fn)

        model_vars = self.evaluate(server_state.model)
        train_vars = model_vars.trainable
        # For SGD: learning_Rate=0.1, update=[1.0, 0.0], initial model=[0.0, 0.0],
        # so updated_val=0.1
        self.assertLen(train_vars, 2)
        self.assertAllClose(train_vars, [[[updated_val], [0.0]], updated_val])
        self.assertAllClose(model_vars.non_trainable, [0.0])