Exemple #1
0
def main():
    data_dir = 'dataset/CASIA-WebFace-112X96' 
    train_set = utils.get_dataset(data_dir)
    nrof_classes = len(train_set)
    print('nrof_classes: ',nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    image_list = np.array(image_list)
    label_list = np.array(label_list,dtype=np.int32)
    dataset_size = len(image_list)
    indices = range(dataset_size)
    np.random.shuffle(indices)

    batch_size = 100
    img_h = 112
    img_w = 96

    def _sample_people_softmax(x):
        global softmax_ind
        if softmax_ind >= dataset_size:
            np.random.shuffle(indices)
            softmax_ind = 0
        true_num_batch = min(batch_size,dataset_size - softmax_ind)

        sample_paths = image_list[indices[softmax_ind:softmax_ind+true_num_batch]]
        sample_labels = label_list[indices[softmax_ind:softmax_ind+true_num_batch]]

        softmax_ind += true_num_batch

        return (np.array(sample_paths), np.array(sample_labels,dtype=np.int32))

    def _parse_function(filename,label):
        file_contents = tf.read_file(filename)
        image = tf.image.decode_image(file_contents, channels=3)
        #image = tf.image.decode_jpeg(file_contents, channels=3)
        print(image.shape)
        return image, label

    epoch_size = 600
    max_nrof_epochs=10
    with tf.device("/cpu:0"):
        softmax_dataset = tf_data.Dataset.range(epoch_size*max_nrof_epochs*100)
        softmax_dataset = softmax_dataset.map(lambda x: tf.py_func(_sample_people_softmax,[x],[tf.string,tf.int32]))
        softmax_dataset = softmax_dataset.flat_map(_from_tensor_slices)
        softmax_dataset = softmax_dataset.map(_parse_function,num_threads=8,output_buffer_size=2000)
        softmax_dataset = softmax_dataset.batch(batch_size)
        softmax_iterator = softmax_dataset.make_initializable_iterator()
        softmax_next_element = softmax_iterator.get_next()
        softmax_next_element[0].set_shape((batch_size, img_h,img_w,3))
    with tf.Session() as sess:
        sess.run(softmax_iterator.initializer)
        for i in range(50):
            t = time.time()
            img_np,label_np = sess.run([softmax_next_element[0],softmax_next_element[1]])
            #print label_np
            print('Load {} images time cost: {}'.format(img_np.shape[0],time.time()-t))
Exemple #2
0
def main(args):
    with tf.Graph().as_default():

        with tf.Session() as sess:

            # Load the model
            utils.load_model(args.model)

            # Get input and output tensors
            images_placeholder = tf.get_default_graph().get_tensor_by_name(
                "image_batch_p:0")
            embeddings = tf.get_default_graph().get_tensor_by_name(
                "embeddings:0")
            phase_train_placeholder = tf.get_default_graph(
            ).get_tensor_by_name("phase_train:0")

            # Load data
            dataset = utils.get_dataset(args.data_root)
            image_paths, _ = utils.get_image_paths_and_labels(dataset)
            images = ImageDataRaw(image_paths, sess, batch_size=1)
            nrof_images = len(image_paths)

            with open("feature_vectors.txt", "w") as file:

                for i in range(nrof_images):
                    #img_orig = cv2.imread(image_paths[i])
                    #img = cv2.cvtColor(img_orig, cv2.COLOR_BGR2RGB)
                    #img = np.array(img)
                    #img = img[np.newaxis, ...]
                    img = images.batch()
                    path = os.path.abspath(image_paths[i])

                    # Run forward pass to calculate embeddings
                    feed_dict = {
                        images_placeholder: img,
                        phase_train_placeholder: False
                    }
                    emb = sess.run(embeddings, feed_dict=feed_dict)

                    # Get gps coordinates
                    try:
                        coordinates = coordinates_from_file(path)
                    except:
                        coordinates = None
                    line_dict = {
                        path: emb[0].tolist(),
                        "coordinates": coordinates
                    }
                    json_data = json.dumps(line_dict)

                    file.write(json_data)
                    file.write('\n')
                    print(path)
def main(args):
    if args.model == "resnet50":
        model = resnet50.ResNet50(include_top=False, input_shape=(224, 224, 3))
        image_size = 224

    dataset = utils.get_data(args.data_path)
    image_paths, labels = utils.get_image_paths_and_labels(dataset)
    images = utils.load_images(image_paths, image_size)

    last_output = model.predict(images)
    with open(args.model + "_" + args.set + ".pkl", "wb") as outfile:
        pickle.dump((last_output, labels), outfile)

    return
def main():
    data_dir = 'dataset/CASIA-WebFace-112X96'
    train_set = utils.get_dataset(data_dir)
    nrof_classes = len(train_set)
    print('nrof_classes: ', nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    input_size = (112, 96)  #h,w
    qr = QueueReader(image_list, label_list, input_size)
    batch_size = 100
    images, labels = qr.dequeue(batch_size)
    sess = tf.Session()
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    num_iter = 100
    for i in range(num_iter):
        t = time.time()
        np_imgs, np_labels = sess.run([images, labels])
        #print np_labels
        print('Load {} images cost time is {}'.format(np_imgs.shape[0],
                                                      time.time() - t))
Exemple #5
0
def main(args):

    #network = importlib.import_module(args.model_def)

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)

    train_set = utils.get_dataset(args.data_dir)
    #train_set = facenet.dataset_from_list2(args.data_dir,'dataset/casia_maxpy_mtcnnpy_182',error_classes=[],drop_key='AsianStarCropBig_YES')
    nrof_classes = len(train_set)
    print('nrof_classes: ', nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    image_list = np.array(image_list)
    label_list = np.array(label_list, dtype=np.int32)

    dataset_size = len(image_list)
    single_batch_size = args.people_per_batch * args.images_per_person
    indices = range(dataset_size)
    np.random.shuffle(indices)

    def _sample_people_softmax(x):
        global softmax_ind
        if softmax_ind >= dataset_size:
            np.random.shuffle(indices)
            softmax_ind = 0
        true_num_batch = min(single_batch_size, dataset_size - softmax_ind)

        sample_paths = image_list[indices[softmax_ind:softmax_ind +
                                          true_num_batch]]
        sample_labels = label_list[indices[softmax_ind:softmax_ind +
                                           true_num_batch]]

        softmax_ind += true_num_batch

        return (np.array(sample_paths), np.array(sample_labels,
                                                 dtype=np.int32))

    def _sample_people(x):
        '''We sample people based on tf.data, where we can use transform and prefetch.

        '''

        image_paths, num_per_class = sample_people(
            train_set, args.people_per_batch * (args.num_gpus - 1),
            args.images_per_person)
        labels = []
        for i in range(len(num_per_class)):
            labels.extend([i] * num_per_class[i])
        return (np.array(image_paths), np.array(labels, dtype=np.int32))

    def _parse_function(filename, label):
        file_contents = tf.read_file(filename)
        image = tf.image.decode_image(file_contents, channels=3)
        #image = tf.image.decode_jpeg(file_contents, channels=3)
        print(image.shape)

        if args.random_crop:
            print('use random crop')
            image = tf.random_crop(image,
                                   [args.image_size, args.image_size, 3])
        else:
            print('Not use random crop')
            #image.set_shape((args.image_size, args.image_size, 3))
            image.set_shape((None, None, 3))
            image = tf.image.resize_images(image,
                                           size=(args.image_height,
                                                 args.image_width))
            #print(image.shape)
        if args.random_flip:
            image = tf.image.random_flip_left_right(image)

        #pylint: disable=no-member
        #image.set_shape((args.image_size, args.image_size, 3))
        image.set_shape((args.image_height, args.image_width, 3))
        if debug:
            image = tf.cast(image, tf.float32)
        else:
            image = tf.image.per_image_standardization(image)
        return image, label

    #train_set = facenet.dataset_from_list(args.data_dir,'dataset/ms_mp',keys=['MultiPics'])
    #train_set = facenet.dataset_from_list(args.data_dir,'dataset/ms_mp')
    gpus = [0, 1]
    #gpus = [0]

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Placeholder for the learning rate
        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')

        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')

        #the image is generated by sequence
        with tf.device("/cpu:0"):

            softmax_dataset = tf_data.Dataset.range(args.epoch_size *
                                                    args.max_nrof_epochs * 100)
            softmax_dataset = softmax_dataset.map(lambda x: tf.py_func(
                _sample_people_softmax, [x], [tf.string, tf.int32]))
            softmax_dataset = softmax_dataset.flat_map(_from_tensor_slices)
            softmax_dataset = softmax_dataset.map(_parse_function,
                                                  num_threads=8,
                                                  output_buffer_size=2000)
            softmax_dataset = softmax_dataset.batch(args.num_gpus *
                                                    single_batch_size)
            softmax_iterator = softmax_dataset.make_initializable_iterator()
            softmax_next_element = softmax_iterator.get_next()
            softmax_next_element[0].set_shape(
                (args.num_gpus * single_batch_size, args.image_height,
                 args.image_width, 3))
            softmax_next_element[1].set_shape(args.num_gpus *
                                              single_batch_size)
            batch_image_split = tf.split(softmax_next_element[0],
                                         args.num_gpus)
            batch_label_split = tf.split(softmax_next_element[1],
                                         args.num_gpus)

        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        print('Using optimizer: {}'.format(args.optimizer))
        if args.optimizer == 'ADAGRAD':
            opt = tf.train.AdagradOptimizer(learning_rate)
        elif args.optimizer == 'MOM':
            opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
        tower_losses = []
        tower_cross = []
        tower_dist = []
        tower_th = []
        for i in range(args.num_gpus):
            with tf.device("/gpu:" + str(i)):
                with tf.name_scope("tower_" + str(i)) as scope:
                    with slim.arg_scope([slim.model_variable, slim.variable],
                                        device="/cpu:0"):
                        with tf.variable_scope(
                                tf.get_variable_scope()) as var_scope:
                            reuse = False if i == 0 else True
                            #with slim.arg_scope(resnet_v2.resnet_arg_scope(args.weight_decay)):
                            #prelogits, end_points = resnet_v2.resnet_v2_50(batch_image_split[i],is_training=True,
                            #        output_stride=16,num_classes=args.embedding_size,reuse=reuse)
                            #prelogits, end_points = network.inference(batch_image_split[i], args.keep_probability,
                            #    phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size,
                            #    weight_decay=args.weight_decay, reuse=reuse)
                            if args.network == 'slim_sphere':
                                prelogits = network.infer(batch_image_split[i])
                            elif args.network == 'densenet':
                                with slim.arg_scope(
                                        densenet.densenet_arg_scope(
                                            args.weight_decay)):
                                    #prelogits, endpoints = densenet.densenet_small(batch_image_split[i],num_classes=args.embedding_size,is_training=True,reuse=reuse)
                                    prelogits, endpoints = densenet.densenet_small_middle(
                                        batch_image_split[i],
                                        num_classes=args.embedding_size,
                                        is_training=True,
                                        reuse=reuse)
                                    prelogits = tf.squeeze(prelogits,
                                                           axis=[1, 2])

                            #prelogits = slim.batch_norm(prelogits, is_training=True, decay=0.997,epsilon=1e-5,scale=True,updates_collections=tf.GraphKeys.UPDATE_OPS,reuse=reuse,scope='softmax_bn')
                            if args.loss_type == 'softmax':
                                cross_entropy_mean = utils.softmax_loss(
                                    prelogits, batch_label_split[i],
                                    len(train_set), args.weight_decay, reuse)
                                regularization_losses = tf.get_collection(
                                    tf.GraphKeys.REGULARIZATION_LOSSES)
                                tower_cross.append(cross_entropy_mean)
                                #loss = cross_entropy_mean + args.weight_decay*tf.add_n(regularization_losses)
                                loss = cross_entropy_mean + tf.add_n(
                                    regularization_losses)
                                tower_dist.append(0)
                                tower_cross.append(cross_entropy_mean)
                                tower_th.append(0)
                                tower_losses.append(loss)
                            elif args.loss_type == 'scatter' or args.loss_type == 'coco':
                                label_reshape = tf.reshape(
                                    batch_label_split[i], [single_batch_size])
                                label_reshape = tf.cast(
                                    label_reshape, tf.int64)
                                if args.loss_type == 'scatter':
                                    scatter_loss, _ = utils.weight_scatter_speed(
                                        prelogits,
                                        label_reshape,
                                        len(train_set),
                                        reuse,
                                        weight=args.weight,
                                        scale=args.scale)
                                else:
                                    scatter_loss, _ = utils.coco_loss(
                                        prelogits,
                                        label_reshape,
                                        len(train_set),
                                        reuse,
                                        alpha=args.alpha,
                                        scale=args.scale)
                                regularization_losses = tf.get_collection(
                                    tf.GraphKeys.REGULARIZATION_LOSSES)
                                loss = scatter_loss[
                                    'loss_total'] + args.weight_decay * tf.add_n(
                                        regularization_losses)
                                tower_dist.append(scatter_loss['loss_dist'])
                                tower_cross.append(0)
                                tower_th.append(scatter_loss['loss_th'])

                                tower_losses.append(loss)

                            #loss = tf.add_n([cross_entropy_mean] + regularization_losses, name='total_loss')
                            tf.get_variable_scope().reuse_variables()
        total_loss = tf.reduce_mean(tower_losses)
        total_cross = tf.reduce_mean(tower_cross)
        total_dist = tf.reduce_mean(tower_dist)
        total_th = tf.reduce_mean(tower_th)
        losses = {}
        losses['total_loss'] = total_loss
        losses['total_cross'] = total_cross
        losses['total_dist'] = total_dist
        losses['total_th'] = total_th
        debug_info = {}
        debug_info['logits'] = prelogits
        #debug_info['end_points'] = end_points
        debug_info['batch_image_split'] = batch_image_split
        debug_info['batch_label_split'] = batch_label_split
        #debug_info['endpoints'] = endpoints

        grads = opt.compute_gradients(total_loss,
                                      tf.trainable_variables(),
                                      colocate_gradients_with_ops=True)
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.group(apply_gradient_op)

        save_vars = [
            var for var in tf.global_variables()
            if 'Adagrad' not in var.name and 'global_step' not in var.name
        ]
        check_nan = tf.add_check_numerics_ops()
        debug_info['check_nan'] = check_nan

        #saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
        saver = tf.train.Saver(save_vars, max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})

        #sess.run(iterator.initializer)
        sess.run(softmax_iterator.initializer)

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            #pdb.set_trace()

            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                saver.restore(sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size
                if debug:
                    debug_train(args, sess, train_set, epoch,
                                image_batch_gather, enqueue_op,
                                batch_size_placeholder, image_batch_split,
                                image_paths_split, num_per_class_split,
                                image_paths_placeholder,
                                image_paths_split_placeholder,
                                labels_placeholder, labels_batch,
                                num_per_class_placeholder,
                                num_per_class_split_placeholder, len(gpus))
                # Train for one epoch
                train(args, sess, epoch, len(gpus), debug_info,
                      learning_rate_placeholder, phase_train_placeholder,
                      global_step, losses, train_op, summary_op,
                      summary_writer, args.learning_rate_schedule_file)

                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, step)

                # Evaluate on LFW
    return model_dir
Exemple #6
0
def main(args):

    #network = importlib.import_module(args.model_def)

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)

    train_set = utils.get_dataset(args.data_dir)
    nrof_classes = len(train_set)
    print('nrof_classes: ', nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    image_list = np.array(image_list)
    print('total images: {}'.format(len(image_list)))
    label_list = np.array(label_list, dtype=np.int32)

    dataset_size = len(image_list)
    data_reader = DataGenerator(image_list, label_list, args.batch_size)

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Placeholder for the learning rate
        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')
        images_placeholder = tf.placeholder(tf.float32, [None, 112, 96, 3],
                                            name='images_placeholder')
        labels_placeholder = tf.placeholder(tf.int32, [None],
                                            name='labels_placeholder')

        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')

        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        print('Using optimizer: {}'.format(args.optimizer))
        if args.optimizer == 'ADAGRAD':
            opt = tf.train.AdagradOptimizer(learning_rate)
        elif args.optimizer == 'MOM':
            opt = tf.train.MomentumOptimizer(learning_rate, 0.9)

        if args.network == 'sphere_network':
            prelogits = network.infer(images_placeholder)
        else:
            raise Exception('Not supported network: {}'.format(args.loss_type))

        if args.loss_type == 'softmax':
            cross_entropy_mean = utils.softmax_loss(prelogits,
                                                    labels_placeholder,
                                                    len(train_set),
                                                    args.weight_decay, False)
            regularization_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            #loss = cross_entropy_mean + args.weight_decay*tf.add_n(regularization_losses)
            loss = cross_entropy_mean + args.weight_decay * tf.add_n(
                regularization_losses)
            #loss = cross_entropy_mean
        else:
            raise Exception('Not supported loss type: {}'.format(
                args.loss_type))

        #loss = tf.add_n([cross_entropy_mean] + regularization_losses, name='total_loss')
        losses = {}
        losses['total_loss'] = loss
        losses['softmax_loss'] = cross_entropy_mean
        debug_info = {}
        debug_info['prelogits'] = prelogits

        grads = opt.compute_gradients(loss, tf.trainable_variables())
        train_op = opt.apply_gradients(grads, global_step=global_step)

        #save_vars = [var for var in tf.global_variables() if 'Adagrad' not in var.name and 'global_step' not in var.name]
        save_vars = tf.global_variables()

        #saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
        saver = tf.train.Saver(save_vars, max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.

        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})

        with sess.as_default():
            #pdb.set_trace()

            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                saver.restore(sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size

                # Train for one epoch
                train(args, sess, epoch, images_placeholder,
                      labels_placeholder, data_reader, debug,
                      learning_rate_placeholder, global_step, losses, train_op,
                      args.learning_rate_schedule_file)

                # Save variables and the metagraph if it doesn't exist already
                model_dir = args.models_base_dir
                checkpoint_path = os.path.join(model_dir,
                                               'model-%s.ckpt' % 'softmax')
                saver.save(sess,
                           checkpoint_path,
                           global_step=step,
                           write_meta_graph=False)

                # Evaluate on LFW
    return model_dir
Exemple #7
0
def main(args):
    log_dir = args.log_dir
    # log_file_name = args.model.split()[-1] + '__' + args.feature_vectors_file
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)

    logging.basicConfig(level=logging.DEBUG, filemode='w')
    class_logger = config_logger('class_logger',
                                 '../testing_results/classes.log')
    overall_logger = config_logger('ovrl_logger',
                                   '../testing_results/results.log')

    with tf.Graph().as_default():

        with tf.Session() as sess:

            # Load the model
            utils.load_model(args.model)

            # Get input and output tensors
            images_placeholder = tf.get_default_graph().get_tensor_by_name(
                "image_batch_p:0")
            embeddings = tf.get_default_graph().get_tensor_by_name(
                "embeddings:0")
            phase_train_placeholder = tf.get_default_graph(
            ).get_tensor_by_name("phase_train:0")

            # Load data
            dataset = utils.get_dataset(args.data_root)
            image_paths, _ = utils.get_image_paths_and_labels(dataset)
            images = ImageDataRaw(image_paths, sess, batch_size=1)
            nrof_images = len(image_paths)
            count = 0

            emb_array = None
            path_list = []
            coords_list = []
            line_count = 0
            with open(args.feature_vectors_file, "r") as file:
                for line in file:
                    # Calculate embedding
                    emb, coords, path = get_embedding_and_path(line)

                    if line_count % 100 == 0:
                        print(path)
                    line_count += 1

                    if emb_array is None:
                        emb_array = np.array(emb)
                    else:
                        emb_array = np.concatenate((emb_array, emb))
                    path_list.append(path)
                    coords_list.append(coords)

            emb_array = emb_array.reshape((-1, 512))
            print(emb_array.shape)

            class_true_count = 0
            class_all_count = 1
            last_class_name = ''
            duration = 0

            for i in range(nrof_images):
                #img_orig = cv2.imread(image_paths[i])
                #img = cv2.cvtColor(img_orig, cv2.COLOR_BGR2RGB)
                #img = np.array(img)
                #img = img[np.newaxis, ...]
                img = images.batch()
                target_path = p.abspath(image_paths[i])
                target_path_short = p.split(target_path)[-1]
                target_class_name = get_querie_class_name(target_path)

                if last_class_name != target_class_name and last_class_name != '':
                    log_class_accuracy(last_class_name, class_true_count,
                                       class_all_count, class_logger)
                    class_all_count, class_true_count = 0, 0

                last_class_name = target_class_name

                # Run forward pass to calculate embeddings
                feed_dict = {
                    images_placeholder: img,
                    phase_train_placeholder: False
                }
                target_emb = sess.run(embeddings, feed_dict=feed_dict)

                # Calculate the area of search
                try:
                    target_gps_coords = coordinates_from_file(target_path)
                    center_characteristics = get_center_from_coords(
                        target_gps_coords)
                    target_coords_are_none = False
                except:
                    target_coords_are_none = True

                img_file_list = []
                upper_bound = args.upper_bound
                start_time = time()

                for j in range(len(emb_array)):

                    # Check coords to be present and if so, whether they are in the target area
                    if (not args.use_coords) or (coords_list[j] is None) or target_coords_are_none or \
                            check_coords_in_radius(center_characteristics, coords_list[j]):
                        # Then calculate distance to the target
                        dist = np.sqrt(
                            np.sum(
                                np.square(
                                    np.subtract(emb_array[j], target_emb[0]))))

                        # Insert a score with a path
                        img_file = ImageFile(path_list[j], dist)
                        img_file_list = insert_element(img_file,
                                                       img_file_list,
                                                       upper_bound=upper_bound)
                    else:
                        continue

                if top_n_accuracy(target_class_name, img_file_list, args.top_n,
                                  upper_bound):
                    print(target_class_name)
                    overall_logger.info(target_class_name + '/' +
                                        target_path_short)
                    count += 1
                    class_true_count += 1
                else:
                    print(target_class_name, list(map(str, img_file_list[:5])))
                    overall_logger.info(target_class_name + '/' +
                                        target_path_short + ' ' +
                                        str(list(map(str, img_file_list[:5]))))

                class_all_count += 1
                duration += time() - start_time

            log_class_accuracy(last_class_name, class_true_count,
                               class_all_count, class_logger)
            print(count / nrof_images)
            print(duration / nrof_images)
            overall_logger.info('Total Accuracy: ' + str(count / nrof_images))
