Exemplo n.º 1
0
def main(args):
    # load the dataset
    dataset = cifar10.get_split('train', FLAGS.data_dir)

    # load batch of dataset
    images, labels = load_batch(dataset, FLAGS.batch_size, is_training=True)

    network_fn = nets_factory.get_network_fn("cifarnet",
                                             num_classes=10,
                                             is_training=True)
    # run the image through the model
    #    predictions,_ = lenet.lenet(images)
    predictions, _ = network_fn(images)
    #    slim.model_analyzer.analyze_ops(tf.get_default_graph(), print_info=True)
    variables = slim.get_model_variables()
    for var in variables:
        tf.summary.histogram(var.op.name, var)
    slim.model_analyzer.analyze_vars(variables, print_info=True)
    # get the cross-entropy loss
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
    tf.losses.softmax_cross_entropy(one_hot_labels, predictions)
    total_loss = tf.losses.get_total_loss()
    tf.summary.scalar('loss', total_loss)

    # use RMSProp to optimize
    optimizer = tf.train.AdamOptimizer(0.001)
    # create train op
    train_op = slim.learning.create_train_op(total_loss, optimizer)

    # run training
    slim.learning.train(train_op,
                        FLAGS.log_dir,
                        save_summaries_secs=20,
                        save_interval_secs=60 * 2)
Exemplo n.º 2
0
def show_img():
    kaggle_test = "/tmp/cifar10"
    dataset = cifar10.get_split('kaggle', kaggle_test)
    data_provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset, shuffle=False)
    image, label = data_provider.get(['image', 'label'])
    images, labels = tf.train.batch([image, label],
                                    batch_size=1,
                                    allow_smaller_final_batch=True)
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        images_, labels_ = sess.run([images, labels])
        print(images_.shape)
        images_ = tf.squeeze(images_).eval()
        print(images_.shape, images_.dtype)
        print(images_, labels_)
        plt.imshow(images_)
        coord.request_stop()
        coord.join(threads)
Exemplo n.º 3
0
    image_raw = tf.image.resize_images(image_raw, [height, width])
    image_raw = tf.squeeze(image_raw)

    # Batch it up.
    images, images_raw, labels = tf.train.batch(
          [image, image_raw, label],
          batch_size=batch_size,
          num_threads=1,
          capacity=2 * batch_size)
    
    return images, images_raw, labels

with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)
    
    dataset = cifar10.get_split('test', data_dir)
    images, images_raw, labels = load_batch(dataset, height=image_size, width=image_size)
    
    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True)

    probabilities = tf.nn.softmax(logits)
    
    checkpoint_path = tf.train.latest_checkpoint(train_dir)
    init_fn = slim.assign_from_checkpoint_fn(
      checkpoint_path,
      slim.get_variables_to_restore())
    
    with tf.Session() as sess:
        with slim.queues.QueueRunners(sess):
Exemplo n.º 4
0
def main(args):
    tf.logging.set_verbosity(tf.logging.DEBUG)
    # load the dataset
    dataset = cifar10.get_split('test', FLAGS.data_dir)

    # load batch
    images, labels = load_batch(
        dataset,
        FLAGS.batch_size,
        is_training=False)
    print(images,labels)
    # get the model prediction
    network_fn =nets_factory.get_network_fn("cifarnet",num_classes= 10,is_training=False)
    # run the image through the model
#    predictions,_ = lenet.lenet(images)
    predictions,_ = network_fn(images)
    # convert prediction values for each class into single class prediction
    predictions = tf.to_int64(tf.argmax(predictions, 1))

    # streaming metrics to evaluate
    metrics_to_values, metrics_to_updates = metrics.aggregate_metric_map({
        'mse': metrics.streaming_mean_squared_error(predictions, labels),
        'accuracy': metrics.streaming_accuracy(predictions, labels),
#        'Recall_3': slim.metrics.streaming_recall_at_k(predictions, labels, 3),
    })

    # write the metrics as summaries
    for metric_name, metric_value in metrics_to_values.items():
        summary_name = 'eval/%s' % metric_name
        tf.summary.scalar(summary_name, metric_value)
        
