Beispiel #1
0
    def __init__(self):
        self.x_input = tf.placeholder(tf.float32, shape=[None, 784])
        self.y_input = tf.placeholder(tf.int64, shape=[None])
        self.x_image = tf.reshape(self.x_input, [-1, 28, 28, 1])

        with slim.arg_scope([slim.conv2d], kernel_size=3, padding='SAME'):
            with slim.arg_scope([slim.max_pool2d], kernel_size=2):
                x = slim.conv2d(self.x_image, num_outputs=32, scope='conv1_1')
                x = slim.conv2d(x, num_outputs=32, scope='conv1_2')
                x = slim.max_pool2d(x, scope='pool1')

                x = slim.conv2d(x, num_outputs=64, scope='conv2_1')
                x = slim.conv2d(x, num_outputs=64, scope='conv2_2')
                x = slim.max_pool2d(x, scope='pool2')

                x = slim.conv2d(x, num_outputs=128, scope='conv3_1')
                x = slim.conv2d(x, num_outputs=128, scope='conv3_2')
                x = slim.max_pool2d(x, scope='pool3')

                x = slim.flatten(x, scope='flatten')

                x = slim.fully_connected(x,
                                         num_outputs=32,
                                         activation_fn=None,
                                         scope='fc1')

                x = slim.fully_connected(x,
                                         num_outputs=2,
                                         activation_fn=None,
                                         scope='fc2')

                self.feature = x = tflearn.prelu(x)

                self.xent, logits, tmp = cos_loss(x,
                                                  self.y_input,
                                                  10,
                                                  alpha=0.25)

                self.y_pred = tf.arg_max(
                    tf.matmul(tmp['x_feat_norm'], tmp['w_feat_norm']), 1)
                self.accuracy = tf.reduce_mean(
                    tf.cast(tf.equal(self.y_pred, self.y_input), tf.float32))
