Пример #1
0
def inputs(source_data):
    """Construct input for SVHN training and evaluation using the Reader ops.
  Args:
    source_data: bool, indicating if one should use the train or eval data set.
  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    lengths: Length Labels. 1D tensor of [batch_size] size.
    digits: Digit Labels. 2D tensor of [batch_size, 5] size.
  Raises:
    ValueError: If no data_dir
  """
    if not FLAGS.train_tf_records_file or not FLAGS.extra_tf_records_file or not FLAGS.test_tf_records_file:
        raise ValueError('Please supply a tf_records_file')
    if source_data:
        images, lengths, digits = data_loader.inputs(
            filenames=[
                FLAGS.train_tf_records_file, FLAGS.extra_tf_records_file
            ],
            batch_size=FLAGS.batch_size,
            num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN)
    else:
        images, lengths, digits = data_loader.inputs(
            filenames=[FLAGS.test_tf_records_file],
            batch_size=FLAGS.batch_size,
            num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_EVAL,
            shuffle=False)
    return images, lengths, digits
Пример #2
0
def main(_):
    # images, labels = distorted_inputs()
    train_images, train_labels = inputs(batch_size=FLAGS.batch_size,
                                        train=True,
                                        shuffle=True,
                                        num_epochs=None)
    test_images, test_labels = inputs(batch_size=FLAGS.batch_size,
                                      train=False,
                                      shuffle=True,
                                      num_epochs=None)
    train(train_images, train_labels, test_images, test_labels)
Пример #3
0
def train(tfrecord_file, train_dir, batch_size, num_epochs):
    _, vectors, labels = data_loader.inputs([tfrecord_file],
                                            batch_size=batch_size,
                                            num_threads=16,
                                            capacity=batch_size * 4,
                                            min_after_dequeue=batch_size * 2,
                                            num_epochs=num_epochs,
                                            is_training=True)

    loss = model.loss(vectors, labels)

    global_step = tf.Variable(0, name='global_step', trainable=False)

    # Create training op with dependencies on update ops for batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = tf.train.AdamOptimizer(learning_rate=0.001). \
            minimize(loss, global_step=global_step)

    # Create training supervisor to manage model logging and saving
    sv = tf.train.Supervisor(logdir=train_dir,
                             global_step=global_step,
                             save_summaries_secs=60,
                             save_model_secs=600)

    with sv.managed_session() as sess:
        while not sv.should_stop():
            _, loss_out, step_out = sess.run([train_op, loss, global_step])

            if step_out % 100 == 0:
                print('Step {}: Loss {}'.format(step_out, loss_out))
Пример #4
0
def evaluate(): 
  """Evaluate Office 31 for the entire webcam dataset"""
  with tf.Graph().as_default() as g:
    # Get images and labels for Office 31.
    images, labels = data_loader.inputs(FLAGS.tf_records_file,
                                      FLAGS.batch_size,
                                      NUM_EXAMPLES_PER_EPOCH_FOR_EVAL,
                                      False)

    # Build a Graph that computes the logits predictions from the inference model.
    # Initialize model
    model = AlexNet(images, keep_prob, NUM_CLASSES, train_layers, 'source/')

    # Link variable to model output
    logits = model.fc8 
    labels = tf.reshape(labels, [BATCH_SIZE])

    # Calculate predictions.
    top_k_op = tf.nn.in_top_k(logits, labels, 1)
    
    variable_averages = tf.train.ExponentialMovingAverage(
        0.9999)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    
    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.summary.merge_all()
    dirname = os.path.dirname(__file__)
    dest_directory = os.path.join(dirname, 'office31_data/eval')
    summary_writer = tf.summary.FileWriter(dest_directory, g)

    eval_once(saver, summary_writer, top_k_op, summary_op)
