示例#1
0
def main():

    if (not handle_args()):
        # invalid arguments exit program
        print_usage()
        return 1

    with tf.Graph().as_default():
        image = tf.placeholder("float",
                               shape=[1, image_size, image_size, 3],
                               name='input')
        prelogits, _ = network.inference(image,
                                         1.0,
                                         phase_train=False,
                                         bottleneck_layer_size=512)
        normalized = tf.nn.l2_normalize(prelogits, 1, name='l2_normalize')
        output = tf.identity(normalized, name='output')
        saver = tf.train.Saver(tf.global_variables())

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            print('Restoring ' + model_base_file_name)
            saver.restore(sess, model_base_file_name)

            # Save the network for fathom
            dir = os.path.dirname(model_base_file_name)
            dir = os.path.join(dir, 'ncs')
            os.mkdir(dir)
            saver.save(sess, os.path.join(dir, 'model'))
示例#2
0
def main(args):
    if not os.path.exists(args.pretrained_model):
        print('invalid pretrained model path')
        return
    weights = np.load(args.pretrained_model)

    pairs = test_utils.read_pairs('/exports_data/czj/data/lfw/files/pairs.txt')
    imglist, labels = test_utils.get_paths(
        '/exports_data/czj/data/lfw/lfw_aligned/', pairs, '_face_.jpg')
    total_images = len(imglist)

    # ---- build graph ---- #
    input = tf.placeholder(tf.float32,
                           shape=[None, 160, 160, 3],
                           name='image_batch')
    prelogits, _ = inception_resnet_v1.inference(input, 1, phase_train=False)
    embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10)

    # ---- extract ---- #
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            log_device_placement=False,
                                            allow_soft_placement=True))
    with sess.as_default():
        beg_time = time.time()
        to_assign = [
            v.assign(weights[()][v.name][0])
            for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        ]
        sess.run(to_assign)
        print('restore parameters: %.2fsec' % (time.time() - beg_time))

        beg_time = time.time()
        images = load_data(imglist)
        print('load images: %.2fsec' % (time.time() - beg_time))

        beg_time = time.time()
        batch_size = 32
        beg = 0
        end = 0
        features = np.zeros((total_images, 128))
        while end < total_images:
            end = min(beg + batch_size, total_images)
            features[beg:end] = sess.run(embeddings, {input: images[beg:end]})
            beg = end
        print('extract features: %.2fsec' % (time.time() - beg_time))

    tpr, fpr, acc, vr, vr_std, far = test_utils.evaluate(features,
                                                         labels,
                                                         num_folds=10)
    # display
    auc = metrics.auc(fpr, tpr)
    eer = brentq(lambda x: 1. - x - interpolate.interp1d(fpr, tpr)(x), 0., 1.)
    print('Acc:           %1.3f+-%1.3f' % (np.mean(acc), np.std(acc)))
    print('VR@FAR=%2.5f:  %2.5f+-%2.5f' % (far, vr, vr_std))
    print('AUC:           %1.3f' % auc)
    print('EER:           %1.3f' % eer)
    sess.close()
示例#3
0
def convert_facenet(dir,
                    model_base_path,
                    image_size,
                    output_size,
                    prefix=None,
                    do_push=False):
    import facenet
    from models import inception_resnet_v1
    out_dir = os.path.join(dir, 'movidius')
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    dir = os.path.join(dir, "facenet")
    if not os.path.exists(dir):
        os.mkdir(dir)
    tf.reset_default_graph()
    with tf.Graph().as_default():
        logging.info("Load FACENET graph")
        image = tf.placeholder("float",
                               shape=[1, image_size, image_size, 3],
                               name='input')
        prelogits, _ = inception_resnet_v1.inference(
            image, 1.0, phase_train=False, bottleneck_layer_size=output_size)
        normalized = tf.nn.l2_normalize(prelogits, 1, name='l2_normalize')
        output = tf.identity(normalized, name='output')

        # Do not remove
        saver = tf.train.Saver(tf.global_variables())

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            base_name = model_base_path
            meta_file, ckpt_file = facenet.get_model_filenames(base_name)
            logging.info("Restore FACENET graph from %s %s", meta_file,
                         ckpt_file)
            saver = tf.train.import_meta_graph(
                os.path.join(base_name, meta_file))
            saver.restore(sess, os.path.join(base_name, ckpt_file))

            logging.info("Freeze FACENET graph")
            saver.save(sess, os.path.join(dir, 'facenet'))

            cmd = 'mvNCCheck {}/facenet.meta -in input -on output -s 12 -cs 0,1,2 -S 255'.format(
                dir)
            logging.info('Validate Movidius: %s', cmd)
            result = subprocess.check_output(cmd, shell=True).decode()
            logging.info(result)
            result = parse_check_ouput(result, prefix=prefix)
            submit(result)
            cmd = 'mvNCCompile {}/facenet.meta -in input -on output -o {}/facenet.graph -s 12'.format(
                dir, out_dir)
            logging.info('Compile: %s', cmd)
            result = subprocess.check_output(cmd, shell=True).decode()
            logging.info(result)
            if do_push:
                push('facenet', out_dir)
示例#4
0
def main():
    args = parse_args()
    out_dir = '/tmp/facenet'

    with tf.Graph().as_default():
        image = tf.placeholder("float",
                               shape=[1, image_size, image_size, 3],
                               name='input')
        prelogits, _ = inception_resnet_v1.inference(
            image,
            1.0,
            phase_train=False,
            bottleneck_layer_size=args.output_size)
        normalized = tf.nn.l2_normalize(prelogits, 1, name='l2_normalize')
        output = tf.identity(normalized, name='output')

        # Do not remove
        saver = tf.train.Saver(tf.global_variables())

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            base_name = args.model_base_path
            meta_file, ckpt_file = facenet.get_model_filenames(base_name)
            saver = tf.train.import_meta_graph(
                os.path.join(base_name, meta_file))
            saver.restore(sess, os.path.join(base_name, ckpt_file))

            # Save the network for fathom
            if not os.path.isdir(out_dir):
                os.makedirs(out_dir)

            saver.save(sess, out_dir + '/facenet')

    if args.check:
        cmd = 'mvNCCheck {0}/facenet.meta -in input -on output -s 12 -cs 0,1,2 -S 255'.format(
            out_dir)
        print('Running check:\n')
        print(cmd)
        print('')
        print(subprocess.check_output(cmd, shell=True).decode())

    cmd = 'mvNCCompile {0}/facenet.meta -in input -on output -o {1} -s 12'.format(
        out_dir, args.output_file)

    print('Run:\n')
    print(cmd)
    print('')
    print(subprocess.check_output(cmd, shell=True).decode())

    shutil.rmtree(out_dir)
