def train():
    """Train Inception on a dataset for a number of steps."""
    ps_hosts = FLAGS.ps_hosts.split(',')
    worker_hosts = FLAGS.worker_hosts.split(',')
    tf.logging.info('PS hosts are: %s' % ps_hosts)
    tf.logging.info('Worker hosts are: %s' % worker_hosts)

    cluster_spec = tf.train.ClusterSpec({
        'ps': ps_hosts,
        'worker': worker_hosts
    })
    server = tf.train.Server({
        'ps': ps_hosts,
        'worker': worker_hosts
    },
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_id,
                             protocol=FLAGS.protocol)
    batchSizeManager = BatchSizeManager(FLAGS.batch_size, len(worker_hosts))
    if FLAGS.job_name == 'ps':
        if FLAGS.task_id == 0:
            rpcServer = batchSizeManager.create_rpc_server(
                ps_hosts[0].split(':')[0])
            rpcServer.serve()
        server.join()

    dataset = ImagenetData(subset=FLAGS.subset)
    rpcClient = batchSizeManager.create_rpc_client(ps_hosts[0].split(':')[0])
    assert dataset.data_files()
    # Only the chief checks for or creates train_dir.
    if FLAGS.task_id == 0:
        if not tf.gfile.Exists(FLAGS.train_dir):
            tf.gfile.MakeDirs(FLAGS.train_dir)
    num_workers = len(cluster_spec.as_dict()['worker'])
    num_parameter_servers = len(cluster_spec.as_dict()['ps'])
    if FLAGS.num_replicas_to_aggregate == -1:
        num_replicas_to_aggregate = num_workers
    else:
        num_replicas_to_aggregate = FLAGS.num_replicas_to_aggregate

    # Both should be greater than 0 in a distributed training.
    assert num_workers > 0 and num_parameter_servers > 0, (
        ' num_workers and '
        'num_parameter_servers'
        ' must be > 0.')

    # Choose worker 0 as the chief. Note that any worker could be the chief
    # but there should be only one chief.
    is_chief = (FLAGS.task_id == 0)

    #batchSizeManager = BatchSizeManager(32, 4)

    # Ops are assigned to worker by default.
    tf.logging.info('cccc-num_parameter_servers:' + str(num_parameter_servers))
    partitioner = tf.fixed_size_partitioner(num_parameter_servers, 0)

    device_setter = tf.train.replica_device_setter(
        ps_tasks=num_parameter_servers)
    slim = tf.contrib.slim
    with tf.device('/job:worker/task:%d' % FLAGS.task_id):
        with tf.variable_scope('root', partitioner=partitioner):
            # Variables and its related init/assign ops are assigned to ps.
            #    with slim.arg_scope(
            #        [slim.variables.variable, slim.variables.global_step],
            #        device=slim.variables.VariableDeviceChooser(num_parameter_servers)):
            with tf.device(device_setter):
                #	partitioner=partitioner):
                # Create a variable to count the number of train() calls. This equals the
                # number of updates applied to the variables.
                #      global_step = slim.variables.global_step()
                global_step = tf.Variable(0, trainable=False)

                # Calculate the learning rate schedule.

                batch_size = tf.placeholder(dtype=tf.int32,
                                            shape=(),
                                            name='batch_size')
                num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                                         FLAGS.batch_size)
                # Decay steps need to be divided by the number of replicas to aggregate.
                decay_steps = int(num_batches_per_epoch *
                                  FLAGS.num_epochs_per_decay /
                                  num_replicas_to_aggregate)

                # Decay the learning rate exponentially based on the number of steps.
                lr = tf.train.exponential_decay(
                    FLAGS.initial_learning_rate,
                    global_step,
                    decay_steps,
                    FLAGS.learning_rate_decay_factor,
                    staircase=True)
                # Add a summary to track the learning rate.
                #      tf.summary.scalar('learning_rate', lr)

                # Create an optimizer that performs gradient descent.

                images, labels = image_processing.distorted_inputs(
                    dataset,
                    batch_size,
                    num_preprocess_threads=FLAGS.num_preprocess_threads)
                print(images.get_shape())
                print(labels.get_shape())

                # Number of classes in the Dataset label set plus 1.
                # Label 0 is reserved for an (unused) background class.
                #      num_classes = dataset.num_classes() + 1
                num_classes = dataset.num_classes()
                print(num_classes)
                #      logits = inception.inference(images, num_classes, for_training=True)
                network_fn = nets_factory.get_network_fn(
                    'inception_v3', num_classes=num_classes)
                (logits, _) = network_fn(images)
                print(logits.get_shape())
                # Add classification loss.
                #      inception.loss(logits, labels, batch_size)

                # Gather all of the losses including regularization losses.
                labels = tf.one_hot(labels, 1000, 1, 0)
                cross_entropy = tf.losses.softmax_cross_entropy(
                    logits=logits, onehot_labels=labels)
                #      losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
                #      losses += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
                losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
                total_loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
                    [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

                #      total_loss = tf.add_n(losses, name='total_loss')

                loss_averages = tf.train.ExponentialMovingAverage(0.9,
                                                                  name='avg')
                loss_averages_op = loss_averages.apply(losses + [total_loss])

                with tf.control_dependencies([loss_averages_op]):
                    opt = tf.train.RMSPropOptimizer(lr,
                                                    RMSPROP_DECAY,
                                                    momentum=RMSPROP_MOMENTUM,
                                                    epsilon=RMSPROP_EPSILON)
                    grads0 = opt.compute_gradients(total_loss)
                    grads = [(tf.scalar_mul(
                        tf.cast(batch_size / FLAGS.batch_size, tf.float32),
                        grad), var) for grad, var in grads0]
                    total_loss = tf.identity(total_loss)

                exp_moving_averager = tf.train.ExponentialMovingAverage(
                    MOVING_AVERAGE_DECAY, global_step)
                variables_averages_op = exp_moving_averager.apply(
                    tf.trainable_variables())

                apply_gradients_op = opt.apply_gradients(
                    grads, global_step=global_step)

                with tf.control_dependencies(
                    [apply_gradients_op, variables_averages_op]):
                    train_op = tf.identity(total_loss, name='train_op')

                # Get chief queue_runners and init_tokens, which is used to synchronize
                # replicas. More details can be found in SyncReplicasOptimizer.
#      chief_queue_runners = [opt.get_chief_queue_runner()]
#      init_tokens_op = opt.get_init_tokens_op()

# Create a saver.
                saver = tf.train.Saver()

                # Build the summary operation based on the TF collection of Summaries.
                #      summary_op = tf.summary.merge_all()

                # Build an initialization operation to run below.
                init_op = tf.global_variables_initializer()

                # We run the summaries in the same thread as the training operations by
                # passing in None for summary_op to avoid a summary_thread being started.
                # Running summaries and training operations in parallel could run out of
                # GPU memory.
                sv = tf.train.Supervisor(
                    is_chief=is_chief,
                    logdir=FLAGS.train_dir,
                    init_op=init_op,
                    summary_op=None,
                    global_step=global_step,
                    recovery_wait_secs=1,
                    saver=None,
                    save_model_secs=FLAGS.save_interval_secs)

                tf.logging.info('%s Supervisor' % datetime.now())

                sess_config = tf.ConfigProto(
                    allow_soft_placement=True,
                    log_device_placement=FLAGS.log_device_placement)

                # Get a session.
                sess = sv.prepare_or_wait_for_session(server.target,
                                                      config=sess_config)

                # Start the queue runners.
                queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
                sv.start_queue_runners(sess, queue_runners)
                tf.logging.info('Started %d queues for processing input data.',
                                len(queue_runners))

                #      if is_chief:
                #        sv.start_queue_runners(sess, chief_queue_runners)
                #        sess.run(init_tokens_op)

                # Train, checking for Nans. Concurrently run the summary operation at a
                # specified interval. Note that the summary_op and train_op never run
                # simultaneously in order to prevent running out of GPU memory.
                #      next_summary_time = time.time() + FLAGS.save_summaries_secs
                step = 0
                time0 = time.time()
                batch_size_num = 1
                while not sv.should_stop():
                    try:
                        start_time = time.time()

                        batch_size_num = 32
                        #	   batch_size_num = int((int(step)/3*10)) % 100000 + 1
                        #          if step < 5:
                        #            batch_size_num = 32
                        #          batch_size_num = (batch_size_num ) % 64 + 1
                        #          else:
                        #            batch_size_num = 80

                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        my_images, loss_value, step = sess.run(
                            [images, train_op, global_step],
                            feed_dict={batch_size: batch_size_num},
                            options=run_options,
                            run_metadata=run_metadata)
                        b = time.time()
                        #          assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
                        if step > FLAGS.max_steps:
                            break
                        duration = time.time() - start_time
                        #	  thread = threading2.Thread(target=get_computation_time, name="get_computation_time",args=(run_metadata.step_stats,step,))
                        #	  thread.start()
                        #          tl = timeline.Timeline(run_metadata.step_stats)
                        #          last_batch_time = tl.get_local_step_duration('sync_token_q_Dequeue')
                        c0 = time.time()
                        #          batch_size_num = batchSizeManager.dictate_new_batch_size(FLAGS.task_id, last_batch_time)
                        #          batch_size_num = rpcClient.update_batch_size(FLAGS.task_id, last_batch_time, available_cpu, available_memory, step, batch_size_num)
                        # batch_size_num = rpcClient.update_batch_size(FLAGS.task_id, 0,0,0, step, batch_size_num)
                        #          ctf = tl.generate_chrome_trace_format()
                        #          with open("timeline.json", 'a') as f:
                        #            f.write(ctf)

                        if step % 1 == 0:
                            examples_per_sec = FLAGS.batch_size / float(
                                duration)
                            c = time.time()
                            tf.logging.info("time statistics" +
                                            " - train_time: " +
                                            str(b - start_time) +
                                            " - get_batch_time: " +
                                            str(c0 - b) + " - get_bs_time:  " +
                                            str(c - c0) + " - accum_time: " +
                                            str(c - time0) +
                                            " - batch_size: " +
                                            str(batch_size_num))
                            format_str = (
                                'Worker %d: %s: step %d, loss = %.2f'
                                '(%.1f examples/sec; %.3f  sec/batch)')
                            tf.logging.info(
                                format_str %
                                (FLAGS.task_id, datetime.now(), step,
                                 loss_value, examples_per_sec, duration))

                        # Determine if the summary_op should be run on the chief worker.
#          if is_chief and next_summary_time < time.time():
#            tf.logging.info('Running Summary operation on the chief.')
#            summary_str = sess.run(summary_op)
#            sv.summary_computed(sess, summary_str)
#            tf.logging.info('Finished running Summary operation.')

# Determine the next time for running the summary.
#            next_summary_time += FLAGS.save_summaries_secs
                    except:
                        if is_chief:
                            tf.logging.info(
                                'Chief got exception while running!')
                        raise

                # Stop the supervisor.  This also waits for service threads to finish.
                sv.stop()
def main(argv=None):
    ps_hosts = FLAGS.ps_hosts.split(',')
    worker_hosts = FLAGS.worker_hosts.split(',')
    tf.logging.info('PS hosts are: %s' % ps_hosts)
    tf.logging.info('Worker hosts are: %s' % worker_hosts)
    cluster_spec = tf.train.ClusterSpec({
        'ps': ps_hosts,
        'worker': worker_hosts
    })
    server = tf.train.Server({
        'ps': ps_hosts,
        'worker': worker_hosts
    },
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_id,
                             protocol=FLAGS.protocol)

    sspManager = SspManager(len(worker_hosts), 5)
    if FLAGS.job_name == 'ps':
        if FLAGS.task_id == 0:
            rpcServer = sspManager.create_rpc_server(ps_hosts[0].split(':')[0])
            rpcServer.serve()
        server.join()

    time.sleep(5)
    rpcClient = sspManager.create_rpc_client(ps_hosts[0].split(':')[0])

    dataset = ImagenetData(subset=FLAGS.subset)
    assert dataset.data_files()
    is_chief = (FLAGS.task_id == 0)
    if is_chief:
        if not tf.gfile.Exists(FLAGS.train_dir):
            tf.gfile.MakeDirs(FLAGS.train_dir)

    num_workers = len(cluster_spec.as_dict()['worker'])
    num_parameter_servers = len(cluster_spec.as_dict()['ps'])

    with tf.device('/job:worker/task:%d' % FLAGS.task_id):
        with slim.scopes.arg_scope(
            [slim.variables.variable, slim.variables.global_step],
                device=slim.variables.VariableDeviceChooser(
                    num_parameter_servers)):
            '''Prepare Input'''
            global_step = slim.variables.global_step()
            batch_size = tf.placeholder(dtype=tf.int32,
                                        shape=(),
                                        name='batch_size')
            images, labels = image_processing.distorted_inputs(
                dataset,
                batch_size,
                num_preprocess_threads=FLAGS.num_preprocess_threads)
            num_classes = dataset.num_classes() + 1
            '''Inference'''
            logits = inception.inference(images,
                                         num_classes,
                                         for_training=True)
            '''Loss'''
            inception.loss(logits, labels, batch_size)
            losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
            losses += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            total_loss = tf.add_n(losses, name='total_loss')
            if is_chief:
                loss_averages = tf.train.ExponentialMovingAverage(0.9,
                                                                  name='avg')
                loss_averages_op = loss_averages.apply(losses + [total_loss])
                with tf.control_dependencies([loss_averages_op]):
                    total_loss = tf.identity(total_loss)
            '''Optimizer'''
            exp_moving_averager = tf.train.ExponentialMovingAverage(
                inception.MOVING_AVERAGE_DECAY, global_step)
            variables_to_average = (tf.trainable_variables() +
                                    tf.moving_average_variables())
            num_batches_per_epoch = (dataset.num_examples_per_epoch() /
                                     FLAGS.batch_size)
            decay_steps = int(num_batches_per_epoch *
                              FLAGS.num_epochs_per_decay / num_workers)
            lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
                                            global_step,
                                            decay_steps,
                                            FLAGS.learning_rate_decay_factor,
                                            staircase=True)
            opt = tf.train.RMSPropOptimizer(lr,
                                            RMSPROP_DECAY,
                                            momentum=RMSPROP_MOMENTUM,
                                            epsilon=RMSPROP_EPSILON)
            '''Train Operation'''
            batchnorm_updates = tf.get_collection(
                slim.ops.UPDATE_OPS_COLLECTION)
            assert batchnorm_updates, 'Batchnorm updates are missing'
            batchnorm_updates_op = tf.group(*batchnorm_updates)
            with tf.control_dependencies([batchnorm_updates_op]):
                total_loss = tf.identity(total_loss)
            naive_grads = opt.compute_gradients(total_loss)
            grads = [(tf.scalar_mul(
                tf.cast(batch_size / FLAGS.batch_size, tf.float32), grad), var)
                     for grad, var in naive_grads]
            apply_gradients_op = opt.apply_gradients(grads,
                                                     global_step=global_step)
            with tf.control_dependencies([apply_gradients_op]):
                train_op = tf.identity(total_loss, name='train_op')
            '''Supervisor and Session'''
            saver = tf.train.Saver()
            init_op = tf.global_variables_initializer()
            sv = tf.train.Supervisor(is_chief=is_chief,
                                     logdir=FLAGS.train_dir,
                                     init_op=init_op,
                                     summary_op=None,
                                     global_step=global_step,
                                     recovery_wait_secs=1,
                                     saver=saver,
                                     save_model_secs=FLAGS.save_interval_secs)
            tf.logging.info('%s Supervisor' % datetime.now())
            sess_config = tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=FLAGS.log_device_placement)
            sess = sv.prepare_or_wait_for_session(server.target,
                                                  config=sess_config)
            queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
            '''Start Training'''
            sv.start_queue_runners(sess, queue_runners)
            tf.logging.info('Started %d queues for processing input data.',
                            len(queue_runners))

            batch_size_num = FLAGS.batch_size
            for step in range(FLAGS.max_steps):
                start_time = time.time()
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                loss_value, gs = sess.run(
                    [train_op, global_step],
                    feed_dict={batch_size: batch_size_num},
                    options=run_options,
                    run_metadata=run_metadata)

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                duration = time.time() - start_time
                examples_per_sec = batch_size_num / float(duration)
                sec_per_batch = float(duration)
                format_str = (
                    "time: " + str(time.time()) +
                    '; %s: step %d (gs %d), loss= %.2f (%.1f samples/s; %.3f s/batch)'
                )
                tf.logging.info(format_str %
                                (datetime.now(), step, gs, loss_value,
                                 examples_per_sec, sec_per_batch))
                rpcClient.check_staleness(FLAGS.task_id, step)