Beispiel #2
0
def main(args):

    #network = importlib.import_module(args.model_def)

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)

    train_set = utils.get_dataset(args.data_dir)
    nrof_classes = len(train_set)
    print('nrof_classes: ', nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    image_list = np.array(image_list)
    label_list = np.array(label_list, dtype=np.int32)

    dataset_size = len(image_list)
    single_batch_size = args.people_per_batch * args.images_per_person
    indices = range(dataset_size)
    np.random.shuffle(indices)

    def _sample_people_softmax(x):
        global softmax_ind
        if softmax_ind >= dataset_size:
            np.random.shuffle(indices)
            softmax_ind = 0
        true_num_batch = min(single_batch_size, dataset_size - softmax_ind)

        sample_paths = image_list[indices[softmax_ind:softmax_ind +
                                          true_num_batch]]
        sample_labels = label_list[indices[softmax_ind:softmax_ind +
                                           true_num_batch]]

        softmax_ind += true_num_batch

        return (np.array(sample_paths), np.array(sample_labels,
                                                 dtype=np.int32))

    def _sample_people(x):
        '''We sample people based on tf.data, where we can use transform and prefetch.

        '''

        image_paths, num_per_class = sample_people(
            train_set, args.people_per_batch * (args.num_gpus - 1),
            args.images_per_person)
        labels = []
        for i in range(len(num_per_class)):
            labels.extend([i] * num_per_class[i])
        return (np.array(image_paths), np.array(labels, dtype=np.int32))

    def _parse_function(filename, label):
        file_contents = tf.read_file(filename)
        image = tf.image.decode_image(file_contents, channels=3)
        #image = tf.image.decode_jpeg(file_contents, channels=3)
        print(image.shape)

        if args.random_crop:
            print('use random crop')
            image = tf.random_crop(image,
                                   [args.image_size, args.image_size, 3])
        else:
            print('Not use random crop')
            #image.set_shape((args.image_size, args.image_size, 3))
            image.set_shape((None, None, 3))
            image = tf.image.resize_images(image,
                                           size=(args.image_height,
                                                 args.image_width))
            #print(image.shape)
        if args.random_flip:
            image = tf.image.random_flip_left_right(image)

        #pylint: disable=no-member
        #image.set_shape((args.image_size, args.image_size, 3))
        image.set_shape((args.image_height, args.image_width, 3))
        if debug:
            image = tf.cast(image, tf.float32)
        else:
            image = tf.image.per_image_standardization(image)
        return image, label

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')

        # Placeholder for the learning rate
        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')

        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')

        #the image is generated by sequence
        with tf.device("/cpu:0"):

            softmax_dataset = tf_data.Dataset.range(args.epoch_size *
                                                    args.max_nrof_epochs * 100)
            softmax_dataset = softmax_dataset.map(lambda x: tf.py_func(
                _sample_people_softmax, [x], [tf.string, tf.int32]))
            softmax_dataset = softmax_dataset.flat_map(_from_tensor_slices)
            softmax_dataset = softmax_dataset.map(_parse_function,
                                                  num_threads=8,
                                                  output_buffer_size=2000)
            softmax_dataset = softmax_dataset.batch(args.num_gpus *
                                                    single_batch_size)
            softmax_iterator = softmax_dataset.make_initializable_iterator()
            softmax_next_element = softmax_iterator.get_next()
            softmax_next_element[0].set_shape(
                (args.num_gpus * single_batch_size, args.image_height,
                 args.image_width, 3))
            softmax_next_element[1].set_shape(args.num_gpus *
                                              single_batch_size)
            batch_image_split = tf.split(softmax_next_element[0],
                                         args.num_gpus)
            batch_label_split = tf.split(softmax_next_element[1],
                                         args.num_gpus)

        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        print('Using optimizer: {}'.format(args.optimizer))
        if args.optimizer == 'ADAGRAD':
            opt = tf.train.AdagradOptimizer(learning_rate)
        elif args.optimizer == 'MOM':
            opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
        elif args.optimizer == 'ADAM':
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=0.9,
                                         beta2=0.999,
                                         epsilon=0.1)
        else:
            raise Exception("Not supported optimizer: {}".format(
                args.optimizer))
        tower_losses = []
        tower_cross = []
        tower_dist = []
        tower_reg = []
        for i in range(args.num_gpus):
            with tf.device("/gpu:" + str(i)):
                with tf.name_scope("tower_" + str(i)) as scope:
                    with slim.arg_scope([slim.model_variable, slim.variable],
                                        device="/cpu:0"):
                        with tf.variable_scope(
                                tf.get_variable_scope()) as var_scope:
                            reuse = False if i == 0 else True
                            #with slim.arg_scope(resnet_v2.resnet_arg_scope(args.weight_decay)):
                            #prelogits, end_points = resnet_v2.resnet_v2_50(batch_image_split[i],is_training=True,
                            #        output_stride=16,num_classes=args.embedding_size,reuse=reuse)
                            #prelogits, end_points = network.inference(batch_image_split[i], args.keep_probability,
                            #    phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size,
                            #    weight_decay=args.weight_decay, reuse=reuse)
                            if args.network == 'sphere_network':
                                prelogits = network.infer(batch_image_split[i])
                            elif args.network == 'resface':
                                prelogits, _ = resface.inference(
                                    batch_image_split[i],
                                    1.0,
                                    weight_decay=args.weight_decay,
                                    reuse=reuse)
                            elif args.network == 'inception_net':
                                prelogits, endpoints = inception_net.inference(
                                    batch_image_split[i],
                                    1,
                                    phase_train=True,
                                    bottleneck_layer_size=args.embedding_size,
                                    weight_decay=args.weight_decay,
                                    reuse=reuse)

                            elif args.network == 'resnet_v2':
                                with slim.arg_scope(
                                        resnet_v2.resnet_arg_scope(
                                            args.weight_decay)):
                                    prelogits, end_points = resnet_v2.resnet_v2_50(
                                        batch_image_split[i],
                                        is_training=True,
                                        output_stride=16,
                                        num_classes=args.embedding_size,
                                        reuse=reuse)
                                    prelogits = tf.squeeze(prelogits,
                                                           axis=[1, 2])

                            else:
                                raise Exception(
                                    "Not supported network: {}".format(
                                        args.network))
                            if args.fc_bn:

                                prelogits = slim.batch_norm(
                                    prelogits,
                                    is_training=True,
                                    decay=0.997,
                                    epsilon=1e-5,
                                    scale=True,
                                    updates_collections=tf.GraphKeys.
                                    UPDATE_OPS,
                                    reuse=reuse,
                                    scope='softmax_bn')
                            if args.loss_type == 'softmax':
                                cross_entropy_mean = utils.softmax_loss(
                                    prelogits, batch_label_split[i],
                                    len(train_set), args.weight_decay, reuse)
                                regularization_losses = tf.get_collection(
                                    tf.GraphKeys.REGULARIZATION_LOSSES)
                                tower_cross.append(cross_entropy_mean)
                                #loss = cross_entropy_mean + args.weight_decay*tf.add_n(regularization_losses)
                                loss = cross_entropy_mean + tf.add_n(
                                    regularization_losses)
                                #tower_dist.append(0)
                                #tower_cross.append(cross_entropy_mean)
                                #tower_th.append(0)
                                tower_losses.append(loss)
                                tower_reg.append(regularization_losses)
                            elif args.loss_type == 'cosface':
                                label_reshape = tf.reshape(
                                    batch_label_split[i], [single_batch_size])
                                label_reshape = tf.cast(
                                    label_reshape, tf.int64)
                                coco_loss = utils.cos_loss(prelogits,
                                                           label_reshape,
                                                           len(train_set),
                                                           reuse,
                                                           alpha=args.alpha,
                                                           scale=args.scale)
                                #regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
                                #reg_loss  = args.weight_decay*tf.add_n(regularization_losses)
                                reg_loss = tf.get_collection(
                                    tf.GraphKeys.REGULARIZATION_LOSSES)
                                loss = coco_loss + reg_loss

                                tower_losses.append(loss)
                                tower_reg.append(reg_loss)

                            #loss = tf.add_n([cross_entropy_mean] + regularization_losses, name='total_loss')
                            tf.get_variable_scope().reuse_variables()
        total_loss = tf.reduce_mean(tower_losses)
        total_reg = tf.reduce_mean(tower_reg)
        losses = {}
        losses['total_loss'] = total_loss
        losses['total_reg'] = total_reg

        grads = opt.compute_gradients(total_loss,
                                      tf.trainable_variables(),
                                      colocate_gradients_with_ops=True)
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.group(apply_gradient_op)

        save_vars = [
            var for var in tf.global_variables()
            if 'Adagrad' not in var.name and 'global_step' not in var.name
        ]

        #saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
        saver = tf.train.Saver(save_vars, max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})

        #sess.run(iterator.initializer)
        sess.run(softmax_iterator.initializer)

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            #pdb.set_trace()

            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                saver.restore(sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size
                if debug:
                    debug_train(args, sess, train_set, epoch,
                                image_batch_gather, enqueue_op,
                                batch_size_placeholder, image_batch_split,
                                image_paths_split, num_per_class_split,
                                image_paths_placeholder,
                                image_paths_split_placeholder,
                                labels_placeholder, labels_batch,
                                num_per_class_placeholder,
                                num_per_class_split_placeholder, len(gpus))
                # Train for one epoch
                train(args, sess, epoch, learning_rate_placeholder,
                      phase_train_placeholder, global_step, losses, train_op,
                      summary_op, summary_writer,
                      args.learning_rate_schedule_file)

                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, step)

    return model_dir