示例#5
0
def main(_):
    if FLAGS.csv_file_path:
        if os.path.exists(FLAGS.csv_file_path) is False:
            csv_dir = FLAGS.csv_file_path.rsplit('/', 1)[0]
            if os.path.exists(csv_dir) is False:
                os.makedirs(csv_dir)
            with open(FLAGS.csv_file_path, 'w') as f:
                writer = csv.writer(f)
                writer.writerow([
                    'Pruned rate', 'Acc Mean', 'Acc Std', 'Epoch No.',
                    'Model size through inference (MB) (Shared part + task-specific part)',
                    'Shared part (MB)', 'Task specific part (MB)',
                    'Whole masks (MB)', 'Task specific masks (MB)',
                    'Task specific batch norm vars (MB)',
                    'Task specific biases (MB)'
                ])

    args, unparsed = parse_arguments(sys.argv[1:])
    FLAGS.filters_expand_ratio = math.sqrt(FLAGS.model_size_expand_ratio)
    FLAGS.history_filters_expand_ratios = [
        math.sqrt(ratio) for ratio in FLAGS.history_model_size_expand_ratio
    ]

    with tf.Graph().as_default():
        with tf.Session() as sess:

            # Read the file containing the pairs used for testing
            pairs = lfw.read_pairs(os.path.expanduser(args.lfw_pairs))

            # Get the paths for the corresponding images
            paths, actual_issame = lfw.get_paths(
                os.path.expanduser(args.lfw_dir), pairs)

            # img = Image.open(paths[0])

            image_paths_placeholder = tf.placeholder(tf.string,
                                                     shape=(None, 1),
                                                     name='image_paths')
            labels_placeholder = tf.placeholder(tf.int32,
                                                shape=(None, 1),
                                                name='labels')
            batch_size_placeholder = tf.placeholder(tf.int32,
                                                    name='batch_size')
            control_placeholder = tf.placeholder(tf.int32,
                                                 shape=(None, 1),
                                                 name='control')
            phase_train_placeholder = tf.placeholder(tf.bool,
                                                     name='phase_train')

            nrof_preprocess_threads = 4
            image_size = (args.image_size, args.image_size)
            eval_input_queue = data_flow_ops.FIFOQueue(
                capacity=2000000,
                dtypes=[tf.string, tf.int32, tf.int32],
                shapes=[(1, ), (1, ), (1, )],
                shared_name=None,
                name=None)
            eval_enqueue_op = eval_input_queue.enqueue_many(
                [
                    image_paths_placeholder, labels_placeholder,
                    control_placeholder
                ],
                name='eval_enqueue_op')
            image_batch, label_batch = facenet.create_input_pipeline(
                eval_input_queue, image_size, nrof_preprocess_threads,
                batch_size_placeholder)
            coord = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=coord, sess=sess)

            # Load the model
            if os.path.isdir(args.model):
                temp_record_file = os.path.join(args.model, 'temp_record.txt')
                checkpoint_file = os.path.join(args.model, 'checkpoint')

                if os.path.exists(temp_record_file) and os.path.exists(
                        checkpoint_file):
                    with open(temp_record_file) as json_file:
                        data = json.load(json_file)
                        max_acc = max(data, key=float)
                        epoch_no = data[max_acc]
                        ckpt_file = args.model + '/model-.ckpt-' + epoch_no

                    with open(checkpoint_file) as f:
                        context = f.read()
                    original_epoch = re.search("(\d)+", context).group()
                    context = context.replace(original_epoch, epoch_no)
                    with open(checkpoint_file, 'w') as f:
                        f.write(context)
                    if os.path.exists(os.path.join(args.model,
                                                   'copied')) is False:
                        os.makedirs(os.path.join(args.model, 'copied'))
                    copyfile(
                        temp_record_file,
                        os.path.join(args.model, 'copied', 'temp_record.txt'))
                    os.remove(temp_record_file)

                elif os.path.exists(checkpoint_file):
                    ckpt = tf.train.get_checkpoint_state(args.model)
                    ckpt_file = ckpt.model_checkpoint_path
                    epoch_no = ckpt_file.rsplit('-', 1)[-1]
                else:
                    print(
                        'No `temp_record.txt` or `checkpoint` in `{}`, you should pass args.model the file path, not the directory'
                        .format(args.model))
                    sys.exit(1)
            else:
                ckpt_file = args.model
                epoch_no = ckpt_file.rsplit('-')[-1]

            # Cannot use meta graph, because we need to dynamically decide batch normalization in regard to current task_id
            # facenet.load_model(args.model, input_map=input_map)
            # embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
            prelogits, _ = network.inference(
                image_batch,
                1.0,
                phase_train=phase_train_placeholder,
                bottleneck_layer_size=args.embedding_size,
                weight_decay=0.0)

            embeddings = tf.nn.l2_normalize(prelogits,
                                            1,
                                            1e-10,
                                            name='embeddings')

            init_fn = slim.assign_from_checkpoint_fn(ckpt_file,
                                                     tf.global_variables())
            init_fn(sess)

            pruned_ratio = 0.0
            model_size = 0.0

            if FLAGS.print_mem or FLAGS.print_mask_info:
                masks = tf.get_collection('masks')

                if FLAGS.print_mask_info:

                    if masks:
                        num_elems_in_each_task_op = {}
                        num_elems_in_tasks_in_masks_op = {
                        }  # two dimentional dictionary
                        num_elems_in_masks_op = []
                        num_remain_elems_in_masks_op = []

                        for task_id in range(1, FLAGS.task_id + 1):
                            num_elems_in_each_task_op[task_id] = tf.constant(
                                0, dtype=tf.int32)
                            num_elems_in_tasks_in_masks_op[task_id] = {}

                        # Define graph
                        for i, mask in enumerate(masks):
                            num_elems_in_masks_op.append(tf.size(mask))
                            num_remain_elems_in_curr_mask = tf.size(mask)
                            for task_id in range(1, FLAGS.task_id + 1):
                                cnt = tf_count(mask, task_id)
                                num_elems_in_tasks_in_masks_op[task_id][
                                    i] = cnt
                                num_elems_in_each_task_op[task_id] = tf.add(
                                    num_elems_in_each_task_op[task_id], cnt)
                                num_remain_elems_in_curr_mask -= cnt

                            num_remain_elems_in_masks_op.append(
                                num_remain_elems_in_curr_mask)

                        num_elems_in_network_op = tf.add_n(
                            num_elems_in_masks_op)

                        print('Calculate pruning status ...')

                        # Doing operation
                        num_elems_in_masks = sess.run(num_elems_in_masks_op)
                        num_elems_in_each_task = sess.run(
                            num_elems_in_each_task_op)
                        num_elems_in_tasks_in_masks = sess.run(
                            num_elems_in_tasks_in_masks_op)
                        num_elems_in_network = sess.run(
                            num_elems_in_network_op)
                        num_remain_elems_in_masks = sess.run(
                            num_remain_elems_in_masks_op)

                        # Print out the result
                        print('Showing pruning status ...')

                        if FLAGS.verbose:
                            for i, mask in enumerate(masks):
                                print('Layer %s: ' % mask.op.name, end='')
                                for task_id in range(1, FLAGS.task_id + 1):
                                    cnt = num_elems_in_tasks_in_masks[task_id][
                                        i]
                                    print('task_%d -> %d/%d (%.2f%%), ' %
                                          (task_id, cnt, num_elems_in_masks[i],
                                           100 * cnt / num_elems_in_masks[i]),
                                          end='')
                                print('remain -> {:.2f}%'.format(
                                    100 * num_remain_elems_in_masks[i] /
                                    num_elems_in_masks[i]))

                        print('Num elems in network: {}'.format(
                            num_elems_in_network))
                        num_elems_of_usued_weights = num_elems_in_network
                        for task_id in range(1, FLAGS.task_id + 1):
                            print('Num elems in task_{}: {}'.format(
                                task_id, num_elems_in_each_task[task_id]))
                            print('Ratio of task_{} to all: {}'.format(
                                task_id, num_elems_in_each_task[task_id] /
                                num_elems_in_network))
                            num_elems_of_usued_weights -= num_elems_in_each_task[
                                task_id]
                        print('Num usued elems in all masks: {}'.format(
                            num_elems_of_usued_weights))

                        # pruned_ratio = num_elems_of_usued_weights / num_elems_in_network
                        # print('Ratio of usused_elem to all: {}'.format(pruned_ratio))
                        # print('Pruning degree relative to task_{}: {:.3f}'.format(FLAGS.task_id, num_elems_of_usued_weights / (num_elems_of_usued_weights + num_elems_in_each_task[FLAGS.task_id])))
                        pruned_ratio_relative_to_all_elems = num_elems_of_usued_weights / num_elems_in_network
                        print('Ratio of usused_elem to all: {}'.format(
                            pruned_ratio_relative_to_all_elems))
                        pruned_ratio_relative_to_curr_task = num_elems_of_usued_weights / (
                            num_elems_of_usued_weights +
                            num_elems_in_each_task[FLAGS.task_id])
                        print('Pruning degree relative to task_{}: {:.3f}'.
                              format(FLAGS.task_id,
                                     pruned_ratio_relative_to_curr_task))

                if FLAGS.print_mem:
                    # Analyze param
                    start_time = time.time()
                    (MB_of_model_through_inference, MB_of_shared_variables,
                     MB_of_task_specific_variables, MB_of_whole_masks,
                     MB_of_task_specific_masks,
                     MB_of_task_specific_batch_norm_variables,
                     MB_of_task_specific_biases
                     ) = model_analyzer.analyze_vars_for_current_task(
                         tf.model_variables(),
                         sess=sess,
                         task_id=FLAGS.task_id,
                         verbose=False)
                    duration = time.time() - start_time
                    print('duration time: {}'.format(duration))

            if FLAGS.eval_once:
                evaluate(
                    sess, eval_enqueue_op, image_paths_placeholder,
                    labels_placeholder, phase_train_placeholder,
                    batch_size_placeholder, control_placeholder, embeddings,
                    label_batch, paths, actual_issame, args.lfw_batch_size,
                    args.lfw_nrof_folds, args.distance_metric,
                    args.subtract_mean, args.use_flipped_images,
                    args.use_fixed_image_standardization, FLAGS.csv_file_path,
                    pruned_ratio, epoch_no, MB_of_model_through_inference,
                    MB_of_shared_variables, MB_of_task_specific_variables,
                    MB_of_whole_masks, MB_of_task_specific_masks,
                    MB_of_task_specific_batch_norm_variables,
                    MB_of_task_specific_biases)
