def train(target, all_data, all_labels, cluster_spec):
    '''
    This is the main function for training
    '''
    image_placeholder = tf.placeholder(
        dtype=tf.float32,
        shape=[FLAGS.batch_size, IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH])
    label_placeholder = tf.placeholder(dtype=tf.int32,
                                       shape=[FLAGS.batch_size])

    num_workers = len(cluster_spec.as_dict()['worker'])
    num_parameter_servers = len(cluster_spec.as_dict()['ps'])

    if FLAGS.num_replicas_to_aggregate == -1:
        num_replicas_to_aggregate = num_workers
    else:
        num_replicas_to_aggregate = FLAGS.num_replicas_to_aggregate

    assert num_workers > 0 and num_parameter_servers > 0, (
        ' num_workers and '
        'num_parameter_servers'
        ' must be > 0.')
    is_chief = (FLAGS.task_id == 0)
    num_examples = all_data.shape[0]

    with tf.device(
            tf.train.replica_device_setter(
                #cpu only
                #            worker_device='/job:worker/task:%d' % FLAGS.task_id,
                #with gpu enabled
                worker_device='/job:worker/task:%d/gpu:0' % FLAGS.task_id,
                cluster=cluster_spec)):

        global_step = tf.Variable(0, name="global_step", trainable=False)

        num_batches_per_epoch = (num_examples / FLAGS.batch_size)
        decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay /
                          num_replicas_to_aggregate)
        lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                        global_step,
                                        decay_steps,
                                        FLAGS.learning_rate_decay_factor,
                                        staircase=True)
        # Logits of training data and valiation data come from the same graph. The inference of
        # validation data share all the weights with train data. This is implemented by passing
        # reuse=True to the variable scopes of train graph
        logits = inference(image_placeholder,
                           FLAGS.num_residual_blocks,
                           reuse=False)

        #            vali_logits = inference(self.vali_image_placeholder, FLAGS.num_residual_blocks, reuse=True)

        # The following codes calculate the train loss, which is consist of the
        # softmax cross entropy and the relularization loss
        #            regu_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = calc_loss(logits, label_placeholder)

        #        predictions = tf.nn.softmax(logits)
        #        train_top1_error = top_k_error(predictions, label_placeholder, 1)

        opt = tf.train.AdamOptimizer(lr)
        if FLAGS.interval_method or FLAGS.worker_times_cdf_method:
            opt = TimeoutReplicasOptimizer(opt,
                                           global_step,
                                           total_num_replicas=num_workers)
        elif FLAGS.backup_worker_method:
            opt = BackupOptimizer(
                opt,
                replicas_to_aggregate=num_replicas_to_aggregate,
                total_num_replicas=num_workers)
        else:
            #            opt = tf.train.SyncReplicasOptimizerV2(
            opt = tf.train.SyncReplicasOptimizer(
                opt,
                replicas_to_aggregate=num_replicas_to_aggregate,
                total_num_replicas=num_workers)

        # Compute gradients with respect to the loss.
        grads = opt.compute_gradients(total_loss)
        #compute weighted gradients here.
        #===============================================================================================
        '''
        #define a placeholder for weighted vector, i.e. LS solution
        weight_vec_placeholder = tf.placeholder(dtype=tf.float32,
                                                shape=(num_workers,))
        grad_list = [x[0] for x in grads]
        new_grad_list = []
        #times gradient from each worker with the corresponding weight
        #which is just scalar multiplication
        for g_idx in range(len(grad_list)):
            grad_on_worker = grad_list[g_idx]
            weight = tf.slice(weight_vec_placeholder, [FLAGS.task_id], [1])
            tf.logging.info("Logging Happens Here!")
            tf.logging.info(weight[0])
            new_grad_list.append(tf.scalar_mul(weight[0], grad_on_worker))
        grad_new = []
        #regenerate the weighted gradients, merging all weighted vector
        for x_idx in range(len(grads)):
            grad_elem = grads[x_idx]
            grad_new.append((new_grad_list[x_idx], grad_elem[1]))
        '''
        #===============================================================================================
        if FLAGS.interval_method or FLAGS.worker_times_cdf_method:
            apply_gradients_op = opt.apply_gradients(
                grads,
                FLAGS.task_id,
                global_step=global_step,
                collect_cdfs=FLAGS.worker_times_cdf_method)
#            apply_gradients_op = opt.apply_gradients(grads, FLAGS.task_id, global_step=global_step)
        elif FLAGS.backup_worker_method:
            apply_gradients_op = opt.apply_gradients(grads,
                                                     FLAGS.task_id,
                                                     global_step=global_step)
        else:
            apply_gradients_op = opt.apply_gradients(grads,
                                                     global_step=global_step)
