def _get_workers(num_workers, period, workers, moving_rate):
    sessions = []
    graphs = []
    train_ops = []
    for worker_id in range(num_workers):
        graph = ops.Graph()
        is_chief = (worker_id == 0)
        with graph.as_default():
            worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
            ea_coustom = ElasticAverageCustomGetter(
                worker_device=worker_device)
            with variable_scope.variable_scope(
                    "", custom_getter=ea_coustom), ops.device(
                        device_setter.replica_device_setter(
                            worker_device=worker_device,
                            ps_device="/job:ps/task:0/cpu:0",
                            ps_tasks=1)):
                global_step = variables.Variable(0,
                                                 name="global_step",
                                                 trainable=False)
                var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
                var_1 = variable_scope.get_variable(initializer=1.0, name="v1")

            with ops.device("/job:worker/task:" + str(worker_id)):
                grads_0 = constant_op.constant(-1.0)
                grads_1 = constant_op.constant(-1.0)

                sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
                opt = ElasticAverageOptimizer(opt=sgd_opt,
                                              num_worker=num_workers,
                                              moving_rate=moving_rate,
                                              communication_period=period,
                                              ea_custom_getter=ea_coustom)
                train_op = [
                    opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
                                        global_step)
                ]
                easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
            # Creates MonitoredSession
            sess = training.MonitoredTrainingSession(workers[worker_id].target,
                                                     hooks=[easgd_hook])

        sessions.append(sess)
        graphs.append(graph)
        train_ops.append(train_op)

    return sessions, graphs, train_ops
 def testPS2TasksWithClusterSpecClass(self):
     cluster_spec = server_lib.ClusterSpec({
         "ps": ["ps0:2222", "ps1:2222"],
         "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
     })
     ea_coustom = ElasticAverageCustomGetter(
         worker_device="/job:worker/task:0")
     from tensorflow.python.training import device_setter
     with ops.device(
         device_setter.replica_device_setter(cluster=cluster_spec,
                                             worker_device="/job:worker/task:0",
                                             ps_device="/job:ps")), \
          variable_scope.variable_scope("", custom_getter=ea_coustom):
         v = variable_scope.get_variable(initializer=[1, 2], name="v")
         w = variable_scope.get_variable(initializer=[2, 1], name="w")
         v_g, w_g = ea_coustom._global_map[v], ea_coustom._global_map[w]
         self.assertDeviceEqual("/job:worker/task:0", v.device)
         self.assertDeviceEqual("job:ps/task:0", v_g.device)
         self.assertDeviceEqual("/job:worker/task:0", w.device)
         self.assertDeviceEqual("job:ps/task:1", w_g.device)
def _get_workers(num_workers, period, workers, moving_rate, num_ps=1):
    sessions = []
    graphs = []
    train_ops = []
    savers = []
    for worker_id in range(num_workers):
        graph = ops.Graph()
        is_chief = (worker_id == 0)
        with graph.as_default():
            worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
            ea_custom = ElasticAverageCustomGetter(worker_device=worker_device)
            with variable_scope.variable_scope(
                    "", custom_getter=ea_custom), ops.device(
                        device_setter.replica_device_setter(
                            worker_device=worker_device,
                            ps_device="/job:ps/task:0/cpu:0",
                            ps_tasks=1)):
                global_step = training_util.get_or_create_global_step()
                var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
                var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
            if num_ps > 1:
                with variable_scope.variable_scope(
                        "",
                        partitioner=partitioned_variables.
                        fixed_size_partitioner(num_ps, axis=0),
                        custom_getter=ea_custom), ops.device(
                            device_setter.replica_device_setter(
                                worker_device=worker_device,
                                ps_device="/job:ps/task:0/cpu:0",
                                ps_tasks=num_ps)):

                    partition_var = variable_scope.get_variable(
                        'partition_var',
                        shape=[2, 4],
                        initializer=init_ops.ones_initializer)
                    part_0 = list(partition_var)[0]
                    part_1 = list(partition_var)[1]

            with ops.device("/job:worker/task:" + str(worker_id)):
                grads_0 = constant_op.constant(-1.0)
                grads_1 = constant_op.constant(-1.0)
                grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]])
                grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])

                sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
                opt = ElasticAverageOptimizer(opt=sgd_opt,
                                              num_worker=num_workers,
                                              moving_rate=moving_rate,
                                              communication_period=period,
                                              ea_custom_getter=ea_custom)
                if num_ps == 1:
                    train_op = [
                        opt.apply_gradients(
                            ([grads_0, var_0], [grads_1, var_1]), global_step)
                    ]
                else:
                    train_op = [
                        opt.apply_gradients(
                            ([grads_0, var_0], [grads_1, var_1],
                             [grads_part_0, part_0], [grads_part_1, part_1]),
                            global_step)
                    ]
                easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
                saver = opt.swapping_saver()
            # Creates MonitoredSession
            sess = training.MonitoredTrainingSession(workers[worker_id].target,
                                                     hooks=[easgd_hook])

        sessions.append(sess)
        graphs.append(graph)
        train_ops.append(train_op)
        savers.append(saver)

    return sessions, graphs, train_ops, savers