#    for name, value in metrics_to_values.items():
#        summary_name = 'eval/%s' % name
#        op = tf.summary.scalar(summary_name, value, collections=[])
#        op = tf.Print(op, [value], summary_name)
#        tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
        
        

    # evaluate on the model saved at the checkpoint directory
    # evaluate every eval_interval_secs
#    slim.evaluation.evaluation_loop(
#        '',
#        FLAGS.checkpoint_dir,
#        FLAGS.log_dir,
#        num_evals=FLAGS.num_evals,
#        eval_op=list(metrics_to_updates.values()),
#        eval_interval_secs=FLAGS.eval_interval_secs)

    checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
    num_batches = math.ceil(10000 / float(FLAGS.batch_size))
    metric_values =slim.evaluation.evaluate_once(
        master ='',
        checkpoint_path =checkpoint_path,
        logdir =FLAGS.log_dir,
        num_evals=num_batches,
        eval_op=list(metrics_to_updates.values()),
        final_op=list(metrics_to_values.values()) 
        )
    for metric, value in zip(metrics_to_values.keys(), metric_values):
     print("%s: %f" %(metric, value))
Exemplo n.º 5
0
def train_model(cifar10_data_dir,
                train_dir,
                checkpoints_dir,
                model_starting_ckpt,
                lr=0.0001,
                steps=200000,
                lower_lr_every_x_steps=50000):
    """
    train an inception model on the cifar10 dataset

    :param cifar10_data_dir:
    :param train_dir:
    :param checkpoints_dir:
    :param model_starting_ckpt:
    :param lr:
    :param steps:
    :param lower_lr_every_x_steps:
    :return:
    """

    image_size = inception.inception_v1.default_image_size  #taken from tfslim.inception-def, but is 224

    loops = int(steps / lower_lr_every_x_steps)

    for i in range(loops):

        step_target = (i + 1) * lower_lr_every_x_steps

        print("learning rate is " + str(lr) + " and step_target is " +
              str(step_target))

        with tf.Graph().as_default():
            tf.logging.set_verbosity(tf.logging.INFO)

            dataset = cifar10.get_split('train', cifar10_data_dir)
            images, _, labels = load_batch(dataset,
                                           height=image_size,
                                           width=image_size)

            val_dataset = cifar10.get_split('test', cifar10_data_dir)
            val_images, _, val_labels = load_batch(val_dataset,
                                                   height=image_size,
                                                   width=image_size)

            # Create the model, use the default arg scope to configure the batch norm parameters.
            with slim.arg_scope(inception.inception_v1_arg_scope()):
                logits, _ = inception.inception_v1(
                    images, num_classes=dataset.num_classes, is_training=True)
                val_logits, _ = inception.inception_v1(
                    val_images,
                    num_classes=dataset.num_classes,
                    is_training=False,
                    reuse=True)

            # Specify the loss function:
            one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
            slim.losses.softmax_cross_entropy(logits, one_hot_labels)
            total_loss = slim.losses.get_total_loss()

            # Specify the `validation` loss function:
            val_one_hot_labels = slim.one_hot_encoding(val_labels,
                                                       dataset.num_classes)
            val_loss = tf.losses.softmax_cross_entropy(
                val_one_hot_labels, val_logits, loss_collection="validation")

            # Create some summaries to visualize the training process:
            tf.summary.scalar('losses/Total Loss', total_loss)
            tf.summary.scalar('validation/Validation_Loss', val_loss)

            # Specify the optimizer and create the train op:
            optimizer = tf.train.AdamOptimizer(learning_rate=lr)
            train_op = slim.learning.create_train_op(total_loss, optimizer)

            my_saver = tf.train.Saver(max_to_keep=50)

            # Run the training:
            final_loss = slim.learning.train(train_op,
                                             logdir=train_dir,
                                             init_fn=get_init_fn(
                                                 checkpoints_dir,
                                                 model_starting_ckpt),
                                             number_of_steps=step_target,
                                             saver=my_saver,
                                             save_summaries_secs=60,
                                             save_interval_secs=1200)

        print('Finished training. Last batch loss %f' % final_loss)

        lr = lr / 10
        step_target += lower_lr_every_x_steps
