def tower_loss(scope, images, labels, phase_train_placeholder, args):

    logits, net_points = inference(images, phase_train=phase_train_placeholder, weight_decay=args.weight_decay)

    embeddings = tf.nn.l2_normalize(logits, 1, 1e-10, name='embeddings')

    # Norm for the prelogits
    eps = 1e-4
    prelogits_norm = tf.reduce_mean(tf.norm(tf.abs(logits) + eps, ord=1.0, axis=1))
    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, prelogits_norm * 5e-5)

    inference_loss, logit = cos_loss(logits, labels, args.num_output)

    # calculate accuracy
    pred = tf.nn.softmax(logit)
    correct_prediction = tf.cast(tf.equal(tf.argmax(pred, 1), tf.cast(labels, tf.int64)), tf.float32)
    accuracy_op = tf.reduce_mean(correct_prediction)

    tf.add_to_collection('losses', inference_loss)

    # total losses
    regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    losses = tf.get_collection('losses', scope)
    total_loss = tf.add_n(losses + regularization_losses, name='total_loss')

    # tf.summary.scalar('regularization_losses', regularization_losses)
    # tf.summary.scalar('inference_loss', inference_loss)
    # tf.summary.scalar('total_loss', total_loss)

    with tf.device('/cpu:0'):
        for l in losses + [total_loss]:
            loss_name = re.sub('tower_[0-9]*/', '', l.op.name)
            tf.summary.scalar(loss_name, l)
    return total_loss, embeddings, accuracy_op
示例#2
0
                info = '{}:{}\n'.format(key,
                                        net_points[key].get_shape().as_list())
                hd.write(info)
            hd.close()

        embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')

        # Norm for the prelogits
        eps = 1e-4
        prelogits_norm = tf.reduce_mean(
            tf.norm(tf.abs(prelogits) + eps, ord=args.prelogits_norm_p,
                    axis=1))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                             prelogits_norm * args.prelogits_norm_loss_factor)

        inference_loss, logit = cos_loss(prelogits, labels, args.num_output)
        # inference_loss, logit = combine_loss(prelogits, labels, args.num_output)

        tf.add_to_collection('losses', inference_loss)

        # total losses
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([inference_loss] + regularization_losses,
                              name='total_loss')

        # define the learning rate schedule
        learning_rate = tf.train.piecewise_constant(
            epoch,
            boundaries=args.lr_schedule,
            values=[0.1, 0.01, 0.001, 0.0001, 0.00001],