示例#6
0
def main(args):
    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir = os.path.join(args.logs_base_dir, subdir, 'logs')
    model_dir = os.path.join(args.logs_base_dir, subdir, 'models')
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    print('log   dir: %s' % log_dir)
    print('model dir: %s' % model_dir)

    # build the graph
    # ---- load pretrained model ---- #
    pretrained = {}
    # Face model
    pretrained['Face'] = np.load(args.face_model)[()]
    # Nose model
    pretrained['Nose'] = np.load(args.nose_model)[()]
    # Lefteye model
    pretrained['Lefteye'] = np.load(args.lefteye_model)[()]
    # Rightmouth model
    pretrained['Rightmouth'] = np.load(args.rightmouth_model)[()]
    # ---- data preparation ---- #
    image_list, label_list, num_classes = train_utils.get_datasets(
        args.data_dir, args.imglist)
    range_size = len(image_list)

    if args.lfw_dir:
        print('LFW directory: %s' % args.lfw_dir)
        pairs = test_utils.read_pairs(args.lfw_pairs)
        lfw_paths, actual_issame = test_utils.get_paths(
            args.lfw_dir, pairs, args.lfw_file_ext)

    with tf.Graph().as_default():
        # random indices producer
        indices_que = tf.train.range_input_producer(range_size)
        dequeue_op = indices_que.dequeue_many(
            args.batch_size * args.epoch_size, 'index_dequeue')

        tf.set_random_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        global_step = tf.Variable(0, trainable=False)
        # lr_base_pl    = tf.placeholder(tf.float32, name='base_learning_rate')
        lr_fusion_pl = tf.placeholder(tf.float32, name='fusion_learning_rate')
        batch_size_pl = tf.placeholder(tf.int32, name='batch_size')
        phase_train_pl = tf.placeholder(tf.bool, name='phase_train')
        face_pl = tf.placeholder(tf.string, name='image_paths1')  # face images
        nose_pl = tf.placeholder(tf.string, name='image_paths2')  # nose images
        lefteye_pl = tf.placeholder(tf.string,
                                    name='image_paths3')  # left eye images
        rightmouth_pl = tf.placeholder(
            tf.string, name='image_paths4')  # right mouth images
        labels_pl = tf.placeholder(tf.int64, name='labels')

        # define a filename queue
        input_queue = tf.FIFOQueue(
            # [notice: capacity > bach_size*epoch_size]
            capacity=100000,
            dtypes=[tf.string, tf.string, tf.string, tf.string, tf.int64],
            shapes=[(1, ), (1, ), (1, ), (1, ), (1, )],
            shared_name=None,
            name='input_que')
        enque_op = input_queue.enqueue_many(
            [face_pl, nose_pl, lefteye_pl, rightmouth_pl, labels_pl],
            name='enque_op')
        # define 4 readers
        num_threads = 4
        threads_input_list = []
        for _ in range(num_threads):
            imgpath1, imgpath2, imgpath3, imgpath4, label = input_queue.dequeue(
            )
            images1 = []
            images2 = []
            images3 = []
            images4 = []
            # face
            for img_path in tf.unstack(imgpath1):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                # [notice: random crop only used in face image]
                if args.random_crop:
                    img = tf.random_crop(img, [160, 160, 3])
                else:
                    img = tf.image.resize_image_with_crop_or_pad(img, 160, 160)
                # [notice: flip only used in face image or nose patch]
                if args.random_flip:
                    img = tf.image.random_flip_left_right(img)
                img.set_shape((160, 160, 3))
                images1.append(tf.image.per_image_standardization(img))
            # Nose
            for img_path in tf.unstack(imgpath2):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                # [notice: flip only used in face image or nose patch]
                if args.random_flip:
                    img = tf.image.random_flip_left_right(img)
                img.set_shape((160, 160, 3))
                images2.append(tf.image.per_image_standardization(img))
            # Lefteye
            for img_path in tf.unstack(imgpath3):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                img.set_shape((160, 160, 3))
                images3.append(tf.image.per_image_standardization(img))
            # Rightmouth
            for img_path in tf.unstack(imgpath4):
                img_contents = tf.read_file(img_path)
                img = tf.image.decode_jpeg(img_contents)
                img.set_shape((160, 160, 3))
                images4.append(tf.image.per_image_standardization(img))
            threads_input_list.append(
                [images1, images2, images3, images4, label])

        # define 4 buffer queue
        face_batch, nose_batch, lefteye_batch, rightmouth_batch, label_batch = tf.train.batch_join(
            threads_input_list,
            # [notice: here is 'batch_size_pl', not 'batch_size'!!]
            batch_size=batch_size_pl,
            shapes=[
                # [notice: shape of each element should be assigned, otherwise it raises
                # "tensorflow queue shapes must have the same length as dtype" exception]
                (args.image_size, args.image_size, 3),
                (args.image_size, args.image_size, 3),
                (args.image_size, args.image_size, 3),
                (args.image_size, args.image_size, 3),
                ()
            ],
            enqueue_many=True,
            # [notice: how long the prefetching is allowed to fill the queue]
            capacity=4 * num_threads * args.batch_size,
            allow_smaller_final_batch=True)
        print('Total classes: %d' % num_classes)
        print('Total images:  %d' % range_size)
        tf.summary.image('face_images', face_batch, 10)
        tf.summary.image('nose_images', nose_batch, 10)
        tf.summary.image('lefteye_images', lefteye_batch, 10)
        tf.summary.image('rightmouth_images', rightmouth_batch, 10)

        # ---- build graph ---- #
        with tf.variable_scope('BaseModel'):
            with tf.device('/gpu:%d' % args.gpu_id1):
                # embeddings for face model
                features1, _ = inception_resnet_v1.inference(
                    face_batch,
                    args.keep_prob,
                    phase_train=phase_train_pl,
                    weight_decay=args.weight_decay,
                    scope='Face')
            with tf.device('/gpu:%d' % args.gpu_id2):
                # embeddings for nose model
                features2, _ = inception_resnet_v1.inference(
                    nose_batch,
                    args.keep_prob,
                    phase_train=phase_train_pl,
                    weight_decay=args.weight_decay,
                    scope='Nose')
            with tf.device('/gpu:%d' % args.gpu_id3):
                # embeddings for left eye model
                features3, _ = inception_resnet_v1.inference(
                    lefteye_batch,
                    args.keep_prob,
                    phase_train=phase_train_pl,
                    weight_decay=args.weight_decay,
                    scope='Lefteye')
            with tf.device('/gpu:%d' % args.gpu_id4):
                # embeddings for right mouth model
                features4, _ = inception_resnet_v1.inference(
                    rightmouth_batch,
                    args.keep_prob,
                    phase_train=phase_train_pl,
                    weight_decay=args.weight_decay,
                    scope='Rightmouth')
        with tf.device('/gpu:%d' % args.gpu_id5):
            with tf.variable_scope("Fusion"):
                # ---- concatenate ---- #
                concated_features = tf.concat(
                    [features1, features2, features3, features4], 1)
                # prelogits
                prelogits = slim.fully_connected(
                    concated_features,
                    args.fusion_dim,
                    activation_fn=None,
                    weights_initializer=tf.truncated_normal_initializer(
                        stddev=0.1),
                    weights_regularizer=slim.l2_regularizer(args.weight_decay),
                    scope='prelogits',
                    reuse=False)
                # logits
                logits = slim.fully_connected(
                    prelogits,
                    num_classes,
                    activation_fn=None,
                    weights_initializer=tf.truncated_normal_initializer(
                        stddev=0.1),
                    weights_regularizer=slim.l2_regularizer(args.weight_decay),
                    scope='logits',
                    reuse=False)
                # normalized feaures
                # [notice: used in test stage]
                embeddings = tf.nn.l2_normalize(prelogits,
                                                1,
                                                1e-10,
                                                name='embeddings')

            # ---- define loss & train op ---- #
            cross_entropy = -tf.reduce_sum(tf.one_hot(
                indices=tf.cast(label_batch, tf.int32),
                depth=num_classes) * tf.log(tf.nn.softmax(logits) + 1e-10),
                                           reduction_indices=[1])
            cross_entropy_mean = tf.reduce_mean(cross_entropy)
            tf.add_to_collection('losses', cross_entropy_mean)
            # weight decay
            reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            # total loss: cross_entropy + weight_decay
            total_loss = tf.add_n([cross_entropy_mean] + reg_loss,
                                  name='total_loss')
            '''
            lr_base = tf.train.exponential_decay(lr_base_pl,
                global_step,
                args.lr_decay_epochs * args.epoch_size,
                args.lr_decay_factor,
                staircase = True)
            '''
            lr_fusion = tf.train.exponential_decay(lr_fusion_pl,
                                                   global_step,
                                                   args.lr_decay_epochs *
                                                   args.epoch_size,
                                                   args.lr_decay_factor,
                                                   staircase=True)
            # tf.summary.scalar('base_learning_rate', lr_base)
            tf.summary.scalar('fusion_learning_rate', lr_fusion)
            var_list1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='BaseModel')
            var_list2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='Fusion')
            '''
            train_op = train_utils.get_fusion_train_op(
                total_loss, global_step, args.optimizer,
                lr_base, var_list1, lr_fusion, var_list2,
                args.moving_average_decay)
            '''
            train_op = train_utils.get_train_op(total_loss, global_step,
                                                args.optimizer, lr_fusion,
                                                args.moving_average_decay,
                                                var_list2)

        # ---- training ---- #
        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=False,
            # [notice: 'allow_soft_placement' will switch to cpu automatically
            #  when some operations are not supported by GPU]
            allow_soft_placement=True))
        saver = tf.train.Saver(var_list1 + var_list2)
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord, sess=sess)

        with sess.as_default():
            # ---- restore pre-trained parameters ---- #
            to_assign = []
            print("restore pretrained parameters...")
            print("total:", len(var_list1))
            for v in var_list1:
                v_name = v.name  # 'BaseModel/Face/xxx'
                v_name = v_name[v_name.find('/') + 1:]  # 'Face/xxx'
                v_name_1 = v_name[:v_name.find('/')]  # 'Face'
                v_name_2 = v_name[v_name.find('/'):]  # '/xxx'
                print("precess: %s" % v_name, end=" ")
                if v_name_1 in pretrained:
                    to_assign.append(
                        v.assign(pretrained[v_name_1][v_name_2][0]))
                    print("[ok]")
                else:
                    print("[no found]")
                    v.assign(pretrained[v_name_1][v_name_2][0])
                    print("done")
            sess.run(to_assign)

            print("start training ...")
            epoch = 0
            while epoch < args.max_num_epochs:
                step = sess.run(global_step, feed_dict=None)
                epoch = step // args.epoch_size

                # run one epoch
                run_epoch(args, sess, epoch, image_list, label_list,
                          dequeue_op, enque_op, face_pl, nose_pl, lefteye_pl,
                          rightmouth_pl, labels_pl, lr_fusion_pl,
                          phase_train_pl, batch_size_pl, global_step,
                          total_loss, reg_loss, train_op, summary_op,
                          summary_writer)

                # snapshot for currently learnt weights
                snapshot(sess, saver, model_dir, subdir, step)

                # evaluate on LFW
                if args.lfw_dir:
                    evaluate(sess, enque_op, face_pl, nose_pl, lefteye_pl,
                             rightmouth_pl, labels_pl, phase_train_pl,
                             batch_size_pl, embeddings, label_batch, lfw_paths,
                             actual_issame, args.lfw_batch_size,
                             args.lfw_num_folds, log_dir, step, summary_writer)
    sess.close()