#           apply_gradients_op = opt.apply_gradients(grad_new, global_step=global_step)
        with tf.control_dependencies([apply_gradients_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Initialize a saver to save checkpoints. Merge all summaries, so we can run all
        # summarizing operations by running summary_op. Initialize a new session
        chief_queue_runners = [opt.get_chief_queue_runner()]
        init_tokens_op = opt.get_init_tokens_op()
        saver = tf.train.Saver()
        summary_op = tf.summary.merge_all()
        init_op = tf.global_variables_initializer()
        test_print_op = logging_ops.Print(0, [0], message="Test print success")
        if is_chief:
            local_init_op = opt.chief_init_op
        else:
            local_init_op = opt.local_step_init_op

        local_init_opt = [local_init_op]
        ready_for_local_init_op = opt.ready_for_local_init_op

        sv = tf.train.Supervisor(
            is_chief=is_chief,
            local_init_op=local_init_op,
            ready_for_local_init_op=ready_for_local_init_op,
            logdir=FLAGS.train_dir,
            init_op=init_op,
            summary_op=None,
            global_step=global_step,
            saver=saver,
            save_model_secs=FLAGS.save_interval_secs)
        tf.logging.info('%s Supervisor' % datetime.now())
        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement)
        sess = sv.prepare_or_wait_for_session(target, config=sess_config)
        queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
        sv.start_queue_runners(sess, queue_runners)
        tf.logging.info('Started %d queues for processing input data.',
                        len(queue_runners))

        if is_chief:
            if not FLAGS.interval_method or FLAGS.worker_times_cdf_method:
                sv.start_queue_runners(sess, chief_queue_runners)
            sess.run(init_tokens_op)

        timeout_client, timeout_server = launch_manager(sess, FLAGS)
        next_summary_time = time.time() + FLAGS.save_summaries_secs
        begin_time = time.time()
        cur_iteration = -1
        local_data_batch_idx = 0
        epoch_counter = 0
        iterations_finished = set()

        if FLAGS.task_id == 0 and FLAGS.interval_method:
            opt.start_interval_updates(sess, timeout_client)
        '''
        np.random.seed(SEED)
        b = np.ones(int(num_batches_per_epoch))
        interval = np.arange(0, int(num_batches_per_epoch))
        idx_list = np.random.choice(interval, int(num_workers), replace=False)     
        '''
        while not sv.should_stop():
            #    try:
            sys.stdout.flush()
            tf.logging.info("A new iteration...")
            cur_iteration += 1

            if FLAGS.worker_times_cdf_method:
                sess.run([opt._wait_op])
                timeout_client.broadcast_worker_dequeued_token(cur_iteration)
            start_time = time.time()
            epoch_counter, local_data_batch_idx, feed_dict = fill_feed_dict(
                all_data, all_labels, image_placeholder, label_placeholder,
                FLAGS.batch_size, local_data_batch_idx, epoch_counter)

            run_options = tf.RunOptions()
            run_metadata = tf.RunMetadata()
            #===============================================================================================
            '''
            LS_start_time = time.time()
            interval_2 = np.arange(0, int(num_workers))
            workers_to_kill = np.random.choice(interval_2, FLAGS.num_worker_kill, replace=False)
            #interval_2 = np.arange(0, WORKER_NUM)
            #workers_to_kill = np.random.choice(interval_2, NUM_WORKER_KILL, replace=False)
            A = np.zeros((int(num_workers), int(num_batches_per_epoch)))
            for i in range(A.shape[0]):
              if i == A.shape[0]-1:
                A[i][idx_list[i]] = 1
                A[i][idx_list[0]] = 1
              else:
                A[i][idx_list[i]] = 1
                A[i][idx_list[i+1]] = 1

            for i in range(len(idx_list)):
              element = idx_list[i]
              if element == A.shape[1]-1:
                idx_list[i] = 0
              else:
                idx_list[i] += 1

            for k in workers_to_kill:
              A[k] = 0

            A_for_calc = np.transpose(A)
            ls_solution = np.dot(np.linalg.pinv(A_for_calc), b)
            tf.logging.info("workers killed this iteration:")
            tf.logging.info(str(workers_to_kill))
            tf.logging.info("The matrix to solve:")
            for item in A_for_calc:
              tf.logging.info(str(item))
            tf.logging.info("Solution of LS:")
            tf.logging.info(str(ls_solution)) 
            LS_duration = time.time() - LS_start_time
            tf.logging.info("LS run time: %s" % str(LS_duration))
            '''
            #===============================================================================================

            if FLAGS.timeline_logging:
                run_options.trace_level = tf.RunOptions.FULL_TRACE
                run_options.output_partition_graphs = True

            #feed_dict[weight_vec_placeholder] = ls_solution
            tf.logging.info("RUNNING SESSION... %f" % time.time())
            tf.logging.info("Data batch index: %s, Current epoch idex: %s" %
                            (str(epoch_counter), str(local_data_batch_idx)))
            loss_value, step = sess.run(
                #[train_op, global_step], feed_dict={feed_dict, x}, run_metadata=run_metadata, options=run_options)
                [train_op, global_step],
                feed_dict=feed_dict,
                run_metadata=run_metadata,
                options=run_options)
            tf.logging.info("DONE RUNNING SESSION...")

            if FLAGS.worker_times_cdf_method:
                timeout_client.broadcast_worker_finished_computing_gradients(
                    cur_iteration)
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            finish_time = time.time()
            if FLAGS.timeline_logging:
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open(
                        '%s/worker=%d_timeline_iter=%d.json' %
                    (FLAGS.train_dir, FLAGS.task_id, step), 'w'):
                    f.write(ctf)
            if step > FLAGS.max_steps:
                break

            duration = time.time() - start_time
            examples_per_sec = FLAGS.batch_size / float(duration)
            format_str = ('Worker %d: %s: step %d, loss = %f'
                          '(%.1f examples/sec; %.3f  sec/batch)')
            tf.logging.info(format_str %
                            (FLAGS.task_id, datetime.now(), step, loss_value,
                             examples_per_sec, duration))
            if is_chief and next_summary_time < time.time(
            ) and FLAGS.should_summarize:
                tf.logging.info('Running Summary operation on the chief.')
                summary_str = sess.run(summary_op)
                sv.summary_computed(sess, summary_str)
                tf.logging.info('Finished running Summary operation.')
                next_summary_time += FLAGS.save_summaries_secs
        #    except tf.errors.DeadlineExceededError:
        #        tf.logging.info("Killed at time %f" % time.time())
        #sess.reset_kill()
        #    except:
        #        tf.logging.info("Unexpected error: %s" % str(sys.exc_info()[0]))
        #sess.reset_kill()
        if is_chief:
            tf.logging.info('Elapsed Time: %f' % (time.time() - begin_time))
        sv.stop()

        if is_chief:
            saver.save(sess,
                       os.path.join(FLAGS.train_dir, 'model.ckpt'),
                       global_step=global_step)
def train(target, dataset, cluster_spec):
    """Train Inception on a dataset for a number of steps."""
    # Number of workers and parameter servers are infered from the workers and ps
    # hosts string.
    num_workers = len(cluster_spec.as_dict()['worker'])
    num_parameter_servers = len(cluster_spec.as_dict()['ps'])
    # If no value is given, num_replicas_to_aggregate defaults to be the number of
    # workers.
    if FLAGS.num_replicas_to_aggregate == -1:
        num_replicas_to_aggregate = num_workers
    else:
        num_replicas_to_aggregate = FLAGS.num_replicas_to_aggregate

    # Both should be greater than 0 in a distributed training.
    assert num_workers > 0 and num_parameter_servers > 0, (
        ' num_workers and '
        'num_parameter_servers'
        ' must be > 0.')

    # Choose worker 0 as the chief. Note that any worker could be the chief
    # but there should be only one chief.
    is_chief = (FLAGS.task_id == 0)

    # Ops are assigned to worker by default.
    with tf.device(
            tf.train.replica_device_setter(
                worker_device='/job:worker/task:%d' % FLAGS.task_id,
                cluster=cluster_spec)):

        # Create a variable to count the number of train() calls. This equals the
        # number of updates applied to the variables. The PS holds the global step.
        global_step = tf.Variable(0, name="global_step", trainable=False)

        # Calculate the learning rate schedule.
        num_batches_per_epoch = (dataset.num_examples / FLAGS.batch_size)

        # Decay steps need to be divided by the number of replicas to aggregate.
        # This was the old decay schedule. Don't want this since it decays too fast with a fixed learning rate.
        decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay /
                          num_replicas_to_aggregate)
        # New decay schedule. Decay every few steps.
        #decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay / num_workers)

        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                        global_step,
                                        decay_steps,
                                        FLAGS.learning_rate_decay_factor,
                                        staircase=True)

        images, labels = mnist.placeholder_inputs(FLAGS.batch_size)

        # Number of classes in the Dataset label set plus 1.
        # Label 0 is reserved for an (unused) background class.
        logits = mnist.inference(images)

        # Add classification loss.
        total_loss = mnist.loss(logits, labels)

        # Create an optimizer that performs gradient descent.
        opt = tf.train.AdamOptimizer(lr)

        # Use V2 optimizer
        if FLAGS.interval_method or FLAGS.worker_times_cdf_method:
            opt = TimeoutReplicasOptimizer(
                opt,
                global_step,
                replicas_to_aggregate=num_replicas_to_aggregate,
                total_num_replicas=num_workers)
        else:
            opt = WeightedGradsOptimizer(
                #      opt = tf.train.SyncReplicasOptimizer(
                opt,
                replicas_to_aggregate=num_replicas_to_aggregate,
                total_num_replicas=num_workers)

        # Compute gradients with respect to the loss.
        grads = opt.compute_gradients(total_loss)

        #===============================================================================================
        #    batch_idx_placeholder = tf.placeholder(dtype=tf.int32, shape=(int(num_workers),))
        #    worker_kill_placeholder = tf.placeholder(dtype=tf.int32, shape=(FLAGS.num_worker_kill,))
        matrix_placeholder = tf.placeholder(dtype=tf.float32,
                                            shape=((int(num_batches_per_epoch),
                                                    int(num_workers))))
        '''
    weight_vec_placeholder = tf.placeholder(dtype=tf.float32,
                                            shape=(num_workers,))
    grad_list = [x[0] for x in grads]
    new_grad_list = []
    for g_idx in range(len(grad_list)):
        grad_on_worker = grad_list[g_idx]
        weight = tf.slice(weight_vec_placeholder, [i], [1])
        new_grad_list.append(tf.mul(grad_on_worker, weight))
    for x_idx in range(len(grads)):
        x = grads[x_idx]
        x[0] = new_grad_list[x_idx]
    '''
        #===============================================================================================

        if FLAGS.interval_method or FLAGS.worker_times_cdf_method:
            #      apply_gradients_op = opt.apply_gradients(grads, FLAGS.task_id, global_step=global_step, collect_cdfs=FLAGS.worker_times_cdf_method,
            #                            batch_idx_list=batch_idx_placeholder, worker_kill_list=worker_kill_placeholder,
            #                            num_workers=int(num_workers), num_batches_per_epoch=int(num_batches_per_epoch))
            apply_gradients_op = opt.apply_gradients(
                grads,
                FLAGS.task_id,
                global_step=global_step,
                collect_cdfs=FLAGS.worker_times_cdf_method,
                matrix_to_solve=matrix_placeholder,
                num_batches_per_epoch=int(num_batches_per_epoch))
        else:
            apply_gradients_op = opt.apply_gradients(grads,
                                                     global_step=global_step)

        with tf.control_dependencies([apply_gradients_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Get chief queue_runners, init_tokens and clean_up_op, which is used to
        # synchronize replicas.
        # More details can be found in sync_replicas_optimizer.
        chief_queue_runners = [opt.get_chief_queue_runner()]
        init_tokens_op = opt.get_init_tokens_op()
        #clean_up_op = opt.get_clean_up_op()

        # Create a saver.
        saver = tf.train.Saver()

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init_op = tf.global_variables_initializer()

        test_print_op = logging_ops.Print(0, [0], message="Test print success")

        # We run the summaries in the same thread as the training operations by
        # passing in None for summary_op to avoid a summary_thread being started.
        # Running summaries and training operations in parallel could run out of
        # GPU memory.
        if is_chief:
            local_init_op = opt.chief_init_op
        else:
            local_init_op = opt.local_step_init_op

        local_init_opt = [local_init_op]
        ready_for_local_init_op = opt.ready_for_local_init_op

        sv = tf.train.Supervisor(
            is_chief=is_chief,
            local_init_op=local_init_op,
            ready_for_local_init_op=ready_for_local_init_op,
            logdir=FLAGS.train_dir,
            init_op=init_op,
            summary_op=None,
            global_step=global_step,
            saver=saver,
            save_model_secs=FLAGS.save_interval_secs)

        tf.logging.info('%s Supervisor' % datetime.now())

        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement)

        # Get a session.
        sess = sv.prepare_or_wait_for_session(target, config=sess_config)

        # Start the queue runners.
        queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
        sv.start_queue_runners(sess, queue_runners)
        tf.logging.info('Started %d queues for processing input data.',
                        len(queue_runners))

        if is_chief:
            if not FLAGS.interval_method or FLAGS.worker_times_cdf_method:
                sv.start_queue_runners(sess, chief_queue_runners)
            sess.run(init_tokens_op)

        # TIMEOUT client overseer.
        # Even if not using timeout, we want to wait until all machines are ready.
        timeout_client, timeout_server = launch_manager(sess, FLAGS)

        # Train, checking for Nans. Concurrently run the summary operation at a
        # specified interval. Note that the summary_op and train_op never run
        # simultaneously in order to prevent running out of GPU memory.
        next_summary_time = time.time() + FLAGS.save_summaries_secs
        begin_time = time.time()
        cur_iteration = -1
        iterations_finished = set()

        if FLAGS.task_id == 0 and FLAGS.interval_method:
            opt.start_interval_updates(sess, timeout_client)

        #the result of normal eqiation waited to be solved like min||Ax - b||^2


#    b = np.ones(int(num_batches_per_epoch))
        interval = np.arange(0, int(num_batches_per_epoch))
        idx_list = np.random.choice(interval, int(num_workers), replace=False)
        while not sv.should_stop():
            sys.stdout.flush()
            tf.logging.info("A new iteration...")

            cur_iteration += 1

            #sess.run([opt._wait_op], options=tf.RunOptions(timeout_in_ms=10000))
            #sess.run([opt._wait_op])
            #sess.run([test_print_op])

            if FLAGS.worker_times_cdf_method:
                sess.run([opt._wait_op])
                timeout_client.broadcast_worker_dequeued_token(cur_iteration)

            start_time = time.time()
            feed_dict = mnist.fill_feed_dict(dataset, images, labels,
                                             FLAGS.batch_size)

            run_options = tf.RunOptions()
            run_metadata = tf.RunMetadata()

            #===============================================================================================
            #      interval_2 = np.arange(0, int(num_workers))
            #      workers_to_kill = np.random.choice(interval_2, FLAGS.num_worker_kill, replace=False)

            LS_start_time = time.time()
            interval_2 = np.arange(0, int(num_workers))
            workers_to_kill = np.random.choice(interval_2,
                                               FLAGS.num_worker_kill,
                                               replace=False)
            #interval_2 = np.arange(0, WORKER_NUM)
            #workers_to_kill = np.random.choice(interval_2, NUM_WORKER_KILL, replace=False)
            A = np.zeros((int(num_workers), int(num_batches_per_epoch)))
            for i in range(A.shape[0]):
                if i == A.shape[0] - 1:
                    A[i][idx_list[i]] = 1
                    A[i][idx_list[0]] = 1
                else:
                    A[i][idx_list[i]] = 1
                    A[i][idx_list[i + 1]] = 1

            for i in range(len(idx_list)):
                element = idx_list[i]
                if element == A.shape[1] - 1:
                    idx_list[i] = 0
                else:
                    idx_list[i] += 1

            for k in workers_to_kill:
                A[k] = 0

            A_for_calc = np.transpose(A)
            #      x = np.dot(np.linalg.pinv(A_for_calc), b)
            #      tf.logging.info("workers killed this iteration:")
            #      tf.logging.info(str(workers_to_kill))
            #  tf.logging.info("The matrix to solve:")
            #  for item in A_for_calc:
            #    tf.logging.info(str(item))
            #      tf.logging.info("Solution of LS:")
            #      tf.logging.info(str(x))
            #      LS_duration = time.time() - LS_start_time
            #      tf.logging.info("LS run time: %s" % str(LS_duration))

            #===============================================================================================

            if FLAGS.timeline_logging:
                run_options.trace_level = tf.RunOptions.FULL_TRACE
                run_options.output_partition_graphs = True

            #timeout_ms = random.randint(300, 1200)
            #tf.logging.info("SETTING TIMEOUT FOR %d ms" % timeout_ms)
            #run_options.timeout_in_ms = 1000 * 60 * 1

            # Increment current iteration
            # Two more tiem in placeholder feed_dict
            feed_dict[matrix_placeholder] = A_for_calc

            tf.logging.info("RUNNING SESSION... %f" % time.time())
            loss_value, step = sess.run([train_op, global_step],
                                        feed_dict=feed_dict,
                                        run_metadata=run_metadata,
                                        options=run_options)
            tf.logging.info("DONE RUNNING SESSION...")

            if FLAGS.worker_times_cdf_method:
                timeout_client.broadcast_worker_finished_computing_gradients(
                    cur_iteration)

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            # Log the elapsed time per iteration
            finish_time = time.time()

            # Create the Timeline object, and write it to a json
            if FLAGS.timeline_logging:
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open(
                        '%s/worker=%d_timeline_iter=%d.json' %
                    (FLAGS.train_dir, FLAGS.task_id, step), 'w') as f:
                    f.write(ctf)

            if step > FLAGS.max_steps:
                break

            duration = time.time() - start_time
            examples_per_sec = FLAGS.batch_size / float(duration)
            format_str = ('Worker %d: %s: step %d, loss = %f'
                          '(%.1f examples/sec; %.3f  sec/batch)')
            tf.logging.info(format_str %
                            (FLAGS.task_id, datetime.now(), step, loss_value,
                             examples_per_sec, duration))

            # Determine if the summary_op should be run on the chief worker.
            if is_chief and next_summary_time < time.time(
            ) and FLAGS.should_summarize:

                tf.logging.info('Running Summary operation on the chief.')
                summary_str = sess.run(summary_op)
                sv.summary_computed(sess, summary_str)
                tf.logging.info('Finished running Summary operation.')

                # Determine the next time for running the summary.
                next_summary_time += FLAGS.save_summaries_secs

        if is_chief:
            tf.logging.info('Elapsed Time: %f' % (time.time() - begin_time))

        # Stop the supervisor.  This also waits for service threads to finish.
        sv.stop()

        # Save after the training ends.
        if is_chief:
            saver.save(sess,
                       os.path.join(FLAGS.train_dir, 'model.ckpt'),
                       global_step=global_step)
def train(target, all_data, all_labels, cluster_spec):
    '''
    This is the main function for training
    '''
    image_placeholder = tf.placeholder(
        dtype=tf.float32,
        shape=[FLAGS.batch_size, IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH])
    label_placeholder = tf.placeholder(dtype=tf.int32,
                                       shape=[FLAGS.batch_size])

    num_workers = len(cluster_spec.as_dict()['worker'])
    num_parameter_servers = len(cluster_spec.as_dict()['ps'])

    if FLAGS.num_replicas_to_aggregate == -1:
        num_replicas_to_aggregate = num_workers
    else:
        num_replicas_to_aggregate = FLAGS.num_replicas_to_aggregate

    assert num_workers > 0 and num_parameter_servers > 0, (
        ' num_workers and '
        'num_parameter_servers'
        ' must be > 0.')
    is_chief = (FLAGS.task_id == 0)
    num_examples = all_data.shape[0]

    with tf.device(
            tf.train.replica_device_setter(
                #cpu only
                #    worker_device='/job:worker/task:%d' % FLAGS.task_id,
                #with gpu enabled
                worker_device='/job:worker/task:%d/gpu:0' % FLAGS.task_id,
                cluster=cluster_spec)):

        global_step = tf.Variable(0, name="global_step", trainable=False)

        num_batches_per_epoch = (num_examples / FLAGS.batch_size)
        decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay /
                          num_replicas_to_aggregate)
        lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                        global_step,
                                        decay_steps,
                                        FLAGS.learning_rate_decay_factor,
                                        staircase=True)
        # Logits of training data and valiation data come from the same graph. The inference of
        # validation data share all the weights with train data. This is implemented by passing
        # reuse=True to the variable scopes of train graph
        logits = inference(image_placeholder,
                           FLAGS.num_residual_blocks,
                           reuse=False)

        # The following codes calculate the train loss, which is consist of the
        # softmax cross entropy and the relularization loss
        #            regu_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = calc_loss(logits, label_placeholder)

        opt = tf.train.AdamOptimizer(lr)

        if FLAGS.interval_method or FLAGS.worker_times_cdf_method:
            opt = TimeoutReplicasOptimizer(opt,
                                           global_step,
                                           total_num_replicas=num_workers)
        elif FLAGS.backup_worker_method:
            opt = BackupOptimizer(
                opt,
                replicas_to_aggregate=num_replicas_to_aggregate,
                total_num_replicas=num_workers)
        else:
            use_svd_compress = FLAGS.svd_rank > 0
            kwargs = {
                'replicas_to_aggregate': num_replicas_to_aggregate,
                'total_num_replicas': num_workers,
                'compress': use_svd_compress,
                'svd_rank': FLAGS.svd_rank
            }
            print('#' * 40)
            print(kwargs)
            print('#' * 40)
            opt = LowCommSync(opt, global_step=global_step, **kwargs)

        # Compute gradients with respect to the loss.
        grads = opt.compute_gradients(total_loss)

        if FLAGS.interval_method or FLAGS.worker_times_cdf_method:
            apply_gradients_op = opt.apply_gradients(
                grads,
                FLAGS.task_id,
                global_step=global_step,
                collect_cdfs=FLAGS.worker_times_cdf_method)