Пример #5
0
def evaluate():
    """Evaluate Office 31 for the entire webcam dataset"""
    with tf.Graph().as_default() as g:
        # Get images and labels for Office 31.
        images, labels = data_loader.inputs(FLAGS.tf_records_file,
                                            FLAGS.batch_size,
                                            NUM_EXAMPLES_PER_EPOCH_FOR_EVAL,
                                            False)

        # Build a Graph that computes the logits predictions from the inference model.
        # Initialize model
        model = AlexNet(images, keep_prob, NUM_CLASSES, train_layers,
                        'source/')

        # Link variable to model output
        logits = model.fc8
        labels = tf.reshape(labels, [BATCH_SIZE])

        # Calculate predictions.
        top_k_op = tf.nn.in_top_k(logits, labels, 1)

        variable_averages = tf.train.ExponentialMovingAverage(0.9999)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()
        dirname = os.path.dirname(__file__)
        dest_directory = os.path.join(dirname, 'office31_data/eval')
        summary_writer = tf.summary.FileWriter(dest_directory, g)

        eval_once(saver, summary_writer, top_k_op, summary_op)
Пример #6
0
def inputs(source_data):
  """Construct input for Office31 training and evaluation using the Reader ops.
  Args:
    source_data: bool, indicating if one should use the train or eval data set.
  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  Raises:
    ValueError: If no data_dir
  """
  if not FLAGS.source_tf_records_file or not FLAGS.target_tf_records_file:
    raise ValueError('Please supply a tf_records_file')
  if source_data:
    images, labels = data_loader.inputs(filename=FLAGS.source_tf_records_file,
                                        batch_size=FLAGS.source_batch_size,
                                        num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN)
  else:
    images, labels = data_loader.inputs(filename=FLAGS.target_tf_records_file,
                                        batch_size=FLAGS.target_batch_size,
                                        num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_EVAL)
  return images, labels
Пример #7
0
def inputs(source_data):
    """Construct input for Office31 training and evaluation using the Reader ops.
  Args:
    source_data: bool, indicating if one should use the train or eval data set.
  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  Raises:
    ValueError: If no data_dir
  """
    if not FLAGS.source_tf_records_file or not FLAGS.target_tf_records_file:
        raise ValueError('Please supply a tf_records_file')
    if source_data:
        images, labels = data_loader.inputs(
            filename=FLAGS.source_tf_records_file,
            batch_size=FLAGS.source_batch_size,
            num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN)
    else:
        images, labels = data_loader.inputs(
            filename=FLAGS.target_tf_records_file,
            batch_size=FLAGS.target_batch_size,
            num_examples_per_epoch=NUM_EXAMPLES_PER_EPOCH_FOR_EVAL)
    return images, labels
Пример #8
0
def generate_predictions(tfrecord_file,
                         train_dir,
                         predictions_file,
                         features_file,
                         batch_size,
                         num_k):
    ids, vectors, _ = data_loader.inputs([tfrecord_file], batch_size=batch_size,
                                         num_threads=16, capacity=batch_size*4,
                                         num_epochs=1, is_training=False)

    predictions = model.inference(vectors)
    features = tf.get_default_graph().get_tensor_by_name('fc1/relu:0')

    init_op = tf.local_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init_op)
        saver.restore(sess, tf.train.latest_checkpoint(train_dir))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        with open(predictions_file, 'w') as f1, open(features_file, 'w') as f2:
            f1.write('VideoId,LabelConfidencePairs\n')

            while True:
                try:
                    ids_out, predictions_out, features_out = sess.run(
                        [ids, predictions, features])
                except tf.errors.OutOfRangeError:
                    break

                for i, _ in enumerate(ids_out):
                    f1.write(ids_out[i].decode())
                    f1.write(',')
                    top_k = np.argsort(predictions_out[i])[::-1][:num_k]
                    for j in top_k:
                        f1.write('{} {:5f} '.format(j, predictions_out[i][j]))
                    f1.write('\n')

                    f2.write(ids_out[i].decode())
                    f2.write(',')
                    for j in range(len(features_out[i]) - 1):
                        f2.write('{:6e},'.format(features_out[i][j]))
                    f2.write('{:6e}'.format(features_out[i][-1]))
                    f2.write('\n')

        coord.request_stop()
        coord.join(threads)
