def _get_workers(num_workers, steps, workers):
    sessions = []
    graphs = []
    train_ops = []
    for worker_id in range(num_workers):
        graph = ops.Graph()
        is_chief = (worker_id == 0)
        with graph.as_default():
            worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
            ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
                worker_device=worker_device)
            with variable_scope.variable_scope(
                    "", custom_getter=ma_coustom), ops.device(
                        device_setter.replica_device_setter(
                            worker_device=worker_device,
                            ps_device="/job:ps/task:0/cpu:0",
                            ps_tasks=1)):

                global_step = variables.Variable(0,
                                                 name="global_step",
                                                 trainable=False)
                var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
                var_1 = variable_scope.get_variable(initializer=1.0, name="v1")

            with ops.device("/job:worker/task:" + str(worker_id)):
                if worker_id == 0:
                    grads_0 = constant_op.constant(-1.0)
                    grads_1 = constant_op.constant(-1.0)
                else:
                    grads_0 = constant_op.constant(-2.0)
                    grads_1 = constant_op.constant(-2.0)
                sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
                opt = model_average_optimizer.ModelAverageOptimizer(
                    opt=sgd_opt,
                    num_worker=num_workers,
                    ma_custom_getter=ma_coustom,
                    is_chief=is_chief,
                    interval_steps=steps)
                train_op = [
                    opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
                                        global_step)
                ]
            easgd_hook = opt.make_session_run_hook()
            # Creates MonitoredSession
            sess = training.MonitoredTrainingSession(workers[worker_id].target,
                                                     hooks=[easgd_hook])

        sessions.append(sess)
        graphs.append(graph)
        train_ops.append(train_op)
    return sessions, graphs, train_ops
Exemple #2
0
 def testPS2TasksWithClusterSpecClass(self):
   cluster_spec = server_lib.ClusterSpec({
       "ps": ["ps0:2222", "ps1:2222"],
       "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
   })
   worker_device = "/job:worker/task:0"
   ma_coustom = model_average_optimizer.ModelAverageCustomGetter(
       worker_device=worker_device)
   from tensorflow.python.training import device_setter
   with ops.device(
       device_setter.replica_device_setter(cluster=cluster_spec,
                                           worker_device=worker_device,
                                           ps_device="/job:ps")), \
        variable_scope.variable_scope("", custom_getter=ma_coustom):
     v = variable_scope.get_variable(initializer=[1, 2], name="v")
     w = variable_scope.get_variable(initializer=[2, 1], name="w")
     v_g, w_g = ma_coustom._local_2_global[v], ma_coustom._local_2_global[w]
     self.assertDeviceEqual("/job:worker/task:0", v.device)
     self.assertDeviceEqual("job:ps/task:0", v_g.device)
     self.assertDeviceEqual("/job:worker/task:0", w.device)
     self.assertDeviceEqual("job:ps/task:1", w_g.device)