#            apply_gradients_op = opt.apply_gradients(grads, FLAGS.task_id, global_step=global_step)
        elif FLAGS.backup_worker_method:
            apply_gradients_op = opt.apply_gradients(grads,
                                                     FLAGS.task_id,
                                                     global_step=global_step)
        else:
            # SVD encode happens right here:
            shapes = [g.get_shape() for g, _ in grads]
            if use_svd_compress:
                encoded_grads = encode(grads, r=1, shapes=shapes)
                apply_gradients_op = opt.apply_gradients(
                    encoded_grads, global_step=global_step)
            else:
                apply_gradients_op = opt.apply_gradients(
                    grads, global_step=global_step)

        with tf.control_dependencies([apply_gradients_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Initialize a saver to save checkpoints. Merge all summaries, so we can run all
        # summarizing operations by running summary_op. Initialize a new session
        chief_queue_runners = [opt.get_chief_queue_runner()]
        init_tokens_op = opt.get_init_tokens_op()
        saver = tf.train.Saver()
        summary_op = tf.summary.merge_all()
        init_op = tf.global_variables_initializer()
        test_print_op = logging_ops.Print(0, [0], message="Test print success")
        if is_chief:
            local_init_op = opt.chief_init_op
        else:
            local_init_op = opt.local_step_init_op

        local_init_opt = [local_init_op]
        ready_for_local_init_op = opt.ready_for_local_init_op

        sv = tf.train.Supervisor(
            is_chief=is_chief,
            local_init_op=local_init_op,
            ready_for_local_init_op=ready_for_local_init_op,
            logdir=FLAGS.train_dir,
            init_op=init_op,
            summary_op=None,
            global_step=global_step,
            saver=saver,
            save_model_secs=FLAGS.save_interval_secs)
        tf.logging.info('%s Supervisor' % datetime.now())
        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement)
        sess = sv.prepare_or_wait_for_session(target, config=sess_config)
        queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
        sv.start_queue_runners(sess, queue_runners)
        tf.logging.info('Started %d queues for processing input data.',
                        len(queue_runners))

        if is_chief:
            if not FLAGS.interval_method or FLAGS.worker_times_cdf_method:
                sv.start_queue_runners(sess, chief_queue_runners)
            sess.run(init_tokens_op)

        timeout_client, timeout_server = launch_manager(sess, FLAGS)
        next_summary_time = time.time() + FLAGS.save_summaries_secs
        begin_time = time.time()
        cur_iteration = -1
        local_data_batch_idx = 0
        epoch_counter = 0
        iterations_finished = set()

        if FLAGS.task_id == 0 and FLAGS.interval_method:
            opt.start_interval_updates(sess, timeout_client)
        '''
        np.random.seed(SEED)
        b = np.ones(int(num_batches_per_epoch))
        interval = np.arange(0, int(num_batches_per_epoch))
        idx_list = np.random.choice(interval, int(num_workers), replace=False)     
        '''
        while not sv.should_stop():
            #    try:
            sys.stdout.flush()
            tf.logging.info("A new iteration...")
            cur_iteration += 1

            if FLAGS.worker_times_cdf_method:
                sess.run([opt._wait_op])
                timeout_client.broadcast_worker_dequeued_token(cur_iteration)
            start_time = time.time()
            epoch_counter, local_data_batch_idx, feed_dict = fill_feed_dict(
                all_data, all_labels, image_placeholder, label_placeholder,
                FLAGS.batch_size, local_data_batch_idx, epoch_counter)

            run_options = tf.RunOptions()
            run_metadata = tf.RunMetadata()

            if FLAGS.timeline_logging:
                run_options.trace_level = tf.RunOptions.FULL_TRACE
                run_options.output_partition_graphs = True

            #feed_dict[weight_vec_placeholder] = ls_solution
            tf.logging.info("Data batch index: %s, Current epoch idex: %s" %
                            (str(epoch_counter), str(local_data_batch_idx)))
            loss_value, step = sess.run(
                #[train_op, global_step], feed_dict={feed_dict, x}, run_metadata=run_metadata, options=run_options)
                [train_op, global_step],
                feed_dict=feed_dict,
                run_metadata=run_metadata,
                options=run_options)

            if FLAGS.worker_times_cdf_method:
                timeout_client.broadcast_worker_finished_computing_gradients(
                    cur_iteration)
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            finish_time = time.time()
            if FLAGS.timeline_logging:
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open(
                        '%s/worker=%d_timeline_iter=%d.json' %
                    (FLAGS.train_dir, FLAGS.task_id, step), 'w'):
                    f.write(ctf)
            if step > FLAGS.max_steps:
                break

            duration = time.time() - start_time
            examples_per_sec = FLAGS.batch_size / float(duration)
            format_str = ('Worker %d: %s: step %d, loss = %f'
                          '(%.1f examples/sec; %.3f  sec/batch)')
            tf.logging.info(format_str %
                            (FLAGS.task_id, datetime.now(), step, loss_value,
                             examples_per_sec, duration))
            if is_chief and next_summary_time < time.time(
            ) and FLAGS.should_summarize:
                tf.logging.info('Running Summary operation on the chief.')
                summary_str = sess.run(summary_op)
                sv.summary_computed(sess, summary_str)
                tf.logging.info('Finished running Summary operation.')
                next_summary_time += FLAGS.save_summaries_secs
        #    except tf.errors.DeadlineExceededError:
        #        tf.logging.info("Killed at time %f" % time.time())
        #sess.reset_kill()
        #    except:
        #        tf.logging.info("Unexpected error: %s" % str(sys.exc_info()[0]))
        #sess.reset_kill()
        if is_chief:
            tf.logging.info('Elapsed Time: %f' % (time.time() - begin_time))
        sv.stop()

        if is_chief:
            saver.save(sess,
                       os.path.join(FLAGS.train_dir, 'model.ckpt'),
                       global_step=global_step)
Esempio n. 4
0
def train(target, dataset, dataset_test, cluster_spec):
    """Train Inception on a dataset for a number of steps."""
    # Number of workers and parameter servers are infered from the workers and ps
    # hosts string.
    num_workers = len(cluster_spec.as_dict()['worker'])
    num_parameter_servers = len(cluster_spec.as_dict()['ps'])
    # If no value is given, num_replicas_to_aggregate defaults to be the number of
    # workers.
    if FLAGS.num_replicas_to_aggregate == -1:
        num_replicas_to_aggregate = num_workers
    else:
        num_replicas_to_aggregate = FLAGS.num_replicas_to_aggregate

    # Both should be greater than 0 in a distributed training.
    assert num_workers > 0 and num_parameter_servers > 0, (
        ' num_workers and '
        'num_parameter_servers'
        ' must be > 0.')

    # Choose worker 0 as the chief. Note that any worker could be the chief
    # but there should be only one chief.
    is_chief = (FLAGS.task_id == 0)

    # Ops are assigned to worker by default.
    with tf.device(
            tf.train.replica_device_setter(
                worker_device='/job:worker/task:%d' % FLAGS.task_id,
                cluster=cluster_spec)):

        # Create a variable to count the number of train() calls. This equals the
        # number of updates applied to the variables. The PS holds the global step.
        global_step = tf.Variable(0, name="global_step", trainable=False)

        # Calculate the learning rate schedule.
        num_batches_per_epoch = (dataset.num_examples / FLAGS.batch_size)

        # Decay steps need to be divided by the number of replicas to aggregate.
        # This was the old decay schedule. Don't want this since it decays too fast with a fixed learning rate.
        decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay /
                          num_replicas_to_aggregate)
        # New decay schedule. Decay every few steps.
        #decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay / num_workers)

        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                        global_step,
                                        decay_steps,
                                        FLAGS.learning_rate_decay_factor,
                                        staircase=True)

        images, labels = mnist.placeholder_inputs(FLAGS.batch_size)
        # images_test, labels_test = mnist.placeholder_inputs(int(FLAGS.batch_size/6))

        # Number of classes in the Dataset label set plus 1.
        # Label 0 is reserved for an (unused) background class.
        logits = mnist.inference(images, train=True)
        # Test logits
        # logits_test = mnist.inference(images_test, train=False)

        # Add classification loss.
        total_loss = mnist.loss(logits, labels)

        # Add train accuracy
        train_acc = mnist.evaluation(logits, labels)
        # Test accuracy
        # test_acc = mnist.evaluation(logits_test, labels_test)

        # Create an optimizer that performs gradient descent.
        opt = tf.train.GradientDescentOptimizer(lr)

        # Use SyncReplicasOptimizer optimizer
        if FLAGS.interval_method or FLAGS.worker_times_cdf_method:
            opt = TimeoutReplicasOptimizer(opt,
                                           global_step,
                                           total_num_replicas=num_workers)
        else:
            opt = tf.train.SyncReplicasOptimizer(
                opt,
                replicas_to_aggregate=num_replicas_to_aggregate,
                total_num_replicas=num_workers)

        # Compute gradients with respect to the loss.
        grads = opt.compute_gradients(total_loss)

        # Apply drop connect if FLAGS.drop_connect is True.
        if FLAGS.drop_connect:
            bernoulli_sampler = tf.contrib.distributions.Bernoulli(
                p=FLAGS.drop_connect_probability)
            dropped_grads = [(drop_connect(gv[0], bernoulli_sampler), gv[1])
                             for gv in grads]

        if FLAGS.interval_method or FLAGS.worker_times_cdf_method:
            apply_gradients_op = opt.apply_gradients(
                grads,
                FLAGS.task_id,
                global_step=global_step,
                collect_cdfs=FLAGS.worker_times_cdf_method)
        else:
            if FLAGS.drop_connect:
                apply_gradients_op = opt.apply_gradients(
                    dropped_grads, global_step=global_step)
            else:
                apply_gradients_op = opt.apply_gradients(
                    grads, global_step=global_step)
        '''
    This part is an old version, new version only uses apply_gradients_op
    with tf.control_dependencies([apply_gradients_op]):
      train_op = tf.identity(total_loss, name='train_op')
    '''

        # Get chief queue_runners, init_tokens and clean_up_op, which is used to
        # synchronize replicas.
        # More details can be found in sync_replicas_optimizer.
        chief_queue_runners = [opt.get_chief_queue_runner()]
        init_tokens_op = opt.get_init_tokens_op()
        #clean_up_op = opt.get_clean_up_op()

        # Create a saver.
        saver = tf.train.Saver()

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init_op = tf.initialize_all_variables()

        test_print_op = logging_ops.Print(0, [0], message="Test print success")

        # We run the summaries in the same thread as the training operations by
        # passing in None for summary_op to avoid a summary_thread being started.
        # Running summaries and training operations in parallel could run out of
        # GPU memory.
        if is_chief:
            local_init_op = opt.chief_init_op
        else:
            local_init_op = opt.local_step_init_op

        local_init_opt = [local_init_op]
        ready_for_local_init_op = opt.ready_for_local_init_op

        sv = tf.train.Supervisor(
            is_chief=is_chief,
            local_init_op=local_init_op,
            ready_for_local_init_op=ready_for_local_init_op,
            logdir=FLAGS.train_dir,
            init_op=init_op,
            summary_op=None,
            global_step=global_step,
            saver=saver,
            save_model_secs=FLAGS.save_interval_secs)

        tf.logging.info('%s Supervisor' % datetime.now())

        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement)

        # Get a session.
        sess = sv.prepare_or_wait_for_session(target, config=sess_config)

        # Start the queue runners.
        queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
        sv.start_queue_runners(sess, queue_runners)
        tf.logging.info('Started %d queues for processing input data.',
                        len(queue_runners))

        if is_chief:
            if not FLAGS.interval_method or FLAGS.worker_times_cdf_method:
                sv.start_queue_runners(sess, chief_queue_runners)
            sess.run(init_tokens_op)

        # TIMEOUT client overseer.
        # Even if not using timeout, we want to wait until all machines are ready.
        timeout_client, timeout_server = launch_manager(sess, FLAGS)

        # Train, checking for Nans. Concurrently run the summary operation at a
        # specified interval. Note that the summary_op and train_op never run
        # simultaneously in order to prevent running out of GPU memory.
        next_summary_time = time.time() + FLAGS.save_summaries_secs
        begin_time = time.time()
        cur_iteration = -1
        iterations_finished = set()

        if FLAGS.task_id == 0 and FLAGS.interval_method:
            opt.start_interval_updates(sess, timeout_client)

        time_acc_list = []

        while not sv.should_stop():
            try:

                sys.stdout.flush()
                tf.logging.info("A new iteration...")

                # Increment current iteration
                cur_iteration += 1

                #sess.run([opt._wait_op], options=tf.RunOptions(timeout_in_ms=10000))
                #sess.run([opt._wait_op])
                #sess.run([test_print_op])

                if FLAGS.worker_times_cdf_method:
                    sess.run([opt._wait_op])
                    timeout_client.broadcast_worker_dequeued_token(
                        cur_iteration)

                start_time = time.time()
                feed_dict = mnist.fill_feed_dict(dataset, images, labels,
                                                 FLAGS.batch_size)
                # feed_dict_test = mnist.fill_feed_dict(dataset_test, images_test,
                #  labels_test, int(FLAGS.batch_size/6))

                run_options = tf.RunOptions()
                run_metadata = tf.RunMetadata()

                if FLAGS.timeline_logging:
                    run_options.trace_level = tf.RunOptions.FULL_TRACE
                    run_options.output_partition_graphs = True

                #timeout_ms = random.randint(300, 1200)
                #tf.logging.info("SETTING TIMEOUT FOR %d ms" % timeout_ms)
                #run_options.timeout_in_ms = 1000 * 60 * 1

                tf.logging.info("RUNNING SESSION... %f" % time.time())
                #if FLAGS.drop_connect:
                # sess.run(drop_connect_op, feed_dict=feed_dict, run_metadata=run_metadata,
                #   options=run_options)
                #  print(sess.run(grads, feed_dict=feed_dict, run_metadata=run_metadata,
                #    options=run_options))

                sess.run(apply_gradients_op,
                         feed_dict=feed_dict,
                         run_metadata=run_metadata,
                         options=run_options)
                loss_value, step, train_acc_value = sess.run(
                    [total_loss, global_step, train_acc],
                    feed_dict=feed_dict,
                    run_metadata=run_metadata,
                    options=run_options)
                # test_acc_value = sess.run(test_acc, feed_dict=feed_dict_test, run_metadata=run_metadata,
                #  options=run_options)

                #step, train_acc_value = sess.run([global_step, train_acc],
                #   feed_dict=feed_dict, run_metadata=run_metadata, options=run_options)
                tf.logging.info("Global step attained: %d" % step)
                tf.logging.info("DONE RUNNING SESSION...")

                if FLAGS.worker_times_cdf_method:
                    timeout_client.broadcast_worker_finished_computing_gradients(
                        cur_iteration)

                # the following assert line sometimes causes problem, remove for now
                # assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

                # Log the elapsed time per iteration
                finish_time = time.time()

                # Create the Timeline object, and write it to a json
                if FLAGS.timeline_logging:
                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open(
                            '%s/worker=%d_timeline_iter=%d.json' %
                        (FLAGS.train_dir, FLAGS.task_id, step), 'w') as f:
                        f.write(ctf)

                if step > FLAGS.max_steps:
                    break

                test_acc_value = 0.0

                duration = finish_time - start_time
                examples_per_sec = FLAGS.batch_size / float(duration)
                format_str = (
                    'Worker %d: %s: step %d, loss = %f, train_acc = %f, test_acc = %f'
                    '(%.1f examples/sec; %.3f  sec/batch)')
                tf.logging.info(format_str %
                                (FLAGS.task_id, datetime.now(), step,
                                 loss_value, train_acc_value, test_acc_value,
                                 examples_per_sec, duration))

                time_acc_list.append(
                    (finish_time, train_acc_value, test_acc_value, loss_value))

                # Save the results when step % FLAGS.save_results_period == 0
                if step % FLAGS.save_results_period == 0:
                    time_acc_file_name = FLAGS.train_dir + (
                        '/worker%d_time_acc.npy' % FLAGS.task_id)
                    # np.save(loss_file_name, loss_list)
                    np.save(time_acc_file_name, time_acc_list)

                # Determine if the summary_op should be run on the chief worker.
                if is_chief and next_summary_time < time.time(
                ) and FLAGS.should_summarize:

                    tf.logging.info('Running Summary operation on the chief.')
                    summary_str = sess.run(summary_op)
                    sv.summary_computed(sess, summary_str)
                    tf.logging.info('Finished running Summary operation.')

                    # Determine the next time for running the summary.
                    next_summary_time += FLAGS.save_summaries_secs
            except tf.errors.DeadlineExceededError:
                tf.logging.info("Killed at time %f" % time.time())
                sess.reset_kill()
            except:
                tf.logging.info("Unexpected error: %s" %
                                str(sys.exc_info()[0]))
                sess.reset_kill()

        if is_chief:
            tf.logging.info('Elapsed Time: %f' % (time.time() - begin_time))

        # Stop the supervisor.  This also waits for service threads to finish.
        sv.stop()

        # Save after the training ends.
        if is_chief:
            saver.save(sess,
                       os.path.join(FLAGS.train_dir, 'model.ckpt'),
                       global_step=global_step)