Beispiel #3
0
def main_train(args):

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)
    # Write arguments to a text file
    utils.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))

    # Store some git revision info in a text file in the log directory
    src_path, _ = os.path.split(os.path.realpath(__file__))
    utils.store_revision_info(src_path, log_dir, ' '.join(sys.argv))

    np.random.seed(seed=args.seed)
    train_set = utils.dataset_from_list(
        args.train_data_dir, args.train_list_dir)  # class objects in a list

    #----------------------class definition-------------------------------------
    '''
    class ImageClass():
    "Stores the paths to images for a given class"
    def __init__(self, name, image_paths):
        self.name = name
        self.image_paths = image_paths
  
    def __str__(self):
        return self.name + ', ' + str(len(self.image_paths)) + ' images'
  
    def __len__(self):
        return len(self.image_paths)
    '''

    nrof_classes = len(train_set)
    print('nrof_classes: ', nrof_classes)
    image_list, label_list = utils.get_image_paths_and_labels(train_set)
    print('total images: ', len(image_list))  # label is in the form scalar.
    image_list = np.array(image_list)
    label_list = np.array(label_list, dtype=np.int32)
    dataset_size = len(image_list)
    single_batch_size = args.class_per_batch * args.images_per_class
    indices = list(range(dataset_size))
    np.random.shuffle(indices)

    def _sample_people_softmax(x):  # loading the images in batches.
        global softmax_ind
        if softmax_ind >= dataset_size:
            np.random.shuffle(indices)
            softmax_ind = 0
        true_num_batch = min(single_batch_size, dataset_size - softmax_ind)

        sample_paths = image_list[indices[softmax_ind:softmax_ind +
                                          true_num_batch]]
        sample_images = []

        for item in sample_paths:
            sample_images.append(np.load(str(item)))
            #print(item)
        #print(type(sample_paths[0]))
        sample_labels = label_list[indices[softmax_ind:softmax_ind +
                                           true_num_batch]]
        softmax_ind += true_num_batch
        return (np.expand_dims(np.array(sample_images, dtype=np.float32),
                               axis=4), np.array(sample_labels,
                                                 dtype=np.int32))

    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')
        # Placeholder for the learning rate
        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')
        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
        #the image is generated by sequence
        with tf.device("/cpu:0"):

            softmax_dataset = tf.data.Dataset.range(args.epoch_size *
                                                    args.max_nrof_epochs)
            softmax_dataset = softmax_dataset.map(lambda x: tf.py_func(
                _sample_people_softmax, [x], [tf.float32, tf.int32]))
            softmax_dataset = softmax_dataset.flat_map(_from_tensor_slices)
            softmax_dataset = softmax_dataset.batch(single_batch_size)
            softmax_iterator = softmax_dataset.make_initializable_iterator()
            softmax_next_element = softmax_iterator.get_next()
            softmax_next_element[0].set_shape(
                (single_batch_size, args.image_height, args.image_width,
                 args.image_width, 1))
            softmax_next_element[1].set_shape(single_batch_size)
            batch_image_split = softmax_next_element[0]
            # batch_image_split = tf.expand_dims(batch_image_split, axis = 4)
            batch_label_split = softmax_next_element[1]

        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        print('Using optimizer: {}'.format(args.optimizer))
        if args.optimizer == 'ADAGRAD':
            opt = tf.train.AdagradOptimizer(learning_rate)
        elif args.optimizer == 'SGD':
            opt = tf.train.GradientDescentOptimizer(learning_rate)
        elif args.optimizer == 'MOM':
            opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
        elif args.optimizer == 'ADAM':
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=0.9,
                                         beta2=0.999,
                                         epsilon=0.1)
        else:
            raise Exception("Not supported optimizer: {}".format(
                args.optimizer))

        losses = {}
        with slim.arg_scope([slim.model_variable, slim.variable],
                            device="/cpu:0"):
            with tf.variable_scope(tf.get_variable_scope()) as var_scope:
                reuse = False

                if args.network == 'sphere_network':

                    prelogits = network.infer(batch_image_split,
                                              args.embedding_size)
                else:
                    raise Exception("Not supported network: {}".format(
                        args.network))

                if args.fc_bn:
                    prelogits = slim.batch_norm(prelogits, is_training=True, decay=0.997,epsilon=1e-5,scale=True,\
                        updates_collections=tf.GraphKeys.UPDATE_OPS,reuse=reuse,scope='softmax_bn')

                if args.loss_type == 'softmax':
                    cross_entropy_mean = utils.softmax_loss(
                        prelogits, batch_label_split, len(train_set), 1.0,
                        reuse)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    loss = cross_entropy_mean + args.weight_decay * tf.add_n(
                        regularization_losses)
                    print('************************' +
                          ' Computing the softmax loss')
                    losses['total_loss'] = cross_entropy_mean
                    losses['total_reg'] = args.weight_decay * tf.add_n(
                        regularization_losses)

                elif args.loss_type == 'lmcl':
                    label_reshape = tf.reshape(batch_label_split,
                                               [single_batch_size])
                    label_reshape = tf.cast(label_reshape, tf.int64)
                    coco_loss = utils.cos_loss(prelogits,
                                               label_reshape,
                                               len(train_set),
                                               reuse,
                                               alpha=args.alpha,
                                               scale=args.scale)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    loss = coco_loss + args.weight_decay * tf.add_n(
                        regularization_losses)
                    print('************************' +
                          ' Computing the lmcl loss')
                    losses['total_loss'] = coco_loss
                    losses['total_reg'] = args.weight_decay * tf.add_n(
                        regularization_losses)

                elif args.loss_type == 'center':
                    # center loss
                    center_loss, centers, centers_update_op = get_center_loss(prelogits, label_reshape, args.center_loss_alfa, \
                        args.num_class_train)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    loss = center_loss + args.weight_decay * tf.add_n(
                        regularization_losses)
                    print('************************' +
                          ' Computing the center loss')
                    losses['total_loss'] = center_loss
                    losses['total_reg'] = args.weight_decay * tf.add_n(
                        regularization_losses)

                elif args.loss_type == 'lmccl':
                    cross_entropy_mean = utils.softmax_loss(
                        prelogits, batch_label_split, len(train_set), 1.0,
                        reuse)
                    label_reshape = tf.reshape(batch_label_split,
                                               [single_batch_size])
                    label_reshape = tf.cast(label_reshape, tf.int64)
                    coco_loss = utils.cos_loss(prelogits,
                                               label_reshape,
                                               len(train_set),
                                               reuse,
                                               alpha=args.alpha,
                                               scale=args.scale)
                    center_loss, centers, centers_update_op = get_center_loss(prelogits, label_reshape, args.center_loss_alfa, \
                        args.num_class_train)
                    regularization_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    reg_loss = args.weight_decay * tf.add_n(
                        regularization_losses)
                    loss = coco_loss + reg_loss + args.center_weighting * center_loss + cross_entropy_mean
                    losses[
                        'total_loss_center'] = args.center_weighting * center_loss
                    losses['total_loss_lmcl'] = coco_loss
                    losses['total_loss_softmax'] = cross_entropy_mean
                    losses['total_reg'] = reg_loss

        grads = opt.compute_gradients(loss,
                                      tf.trainable_variables(),
                                      colocate_gradients_with_ops=True)
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # used for updating the centers in the center loss.
        if args.loss_type == 'lmccl' or args.loss_type == 'center':
            with tf.control_dependencies([centers_update_op]):
                with tf.control_dependencies(update_ops):
                    train_op = tf.group(apply_gradient_op)
        else:
            with tf.control_dependencies(update_ops):
                train_op = tf.group(apply_gradient_op)

        save_vars = [
            var for var in tf.global_variables()
            if 'Adagrad' not in var.name and 'global_step' not in var.name
        ]
        saver = tf.train.Saver(save_vars, max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()
        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})

        #sess.run(iterator.initializer)
        sess.run(softmax_iterator.initializer)
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                saver.restore(sess, os.path.expanduser(args.pretrained_model))

            # Training and validation loop
            epoch = 0
            while epoch < args.max_nrof_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size
                if debug:
                    debug_train(args, sess, train_set, epoch, image_batch_gather,\
                     enqueue_op,batch_size_placeholder, image_batch_split,image_paths_split,num_per_class_split,
                            image_paths_placeholder,image_paths_split_placeholder, labels_placeholder, labels_batch,\
                             num_per_class_placeholder,num_per_class_split_placeholder,len(gpus))
                # Train for one epoch
                if args.loss_type == 'lmccl' or args.loss_type == 'center':
                    train_contain_center(args, sess, epoch,
                                         learning_rate_placeholder,
                                         phase_train_placeholder, global_step,
                                         losses, train_op, summary_op,
                                         summary_writer, '', centers_update_op)
                else:
                    train(args, sess, epoch, learning_rate_placeholder,
                          phase_train_placeholder, global_step, losses,
                          train_op, summary_op, summary_writer, '')
                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, step)
    return model_dir