Exemplo n.º 6
0
def run_adaptive_training():
    with tf.Graph().as_default():
        with tf.Session() as sess:
            tf.logging.info('Setup')

            tf_global_step = tf.train.get_or_create_global_step()
            p_images = tf.placeholder(tf.float32,
                                      shape=(None,
                                             image_size[FLAGS.dataset_name],
                                             image_size[FLAGS.dataset_name],
                                             channels[FLAGS.dataset_name]))
            p_labels = tf.placeholder(tf.int64,
                                      shape=(None,
                                             num_classes[FLAGS.dataset_name]))
            p_emb_idx = tf.placeholder(tf.int32,
                                       shape=(FLAGS.batch_size / 2, ))
            p_assign_idx = tf.placeholder(tf.int32)

            # Data for eval. Currently hardcoded for cifar
            eval_data_provider = dataset_data_provider.DatasetDataProvider(
                cifar10.get_split('test', '.'),
                common_queue_capacity=2 * FLAGS.batch_size,
                common_queue_min=FLAGS.batch_size)
            e_image, e_label = eval_data_provider.get(['image', 'label'])
            e_image = tf.to_float(e_image)  # TODO this is a hack
            eval_images, eval_labels = tf.train.batch(
                [e_image, e_label],
                batch_size=FLAGS.batch_size,
                num_threads=1,
                capacity=5 * FLAGS.batch_size,
                allow_smaller_final_batch=True)

            with tf.device('/cpu:0'):
                dataloader = adamb_data_loader.adamb_data_loader(
                    FLAGS.dataset_name,
                    decay=FLAGS.decay,
                    loss_scaling=FLAGS.loss_scaling)

            if FLAGS.model == 'resnet':
                with slim.arg_scope(resnet_v2.resnet_arg_scope()):
                    logits, end_points = resnet_v2.resnet_v2_50(
                        p_images,
                        num_classes[FLAGS.dataset_name],
                        is_training=True,
                        global_pool=True)
                    embeddings = end_points[
                        'global_pool']  # size (BATCH,1,1,2048) # this doesn't work tf1.4. using batch
                    predictions = end_points['predictions']
                    predictions = tf.argmax(predictions, 1)

            if FLAGS.model == 'inception':
                with slim.arg_scope(inception.inception_v1_arg_scope()):
                    print(num_classes)
                    logits, end_points = inception.inception_v1(
                        p_images,
                        num_classes[FLAGS.dataset_name],
                        is_training=True,
                        global_pool=True)
                    embeddings = end_points['global_pool']
                    predictions = tf.argmax(logits, 1)

            embeddings = tf.squeeze(embeddings, [1, 2])
            tf.logging.debug("embeddings size: ", embeddings.shape)

            sample_losses = tf.losses.softmax_cross_entropy(
                logits=logits, onehot_labels=p_labels)
            # sample_losses = tf.losses.sparse_softmax_cross_entropy(logits=logits,
            #                                                 labels=p_labels,
            #                                                 loss_collection=None)
            # tf.losses.sparse_softmax_cross_entropy(
            #     labels=p_labels, logits=logits, weights=1.0)

            # Try total loss with sample loss commented out. also something is
            # happening here that is making softmax super slow...

            # total_loss = tf.losses.get_total_loss()

            optimizer = _get_optimizer(FLAGS.opt)

            # train_op = optimizer.minimize(tf.reduce_mean(sample_losses), global_step=tf_global_step)  # sample_losses)
            train_op = optimizer.minimize(
                sample_losses, global_step=tf_global_step)  # sample_losses)
            # train_op = slim.learning.create_train_op(total_loss, optimizer)
            tf.logging.info('Model + training setup')

            embedding_list = tf.get_variable(
                'EmbeddingList',
                shape=[num_train_samples[FLAGS.dataset_name], 2048],
                initializer=tf.random_normal_initializer(mean=0.3, stddev=0.6),
                trainable=False)

            b = tf.gather(embedding_list, p_emb_idx, axis=0)
            c = tf.matmul(
                b, embedding_list,
                transpose_b=True)  # this transpose could be backwards
            squared_euclid = tf.transpose(
                tf.transpose(
                    tf.reduce_sum(tf.square(embedding_list), axis=1) - 2 * c) +
                tf.reduce_sum(tf.square(b), axis=1)
            )  # TODO check this, last term may be incorrect

            if FLAGS.pot_func == 'sq_recip':
                recip_squared_euclid = tf.reciprocal(
                    squared_euclid + FLAGS.recip_scale)  # hyperparam fix infs
                potential = recip_squared_euclid
            else:
                neg_exp_euclid = tf.exp(-FLAGS.recip_scale * squared_euclid /
                                        1000)
                potential = neg_exp_euclid

            m, n = potential.get_shape().as_list()
            class_starts = dataloader.class_starts

            def get_mask(class_starts, labels, batch_size):
                labels_mask = np.ones(shape=(batch_size,
                                             50000))  # fix these hard codes
                static_range = 5000  # fix these hard codes
                # class_starts = np.asarray(class_starts)  # Possibly relevant
                mins = class_starts[labels]
                mask_idxs = mins[..., None] + np.arange(static_range)
                labels_mask[np.expand_dims(np.arange(batch_size), 1),
                            mask_idxs] = 0.0
                return labels_mask

            labels_mask = tf.py_func(get_mask, [
                class_starts,
                tf.argmax(p_labels, axis=1),
                tf.cast(FLAGS.batch_size / 2, tf.int32)
            ], tf.double)  #tf.int32)  # TODO last term may be wrong

            diverse_dist = tf.multiply(potential,
                                       tf.cast(labels_mask, tf.float32))

            cumm_array = tf.cumsum(diverse_dist, axis=1)
            max_cumm_array = tf.reduce_max(cumm_array, axis=1)
            bin_min = tf.cumsum(max_cumm_array + 1, exclusive=True)
            cumm_array = tf.expand_dims(bin_min, 1) + cumm_array
            scaled_seed = bin_min + tf.multiply(
                tf.random_uniform([
                    tf.cast(FLAGS.batch_size / 2, tf.int32),
                ]), max_cumm_array)
            scaled_seed_idx = tf.py_func(
                searchsortedtf, [tf.reshape(cumm_array, [-1]), scaled_seed],
                tf.int64)
            pair_idx_tensor = tf.cast(scaled_seed_idx,
                                      tf.int32) - tf.range(m) * n

            # Embedding update
            emb_update_op = tf.scatter_nd_update(
                embedding_list, tf.expand_dims(p_assign_idx, 1), embeddings)

            # predictions = tf.squeeze(end_points['predictions'], [1, 2], name='SpatialSqueeze')
            # predictions = tf.argmax(predictions, 1)
            accuracy_op = tf.reduce_mean(
                tf.to_float(tf.equal(predictions, tf.argmax(p_labels, 1))))

            tf.summary.scalar('Train_Accuracy', accuracy_op)
            tf.summary.scalar('Total_Loss', sample_losses)  #total_loss)
            tf.summary.image('input', p_images)
            slim.summaries.add_histogram_summaries(slim.get_model_variables())
            # slim.summaries.add_histogram_summaries()
            summary_writer = tf.summary.FileWriter(FLAGS.train_log_dir,
                                                   sess.graph)
            merged_summary_op = tf.summary.merge_all()
            saver = tf.train.Saver()

            tf.logging.info('Savers')
            sess.run(tf.global_variables_initializer())

            # Training loop
            for step in range(FLAGS.max_steps):
                # TODO Still need to modify this to do more than just half-batches
                # TODO should be split up into functions that return images/labels from idx and a separate function that finds those idxs, not one in the same

                if FLAGS.method == 'pairwise':
                    images, _, labels, sample_idxs = dataloader.load_batch(
                        batch_size=int(FLAGS.batch_size / 2),
                        method=FLAGS.method)
                    # print('images', images.shape, 'labels', labels.shape, 'sample_idxs', sample_idxs.shape)
                    pair_idxs = sess.run(pair_idx_tensor,
                                         feed_dict={
                                             p_emb_idx: sample_idxs,
                                             p_labels: labels
                                         })
                    if FLAGS.debug:
                        np_diverse_dist = sess.run(diverse_dist,
                                                   feed_dict={
                                                       p_emb_idx: sample_idxs,
                                                       p_labels: labels
                                                   })
                        print('np_diverse_dist: ', np_diverse_dist)
                    image_pairs, label_pairs = dataloader.get_data_from_idx(
                        pair_idxs)
                    images = np.concatenate((images, image_pairs), axis=0)
                    labels = np.concatenate((labels, label_pairs), axis=0)
                    sample_idxs = np.append(sample_idxs, pair_idxs)

                    _, losses, acc, batch_embeddings, _, summary = sess.run(
                        [
                            train_op, sample_losses, accuracy_op, embeddings,
                            emb_update_op, merged_summary_op
                        ],
                        feed_dict={
                            p_images: images,
                            p_labels: labels,
                            p_assign_idx: sample_idxs
                        })
                    # _, losses, batch_embeddings, summary = sess.run([train_op, sample_losses, embeddings, merged_summary_op],
                    #                                                    feed_dict={p_images: images, p_labels: labels, p_assign_idx: sample_idxs})

                else:
                    images, _, labels, sample_idxs = dataloader.load_batch(
                        batch_size=FLAGS.batch_size, method=FLAGS.method)
                    # print('images', images.shape, 'labels', labels.shape, 'sample_idxs', sample_idxs.shape)
                    _, losses, acc, summary = sess.run([
                        train_op, sample_losses, accuracy_op, merged_summary_op
                    ],
                                                       feed_dict={
                                                           p_images: images,
                                                           p_labels: labels
                                                       })

                tf.logging.debug('loss: ' + str(losses.mean()))

                if FLAGS.method == 'singleton':
                    dataloader.update(FLAGS.method,
                                      sample_idxs,
                                      metrics={'losses': losses})
                if FLAGS.method == 'pairwise':
                    dataloader.update(FLAGS.method,
                                      sample_idxs,
                                      metrics={
                                          'losses': losses,
                                          'batch_embeddings': batch_embeddings
                                      })

                if ((step + 1) % FLAGS.save_summary_iters == 0
                        or (step + 1) == FLAGS.max_steps):
                    tf.logging.info('Iteration %d complete', step)
                    summary_writer.add_summary(summary, step)
                    label_names = np.argmax(labels, 1)
                    # print(type(label_names))
                    # print(cifar_classes[label_names])
                    print(acc)
                    summary_writer.flush()
                    tf.logging.info('loss: ' + str(losses.mean()))
                    tf.logging.debug('Summary Saved')

                if ((step + 1) % FLAGS.save_model_iters == 0
                        or (step + 1) == FLAGS.max_steps):
                    checkpoint_file = join(FLAGS.train_log_dir, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_file,
                               global_step=tf_global_step)
                    tf.logging.debug('Model Saved')