def main(args):
    print("main start")
    np.random.seed(seed=args.seed)
    #train_set =   ImageClass list
    train_set = faceUtil.get_dataset(args.data_dir)

    #总类别
    nrof_classes = len(train_set)
    print(nrof_classes)

    #subdir =20171122-112109
    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')

    #log_dir = c:\User\logs\facenet\20171122-
    log_dir = os.path.join(os.path.expanduser(args.logs_base_dir),subdir)
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)
    print("log_dir =",log_dir)

    # model_dir =c:\User/models/facenet/2017;;;
    model_dir = os.path.join(os.path.expanduser(args.models_base_dir), subdir)
    if not os.path.isdir(model_dir):  # Create the model directory if it doesn't exist
        os.makedirs(model_dir)

    print("model_dir =", model_dir)
    pretrained_model = None
    if args.pretrained_model:
        # pretrained_model = os.path.expanduser(args.pretrained_model)
        # pretrained_model = tf.train.get_checkpoint_state(args.pretrained_model)
        pretrained_model = args.pretrained_model
        print('Pre-trained model: %s' % pretrained_model)


    # Write arguments to a text file
    faceUtil.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt'))
    print("write_arguments_to_file")
    with tf.Graph().as_default():
        tf.set_random_seed(args.seed)
        global_step = tf.Variable(0,trainable=False)

        #两个列表 image_list= 图片地址列表, label_list = 对应label列表,两个大小相同
        image_list, label_list  = faceUtil.get_image_paths_and_labels(train_set)
        assert len(image_list) > 0 , 'dataset is empty'
        print("len(image_list) = ",len(image_list))

        # Create a queue that produces indices into the image_list and label_list
        labels = ops.convert_to_tensor(label_list,dtype=tf.int64)
        range_size = array_ops.shape(labels)[0]
        range_size = tf.Print(range_size, [tf.shape(range_size)],message='Shape of range_input_producer range_size : ',summarize=4, first_n=1)

        #产生一个队列,队列包含0到range_size-1的元素,打乱
        index_queue = tf.train.range_input_producer(range_size,num_epochs=None,shuffle=True,seed=None,capacity=32)

        #从index_queue中取出 args.batch_size*args.epoch_size  个元素,用来从image_list, label_list中取出一部分feed给网络
        index_dequeue_op = index_queue.dequeue_many(args.batch_size *  args.epoch_size,'index_dequeue')

        #学习率
        learning_rate_placeholder = tf.placeholder(tf.float32,name='learning_rate')
        #批大小 arg.batch_size
        batch_size_placeholder = tf.placeholder(tf.int32,name='batch_size')
        #是否训练中
        phase_train_placeholder = tf.placeholder(tf.bool,name='phase_train')
        #图像路径 大小 arg.batch_size * arg.epoch_size
        image_paths_placeholder = tf.placeholder(tf.string,shape=[None,1],name='image_paths')
        #图像标签 大小:arg.batch_size * arg.epoch_size
        labels_placeholder = tf.placeholder(tf.int64,shape=[None,1],name='labels')

        #新建一个队列,数据流操作,fifo,先入先出
        input_queue = data_flow_ops.FIFOQueue(capacity=100000,dtypes=[tf.string,tf.int64],shapes=[(1,),(1,)],shared_name=None,name=None)

        # enqueue_many返回的是一个操作 ,入站的数量是 len(image_paths_placeholder) = 从index_queue中取出 args.batch_size*args.epoch_size个元素
        enqueue_op = input_queue.enqueue_many([image_paths_placeholder,labels_placeholder],name='enqueue_op')

        nrof_preprocess_threads = 4
        images_and_labels = []

        for _ in range(nrof_preprocess_threads):
            filenames , label = input_queue.dequeue()
            # label = tf.Print(label,[tf.shape(label)],message='Shape of one thread  input_queue.dequeue label : ',
            #                  summarize=4,first_n=1)
            # filenames = tf.Print(filenames, [tf.shape(filenames)], message='Shape of one thread  input_queue.dequeue filenames : ',
            #                  summarize=4, first_n=1)
            print("one thread  input_queue.dequeue len = ",tf.shape(label))
            images =[]
            for filenames in tf.unstack(filenames):
                file_contents = tf.read_file(filenames)
                image = tf.image.decode_image(file_contents,channels=3)

                if args.random_rotate:
                    image = tf.py_func(faceUtil.random_rotate_image, [image], tf.uint8)

                if args.random_crop:
                    image = tf.random_crop(image,[args.image_size,args.image_size,3])

                else:
                    image = tf.image.resize_image_with_crop_or_pad(image,args.image_size,args.image_size)

                if args.random_flip:
                    image = tf.image.random_flip_left_right(image)

                image.set_shape((args.image_size,args.image_size,3))
                images.append(tf.image.per_image_standardization(image))

            #从队列中取出名字 解析为image 然后加进images_and_labels 可能长度 =  4 *
            images_and_labels.append([images,label])

        #最终一次进入网络的数据: 长应该度 = batch_size_placeholder
        image_batch, label_batch = tf.train.batch_join(images_and_labels,batch_size=batch_size_placeholder,
                                                       shapes=[(args.image_size,args.image_size,3),()],
                                                       enqueue_many = True,
                                                       capacity = 4 * nrof_preprocess_threads *  args.batch_size,
                                                       allow_smaller_final_batch=True)
        print('final input net  image_batch len = ',tf.shape(image_batch))

        image_batch = tf.Print(image_batch, [tf.shape(image_batch)], message='final input net  image_batch shape = ',
                         summarize=4, first_n=1)
        image_batch = tf.identity(image_batch, 'image_batch')
        image_batch = tf.identity(image_batch, 'input')
        label_batch = tf.identity(label_batch, 'label_batch')

        print('Total number of classes: %d' % nrof_classes)
        print('Total number of examples: %d' % len(image_list))

        print('Building training graph')

        # 将指数衰减应用到学习率上
        learning_rate = tf.train.exponential_decay(learning_rate= learning_rate_placeholder,
                                                   global_step = global_step,
                                                   decay_steps=args.learning_rate_decay_epochs * args.epoch_size,
                                                   decay_rate=args.learning_rate_decay_factor,
                                                   staircase = True)
        #decay_steps=args.learning_rate_decay_epochs * args.epoch_size,

        tf.summary.scalar('learning_rate', learning_rate)

        # Build the inference graph
        prelogits, _ = network.inference(image_batch,args.keep_probability,phase_train=phase_train_placeholder,
                                         bottleneck_layer_size=args.embedding_size,weight_decay=args.weight_decay)

        prelogits = tf.Print(prelogits, [tf.shape(prelogits)], message='prelogits shape = ',
                               summarize=4, first_n=1)
        print("prelogits.shape = ",prelogits.get_shape().as_list())

        # logits =slim.fully_connected(prelogits, len(train_set), activation_fn=None,
        #                               weights_initializer=tf.contrib.layers.xavier_initializer(),
        #                               weights_regularizer=slim.l2_regularizer(args.weight_decay),
        #                               scope='Logits', reuse=False)
        #
        # # Calculate the average cross entropy loss across the batch
        # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        #     labels=label_batch, logits=logits, name='cross_entropy_per_example')
        # tf.reduce_mean(cross_entropy, name='cross_entropy')
        _,cross_entropy_mean = soft_loss_nobias(prelogits,label_batch,len(train_set))
        tf.add_to_collection('losses', cross_entropy_mean)

        regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([cross_entropy_mean] + regularization_losses,name='total_loss')

        train_op = faceUtil.train(total_loss, global_step, args.optimizer, learning_rate,
                                  args.moving_average_decay, tf.global_variables(), args.log_histograms)
        # print("global_variables len = {}".format(len(tf.global_variables())))
        # print("local_variables len = {}".format(len(tf.local_variables())))
        # print("trainable_variables len = {}".format(len(tf.trainable_variables())))
        # for v in tf.trainable_variables() :
        #     print("trainable_variables :{}".format(v.name))
        # train_op = faceUtil.train(sphere_loss,global_step,args.optimizer,learning_rate,
        #                   args.moving_average_decay, tf.global_variables(), args.log_histograms)

        #创建saver
        variables = tf.trainable_variables()
        print("variables_trainable len = ", len(variables))
        for v in variables:
             print('variables_trainable : {}'.format(v.name))
        saver = tf.train.Saver(var_list=variables, max_to_keep=2)

        # variables_to_restore  = [v for v in variables if v.name.split('/')[0] != 'Logits']
        # print("variables_trainable len = ",len(variables))
        # print("variables_to_restore len = ",len(variables_to_restore))
        # # for v in variables_to_restore :
        # #     print("variables_to_restore : ",v.name)
        # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3)


        # variables_trainable = tf.trainable_variables()
        # print("variables_trainable len = ",len(variables_trainable))
        # # for v in variables_trainable :
        # #     print('variables_trainable : {}'.format(v.name))
        # variables_to_restore = slim.get_variables_to_restore(include=['InceptionResnetV1'])
        # print("variables_to_restore len = ",len(variables_to_restore))
        # saver = tf.train.Saver(var_list=variables_to_restore,max_to_keep=3)



        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # 能够在gpu上分配的最大内存
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_memory_fraction)
        sess = tf.Session(config=tf.ConfigProto(gpu_options = gpu_options,log_device_placement = False))

        # Initialize variables
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # 获取线程坐标
        coord = tf.train.Coordinator()

        # 将队列中的所有Runner开始执行
        tf.train.start_queue_runners(coord=coord,sess=sess)

        with sess.as_default():
            print('Running training')
            if pretrained_model :
                print('Restoring pretrained model_checkpoint_path: %s' % pretrained_model)
                saver.restore(sess,pretrained_model)

            # Training and validation loop
            print('Running training really')
            epoch = 0
            # 将所有数据过一遍的次数
            while epoch < args.max_nrof_epochs:

                #这里是返回当前的global_step值吗,step可以看做是全局的批处理个数
                step = sess.run(global_step,feed_dict=None)

                #epoch_size是一个epoch中批的个数
                # 这个epoch是全局的批处理个数除以一个epoch中批的个数得到epoch,这个epoch将用于求学习率
                epoch = step // args.epoch_size
                # Train for one epoch
                train(args, sess, epoch, image_list, label_list, index_dequeue_op, enqueue_op, image_paths_placeholder, labels_placeholder,
                    learning_rate_placeholder, phase_train_placeholder, batch_size_placeholder, global_step,
                    total_loss, train_op, summary_op, summary_writer, regularization_losses, args.learning_rate_schedule_file)

                # Save variables and the metagraph if it doesn't exist already
                save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step)

    return model_dir