Exemple #8
0
def main(args):
	global network
	train_set = utils.get_dataset(vars.train_dir) #Load dataset
	val_set = utils.get_dataset(vars.val_dir) 

	# Get a list of image paths and their labels
	train_img_list, train_labels = utils.get_image_paths_and_labels(train_set)
	assert len(train_img_list)>0, 'The training set should not be empty'

	val_img_list, val_labels = utils.get_image_paths_and_labels(val_set)

	#utils.augment_images(train_img_list, 4) #it only must be called one time to generate several images from single image (don't forget to set validation_set_split_ratio = 0)
	
	if(os.path.exists('train_descs.npy')):
		train_descs	= np.load('train_descs.npy')
	else:
		train_descs = hogutils.get_hog_desc(train_img_list, False)
		np.save('train_descs.npy', train_descs)

	if(os.path.exists('val_descs.npy')):
		val_descs	= np.load('val_descs.npy')
	else:
		val_descs = hogutils.get_hog_desc(val_img_list, False)
		np.save('val_descs.npy', val_descs)

	train_labels = np.array(train_labels, dtype=np.int64)
	val_labels = np.array(val_labels, dtype=np.int64)
	# Shuffle data
	rand = np.random.RandomState(10)
	shuffle = rand.permutation(len(train_labels))	
	train_descs, train_labels = train_descs[shuffle], train_labels[shuffle]

	##############################################################################
	if(vars.model_name == 'mlp-torch'):
		model = torch.nn.Sequential(
				torch.nn.Linear(2025, 128),
				torch.nn.ReLU(),
				torch.nn.Linear(128, 64),
				torch.nn.ReLU(),
				torch.nn.Linear(64, 5),
			)

		train_dataset = Data.TensorDataset(torch.from_numpy(train_descs), torch.from_numpy(train_labels))
		val_dataset = Data.TensorDataset(torch.from_numpy(val_descs), torch.from_numpy(val_labels))
		datasets = {'train': train_dataset, 'val': val_dataset}

		vars.dataloaders = {x: Data.DataLoader(datasets[x], batch_size=vars.batch_size, shuffle=True, num_workers=0)
				for x in ['train', 'val']}

		vars.dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}
		#vars.class_names = datasets['train'].classes

		optimizer = optim.SGD(model.parameters(), lr = vars.learning_rate, momentum=0.9)
		#optimizer = optim.Adam(model.parameters(), lr=0.05)
		# Decay LR by a factor of 0.6 every 6 epochs
		exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = vars.scheduler_step_size, gamma = vars.scheduler_gamma)
		model = model.to(vars.device)
		model = train_model(model, vars.criterion, optimizer, exp_lr_scheduler, vars.num_epochs)

		log_file = open(".\\Time-{}-{}.log".format(vars.model_name, vars.batch_size),"w")	
		for dev in ['cuda', 'cpu']:
			vars.device = torch.device(dev)
			model = model.to(vars.device)
			#run model on one batch to allocate required memory on device (and have more exact results)
			inputs, classes = next(iter(vars.dataloaders['train']))
			inputs = inputs.to(vars.device)
			outputs = model(inputs)

			s = test_model(model, vars.criterion, 'val', 100)
			log_file.write(s)
			#log_file.write('\n' + '-'*80)
		
		#log_file.write(summary(model, input_size=(3, vars.input_size, vars.input_size), batch_size=-1, device=vars.device.type))
		log_file.close() 
	elif (vars.model_name == 'svm'):
		print('Training SVM model ...')
		model = svmInit()
		svmTrain(model, train_descs, train_labels)
		model.save('svm_model.xml')
		print('Evaluating model ... ')
		svmEvaluate(model, None, train_descs, train_labels)
		t0 = time.time()
		svmEvaluate(model, None, val_descs, val_labels)
		time_elapsed = time.time()-t0
		print('Test completed over {} samples in {:.2f}s'.format(len(train_labels), time_elapsed))
		print('Test time per sample {:.3f}ms'.format(time_elapsed * 1000 / len(train_labels)))
	elif (vars.model_name == 'knn'):
		print('Training KNN model ...')
		model = cv2.ml.KNearest_create()
		model.setDefaultK(5)
		model.setIsClassifier(True)
		model.train(train_descs, cv2.ml.ROW_SAMPLE, train_labels)
		model.save('knn.xml')
		print('Evaluating model ... ')
		svmEvaluate(model, None, train_descs, train_labels)
		t0 = time.time()
		svmEvaluate(model, None, val_descs, val_labels)
		time_elapsed = time.time()-t0
		print('Test completed over {} samples in {:.2f}s'.format(len(train_labels), time_elapsed))
		print('Test time per sample {:.3f}ms'.format(time_elapsed * 1000 / len(train_labels)))		
	elif(vars.model_name == 'bayes'):
		print('Training Bayes model ...')
		model = cv2.ml.NormalBayesClassifier_create()
		model.train(train_descs, cv2.ml.ROW_SAMPLE, train_labels)
		model.save('bayes.xml')
		print('Evaluating model ... ')
		svmEvaluate(model, None, train_descs, train_labels)
		t0 = time.time()
		svmEvaluate(model, None, val_descs, val_labels)
		time_elapsed = time.time()-t0
		print('Test completed over {} samples in {:.2f}s'.format(len(train_labels), time_elapsed))
		print('Test time per sample {:.3f}ms'.format(time_elapsed * 1000 / len(train_labels)))

	elif(vars.model_name == 'mlp-keras'):
		train_labels = to_categorical(train_labels)
		if (len(val_labels) > 0):
			val_labels = to_categorical(val_labels)


		network.add(layers.Dense(128, activation='relu', input_shape=(2025,)))
		network.add(layers.Dense(64, activation='relu'))
		network.add(layers.Dense(5, activation='softmax'))
		
		opt = keras.optimizers.SGD(lr=0.05, momentum=0.5, decay=1e-3, nesterov=False)]
			#keras.optimizers.RMSprop(lr=0.001, decay=1e-6)]#
			#keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
			#keras.optimizers.Nadam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-10, schedule_decay=0.004)

		network.summary()

		network.reset_states()
		network.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])		
		#saves the model weights after each epoch if the validation loss decreased
		now = datetime.now() # current date and time
		checkpointer = ModelCheckpoint(filepath='best_model_' + now.strftime("%Y%m%d") + '.hdf5', verbose=1, save_best_only=True)
		
		manageTrainEvents = ManageTrainEvents()
		history = network.fit(train_descs, train_labels, validation_data=(val_descs, val_labels), 
				epochs=vars.num_epochs, batch_size=vars.batch_size, callbacks=[checkpointer, manageTrainEvents])

		network.save('Rec_' + now.strftime("%Y%m%d-%H%M") + '.hdf5')
		#Plot loss and accuracy
		acc = history.history['acc']
		val_acc = history.history['val_acc']
		loss = history.history['loss']
		val_loss = history.history['val_loss']
		utils.plot_graphs(loss, val_loss, acc, val_acc, True)
		#Evaluate on test dataset
		print("\nComputing test accuracy")
		test_loss, test_acc = network.evaluate(val_descs, val_labels)
		print('test_acc:', test_acc)