Пример #3
0
                            """Number of iterations to run.""")
tf.app.flags.DEFINE_string('model_file', 'model/DCNet_', """Directory to save model""")


is_training = tf.placeholder("bool")

train_set = ImagenetData(subset='train')
tr_images, tr_labels = alex2012_image_processing.distorted_inputs(train_set)

val_set  = ImagenetData(subset='validation')
val_images, val_labels = alex2012_image_processing.inputs(val_set)

images, labels = tf.cond(is_training, lambda: [tr_images, tr_labels], lambda: [val_images, val_labels])

cnn = VGG()
cnn.build(images, train_set.num_classes(), is_training)

fit_loss = loss2(cnn.score, labels, train_set.num_classes(), 'c_entropy') 
reg_loss = tf.add_n(tf.losses.get_regularization_losses())
orth_loss = tf.add_n(tf.get_collection('orth_constraint'))
loss_op = fit_loss + orth_loss + reg_loss

lr_ = tf.placeholder("float")

weight_list = [v for v in tf.trainable_variables() 
        if ('/filter' in v.name and 'score' not in v.name and 'shortcut' not in v.name)]
assign_op_list = []
for v in weight_list:
    assign_op_list.append(tf.assign(v, cnn.sphere_dict[v.name]))
assign_op = tf.group(*assign_op_list)