示例#8
0
    def __init__(self,
                 n_gpu,
                 datatype,
                 uid=None,
                 output_path=None,
                 enable_simul=None,
                 start_posx=None,
                 start_negx=None,
                 classifier_only=False):
        self.n_gpu = n_gpu
        self.datatype = datatype
        self.uid = uid
        self.output_path = output_path
        self.enable_simul = enable_simul
        self.start_posx = start_posx
        self.start_negx = start_negx

        #setup tf
        cfg = tf.ConfigProto()
        cfg.allow_soft_placement = True
        cfg.gpu_options.allow_growth = True
        self.sess = tf.InteractiveSession(config=cfg)

        if datatype == 'face':

            import facenet
            import models.inception_resnet_v1 as featurenet

            self.images_placeholder = tf.placeholder(tf.float32,
                                                     (None, 160, 160, 3))
            self.phase_train_placeholder = tf.placeholder(tf.bool, ())

            prelogits, end_points = featurenet.inference(
                self.images_placeholder,
                1.0,
                phase_train=self.phase_train_placeholder,
                bottleneck_layer_size=512,
                weight_decay=0.0)
            self.featuremap = end_points['PreLogitsFlatten']
            self.featuremap_size = self.featuremap.get_shape()[1]
            self.img_size_for_fm = 160
            facenetloader = tf.train.Saver()
            facenetloader.restore(tf.get_default_session(),
                                  './data/model-20180402-114759.ckpt-275')

            if not classifier_only:
                if os.path.isfile('./data/ffhq256.npy'):
                    self.allimages = np.load('./data/ffhq256.npy')
                else:
                    from config import ffhq_path
                    imagefiles = glob.glob(os.path.join(ffhq_path, '*'))
                    cnt = len(imagefiles)
                    self.allimages = np.zeros((cnt, 256, 256, 3), np.uint8)
                    for i in range(cnt):
                        self.allimages[i] = cv2.resize(
                            cv2.imread(imagefiles[i]), (256, 256))
                    np.save('./data/ffhq256.npy', self.allimages)

                if os.path.isfile('./data/allfeatures.npy'):
                    self.allfeatures = np.load('./data/allfeatures.npy')
                else:
                    self.allfeatures = np.zeros(
                        (len(self.allimages), self.featuremap_size),
                        np.float32)
                    for i in range(int(np.ceil(len(self.allimages) / 32))):
                        images = self.allimages[i * 32:(i + 1) *
                                                32][:, 29:189,
                                                    9:169, :] / 255.0
                        self.allfeatures[i * 32:(i + 1) *
                                         32] = self.extract_features(images)
                    np.save('./data/allfeatures.npy', self.allfeatures)

                self.build_userdis_classifier()
                self.init_general()
                self.logger.info('initialization finished.')

            else:
                self.build_userdis_classifier_test()

        elif datatype == 'bedroom':

            import dnnlib
            import dnnlib.tflib as tflib
            _G, _D, self.Gs = pickle.load(
                open('./data/karras2019stylegan-bedrooms-256x256.pkl', 'rb'))
            self.fmt = dict(func=tflib.convert_images_to_uint8,
                            nchw_to_nhwc=True)
            self.zdim = self.Gs.input_shape[1]

            latents = np.random.normal(size=(20, self.Gs.input_shape[1]))
            images = self.Gs.run(latents,
                                 None,
                                 truncation_psi=0.7,
                                 randomize_noise=False,
                                 output_transform=self.fmt)
            #print(np.shape(images))

            import vgg16
            self.vgg16input = tf.placeholder("float", [None, 224, 224, 3])
            self.vgg = vgg16.Vgg16('./data/vgg16.npy')
            with tf.name_scope("vgg16"):
                self.vgg.build(self.vgg16input)

            self.featuremap_size = 25088
            self.img_size_for_fm = 224
            if not classifier_only:
                self.build_userdis_classifier()
                self.init_general()
                self.logger.info('initialization finished.')
            else:
                self.build_userdis_classifier_test()

        else:
            print('Invalid datatype!')
