예제 #1
0
def create_session_config():
    #session_config = config_pb2.ConfigProto(allow_soft_placement=True, isolate_session_state=True)
    rpc_options = config_pb2.RPCOptions()
    # Setting cache_rpc_response to true will enable sender side caching of
    # response for RecvTensorAsync and RecvBufAsync to allow receiver to retry
    # requests . This is only necessary when the network fabric is experiencing a
    # significant error rate.  Without it we'll fail a step on an network error,
    # while with it we'll be able to complete long steps (like complex
    # initializations) in the face of some network errors during RecvTensor.
    rpc_options.cache_rpc_response = True
    rewriter_config = rewriter_config_pb2.RewriterConfig(
        disable_model_pruning=True,
        disable_meta_optimizer=True,
        dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
        fail_on_optimizer_errors=True,
    )
    graph_options = config_pb2.GraphOptions(
        rewrite_options=rewriter_config,
        place_pruned_graph=True,
        infer_shapes=True,
    )
    session_config = config_pb2.ConfigProto(
        graph_options=graph_options,
        allow_soft_placement=True,
        isolate_session_state=False,
    )
    # share variables across sessions on TPUs
    session_config.experimental.share_session_state_in_clusterspec_propagation = True
    # TODO: research this. What does it do?
    # session_config.share_cluster_devices_in_session = True
    return session_config
예제 #2
0
    def _useRPCConfig(self):
        """Return a `tf.ConfigProto` that ensures we use the RPC stack for tests.

    This configuration ensures that we continue to exercise the gRPC
    stack when testing, rather than using the in-process optimization,
    which avoids using gRPC as the transport between a client and
    master in the same process.

    Returns:
      A `tf.ConfigProto`.
    """
        return tf.ConfigProto(rpc_options=config_pb2.RPCOptions(
            use_rpc_for_inprocess_master=True))
예제 #3
0
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import gradient_descent



if __name__ == '__main__':
  _tf_patch = patch_tensorflow_interactive()
  if len(sys.argv) <= 1:
    from tensorflow.core.protobuf import config_pb2
    import tensorflow as tf
    tf1 = tf.compat.v1
    tf.compat.v1.logging.set_verbosity('DEBUG')
    import numpy as np
    #session_config = config_pb2.ConfigProto(allow_soft_placement=True, isolate_session_state=True)
    rpc_options = config_pb2.RPCOptions()
    # Setting cache_rpc_response to true will enable sender side caching of
    # response for RecvTensorAsync and RecvBufAsync to allow receiver to retry
    # requests . This is only necessary when the network fabric is experiencing a
    # significant error rate.  Without it we'll fail a step on an network error,
    # while with it we'll be able to complete long steps (like complex
    # initializations) in the face of some network errors during RecvTensor.
    rpc_options.cache_rpc_response = True

    rewriter_config = rewriter_config_pb2.RewriterConfig(
        disable_model_pruning=True,
        disable_meta_optimizer=True,
        dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
        fail_on_optimizer_errors=True,
        )
예제 #4
0
def main(_):
    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")

    # Create a cluster from the parameter server and worker hosts.
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    # Create and start a server for the local task.
    server = tf.train.Server(
        cluster,
        job_name=FLAGS.job_name,
        task_index=FLAGS.task_index,
        config=tf.ConfigProto(rpc_options=config_pb2.RPCOptions(
            ex_grpc_compression=FLAGS.compression_on)))

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":

        # Assigns ops to the local worker by default.
        with tf.device(
                tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % FLAGS.task_index,
                    cluster=cluster)):

            # Build model...
            partition_size = 784 // FLAGS.num_partition
            mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
            x = tf.placeholder(tf.float32,
                               [None, partition_size * FLAGS.num_batch])
            W = tf.Variable(tf.zeros([784, 10]))
            partition_index = tf.placeholder(tf.int32)

            #W_ = W[partition_index*partition_size:(partition_index+1)*partition_size, :]
            #W_ = tf.Variable(tf.zeros([784, 0]))
            #      for j in partition_index:
            # W_ = tf.concat(W_, W[partition_index*partition_size:(partition_index+1)*partition_size, :], 1)
            W_ = tf.gather(W, partition_index)
            b = tf.Variable(tf.zeros([10]))
            y = tf.matmul(x, W_) + b
            y_ = tf.placeholder(tf.float32, [None, 10])
            loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
            #train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

            #loss = ...
            global_step = tf.contrib.framework.get_or_create_global_step()
            train_op = tf.train.GradientDescentOptimizer(0.5).minimize(
                loss, global_step=global_step)
            #train_op = tf.train.AdagradOptimizer(0.01).minimize(
            #    loss, global_step=global_step)

        # The StopAtStepHook handles stopping after running given steps.
        hooks = [tf.train.StopAtStepHook(last_step=FLAGS.train_steps)]

        # The MonitoredTrainingSession takes care of session initialization,
        # restoring from a checkpoint, saving to a checkpoint, and closing when done
        # or an error occurs.
        with tf.train.MonitoredTrainingSession(
                master=server.target,
                is_chief=(FLAGS.task_index == 0),
                checkpoint_dir="/tmp/train_logs",
                #config=tf.ConfigProto(log_device_placement=True),
                config=tf.ConfigProto(rpc_options=config_pb2.RPCOptions(
                    ex_grpc_compression=FLAGS.compression_on)),
                hooks=hooks) as mon_sess:

            q = itertools.cycle(
                itertools.combinations(range(FLAGS.num_partition),
                                       FLAGS.num_batch))

            #import pdb
            #pdb.set_trace()

            while not mon_sess.should_stop():
                #for _ in range(1000):
                # Run a training step asynchronously.
                # See `tf.train.SyncReplicasOptimizer` for additional details on how to
                # perform *synchronous* training.
                # mon_sess.run handles AbortedError in case of preempted PS.

                #pass in parameter num_partition, num_batch
                which = next(
                    q
                )  #bin(functools.reduce(int.__or__, (1<<d for d in q[i % len(q)]), 0))[2:].rjust(FLAGS.num_partition, "0")
                #print(which)

                batch_xs, batch_ys = mnist.train.next_batch(100)
                #        mask_matrix = np.zeros((100, 784), np.float32)
                '''
        for j in which:
                mask_matrix[:, j*partition_size:(j+1)*partition_size] = batch_xs[:, j*partition_size:(j+1)*partition_size]

        batch_xs = mask_matrix
        '''
                #index = [0]
                #for j in which:
                #        index = np.concatenate((index, list(range(j*partition_size, (j+1)*partition_size))))
                index = np.concatenate(
                    tuple(
                        list(
                            range(j * partition_size, (j + 1) *
                                  partition_size)) for j in which))

                #new_x = []
                #for j in which:
                #np.concatenate(new_x, batch_xs[:, j*partition_size:(j+1)*partition_size], axis = 1)

                import pdb
                # pdb.set_trace()

                batch_xs = batch_xs[:, index]

                #batch_xs = np.concatenate((batch_xs[:, :196], np.zeros((100, 588), np.float32)), axis=1)

                # import pdb
                # pdb.set_trace()
                mon_sess.run(train_op,
                             feed_dict={
                                 x: batch_xs,
                                 y_: batch_ys,
                                 partition_index: index
                             })