def worker_task(ps, worker_index, batch_size=50): # Download MNIST. print("Worker " + str(worker_index)) mnist = download_mnist_retry(seed=worker_index) # Initialize the model. net = SimpleCNN() keys = net.get_weights()[0] while True: # Get the current weights from the parameter server. weights = ray.get(ps.pull.remote(keys)) net.set_weights(keys, weights) # Compute an update and push it to the parameter server. xs, ys = mnist.train.next_batch(batch_size) gradients = net.compute_update(xs, ys) ps.push.remote(keys, gradients)
ray_ctx = OrcaContext.get_ray_context() else: print( "init_orca_context failed. cluster_mode should be one of 'local', 'yarn' and 'spark-submit' but got " + cluster_mode) # Create a parameter server with some random weights. net = SimpleCNN() all_keys, all_values = net.get_weights() ps = ParameterServer.remote(all_keys, all_values) # Start some training tasks. worker_tasks = [worker_task.remote(ps, i) for i in range(args.num_workers)] # Download MNIST. mnist = download_mnist_retry() print("Begin iteration") i = 0 while i < args.iterations: # Get and evaluate the current model. print("-----Iteration" + str(i) + "------") current_weights = ray.get(ps.pull.remote(all_keys)) net.set_weights(all_keys, current_weights) test_xs, test_ys = mnist.test.next_batch(1000) accuracy = net.compute_accuracy(test_xs, test_ys) print("Iteration {}: accuracy is {}".format(i, accuracy)) i += 1 time.sleep(1) ray_ctx.stop() stop_orca_context()