def tower_loss(scope, images1, labels1, hots1, images2, labels2, hots2):
    """Calculate the total loss on a single tower running the multi-task_cnn model.
    Args:
      scope: unique prefix string identifying the multi-task_cnn tower, e.g. 'tower_0'
      images: Images. 4D tensor of shape [batch_size, height, width, 3].
      labels: Labels. 1D tensor of shape [batch_size].
    Returns:
       Tensor of shape [] containing the total loss for a batch of data
    """

    # Build inference Graph.
    logits1 = cnn.inference(images1, n_cnn=5)

    tf.get_variable_scope().reuse_variables()
    logits2 = cnn.inference(images2, n_cnn=5)

    # Build the portion of the Graph calculating the losses. Note that we will
    # assemble the total_loss using a custom function below.
    _ = cnn.loss(logits1, labels1, hots1, logits2, labels2, hots2, loss_type=1)

    # Assemble all of the losses for the current tower only.
    losses = tf.get_collection('losses', scope)

    # Calculate the total loss for the current tower.
    total_loss = tf.add_n(losses, name='total_loss')

    # Attach a scalar summary to all individual losses and the total loss; do the
    # same for the averaged version of the losses.
    for l in losses + [total_loss]:
        # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
        # session. This helps the clarity of presentation on tensorboard.
        loss_name = re.sub('%s_[0-9]*/' % cnn.TOWER_NAME, '', l.op.name)
        tf.summary.scalar(loss_name, l)

    return total_loss
Ejemplo n.º 2
0
def train():
    """Train CGCNN for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get data for training
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            energies, sites_matrices, adj_matrices = cnn.inputs(
                eval_data=False)

        # Build a Graph that computes the energy predictions from the
        # inference model.
        energies_hat = cnn.inference(sites_matrices, adj_matrices)

        # Calculate loss.
        loss = cnn.loss(energies_hat, energies)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cnn.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Ejemplo n.º 3
0
def train():
    print(1)
    with tf.Graph().as_default():
        #print(2)
        global_step = tf.Variable(0, trainable=False)

        images, labels = read_record.read_and_decode(FLAGS.data_dir +
                                                     '/train.tfrecords')
        image_batch, label_batch = cnn.inputs(images, labels, FLAGS.batch_size)
        #print(3)
        logits = cnn.cnn_model(image_batch)
        loss = cnn.loss(logits, label_batch)
        #print(4)
        train_op = cnn.train(loss, global_step, FLAGS.batch_size)
        #print(5)
        saver = tf.train.Saver(tf.global_variables())

        summary_op = tf.summary.merge_all()

        init = tf.global_variables_initializer()

        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))

        sess.run(init)

        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        loss_list = []
        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            loss_list.append(loss_value)

            if step % 465 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = 0  #num_examples_per_step / duration
                sec_per_batch = float(duration)
                average_loss_value = np.mean(loss_list)
                #total_loss_list.append(average_loss_value)
                loss_list.clear()
                format_str = (
                    '%s: epoch %d, loss = %.4f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step / 465, average_loss_value,
                       examples_per_sec, sec_per_batch))

            if step % (465 * 30 + 1) == 0:
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                               'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
Ejemplo n.º 4
0
def main(args=None):

    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)


    with tf.Graph().as_default():
        images, labels = network.train_set()
        logits = network.inference(images)
        loss = network.loss(logits, labels)
        train = network.train(loss, 0.01)

        summary = tf.merge_all_summaries()
        init = tf.initialize_all_variables()

        saver = tf.train.Saver()
        with tf.Session() as sess:
            summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
            sess.run(init)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            try:
                for step in range(300):
                    if not coord.should_stop():
                        _, loss_value = sess.run([train, loss])
                        print 'Step %d: loss = %.2f' % (step, loss_value)

                        summary_str = sess.run(summary)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                        checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint')
                        saver.save(sess, checkpoint_file, global_step=step)

            except tf.errors.OutOfRangeError:
                print 'Done training -- epoch limit reached'
            finally:
                coord.request_stop()

            coord.join(threads)
Ejemplo n.º 5
0
def run_training():
    # for mnist
    # train_data, test_data, validation_data = input_data.read_data_sets("../data/MNIST_data/")
    # for cifar-10
    train_data, test_data, validation_data = input_data.load_data()

    with tf.Graph().as_default():
        image_pl, label_pl, keep_prob_pl = place_holder(FLAGS.batch_size)
        logits = nn_structure.inference(image_pl, conv_1_params,
                                        max_pool_1_params, conv_2_params,
                                        max_pool_2_params,
                                        full_connected_units, keep_prob_pl)
        loss = nn_structure.loss(logits, label_pl)
        train_op = nn_structure.train(loss, FLAGS.learning_rate)
        eval_correct = nn_structure.evaluation(logits, label_pl, k=1)
        init = tf.initialize_all_variables()

        with tf.Session() as sess:
            sess.run(init)
            start_time = time.time()
            for step in range(FLAGS.max_step):
                feed_dict = fill_feed_dict(train_data, 0.5, image_pl, label_pl,
                                           keep_prob_pl)
                _, loss_value = sess.run([train_op, loss], feed_dict)

                if step % 100 == 0:
                    duration = time.time() - start_time
                    print("Step: {:d}, Training Loss: {:.4f}, {:.1f}ms/step".
                          format(step, loss_value, duration * 10))
                    start_time = time.time()

                if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_step:
                    print("Train Eval:")
                    do_eval(sess, eval_correct, train_data, image_pl, label_pl,
                            keep_prob_pl)
                    print("Validation Eval:")
                    do_eval(sess, eval_correct, validation_data, image_pl,
                            label_pl, keep_prob_pl)
                    print("Test Eval:")
                    do_eval(sess, eval_correct, test_data, image_pl, label_pl,
                            keep_prob_pl)
Ejemplo n.º 6
0
    csvlist[0].append("step")
    csvlist[0].append("accuracy")
    csvlist[0].append("loss")

    with tf.Graph().as_default():
        # image tensor
        images_placeholder = tf.placeholder("float",
                                            shape=(None, nn.IMAGE_PIXELS))
        # label tensor
        labels_placeholder = tf.placeholder("float",
                                            shape=(None, nn.NUM_CLASSES))
        # dropout tensor
        keep_prob = tf.placeholder("float")

        logits = nn.inference(images_placeholder, keep_prob)
        loss_value = nn.loss(logits, labels_placeholder)
        train_op = nn.training(loss_value, FLAGS.learning_rate)
        acc = nn.accuracy(logits, labels_placeholder)

        saver = tf.train.Saver()
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        # train
        for step in range(FLAGS.max_steps):
            for i in range(len(train_image) / FLAGS.batch_size):
                batch = FLAGS.batch_size * i
                sess.run(train_op,
                         feed_dict={
Ejemplo n.º 7
0
        BATCH_SIZE,
        image_size=IMAGE_SIZE,
        ch_size=CH_SIZE,
        shuffle = True,
        distored = True)
"""

