def testCreateOnecloneWithPS(self):
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(0)
            tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.float32)

            model_fn = BatchNormClassifier
            model_args = (tf_inputs, tf_labels)
            deploy_config = model_deploy.DeploymentConfig(num_clones=1,
                                                          num_ps_tasks=1)

            self.assertEqual(framework.get_variables(), [])
            clones = model_deploy.create_clones(deploy_config, model_fn,
                                                model_args)
            self.assertEqual(len(framework.get_variables()), 5)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            self.assertEqual(len(update_ops), 2)

            optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            self.assertEqual(len(grads_and_vars),
                             len(tf.trainable_variables()))
            self.assertEqual(total_loss.op.name, 'total_loss')
            for g, v in grads_and_vars:
                self.assertDeviceEqual(g.device, '/job:worker/device:GPU:0')
                self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
    def testCreateMulticloneWithPS(self):
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(0)
            tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.float32)

            model_fn = BatchNormClassifier
            clone_args = (tf_inputs, tf_labels)
            deploy_config = model_deploy.DeploymentConfig(num_clones=2,
                                                          num_ps_tasks=2)

            self.assertEqual(framework.get_variables(), [])
            clones = model_deploy.create_clones(deploy_config, model_fn,
                                                clone_args)
            self.assertEqual(len(framework.get_variables()), 5)
            for i, v in enumerate(framework.get_variables()):
                t = i % 2
                self.assertDeviceEqual(v.device,
                                       '/job:ps/task:%d/device:CPU:0' % t)
                self.assertDeviceEqual(v.device, v.value().device)
            self.assertEqual(len(clones), 2)
            for i, clone in enumerate(clones):
                self.assertEqual(
                    clone.outputs.op.name,
                    'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
                self.assertEqual(clone.scope, 'clone_%d/' % i)
                self.assertDeviceEqual(clone.device,
                                       '/job:worker/device:GPU:%d' % i)
    def testCreateMulticlone(self):
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(0)
            tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.float32)

            model_fn = BatchNormClassifier
            clone_args = (tf_inputs, tf_labels)
            num_clones = 4
            deploy_config = model_deploy.DeploymentConfig(
                num_clones=num_clones)

            self.assertEqual(framework.get_variables(), [])
            clones = model_deploy.create_clones(deploy_config, model_fn,
                                                clone_args)
            self.assertEqual(len(framework.get_variables()), 5)
            for v in framework.get_variables():
                self.assertDeviceEqual(v.device, 'CPU:0')
                self.assertDeviceEqual(v.value().device, 'CPU:0')
            self.assertEqual(len(clones), num_clones)
            for i, clone in enumerate(clones):
                self.assertEqual(
                    clone.outputs.op.name,
                    'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               clone.scope)
                self.assertEqual(len(update_ops), 2)
                self.assertEqual(clone.scope, 'clone_%d/' % i)
                self.assertDeviceEqual(clone.device, 'GPU:%d' % i)
    def testCreateOnecloneWithPS(self):
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(0)
            tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.float32)

            model_fn = BatchNormClassifier
            clone_args = (tf_inputs, tf_labels)
            deploy_config = model_deploy.DeploymentConfig(num_clones=1,
                                                          num_ps_tasks=1)

            self.assertEqual(framework.get_variables(), [])
            clones = model_deploy.create_clones(deploy_config, model_fn,
                                                clone_args)
            self.assertEqual(len(clones), 1)
            clone = clones[0]
            self.assertEqual(clone.outputs.op.name,
                             'BatchNormClassifier/fully_connected/Sigmoid')
            self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:0')
            self.assertEqual(clone.scope, '')
            self.assertEqual(len(framework.get_variables()), 5)
            for v in framework.get_variables():
                self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
                self.assertDeviceEqual(v.device, v.value().device)
    def testCreateLogisticClassifier(self):
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(0)
            tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
            tf_labels = tf.constant(self._labels, dtype=tf.float32)

            model_fn = LogisticClassifier
            clone_args = (tf_inputs, tf_labels)
            deploy_config = model_deploy.DeploymentConfig(num_clones=1)

            self.assertEqual(framework.get_variables(), [])
            clones = model_deploy.create_clones(deploy_config, model_fn,
                                                clone_args)
            clone = clones[0]
            self.assertEqual(len(framework.get_variables()), 2)
            for v in framework.get_variables():
                self.assertDeviceEqual(v.device, 'CPU:0')
                self.assertDeviceEqual(v.value().device, 'CPU:0')
            self.assertEqual(clone.outputs.op.name,
                             'LogisticClassifier/fully_connected/Sigmoid')
            self.assertEqual(clone.scope, '')
            self.assertDeviceEqual(clone.device, 'GPU:0')
            self.assertEqual(len(tf.losses.get_losses()), 1)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            self.assertEqual(update_ops, [])