def _configure_learning_rate(num_samples_per_epoch, global_step):

    decay_steps = int(num_samples_per_epoch * FLAGS.num_epochs_per_decay /
                      FLAGS.batch_size)

    return tf.train.exponential_decay(FLAGS.learning_rate,
                                      global_step,
                                      decay_steps,
                                      FLAGS.learning_rate_decay_factor,
                                      staircase=True,
                                      name='exponential_decay_learning_rate')


# prepare images and lables for cifar10 dataset
dataset = cifar10.get_split('train', FLAGS.data_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
    dataset,
    num_readers=FLAGS.num_readers,
    common_queue_capacity=20 * FLAGS.batch_size,
    common_queue_min=10 * FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])

images, labels = tf.train.batch([image, label],
                                batch_size=FLAGS.batch_size,
                                num_threads=FLAGS.num_preprocessing_threads,
                                capacity=5 * FLAGS.batch_size)

labels = slim.one_hot_encoding(labels, dataset.num_classes)

batch_queue = slim.prefetch_queue.prefetch_queue([images, labels], capacity=2)
slim = tf.contrib.slim

# set the model and data path here
tf.app.flags.DEFINE_string(
    'train_dir', '/tmp/tfmodel/',
    'Directory where checkpoints and event logs are written to and load from.')