#output=mynn.inference2(images,keep_prob,IMAGE_SIZE,CH_SIZE,NUM_CLASS)
output = mynn.inference(images, keep_prob, IMAGE_SIZE, CH_SIZE, NUM_CLASS)
validate = mynn.inference(v_images,
                          keep_prob,
                          IMAGE_SIZE,
                          CH_SIZE,
                          NUM_CLASS,
                          validate=True)
loss = mynn.loss(output, labels)
train_op = mynn.training(loss)
acc = mynn.accuracy(validate, v_labels)

with tf.Session() as sess:
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(tf.initialize_all_variables())
    ckpt = tf.train.get_checkpoint_state(sess, '/output/')
    print(ckpt)
    saver.restore(sess, '/output/model.ckpt-%s' % (94000))
    # SummaryWriterでグラフを書く
    tf.train.start_queue_runners(sess)
    summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(LOGDIR, graph=sess.graph)
    for step in range(MAX_STEPS):
        start_time = time.time()
Ejemplo n.º 8
0
def train():
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)
    images, labels = cnn.distorted_inputs()
    logits = cnn.inference(images)
    loss = cnn.loss(logits, labels)
    train_op = cnn.train(loss, global_step)
    summary_op = tf.merge_all_summaries()
    init = tf.initialize_all_variables()
    sess = tf.Session(config=tf.ConfigProto(
        log_device_placement=LOG_DEVICE_PLACEMENT))
    saver = tf.train.Saver(tf.all_variables())

    if tf.gfile.Exists(TRAIN_DIR):
      ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
      last_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
      ckpt_dir = os.path.join(CHECKPOINT_DIR,"model.ckpt-" + last_step)
      if ckpt and ckpt_dir:
        tf.gfile.DeleteRecursively(TRAIN_DIR)
        saver.restore(sess, ckpt_dir)
        assign_op = global_step.assign(int(last_step))
        sess.run(assign_op)
        print ("Read old model from: ", ckpt_dir)
        print ("Starting training at: ", sess.run(global_step))        
      else:
        tf.gfile.DeleteRecursively(TRAIN_DIR)
        sess.run(init)
        print ("No model found. Starting training at: ",sess.run(global_step))
    else:
      tf.gfile.MakeDirs(TRAIN_DIR)
      sess.run(init)
      print ("No folder found. Starting training at: ",sess.run(global_step))
    print ("Writing train results to: ", TRAIN_DIR)
    print ("Train file: ", TRAIN_FILE)
    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.train.SummaryWriter(TRAIN_DIR,
                                            graph_def=sess.graph_def)

    for step in xrange(sess.run(global_step), MAX_STEPS):
      start_time = time.time()
      _, loss_value = sess.run([train_op, loss])
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = BATCH_SIZE
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

      if step % 10 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % 1000 == 0 or (step + 1) == MAX_STEPS:
        checkpoint_path = os.path.join(TRAIN_DIR, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
Ejemplo n.º 9
0
import tensorflow as tf
import cnn_input as cnn_input
import cnn as cnn
import time

image, label = cnn_input.generate_image_and_label()
images, labels = cnn_input.generate_images_and_labels_batch(image=image,
                                                            label=label,
                                                            shuffle=True)
#神经网络计算出来的值
logits = cnn.inference(images)
loss = cnn.loss(logits, labels)  # 返回的交叉熵的均值
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)  #梯度下降
correct_predict = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32))  #在训练集上的正确率
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