Пример #9
0
def test_input_full(config, seqconfig):
    train_data = os.path.join(config.tfrecord_dir, config.train_tfrecords)
    val_data = os.path.join(config.tfrecord_dir, config.test_tfrecords)
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(
            tfrecord_file=train_data,
            num_epochs=config.epochs,
            image_target_size=config.image_orig_size,
            label_shape=config.num_classes,
            batch_size=config.train_batch,
            data_augment=False)
        val_images, val_labels = inputs(
            tfrecord_file=val_data,
            num_epochs=config.epochs,
            image_target_size=config.image_orig_size,
            label_shape=config.num_classes,
            batch_size=1)
        label_shaped = tf.reshape(
            train_labels, [config.train_batch, config.num_classes / 3, 3])
        error = getMeanError(label_shaped, label_shaped)
        val_label_shaped = tf.reshape(val_labels,
                                      [1, config.num_classes / 3, 3])
    with tf.Session() as sess:
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer()))
        step = 0
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                image_np, image_label, train_error = sess.run(
                    [train_images, train_labels, error])
                print step
                val_image_np, val_image_label, val_image_reshaped_label = sess.run(
                    [val_images, val_labels, val_label_shaped])

                if (step > 0):  #and (step <2):

                    # for b in range(config.train_batch):
                    #     im = image_np[b]
                    #     image_com = image_coms[b]
                    #     image_M = image_Ms[b]
                    #
                    #     jts = image_label[b]
                    #     print("shape of jts:{}".format(jts.shape))
                    #     im = im.reshape([128,128])
                    #     check_image_label(im,jts,image_com,image_M,seqconfig['cube'][2] / 2.,allJoints=True,line=False)

                    val_im = val_image_np[0]
                    print("val_im shape:{}".format(val_im.shape))
                    val_jts = val_image_reshaped_label[0]
                    #val_im = val_im.reshape([128, 128])
                    #check_image_label(val_im, val_jts, val_image_com, val_image_M, seqconfig['cube'][2] / 2.,allJoints=True,line=False)
                    show_image(val_im, val_jts)
                step += 1

        except tf.errors.OutOfRangeError:
            print("Done. Epoch limit reached.")
        finally:
            coord.request_stop()
        coord.join(threads)