tf.app.flags.DEFINE_string('data_dir', '/tmp/tfdata/', 'Directory of dataset.')

tf.app.flags.DEFINE_string('visual_dir', '/tmp/tfvisual/',
                           'Directory of visualization results.')

FLAGS = tf.app.flags.FLAGS

# prepare images and lables for cifar10 dataset
dataset = cifar10.get_split('test', FLAGS.data_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
[image, label] = provider.get(['image', 'label'])

images = tf.expand_dims(image, 0)
labels = tf.expand_dims(label, 0)

labels = slim.one_hot_encoding(labels, dataset.num_classes)

images = tf.cast(images, tf.float32)
# rerange images to [-4, 4] as the training did
images = (images - 127) / 128 * 4

model = Mymodel(dropout=1.0)
model.build(images, labels, train_mode=False)
Exemplo n.º 9
0
    for var in slim.get_model_variables():
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                break
        else:
            variables_to_restore.append(var)

    return slim.assign_from_checkpoint_fn(
        os.path.join(checkpoints_dir, 'inception_v1.ckpt'),
        variables_to_restore)


with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = cifar10.get_split('train', data_dir)
    images, _, labels = load_batch(dataset,
                                   height=image_size,
                                   width=image_size)

    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, _ = inception.inception_v1(images,
                                           num_classes=dataset.num_classes,
                                           is_training=True)

    # Specify the loss function:
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
    slim.losses.softmax_cross_entropy(logits, one_hot_labels)
    total_loss = slim.losses.get_total_loss()
Exemplo n.º 10
0
from datasets import cifar10
import tensorflow as tf
import numpy as np
from nets import nets_factory
from preprocessing import preprocessing_factory
from datasets import dataset_utils
import csv
slim = tf.contrib.slim
BATCH_SIZE = 100

if __name__ == "__main__":
    tf.logging.set_verbosity(tf.logging.DEBUG)
    checkpoint_path = tf.train.latest_checkpoint("./log/train")
    print(checkpoint_path)
    kaggle_test = "/tmp/cifar10"
    dataset = cifar10.get_split('kaggle', kaggle_test)

    data_provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset, shuffle=False)
    image, label = data_provider.get(['image', 'label'])
    #    for item in dataset.list_items():

    preprocessing_fn = preprocessing_factory.get_preprocessing(
        "cifarnet", is_training=False)
    image = preprocessing_fn(image, 32, 32)

    images, labels = tf.train.batch([image, label],
                                    batch_size=BATCH_SIZE,
                                    allow_smaller_final_batch=True)

    # get the model prediction
Exemplo n.º 11
0
        session.run(train_step_fn.accuracy_test)

    train_step_fn.step += 1

    return [total_loss, should_stop]

if __name__ == '__main__':
    train_dir = 'tensorflow_log_lenet'
    data_dirname = 'cifar10'
    lr = 0.0001
    epochs = 1000
    batch_size = 128

    with tf.Graph().as_default():
        # load training data
        dataset_train = cifar10.get_split('train', data_dirname)
        images_train, labels_train = load_batch(dataset_train, \
                    batch_size = batch_size)
        dataset_test = cifar10.get_split('test', data_dirname)
        images_test, labels_test = load_batch(dataset_test, \
                batch_size = batch_size)

        # define the loss
        with tf.variable_scope('LeNet') as scope:
            logits_train, end_points_train = leNet(images_train, \
                    n_class = dataset_train.num_classes, \
                    is_training = True)
            one_hot_labels_train = slim.one_hot_encoding(\
                    labels_train, dataset_train.num_classes)
            slim.losses.softmax_cross_entropy(\
                    logits_train, one_hot_labels_train)