t1 = time.time()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10000):
        if i % 100 == 0:
            acc = sess.run(accuracy)
            print('epoch:%d, acc: %f' % (i, acc))
        train_op = sess.run(train_step)
    coord.request_stop()
    coord.join(threads)
Ejemplo n.º 10
0
def train():
  with tf.Graph().as_default():
    
    log('===== START TRAIN RUN: ' + str(datetime.now()) + '=====')
    
    global_step = tf.Variable(0, trainable=False)
    
    # get examples and labels
    examples, labels = cnn.inputs(data_type='train')

    # build graph to compute logits
    logits = cnn.inference(examples)

    # compute loss
    loss, losses_collection = cnn.loss(logits, labels)
    accuracy = cnn.accuracy(logits, labels)

    # train model with one batch of examples
    train_op = cnn.train(loss, global_step)

    # create saver
    saver = tf.train.Saver(tf.all_variables())
  
    # build summary and init op
    summary_op = tf.merge_all_summaries()
    init_op = tf.initialize_all_variables()

    # start session
    # sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
    sess = tf.Session()
    sess.run(init_op)
    
    # start queue runners
    tf.train.start_queue_runners(sess=sess)

    # set up summary writers
    train_writer = tf.train.SummaryWriter(config.train_dir, sess.graph)
    
    for step in xrange(config.max_steps):
      
      start_time = time.time()
      summary, loss_value, accuracy_value, _ = sess.run([summary_op, loss, accuracy, train_op])

      loss_breakdown = [(str(l.op.name), sess.run(l)) for l in losses_collection]
        
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % config.summary_every_n_steps == 0: # summaries
        
        examples_per_sec = config.batch_size / duration
        sec_per_batch = float(duration)
        
        train_writer.add_summary(summary, step)

        log_str_1 = ('%s: step %d, loss = %.3f (%.2f examples/sec; %.3f sec/batch), accuracy %.3f   ') % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch, accuracy_value)
        log_str_1 += str(loss_breakdown) # print loss breakdown
        log(log_str_1)

        log("memory usage: {} Mb".format(float(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)/1000000.0))
        

      if (step % config.ckpt_every_n_steps == 0) and (step>0): # save weights to file & validate
        checkpoint_path = os.path.join(config.checkpoint_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
        log("Checkpoint saved at step %d" % step)