def main(args):
    network = importlib.import_module(args.model_def)
    image_size = (args.image_size, args.image_size)

    subdir = datetime.strftime(datetime.now(), '%Y%m%d')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    stat_file_name = os.path.join(log_dir, 'stat.h5')

    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)
    random.seed(args.seed)
    dataset = utils.get_dataset(args.data_dir)
    # print(dataset[1].image_paths)
    if args.filter_filename:
        dataset = filter_dataset(dataset, os.path.expanduser(args.filter_filename),
                                 args.filter_percentile, args.filter_min_nrof_images_per_class)

    if args.validation_set_split_ratio > 0.0:
        train_set, val_set = utils.split_dataset(dataset, args.validation_set_split_ratio,
                                                   args.min_nrof_val_images_per_class, 'SPLIT_IMAGES')
    else:
        train_set, val_set = dataset, []

    nrof_classes = len(train_set)

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    pretrained_model = None
    if args.pretrained_model:
        pretrained_model = os.path.expanduser(args.pretrained_model)
        print('Pre-trained model: %s' % pretrained_model)

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Get a list of image paths and their labels
        image_list, label_list = utils.get_image_paths_and_labels(train_set)
        assert len(image_list) > 0, 'The training set should not be empty'

        val_image_list, val_label_list = utils.get_image_paths_and_labels(val_set)

        learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate')
        batch_size_placeholder = tf.placeholder(tf.int32, name='batch_size')
        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
        image_paths_placeholder = tf.placeholder(tf.string, shape=(None, 1), name='image_paths')
        labels_placeholder = tf.placeholder(tf.int32, shape=(None, 1), name='labels')
        control_placeholder = tf.placeholder(tf.int32, shape=(None, 1), name='control')

        image_batch_plh = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='image_batch_p')
        label_batch_plh = tf.placeholder(tf.int32, name='label_batch_p')

        print('Number of classes in training set: %d' % nrof_classes)
        print('Number of examples in training set: %d' % len(image_list))

        print('Number of classes in validation set: %d' % len(val_set))
        print('Number of examples in validation set: %d' % len(val_image_list))

        print('Building training graph')

        # Build the inference graph
        # prelogits, _ = efficientnet_builder.build_model_base(image_batch_plh, 'efficientnet-b2', training=True)
        prelogits, _ = network.inference(image_batch_plh, args.keep_probability, image_size,
                                       phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size,
                                         weight_decay=args.weight_decay)
        logits = slim.fully_connected(prelogits, len(train_set), activation_fn=None,
                                      weights_initializer=slim.initializers.xavier_initializer(),
                                      weights_regularizer=slim.l2_regularizer(args.weight_decay),
                                      scope='Logits', reuse=False)

        embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings')

        # Norm for the prelogits
        eps = 1e-4
        prelogits_norm = tf.reduce_mean(tf.norm(tf.abs(prelogits) + eps, ord=args.prelogits_norm_p, axis=1))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, prelogits_norm * args.prelogits_norm_loss_factor)

        # Add center loss
        prelogits_center_loss, _ = utils.center_loss(prelogits, label_batch_plh, args.center_loss_alfa, nrof_classes)
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, prelogits_center_loss * args.center_loss_factor)

        learning_rate = tf.train.exponential_decay(learning_rate_placeholder, global_step,
                                                   args.learning_rate_decay_epochs * args.epoch_size,
                                                   args.learning_rate_decay_factor, staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        # Calculate the average cross entropy loss across the batch
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=label_batch_plh, logits=logits, name='cross_entropy_per_example')
        cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
        tf.add_to_collection('losses', cross_entropy_mean)

        correct_prediction = tf.cast(tf.equal(tf.argmax(logits, 1), tf.cast(label_batch_plh, tf.int64)), tf.float32)
        accuracy = tf.reduce_mean(correct_prediction, name='accuracy')

        # Calculate the total losses
        regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([cross_entropy_mean] + regularization_losses, name='total_loss')

        # Separate facenet variables from smaug's ones
        facenet_global_vars = tf.global_variables()

        # Build a Graph that trains the model with one batch of examples and updates the model parameters
        train_op = utils.train(total_loss, global_step, args.optimizer,
                                 learning_rate, args.moving_average_decay, facenet_global_vars, args.log_histograms)

        # Create a saver
        facenet_saver_vars = tf.trainable_variables()
        facenet_saver_vars.append(global_step)
        saver = tf.train.Saver(facenet_saver_vars, max_to_keep=10)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Create session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)

        # Create normal pipeline
        dataset_train = LabeledImageData(image_list, label_list, sess, batch_size=args.batch_size, shuffle=True,
                                         use_flip=True, use_black_patches=True, use_crop=True)
        dataset_val = LabeledImageDataRaw(val_image_list, val_label_list, sess, batch_size=args.val_batch_size,
                                          shuffle=False)

        # Start running operations on the Graph. Change to tf.compat in newer versions of tf.
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        with sess.as_default():
            if pretrained_model:
                print('Restoring pretrained model: %s' % pretrained_model)
                ckpt_dir_or_file = tf.train.latest_checkpoint(pretrained_model)
                saver.restore(sess, ckpt_dir_or_file)

            # Training and validation loop
            print('Running training')
            nrof_steps = args.max_nrof_epochs * args.epoch_size
            nrof_val_samples = int(math.ceil(
                args.max_nrof_epochs / args.validate_every_n_epochs))  # Validate every validate_every_n_epochs as well as in the last epoch
            stat = {
                'loss': np.zeros((nrof_steps,), np.float32),
                'center_loss': np.zeros((nrof_steps,), np.float32),
                'reg_loss': np.zeros((nrof_steps,), np.float32),
                'xent_loss': np.zeros((nrof_steps,), np.float32),
                'prelogits_norm': np.zeros((nrof_steps,), np.float32),
                'accuracy': np.zeros((nrof_steps,), np.float32),
                'val_loss': np.zeros((nrof_val_samples,), np.float32),
                'val_xent_loss': np.zeros((nrof_val_samples,), np.float32),
                'val_accuracy': np.zeros((nrof_val_samples,), np.float32),
                'lfw_accuracy': np.zeros((args.max_nrof_epochs,), np.float32),
                'lfw_valrate': np.zeros((args.max_nrof_epochs,), np.float32),
                'learning_rate': np.zeros((args.max_nrof_epochs,), np.float32),
                'time_train': np.zeros((args.max_nrof_epochs,), np.float32),
                'time_validate': np.zeros((args.max_nrof_epochs,), np.float32),
                'time_evaluate': np.zeros((args.max_nrof_epochs,), np.float32),
                'prelogits_hist': np.zeros((args.max_nrof_epochs, 1000), np.float32),
                'smaug_alpha_loss': np.zeros((nrof_steps,), np.float32),
                'smaug_total_loss': np.zeros((nrof_steps,), np.float32)
            }
            global_step_ = sess.run(global_step)
            start_epoch = 1 + global_step_ // args.epoch_size
            batch_number = global_step_ % args.epoch_size
            biggest_acc = 0.0
            for epoch in range(start_epoch, args.max_nrof_epochs + 1):
                step = sess.run(global_step, feed_dict=None)
                # Train for one epoch
                t = time.time()
                cont = train(args, sess, epoch, batch_number,
                             learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder,
                             image_batch_plh, label_batch_plh, global_step,
                             total_loss, train_op, summary_op, summary_writer, regularization_losses,
                             args.learning_rate_schedule_file,
                             stat, cross_entropy_mean, accuracy, learning_rate, prelogits, prelogits_center_loss,
                             prelogits_norm, args.prelogits_hist_max, dataset_train, )
                stat['time_train'][epoch - 1] = time.time() - t
                print("------------------Accuracy-----------------" + str(stat['val_accuracy']))
                if not cont:
                    break

                t = time.time()
                if len(val_image_list) > 0 and ((epoch - 1) % args.validate_every_n_epochs == args.validate_every_n_epochs - 1 or epoch == args.max_nrof_epochs):
                    validate(args, sess, epoch, val_label_list, phase_train_placeholder, batch_size_placeholder,
                             stat, total_loss, cross_entropy_mean, accuracy, args.validate_every_n_epochs,
                             image_batch_plh, label_batch_plh, dataset_val)
                stat['time_validate'][epoch - 1] = time.time() - t

                cur_val_acc = get_val_acc(epoch, stat, args.validate_every_n_epochs)

                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, epoch, args.save_every,
                                             cur_val_acc, biggest_acc, args.save_best)

                biggest_acc = update_biggest_acc(biggest_acc, cur_val_acc)

                print('Saving statistics')
                with h5py.File(stat_file_name, 'w') as f:
                    for key, value in stat.items():
                        f.create_dataset(key, data=value)

    return model_dir
