def _get_workers(num_workers, steps, workers): sessions = [] graphs = [] train_ops = [] for worker_id in range(num_workers): graph = ops.Graph() is_chief = (worker_id == 0) with graph.as_default(): worker_device = "/job:worker/task:%d/cpu:0" % (worker_id) ma_coustom = model_average_optimizer.ModelAverageCustomGetter( worker_device=worker_device) with variable_scope.variable_scope( "", custom_getter=ma_coustom), ops.device( device_setter.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/task:0/cpu:0", ps_tasks=1)): global_step = variables.Variable(0, name="global_step", trainable=False) var_0 = variable_scope.get_variable(initializer=0.0, name="v0") var_1 = variable_scope.get_variable(initializer=1.0, name="v1") with ops.device("/job:worker/task:" + str(worker_id)): if worker_id == 0: grads_0 = constant_op.constant(-1.0) grads_1 = constant_op.constant(-1.0) else: grads_0 = constant_op.constant(-2.0) grads_1 = constant_op.constant(-2.0) sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) opt = model_average_optimizer.ModelAverageOptimizer( opt=sgd_opt, num_worker=num_workers, ma_custom_getter=ma_coustom, is_chief=is_chief, interval_steps=steps) train_op = [ opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]], global_step) ] easgd_hook = opt.make_session_run_hook() # Creates MonitoredSession sess = training.MonitoredTrainingSession(workers[worker_id].target, hooks=[easgd_hook]) sessions.append(sess) graphs.append(graph) train_ops.append(train_op) return sessions, graphs, train_ops
def testPS2TasksWithClusterSpecClass(self): cluster_spec = server_lib.ClusterSpec({ "ps": ["ps0:2222", "ps1:2222"], "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] }) worker_device = "/job:worker/task:0" ma_coustom = model_average_optimizer.ModelAverageCustomGetter( worker_device=worker_device) from tensorflow.python.training import device_setter with ops.device( device_setter.replica_device_setter(cluster=cluster_spec, worker_device=worker_device, ps_device="/job:ps")), \ variable_scope.variable_scope("", custom_getter=ma_coustom): v = variable_scope.get_variable(initializer=[1, 2], name="v") w = variable_scope.get_variable(initializer=[2, 1], name="w") v_g, w_g = ma_coustom._local_2_global[v], ma_coustom._local_2_global[w] self.assertDeviceEqual("/job:worker/task:0", v.device) self.assertDeviceEqual("job:ps/task:0", v_g.device) self.assertDeviceEqual("/job:worker/task:0", w.device) self.assertDeviceEqual("job:ps/task:1", w_g.device)