示例#9
0
def main(_):

    if FLAGS.csv_file_path:
        if os.path.exists(FLAGS.csv_file_path) is False:
            csv_dir = FLAGS.csv_file_path.rsplit('/', 1)[0]
            if os.path.exists(csv_dir) is False:
                os.makedirs(csv_dir)

            if FLAGS.task_name == 'chalearn/age':
                with open(FLAGS.csv_file_path, 'w') as f:
                    writer = csv.writer(f)
                    writer.writerow([
                        'Pruned rate', 'MAE', 'Acc', 'Epoch No.',
                        'Model size through inference (MB) (Shared part + task-specific part)',
                        'Shared part (MB)', 'Task specific part (MB)',
                        'Whole masks (MB)', 'Task specific masks (MB)',
                        'Task specific batch norm vars (MB)',
                        'Task specific biases (MB)'
                    ])
            else:
                with open(FLAGS.csv_file_path, 'w') as f:
                    writer = csv.writer(f)
                    writer.writerow([
                        'Pruned rate', 'Acc', 'Epoch No.',
                        'Model size through inference (MB) (Shared part + task-specific part)',
                        'Shared part (MB)', 'Task specific part (MB)',
                        'Whole masks (MB)', 'Task specific masks (MB)',
                        'Task specific batch norm vars (MB)',
                        'Task specific biases (MB)'
                    ])

    args, unparsed = parse_arguments(sys.argv[1:])
    FLAGS.filters_expand_ratio = math.sqrt(FLAGS.filters_expand_ratio)
    FLAGS.history_filters_expand_ratios = [
        math.sqrt(float(ratio))
        for ratio in FLAGS.history_filters_expand_ratios
    ]

    with tf.Graph().as_default():

        with tf.Session() as sess:
            if 'emotion' in FLAGS.task_name or 'chalearn' in FLAGS.task_name:
                test_data_path = os.path.join(args.data_dir, 'val')
            else:
                test_data_path = os.path.join(args.data_dir, 'test')

            test_set = utils.get_dataset(test_data_path)

            # Get the paths for the corresponding images
            image_list, label_list = facenet.get_image_paths_and_labels(
                test_set)

            image_paths_placeholder = tf.placeholder(tf.string,
                                                     shape=(None, 1),
                                                     name='image_paths')
            labels_placeholder = tf.placeholder(tf.int32,
                                                shape=(None, 1),
                                                name='labels')
            batch_size_placeholder = tf.placeholder(tf.int32,
                                                    name='batch_size')
            control_placeholder = tf.placeholder(tf.int32,
                                                 shape=(None, 1),
                                                 name='control')
            phase_train_placeholder = tf.placeholder(tf.bool,
                                                     name='phase_train')

            nrof_preprocess_threads = 4
            image_size = (args.image_size, args.image_size)
            eval_input_queue = data_flow_ops.FIFOQueue(
                capacity=2000000,
                dtypes=[tf.string, tf.int32, tf.int32],
                shapes=[(1, ), (1, ), (1, )],
                shared_name=None,
                name=None)
            eval_enqueue_op = eval_input_queue.enqueue_many(
                [
                    image_paths_placeholder, labels_placeholder,
                    control_placeholder
                ],
                name='eval_enqueue_op')
            image_batch, label_batch = facenet.create_input_pipeline(
                eval_input_queue, image_size, nrof_preprocess_threads,
                batch_size_placeholder)
            coord = tf.train.Coordinator()
            tf.train.start_queue_runners(coord=coord, sess=sess)

            # Load the model
            if os.path.isdir(args.model):
                temp_record_file = os.path.join(args.model, 'temp_record.txt')
                checkpoint_file = os.path.join(args.model, 'checkpoint')

                if os.path.exists(temp_record_file) and os.path.exists(
                        checkpoint_file):
                    with open(temp_record_file) as json_file:
                        data = json.load(json_file)
                        max_acc = max(data, key=float)
                        epoch_no = data[max_acc]
                        ckpt_file = args.model + '/model-.ckpt-' + epoch_no

                    with open(checkpoint_file) as f:
                        context = f.read()
                    original_epoch = re.search("(\d)+", context).group()
                    context = context.replace(original_epoch, epoch_no)
                    with open(checkpoint_file, 'w') as f:
                        f.write(context)
                    if os.path.exists(os.path.join(args.model,
                                                   'copied')) is False:
                        os.makedirs(os.path.join(args.model, 'copied'))
                    copyfile(
                        temp_record_file,
                        os.path.join(args.model, 'copied', 'temp_record.txt'))
                    os.remove(temp_record_file)

                elif os.path.exists(checkpoint_file):
                    ckpt = tf.train.get_checkpoint_state(args.model)
                    ckpt_file = ckpt.model_checkpoint_path
                    epoch_no = ckpt_file.rsplit('-', 1)[-1]
                else:
                    print(
                        'No `temp_record.txt` or `checkpoint` in `{}`, you should pass args.model the file path, not the directory'
                        .format(args.model))
                    sys.exit(1)
            else:
                ckpt_file = args.model
                epoch_no = ckpt_file.rsplit('-')[-1]

            prelogits, _ = network.inference(
                image_batch,
                1.0,
                phase_train=phase_train_placeholder,
                bottleneck_layer_size=args.embedding_size,
                weight_decay=0.0)

            with tf.variable_scope('task_{}'.format(FLAGS.task_id)):
                if FLAGS.task_name == 'chalearn/age':
                    logits = slim.fully_connected(prelogits,
                                                  100,
                                                  activation_fn=None,
                                                  scope='Logits',
                                                  reuse=False)
                else:
                    logits = slim.fully_connected(prelogits,
                                                  len(test_set),
                                                  activation_fn=None,
                                                  scope='Logits',
                                                  reuse=False)

            # Get output tensor
            if FLAGS.task_name == 'chalearn/age':
                softmax = tf.nn.softmax(logits=logits)
                labels_range = tf.range(1.0, 101.0)  # [1.0, ..., 100.0]
                labels_matrix = tf.broadcast_to(
                    labels_range,
                    [args.test_batch_size, labels_range.shape[0]])
                result_vector = tf.reduce_sum(softmax * labels_matrix, axis=1)
                MAE_error_vector = tf.abs(result_vector -
                                          tf.cast(label_batch, tf.float32))
                MAE_avg_error = tf.reduce_mean(MAE_error_vector)

                correct_prediction = tf.cast(
                    tf.equal(tf.argmax(logits, 1),
                             tf.cast(label_batch, tf.int64)), tf.float32)
                accuracy = tf.reduce_mean(correct_prediction)
                regularization_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                total_loss = tf.add_n([MAE_avg_error] + regularization_losses)

                criterion = MAE_avg_error
            else:
                cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=label_batch,
                    logits=logits,
                    name='cross_entropy_per_example')
                cross_entropy_mean = tf.reduce_mean(cross_entropy)

                correct_prediction = tf.cast(
                    tf.equal(tf.argmax(logits, 1),
                             tf.cast(label_batch, tf.int64)), tf.float32)
                accuracy = tf.reduce_mean(correct_prediction)
                regularization_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                total_loss = tf.add_n([cross_entropy_mean] +
                                      regularization_losses)

                criterion = cross_entropy_mean

            init_fn = slim.assign_from_checkpoint_fn(ckpt_file,
                                                     tf.global_variables())
            init_fn(sess)

            pruned_ratio_relative_to_curr_task = 0.0
            model_size = 0.0
            if FLAGS.print_mem or FLAGS.print_mask_info:
                masks = tf.get_collection('masks')

                if FLAGS.print_mask_info:

                    if masks:
                        num_elems_in_each_task_op = {}
                        num_elems_in_tasks_in_masks_op = {
                        }  # two dimentional dictionary
                        num_elems_in_masks_op = []
                        num_remain_elems_in_masks_op = []

                        for task_id in range(1, FLAGS.task_id + 1):
                            num_elems_in_each_task_op[task_id] = tf.constant(
                                0, dtype=tf.int32)
                            num_elems_in_tasks_in_masks_op[task_id] = {}

                        # Define graph
                        for i, mask in enumerate(masks):
                            num_elems_in_masks_op.append(tf.size(mask))
                            num_remain_elems_in_curr_mask = tf.size(mask)
                            for task_id in range(1, FLAGS.task_id + 1):
                                cnt = tf_count(mask, task_id)
                                num_elems_in_tasks_in_masks_op[task_id][
                                    i] = cnt
                                num_elems_in_each_task_op[task_id] = tf.add(
                                    num_elems_in_each_task_op[task_id], cnt)
                                num_remain_elems_in_curr_mask -= cnt

                            num_remain_elems_in_masks_op.append(
                                num_remain_elems_in_curr_mask)

                        num_elems_in_network_op = tf.add_n(
                            num_elems_in_masks_op)

                        print('Calculate pruning status ...')

                        # Doing operation
                        num_elems_in_masks = sess.run(num_elems_in_masks_op)
                        num_elems_in_each_task = sess.run(
                            num_elems_in_each_task_op)
                        num_elems_in_tasks_in_masks = sess.run(
                            num_elems_in_tasks_in_masks_op)
                        num_elems_in_network = sess.run(
                            num_elems_in_network_op)
                        num_remain_elems_in_masks = sess.run(
                            num_remain_elems_in_masks_op)

                        # Print out the result
                        print('Showing pruning status ...')

                        if FLAGS.verbose:
                            for i, mask in enumerate(masks):
                                print('Layer %s: ' % mask.op.name, end='')
                                for task_id in range(1, FLAGS.task_id + 1):
                                    cnt = num_elems_in_tasks_in_masks[task_id][
                                        i]
                                    print('task_%d -> %d/%d (%.2f%%), ' %
                                          (task_id, cnt, num_elems_in_masks[i],
                                           100 * cnt / num_elems_in_masks[i]),
                                          end='')
                                print('remain -> {:.2f}%'.format(
                                    100 * num_remain_elems_in_masks[i] /
                                    num_elems_in_masks[i]))

                        print('Num elems in network: {}'.format(
                            num_elems_in_network))
                        num_elems_of_usued_weights = num_elems_in_network
                        for task_id in range(1, FLAGS.task_id + 1):
                            print('Num elems in task_{}: {}'.format(
                                task_id, num_elems_in_each_task[task_id]))
                            print('Ratio of task_{} to all: {}'.format(
                                task_id, num_elems_in_each_task[task_id] /
                                num_elems_in_network))
                            num_elems_of_usued_weights -= num_elems_in_each_task[
                                task_id]
                        print('Num usued elems in all masks: {}'.format(
                            num_elems_of_usued_weights))

                        pruned_ratio_relative_to_all_elems = num_elems_of_usued_weights / num_elems_in_network
                        print('Ratio of usused_elem to all: {}'.format(
                            pruned_ratio_relative_to_all_elems))
                        pruned_ratio_relative_to_curr_task = num_elems_of_usued_weights / (
                            num_elems_of_usued_weights +
                            num_elems_in_each_task[FLAGS.task_id])
                        print('Pruning degree relative to task_{}: {:.3f}'.
                              format(FLAGS.task_id,
                                     pruned_ratio_relative_to_curr_task))

                if FLAGS.print_mem:
                    # Analyze param
                    start_time = time.time()
                    (MB_of_model_through_inference, MB_of_shared_variables,
                     MB_of_task_specific_variables, MB_of_whole_masks,
                     MB_of_task_specific_masks,
                     MB_of_task_specific_batch_norm_variables,
                     MB_of_task_specific_biases
                     ) = model_analyzer.analyze_vars_for_current_task(
                         tf.model_variables(),
                         sess=sess,
                         task_id=FLAGS.task_id,
                         verbose=False)
                    duration = time.time() - start_time
                    print('duration time: {}'.format(duration))
            if FLAGS.eval_once:
                validate(
                    args, sess, image_list, label_list, eval_enqueue_op,
                    image_paths_placeholder, labels_placeholder,
                    control_placeholder, phase_train_placeholder,
                    batch_size_placeholder, total_loss, regularization_losses,
                    criterion, accuracy, args.use_fixed_image_standardization,
                    FLAGS.csv_file_path, pruned_ratio_relative_to_curr_task,
                    epoch_no, MB_of_model_through_inference,
                    MB_of_shared_variables, MB_of_task_specific_variables,
                    MB_of_whole_masks, MB_of_task_specific_masks,
                    MB_of_task_specific_batch_norm_variables,
                    MB_of_task_specific_biases)

            return