Пример #10
0
def test_input_tree(config,seqconfig):
    train_data = os.path.join(config.tfrecord_dir, config.train_tfrecords)
    val_data = os.path.join(config.tfrecord_dir, config.val_tfrecords)
    with tf.device('/cpu:0'):
        train_images,train_labels,com3Ds,Ms = inputs(tfrecord_file = train_data,
                                           num_epochs=config.epochs,
                                           image_target_size = config.image_target_size,
                                           label_shape=config.num_classes,
                                           batch_size =config.train_batch)
        val_images, val_labels, val_com3Ds, val_Ms = inputs(tfrecord_file=val_data,
                                                            num_epochs=config.epochs,
                                                            image_target_size=config.image_target_size,
                                                            label_shape=config.num_classes,
                                                            batch_size=1)
        label_shaped = tf.reshape(train_labels,[config.train_batch,config.num_classes/3,3])
        split_lable=tf.split(label_shaped,36,axis=1)
        P_label_shaped = tf.concat(
            [split_lable[0], split_lable[1], split_lable[2], split_lable[3], split_lable[4], split_lable[5],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        R_label_shaped = tf.concat(
            [split_lable[6], split_lable[7], split_lable[8], split_lable[9], split_lable[10], split_lable[11],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        M_label_shaped = tf.concat(
            [split_lable[12], split_lable[13], split_lable[14], split_lable[15], split_lable[16], split_lable[17],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        I_label_shaped = tf.concat(
            [split_lable[18], split_lable[19], split_lable[20], split_lable[21], split_lable[22], split_lable[23],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        T_label_shaped = tf.concat(
            [split_lable[24], split_lable[25], split_lable[26], split_lable[27], split_lable[28],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        P_label = tf.reshape(P_label_shaped, [config.train_batch, P_label_shaped.get_shape().as_list()[1] * 3])
        R_label = tf.reshape(R_label_shaped, [config.train_batch, R_label_shaped.get_shape().as_list()[1] * 3])
        M_label = tf.reshape(M_label_shaped, [config.train_batch, M_label_shaped.get_shape().as_list()[1] * 3])
        I_label = tf.reshape(I_label_shaped, [config.train_batch, I_label_shaped.get_shape().as_list()[1] * 3])
        T_label = tf.reshape(T_label_shaped, [config.train_batch, T_label_shaped.get_shape().as_list()[1] * 3])
        error = getMeanError(label_shaped,label_shaped)
        val_label_shaped = tf.reshape(val_labels, [1, config.num_classes/3, 3])
        val_split_lable = tf.split(val_label_shaped, 36, axis=1)
        val_P_label_shaped = tf.concat(
            [val_split_lable[0], val_split_lable[1], val_split_lable[2], val_split_lable[3], val_split_lable[4],
             val_split_lable[5],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_R_label_shaped = tf.concat(
            [val_split_lable[6], val_split_lable[7], val_split_lable[8], val_split_lable[9], val_split_lable[10],
             val_split_lable[11],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_M_label_shaped = tf.concat(
            [val_split_lable[12], val_split_lable[13], val_split_lable[14], val_split_lable[15], val_split_lable[16],
             val_split_lable[17],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_I_label_shaped = tf.concat(
            [val_split_lable[18], val_split_lable[19], val_split_lable[20], val_split_lable[21], val_split_lable[22],
             val_split_lable[23],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_T_label_shaped = tf.concat(
            [val_split_lable[24], val_split_lable[25], val_split_lable[26], val_split_lable[27], val_split_lable[28],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_P_label = tf.reshape(val_P_label_shaped, [config.val_batch, val_P_label_shaped.get_shape().as_list()[1] * 3])
        val_R_label = tf.reshape(val_R_label_shaped, [config.val_batch, val_R_label_shaped.get_shape().as_list()[1] * 3])
        val_M_label = tf.reshape(val_M_label_shaped, [config.val_batch, val_M_label_shaped.get_shape().as_list()[1] * 3])
        val_I_label = tf.reshape(val_I_label_shaped, [config.val_batch, val_I_label_shaped.get_shape().as_list()[1] * 3])
        val_T_label = tf.reshape(val_T_label_shaped, [config.val_batch, val_T_label_shaped.get_shape().as_list()[1] * 3])
    with tf.Session() as sess:
        sess.run(tf.group(tf.global_variables_initializer(),
                          tf.local_variables_initializer()))
        step =0
        coord=tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                image_np,image_label,image_coms,image_Ms,train_error,P,R,M,I,T = sess.run([train_images,train_labels,com3Ds,Ms,error,P_label,R_label,\
                                                                                 M_label,I_label,T_label])
                print step
                print image_np.shape
                print train_error

                val_image_np, val_image_label, val_image_coms, val_image_Ms,val_P,val_R,val_M,val_I,val_T= sess.run(
                    [val_images, val_labels, val_com3Ds, val_Ms,val_P_label,val_R_label,val_M_label,val_I_label,val_T_label])

                #image = tf.split(image_np, 3, 3)[0]
                #print image.shape
                #print image_label.shape

                if (step > 0) and (step <2):

                    for b in range(config.train_batch):
                        im = image_np[b]
                        print("im shape:{}".format(im.shape))
                        image_com = image_coms[b]
                        image_M = image_Ms[b]
                        #print("shape of im:{}".format(im.shape))
                        jts = image_label[b]
                        im = im.reshape([128,128])
                        check_image_label(im,jts,image_com,image_M,seqconfig['cube'][2] / 2.,allJoints=True,line=True)
                        check_image_label(im,P[b],image_com,image_M,seqconfig['cube'][2] / 2.,line=False)
                        check_image_label(im, R[b], image_com, image_M, seqconfig['cube'][2] / 2., line=False)
                        check_image_label(im, M[b], image_com, image_M, seqconfig['cube'][2] / 2., line=False)
                        check_image_label(im, I[b], image_com, image_M, seqconfig['cube'][2] / 2., line=False)
                        check_image_label(im, T[b], image_com, image_M, seqconfig['cube'][2] / 2., line=False)

                    val_im = val_image_np[0]
                    print("val_im shape:{}".format(val_im.shape))
                    val_image_com = val_image_coms[0]
                    val_image_M = val_image_Ms[0]
                    # print("shape of im:{}".format(im.shape))
                    val_jts = val_image_label[0]
                    val_im = val_im.reshape([128, 128])
                    check_image_label(val_im, val_jts, val_image_com, val_image_M, seqconfig['cube'][2] / 2.,allJoints=True,line=True)
                    check_image_label(val_im, val_P[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,line=False)
                    check_image_label(val_im, val_R[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,
                                      line=False)
                    check_image_label(val_im, val_M[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,
                                      line=False)
                    check_image_label(val_im, val_I[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,
                                      line=False)
                    check_image_label(val_im, val_T[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,
                                      line=False)
                step += 1

        except tf.errors.OutOfRangeError:
            print("Done. Epoch limit reached.")
        finally:
            coord.request_stop()
        coord.join(threads)
Пример #11
0
def train_model(config,seqconfig):
    md = tf_monkeydetector.tfMonkeyDetector(365.456,365.456,256,212,[800,800,1200],200,10000)
    phaseII = False

    train_data = os.path.join(config.tfrecord_dir, config.train_tfrecords)
    val_data = os.path.join(config.tfrecord_dir, config.val_tfrecords)
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(tfrecord_file = train_data,
                                           num_epochs=config.epochs,
                                           image_target_size = config.image_orig_size,
                                           label_shape=config.num_classes,
                                           batch_size =config.train_batch,
                                           data_augment=False)
        val_images, val_labels = inputs(tfrecord_file=val_data,
                                                            num_epochs=config.epochs,
                                                            image_target_size=config.image_orig_size,
                                                            label_shape=config.num_classes,
                                                            batch_size=config.val_batch)

    with tf.device('/gpu:0'):
        with tf.variable_scope("cnn") as scope:
            # place holders for training and validation!
            crop_input = tf.placeholder(tf.float32,
                                        [None, config.image_target_size[0], config.image_target_size[1],
                                         config.image_target_size[2]], name='patch_placeholder')
            crop_labels = tf.placeholder(tf.float32,
                                        [None, config.num_joints*config.num_dims], name='labels_placeholder')

            val_crop_input = tf.placeholder(tf.float32,
                                        [None, config.image_target_size[0], config.image_target_size[1],
                                         config.image_target_size[2]], name='patch_placeholder')
            val_crop_labels = tf.placeholder(tf.float32,
                                         [None, config.num_joints * config.num_dims], name='labels_placeholder')

            print("create training graphs:")
            ##########
            # PHASE I: attention component to locate the center of mass of the monkey
            ##########
            # build the model
            attn_model = attn_model_struct()
            train_images = train_images / config.image_max_depth
            attn_model.build(train_images,config.num_dims,train_mode=True)
            # calculate the 3d center of mass and project to image coordinates and normalize!
            train_labels_reshaped = tf.reshape(train_labels, [config.train_batch, config.num_classes / 3, 3])
            com2d = md.calculateCoMfrom3DJoints(train_labels_reshaped) / [config.image_orig_size[0],config.image_orig_size[1],config.image_max_depth]
            # define the loss and optimization functions
            attn_loss = tf.nn.l2_loss(attn_model.out_put - com2d)
            attn_train_op = tf.train.AdamOptimizer(1e-4).minimize(attn_loss)
            # if config.wd_penalty is None:
            #     attn_train_op = tf.train.AdamOptimizer(1e-4).minimize(attn_loss)
            # else:
            #     attn_wd_l = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'biases' not in v.name]
            #     attn_loss_wd=attn_loss+(config.wd_penalty * tf.add_n([tf.nn.l2_loss(x) for x in attn_wd_l]))
            #     attn_train_op = tf.train.AdamOptimizer(1e-4).minimize(attn_loss_wd)

            attn_train_results_shaped = tf.reshape(attn_model.out_put, [config.val_batch, config.num_dims])
            attn_train_error = calc_com_error(com2d,attn_train_results_shaped)

            ##########
            # PHASE II: intrinsic pose estimation component
            ##########
            # build the second model
            model=cnn_model_struct()
            model.build(crop_input,config.num_classes,train_mode=True)
            loss=tf.nn.l2_loss(model.out_put-crop_labels)
            train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)
            # if config.wd_penalty is None:
            #     train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)
            # else:
            #     wd_l = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if 'biases' not in v.name]
            #     loss_wd=loss+(config.wd_penalty * tf.add_n([tf.nn.l2_loss(x) for x in wd_l]))
            #     train_op = tf.train.AdamOptimizer(1e-4).minimize(loss_wd)

            crop_labels_shaped=tf.reshape(crop_labels,[config.train_batch,config.num_classes/3,3])*seqconfig['cube'][2] / 2.
            train_results_shaped=tf.reshape(model.out_put,[config.train_batch,config.num_classes/3,3])*seqconfig['cube'][2] / 2.
            train_error = getMeanError_train(crop_labels_shaped,train_results_shaped)

            print ('using validation!')
            scope.reuse_variables()
            attn_val_model = attn_model_struct()
            val_images = val_images / config.image_max_depth
            attn_val_model.build(val_images,config.num_dims,train_mode=False)
            attn_val_labels_shaped = tf.reshape(val_labels,[config.val_batch,config.num_classes/3,3])
            #attn_val_results_shaped = tf.reshape(attn_val_model.out_put,[config.val_batch,config.num_dims])*[config.image_orig_size[0],config.image_orig_size[1],config.image_max_depth]
            attn_val_results_shaped = tf.reshape(attn_val_model.out_put, [config.val_batch, config.num_dims])
            val_com2d = md.calculateCoMfrom3DJoints(attn_val_labels_shaped) / [config.image_orig_size[0],config.image_orig_size[1],config.image_max_depth]
            attn_val_error = calc_com_error(val_com2d, attn_val_results_shaped)

            val_model=cnn_model_struct()
            val_model.build(val_crop_input,config.num_classes,train_mode=False)
            val_crop_labels_shaped = tf.reshape(val_crop_labels, [config.val_batch, config.num_classes / 3, 3])*seqconfig['cube'][2] / 2.
            val_results_shaped = tf.reshape(val_model.out_put, [config.val_batch, config.num_classes / 3, 3])*seqconfig['cube'][2] / 2.
            val_error = getMeanError_train(val_crop_labels_shaped, val_results_shaped)

            tf.summary.scalar("attention_loss", attn_loss)
            tf.summary.scalar("train error", attn_train_error)
            tf.summary.scalar("validation error", attn_val_error)

            if phaseII:
                tf.summary.scalar("pose_loss", loss)
                #if config.wd_penalty is not None:
                #    tf.summary.scalar("pose_loss_wd", loss_wd)
                tf.summary.scalar("pose_train error", train_error)
                tf.summary.scalar("pose_validation error", val_error)

            summary_op = tf.summary.merge_all()

        saver = tf.train.Saver(tf.global_variables())

    # Initialize the graph
    gpuconfig = tf.ConfigProto()
    gpuconfig.gpu_options.allow_growth = True
    gpuconfig.allow_soft_placement = True
    first_v=True
    with tf.Session(config=gpuconfig) as sess:
        summary_writer = tf.summary.FileWriter(config.train_summaries, sess.graph)
        sess.run(tf.group(tf.global_variables_initializer(),
                          tf.local_variables_initializer()))
        step =0
        coord=tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                if step < config.num_attn_steps:
                    _, image_np, image_label, tr_error, tr_loss, tr_res, tr_com = sess.run(
                        [attn_train_op, train_images, train_labels, attn_train_error, attn_loss, attn_train_results_shaped, com2d])
                    print("step={},loss={},error={}".format(step, tr_loss, tr_error))

                    if step % 250 ==0:
                        val_image_np, val_image_label, v_error, v_com, v_res= sess.run([val_images, val_labels, attn_val_error, val_com2d, attn_val_results_shaped])
                        print("     val error={}".format(v_error))
                        summary_str=sess.run(summary_op)
                        summary_writer.add_summary(summary_str,step)
                else:
                    phaseII = True
                    # get results from phase one for the center of mass
                    image_np, image_label_shaped, tr_error, tr_loss, tr_res, tr_com = sess.run(
                        [train_images, train_labels_reshaped, attn_train_error, attn_loss,
                         attn_train_results_shaped, com2d])

                    patches, rel_labels = prepare_data(image_np,image_label_shaped,tr_res,md,config)
                    #import ipdb; ipdb.set_trace()
                    # _, tr_error, tr_loss, tr_loss_wd = sess.run(
                    #     [train_op, train_error, loss, loss_wd],feed_dict={crop_input:patches,crop_labels:rel_labels})
                    _, tr_error, tr_loss = sess.run(
                         [train_op, train_error, loss],feed_dict={crop_input:patches,crop_labels:rel_labels})
                    print("step={},loss={},error={} mm".format(step, tr_loss, tr_error))

                    if step % 1000 == 0:
                        # run validation
                        val_image_np, val_image_label_shaped, _, val_res, _ = sess.run(
                            [val_images, attn_val_labels_shaped, attn_val_error,
                             attn_val_results_shaped, val_com2d])

                        patches, rel_labels = prepare_data(val_image_np, val_image_label_shaped, val_res, md, config, show=True)
                        v_error = sess.run([val_error],feed_dict={val_crop_input: patches, val_crop_labels: rel_labels})
                        print("step={},error={} mm".format(step, v_error))

                        # save the model checkpoint if it's the best yet
                        if first_v is True:
                            val_min = v_error
                            first_v = False
                        else:
                            if v_error < val_min:
                                print(os.path.join(
                                    config.model_output,
                                    'attn_cnn_model' + str(step) +'.ckpt'))
                                saver.save(sess, os.path.join(
                                    config.model_output,
                                    'attn_cnn_model' + str(step) +'.ckpt'), global_step=step)
                                # store the new max validation accuracy
                                val_min = v_error

                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, step)

                step += 1
        except tf.errors.OutOfRangeError:
            print("Done. Epoch limit reached.")
        finally:
            coord.request_stop()
        coord.join(threads)
Пример #12
0
def test_model(config,seqconfig):
    #test_data = os.path.join(config.tfrecord_dir, config.test_tfrecords)
    test_data = os.path.join(config.tfrecord_dir, config.train_tfrecords)
    
    md = tf_monkeydetector.tfMonkeyDetector(365.456, 365.456, 256, 212, [800, 800, 1200], 200, 10000)
    print test_data
    with tf.device('/cpu:0'):
        images, labels = inputs(tfrecord_file=test_data,
                                            num_epochs=None,
                                            image_target_size=config.image_orig_size,
                                            label_shape=config.num_classes,
                                            batch_size=config.test_batch)

    with tf.device('/gpu:0'):
        with tf.variable_scope("cnn") as scope:

            crop_input = tf.placeholder(tf.float32,
                                        [None, config.image_target_size[0], config.image_target_size[1],
                                         config.image_target_size[2]], name='patch_placeholder')
            crop_labels = tf.placeholder(tf.float32,
                                         [None, config.num_joints * config.num_dims], name='labels_placeholder')

            attn_model = attn_model_struct()
            images = images / config.image_max_depth
            attn_model.build(images, config.num_dims, train_mode=True)
            attn_results_shaped = tf.reshape(attn_model.out_put, [1, config.num_dims])

            model=cnn_model_struct()
            model.build(crop_input, config.num_classes, train_mode=False)
            labels_shaped = tf.reshape(labels, [(config.num_classes / 3), 3]) * \
                                seqconfig['cube'][2] / 2.
            results_shaped = tf.reshape(model.out_put, [(config.num_classes / 3), 3]) * \
                                 seqconfig['cube'][2] / 2.
            #error = getMeanError(labels_shaped, results_shaped)

        # Initialize the graph
        gpuconfig = tf.ConfigProto()
        gpuconfig.gpu_options.allow_growth = True
        gpuconfig.allow_soft_placement = True
        saver = tf.train.Saver()

        with tf.Session(config=gpuconfig).as_default() as sess:
            sess.run(tf.group(tf.global_variables_initializer(),
                              tf.local_variables_initializer()))
            step=0
            coord = tf.train.Coordinator()
            threads=tf.train.start_queue_runners(coord=coord,sess=sess)

            checkpoints = tf.train.latest_checkpoint(config.model_output)
            saver.restore(sess, checkpoints)

            try:
                while not coord.should_stop():
                    #import ipdb; ipdb.set_trace()
                    images_np,results_com = sess.run([images,attn_results_shaped])
                    patches,coms,Ms = prepare_data_test(images_np, results_com, md, config)

                    t_res = sess.run([results_shaped], feed_dict={crop_input: patches})
                    #print("step={}, test error={} mm".format(step,joint_error))
                    print("step={}".format(step))

                    #if step%100 ==0:
                    result_name="{}image_{}.png".format(config.results_dir,step)
                    retrieved_jnts_xyz, retrieved_jnts_uvd = md.getAbsoluteCoordinates(t_res[0], coms[0])
                    plt.imshow(images_np[0].squeeze())
                    plt.scatter(retrieved_jnts_uvd[:, 0], retrieved_jnts_uvd[:, 1], c='r')
                    plt.savefig(result_name)
                    plt.show()

                    step+=1
            except tf.errors.OutOfRangeError:
                print("Done.Epoch limit reached.")
            finally:
                coord.request_stop()
            coord.join(threads)
            print("load model from {}".format(checkpoints))