Exemple #10
0
def main_train(args):

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)
    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)
    train_set = utils.dataset_from_list(
        args.train_data_dir, args.train_list_dir)  # class objects in a list

    #----------------------class definition-------------------------------------
    '''
    class ImageClass():
    "Stores the paths to images for a given class"
    def __init__(self, name, image_paths):
        self.name = name
        self.image_paths = image_paths
  
    def __str__(self):
        return self.name + ', ' + str(len(self.image_paths)) + ' images'
  
    def __len__(self):
        return len(self.image_paths)
    '''

    nrof_classes = len(train_set)
    print('nrof_classes: ', nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    print('total images: ', len(image_list))  # label is in the form scalar.
    image_list = np.array(image_list)
    label_list = np.array(label_list, dtype=np.int32)
    dataset_size = len(image_list)
    single_batch_size = args.class_per_batch * args.images_per_class
    indices = list(range(dataset_size))
    np.random.shuffle(indices)

    def _sample_people_softmax(x):  # loading the images in batches.
        global softmax_ind
        if softmax_ind >= dataset_size:
            np.random.shuffle(indices)
            softmax_ind = 0
        true_num_batch = min(single_batch_size, dataset_size - softmax_ind)

        sample_paths = image_list[indices[softmax_ind:softmax_ind +
                                          true_num_batch]]
        sample_images = []

        for item in sample_paths:
            sample_images.append(np.load(str(item)))
            #print(item)
        #print(type(sample_paths[0]))
        sample_labels = label_list[indices[softmax_ind:softmax_ind +
                                           true_num_batch]]
        softmax_ind += true_num_batch
        return (np.expand_dims(np.array(sample_images, dtype=np.float32),
                               axis=4), np.array(sample_labels,
                                                 dtype=np.int32))

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')
        # Placeholder for the learning rate
        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')
        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
        #the image is generated by sequence
        with tf.device("/cpu:0"):

            softmax_dataset = tf.data.Dataset.range(args.epoch_size *
                                                    args.max_nrof_epochs)
            softmax_dataset = softmax_dataset.map(lambda x: tf.py_func(
                _sample_people_softmax, [x], [tf.float32, tf.int32]))
            softmax_dataset = softmax_dataset.flat_map(_from_tensor_slices)
            softmax_dataset = softmax_dataset.batch(single_batch_size)
            softmax_iterator = softmax_dataset.make_initializable_iterator()
            softmax_next_element = softmax_iterator.get_next()
            softmax_next_element[0].set_shape(
                (single_batch_size, args.image_height, args.image_width,
                 args.image_width, 1))
            softmax_next_element[1].set_shape(single_batch_size)
            batch_image_split = softmax_next_element[0]
            # batch_image_split = tf.expand_dims(batch_image_split, axis = 4)
            batch_label_split = softmax_next_element[1]

        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        print('Using optimizer: {}'.format(args.optimizer))
        if args.optimizer == 'ADAGRAD':
            opt = tf.train.AdagradOptimizer(learning_rate)
        elif args.optimizer == 'SGD':
            opt = tf.train.GradientDescentOptimizer(learning_rate)
        elif args.optimizer == 'MOM':
            opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
        elif args.optimizer == 'ADAM':
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=0.9,
                                         beta2=0.999,
                                         epsilon=0.1)
        else:
            raise Exception("Not supported optimizer: {}".format(
                args.optimizer))

        losses = {}
        with slim.arg_scope([slim.model_variable, slim.variable],
                            device="/cpu:0"):
            with tf.variable_scope(tf.get_variable_scope()) as var_scope:
                reuse = False

                if args.network == 'sphere_network':

                    prelogits = network.infer(batch_image_split,
                                              args.embedding_size)
                else:
                    raise Exception("Not supported network: {}".format(
                        args.network))

                if args.fc_bn:
                    prelogits = slim.batch_norm(prelogits, is_training=True, decay=0.997,epsilon=1e-5,scale=True,\
                        updates_collections=tf.GraphKeys.UPDATE_OPS,reuse=reuse,scope='softmax_bn')

                if args.loss_type == 'softmax':
                    cross_entropy_mean = utils.softmax_loss(
                        prelogits, batch_label_split, len(train_set), 1.0,
                        reuse)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    loss = cross_entropy_mean + args.weight_decay * tf.add_n(
                        regularization_losses)
                    print('************************' +
                          ' Computing the softmax loss')
                    losses['total_loss'] = cross_entropy_mean
                    losses['total_reg'] = args.weight_decay * tf.add_n(
                        regularization_losses)

                elif args.loss_type == 'lmcl':
                    label_reshape = tf.reshape(batch_label_split,
                                               [single_batch_size])
                    label_reshape = tf.cast(label_reshape, tf.int64)
                    coco_loss = utils.cos_loss(prelogits,
                                               label_reshape,
                                               len(train_set),
                                               reuse,
                                               alpha=args.alpha,
                                               scale=args.scale)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    loss = coco_loss + args.weight_decay * tf.add_n(
                        regularization_losses)
                    print('************************' +
                          ' Computing the lmcl loss')
                    losses['total_loss'] = coco_loss
                    losses['total_reg'] = args.weight_decay * tf.add_n(
                        regularization_losses)

                elif args.loss_type == 'center':
                    # center loss
                    center_loss, centers, centers_update_op = get_center_loss(prelogits, label_reshape, args.center_loss_alfa, \
                        args.num_class_train)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    loss = center_loss + args.weight_decay * tf.add_n(
                        regularization_losses)
                    print('************************' +
                          ' Computing the center loss')
                    losses['total_loss'] = center_loss
                    losses['total_reg'] = args.weight_decay * tf.add_n(
                        regularization_losses)

                elif args.loss_type == 'lmccl':
                    cross_entropy_mean = utils.softmax_loss(
                        prelogits, batch_label_split, len(train_set), 1.0,
                        reuse)
                    label_reshape = tf.reshape(batch_label_split,
                                               [single_batch_size])
                    label_reshape = tf.cast(label_reshape, tf.int64)
                    coco_loss = utils.cos_loss(prelogits,
                                               label_reshape,
                                               len(train_set),
                                               reuse,
                                               alpha=args.alpha,
                                               scale=args.scale)
                    center_loss, centers, centers_update_op = get_center_loss(prelogits, label_reshape, args.center_loss_alfa, \
                        args.num_class_train)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    reg_loss = args.weight_decay * tf.add_n(
                        regularization_losses)
                    loss = coco_loss + reg_loss + args.center_weighting * center_loss + cross_entropy_mean
                    losses[
                        'total_loss_center'] = args.center_weighting * center_loss
                    losses['total_loss_lmcl'] = coco_loss
                    losses['total_loss_softmax'] = cross_entropy_mean
                    losses['total_reg'] = reg_loss

        grads = opt.compute_gradients(loss,
                                      tf.trainable_variables(),
                                      colocate_gradients_with_ops=True)
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # used for updating the centers in the center loss.
        if args.loss_type == 'lmccl' or args.loss_type == 'center':
            with tf.control_dependencies([centers_update_op]):
                with tf.control_dependencies(update_ops):
                    train_op = tf.group(apply_gradient_op)
        else:
            with tf.control_dependencies(update_ops):
                train_op = tf.group(apply_gradient_op)

        save_vars = [
            var for var in tf.global_variables()
            if 'Adagrad' not in var.name and 'global_step' not in var.name
        ]
        saver = tf.train.Saver(save_vars, max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()
        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})

        #sess.run(iterator.initializer)
        sess.run(softmax_iterator.initializer)
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                saver.restore(sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size
                if debug:
                    debug_train(args, sess, train_set, epoch, image_batch_gather,\
                     enqueue_op,batch_size_placeholder, image_batch_split,image_paths_split,num_per_class_split,
                            image_paths_placeholder,image_paths_split_placeholder, labels_placeholder, labels_batch,\
                             num_per_class_placeholder,num_per_class_split_placeholder,len(gpus))
                # Train for one epoch
                if args.loss_type == 'lmccl' or args.loss_type == 'center':
                    train_contain_center(args, sess, epoch,
                                         learning_rate_placeholder,
                                         phase_train_placeholder, global_step,
                                         losses, train_op, summary_op,
                                         summary_writer, '', centers_update_op)
                else:
                    train(args, sess, epoch, learning_rate_placeholder,
                          phase_train_placeholder, global_step, losses,
                          train_op, summary_op, summary_writer, '')
                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, step)
    return model_dir