Beispiel #4
0
def main(args):

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir), subdir)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(
            model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    np.random.seed(seed=args.seed)

    print('load data...')
    if args.dataset == 'webface':
        train_set = utils.get_dataset(args.data_dir)
    elif args.dataset == 'mega':
        train_set = utils.dataset_from_cache(args.data_dir)
    #train_set.extend(ic_train_set)
    print('Loaded dataset: {} persons'.format(len(train_set)))
    nrof_classes = len(train_set)
    class_indices = list(np.arange(nrof_classes))

    #np.random.shuffle(class_indices)
    #print(class_indices)

    def _sample_people(x):
        '''We sample people based on tf.data, where we can use transform and prefetch.

        '''
        scale = 1 if args.mine_method != 'simi_online' else args.scale
        image_paths, labels = sample_people(
            train_set, class_indices,
            args.people_per_batch * args.num_gpus * scale,
            args.images_per_person)
        #labels = []
        #print(labels)
        #for i in range(len(num_per_class)):
        #    labels.extend([i]*num_per_class[i])
        return (np.array(image_paths), np.array(labels, dtype=np.int32))

    def _parse_function(filename, label):
        file_contents = tf.read_file(filename)

        image = tf.image.decode_image(file_contents, channels=3)
        #image = tf.image.decode_jpeg(file_contents, channels=3)
        if args.random_crop:
            print('use random crop')
            image = tf.random_crop(image,
                                   [args.image_size, args.image_size, 3])
        else:
            print('Not use random crop')
            #image.set_shape((args.image_size, args.image_size, 3))
            image.set_shape((None, None, 3))
            image = tf.image.resize_images(image,
                                           size=(args.image_size,
                                                 args.image_size))

        if args.random_flip:
            image = tf.image.random_flip_left_right(image)

        #pylint: disable=no-member
        image.set_shape((args.image_size, args.image_size, 3))
        print('img shape', image.shape)
        image = tf.cast(image, tf.float32)
        image = tf.subtract(image, 127.5)
        image = tf.div(image, 128.)
        return image, label

    gpus = range(args.num_gpus)
    print('Model directory: %s' % model_dir)
    print('Log directory: %s' % log_dir)
    if args.pretrained_model:
        print('Pre-trained model: %s' %
              os.path.expanduser(args.pretrained_model))

    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0, trainable=False, name='global_step')
        # Placeholder for the learning rate
        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   name='learning_rate')
        phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')
        #the image is generated by sequence
        single_batch_size = args.people_per_batch * args.images_per_person
        total_batch_size = args.num_gpus * single_batch_size
        with tf.device("/cpu:0"):
            dataset = tf_data.Dataset.range(args.epoch_size *
                                            args.max_nrof_epochs * 100)
            #dataset.repeat(args.max_nrof_epochs)
            #sample people based map
            dataset = dataset.map(lambda x: tf.py_func(_sample_people, [x],
                                                       [tf.string, tf.int32]))
            dataset = dataset.flat_map(_from_tensor_slices)
            dataset = dataset.map(_parse_function, num_parallel_calls=8)
            dataset = dataset.batch(total_batch_size)
            iterator = dataset.make_initializable_iterator()
            next_element = iterator.get_next()
            batch_image_split = tf.split(next_element[0], args.num_gpus)
            batch_label = next_element[1]

            global trip_thresh
            trip_thresh = args.num_gpus * args.people_per_batch * args.images_per_person * 10

        #learning_rate = tf.train.exponential_decay(args.learning_rate, global_step,
        learning_rate = tf.train.exponential_decay(
            learning_rate_placeholder,
            global_step,
            args.learning_rate_decay_epochs * args.epoch_size,
            args.learning_rate_decay_factor,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)

        opt = utils.get_opt(args.optimizer, learning_rate)

        tower_embeddings = []
        tower_feats = []
        for i in range(len(gpus)):
            with tf.device("/gpu:" + str(gpus[i])):
                with tf.name_scope("tower_" + str(gpus[i])) as scope:
                    with slim.arg_scope([slim.model_variable, slim.variable],
                                        device="/cpu:0"):
                        # Build the inference graph
                        with tf.variable_scope(
                                tf.get_variable_scope()) as var_scope:
                            reuse = False if i == 0 else True
                            if args.network == 'resnet_v2':
                                with slim.arg_scope(
                                        resnet_v2.resnet_arg_scope(
                                            args.weight_decay)):
                                    #prelogits, end_points = resnet_v1.resnet_v1_50(batch_image_split[i], is_training=phase_train_placeholder, output_stride=16, num_classes=args.embedding_size, reuse=reuse)
                                    prelogits, end_points = resnet_v2.resnet_v2_50(
                                        batch_image_split[i],
                                        is_training=True,
                                        output_stride=16,
                                        num_classes=args.embedding_size,
                                        reuse=reuse)
                                    prelogits = tf.squeeze(
                                        prelogits, [1, 2],
                                        name='SpatialSqueeze')
                            elif args.network == 'resface':
                                prelogits, end_points = resface.inference(
                                    batch_image_split[i],
                                    1.0,
                                    bottleneck_layer_size=args.embedding_size,
                                    weight_decay=args.weight_decay,
                                    reuse=reuse)
                                print('res face prelogits', prelogits)
                            elif args.network == 'mobilenet':
                                prelogits, net_points = mobilenet.inference(
                                    batch_image_split[i],
                                    bottleneck_layer_size=args.embedding_size,
                                    phase_train=True,
                                    weight_decay=args.weight_decay,
                                    reuse=reuse)
                            if args.fc_bn:
                                print('use fc bn')
                                embeddings = slim.batch_norm(
                                    prelogits,
                                    is_training=True,
                                    decay=0.997,
                                    epsilon=1e-5,
                                    scale=True,
                                    updates_collections=tf.GraphKeys.
                                    UPDATE_OPS,
                                    reuse=reuse,
                                    scope='softmax_bn')
                                embeddings = tf.nn.l2_normalize(
                                    embeddings, 1, 1e-10, name='embeddings')

                            tf.get_variable_scope().reuse_variables()
                        tower_embeddings.append(embeddings)
        embeddings_gather = tf.concat(tower_embeddings,
                                      axis=0,
                                      name='embeddings_concat')
        if args.with_softmax:
            coco_loss = utils.cos_loss(embeddings_gather, batch_label,
                                       len(train_set))
        # select triplet pair by tf op
        with tf.name_scope('triplet_part'):
            #embeddings_norm = tf.nn.l2_normalize(embeddings_gather,axis=1)
            #distances = utils._pairwise_distances(embeddings_norm,squared=True)
            distances = utils._pairwise_distances(embeddings_gather,
                                                  squared=True)
            print('triplet strategy', args.strategy)
            if args.strategy == 'min_and_min':
                pair = tf.py_func(select_triplets_min_min,
                                  [distances, batch_label, args.alpha],
                                  tf.int64)
            elif args.strategy == 'min_and_max':
                pair = tf.py_func(select_triplets_min_max,
                                  [distances, batch_label, args.alpha],
                                  tf.int64)
            elif args.strategy == 'hardest':
                pair = tf.py_func(select_triplets_hardest,
                                  [distances, batch_label, args.alpha],
                                  tf.int64)
            elif args.strategy == 'batch_random':
                pair = tf.py_func(select_triplets_batch_random,
                                  [distances, batch_label, args.alpha],
                                  tf.int64)
            elif args.strategy == 'batch_all':
                pair = tf.py_func(select_triplets_batch_all,
                                  [distances, batch_label, args.alpha],
                                  tf.int64)
            else:
                raise ValueError('Not supported strategy {}'.format(
                    args.strategy))
            triplet_handle = {}
            triplet_handle['embeddings'] = embeddings_gather
            triplet_handle['labels'] = batch_label
            triplet_handle['pair'] = pair
        if args.mine_method == 'online':
            pair_reshape = tf.reshape(pair, [-1])
            embeddings_gather = tf.gather(embeddings_gather, pair_reshape)
        anchor, positive, negative = tf.unstack(
            tf.reshape(embeddings_gather, [-1, 3, args.embedding_size]), 3, 1)
        triplet_loss, pos_d, neg_d = utils.triplet_loss(
            anchor, positive, negative, args.alpha)
        # Calculate the total losses
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        triplet_loss = tf.add_n([triplet_loss])
        total_loss = triplet_loss + tf.add_n(regularization_losses)
        if args.with_softmax:
            total_loss = total_loss + args.softmax_loss_weight * coco_loss
        #total_loss =  tf.add_n(regularization_losses)
        losses = {}
        losses['triplet_loss'] = triplet_loss
        losses['total_loss'] = total_loss

        update_vars = tf.trainable_variables()
        with tf.device("/gpu:" + str(gpus[0])):
            grads = opt.compute_gradients(total_loss,
                                          update_vars,
                                          colocate_gradients_with_ops=True)

        if args.pretrain_softmax:
            softmax_vars = [
                var for var in update_vars if 'centers_var' in var.name
            ]
            print('softmax vars', softmax_vars)
            softmax_grads = opt.compute_gradients(coco_loss, softmax_vars)
            softmax_update_op = opt.apply_gradients(softmax_grads)
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        #update_ops = [op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if 'pair_part' in op.name]
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        print('update ops', update_ops)
        with tf.control_dependencies(update_ops):
            train_op_dep = tf.group(apply_gradient_op)
        train_op = tf.cond(tf.is_nan(triplet_loss),
                           lambda: tf.no_op('no_train'), lambda: train_op_dep)

        save_vars = [
            var for var in tf.global_variables()
            if 'Adagrad' not in var.name and 'global_step' not in var.name
        ]
        restore_vars = [
            var for var in tf.global_variables()
            if 'Adagrad' not in var.name and 'global_step' not in var.name
            and 'pair_part' not in var.name and 'centers_var' not in var.name
        ]
        print('restore vars', restore_vars)
        saver = tf.train.Saver(save_vars, max_to_keep=3)
        restorer = tf.train.Saver(restore_vars, max_to_keep=3)

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                                allow_soft_placement=True))

        # Initialize variables
        sess.run(tf.global_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(tf.local_variables_initializer(),
                 feed_dict={phase_train_placeholder: True})
        sess.run(iterator.initializer)

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        forward_embeddings = []
        with sess.as_default():
            if args.pretrained_model:
                print('Restoring pretrained model: %s' % args.pretrained_model)
                restorer.restore(sess,
                                 os.path.expanduser(args.pretrained_model))
            # Training and validation loop
            if args.pretrain_softmax:
                total_images = len(train_set) * 20
                lr_init = 0.1
                for epoch in range(args.softmax_epoch):
                    for i in range(total_images // total_batch_size):
                        if epoch == 4:
                            lr_init = 0.05
                        if epoch == 7:
                            lr_init = 0.01
                        coco_loss_err, _ = sess.run(
                            [coco_loss, softmax_update_op],
                            feed_dict={
                                phase_train_placeholder: False,
                                learning_rate_placeholder: lr_init
                            })
                        print('{}/{} {} coco loss err:{} lr:{}'.format(
                            i, total_images // total_batch_size, epoch,
                            coco_loss_err, lr_init))

            epoch = 0
            while epoch < args.max_nrof_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size
                # Train for one epoch
                if args.mine_method == 'simi_online':
                    train_simi_online(args, sess, epoch, len(gpus),
                                      embeddings_gather, batch_label,
                                      next_element[0], batch_image_split,
                                      learning_rate_placeholder, learning_rate,
                                      phase_train_placeholder, global_step,
                                      pos_d, neg_d, triplet_handle, losses,
                                      train_op, summary_op, summary_writer,
                                      args.learning_rate_schedule_file)
                elif args.mine_method == 'online':
                    train_online(args, sess, epoch, learning_rate,
                                 phase_train_placeholder, global_step, losses,
                                 train_op, summary_op, summary_writer,
                                 args.learning_rate_schedule_file)
                else:
                    raise ValueError('Not supported mini method {}'.format(
                        args.mine_method))
                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer,
                                             model_dir, subdir, step)
    return model_dir