Example #1
0
def main(args):
    print(args)

    LOG_STEP = 250
    SAVE_STEP = 500
    LOG_ALL_TRAIN_PARAMS = False
    LEARNING_RATE = 1e-4 / args.finetune_level if args.finetune_level > 1 else 1e-3

    with tf.variable_scope('Data_Generator'):
        data_reader = DataReader(data_path=args.data_path)
        train_x, train_y = data_reader.get_instance(
            batch_size=args.batch_size,
            mode='train',
            augmentation_level=args.finetune_level)
        valid_x, valid_y = data_reader.get_instance(
            batch_size=args.batch_size * 2, mode='valid')
        class_num = len(data_reader.dict_class.keys())

    network = model.TeacherNetwork()
    logits, net_dict = network.build_network(
        train_x,
        class_num=len(data_reader.dict_class.keys()),
        reuse=False,
        is_train=True)
    v_logits, v_net_dict = network.build_network(
        valid_x,
        class_num=len(data_reader.dict_class.keys()),
        reuse=True,
        is_train=True,
        dropout=1)
    prelogits = net_dict['PreLogitsFlatten']
    v_prelogits = v_net_dict['PreLogitsFlatten']

    use_center, use_pln, use_triplet, use_him = [False for _ in range(4)]
    pln_factor, center_factor = [0, 0]
    if args.finetune_level == 2:
        use_center, use_pln, use_triplet, use_him = [True for _ in range(4)]
        with tf.variable_scope('Output'):
            embed = slim.fully_connected(prelogits,
                                         128,
                                         tf.identity,
                                         scope='Embedding')
            v_embed = slim.fully_connected(v_prelogits,
                                           128,
                                           tf.identity,
                                           reuse=True,
                                           scope='Embedding')
        pln_factor = 1e-4
        center_factor = 1e-4
    else:
        embed = None
        v_embed = None
        if args.finetune_level == 1:
            use_center = True
            use_pln = True
            pln_factor = 1e-5
            center_factor = 1e-5

    loss_func = utils.LossFunctions(
        prelogit_norm_factor=pln_factor,
        center_loss_factor=center_factor,
    )
    loss, accu = loss_func.calculate_loss(logits,
                                          train_y,
                                          prelogits,
                                          class_num,
                                          use_center_loss=use_center,
                                          embed=embed,
                                          use_triplet_loss=use_triplet,
                                          use_prelogits_norm=use_pln,
                                          use_hard_instance_mining=use_him,
                                          scope_name='Training')
    _, v_accu = loss_func.calculate_loss(v_logits,
                                         valid_y,
                                         v_prelogits,
                                         class_num,
                                         use_center_loss=use_center,
                                         embed=v_embed,
                                         use_triplet_loss=use_triplet,
                                         use_prelogits_norm=use_pln,
                                         use_hard_instance_mining=use_him,
                                         scope_name='Validation')

    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                               global_step,
                                               10000,
                                               0.9,
                                               staircase=True)
    if args.optim_type == 'adam':
        optim = tf.train.AdamOptimizer(learning_rate)
    elif args.optim_type == 'adagrad':
        optim = tf.train.AdagradOptimizer(learning_rate)
    else:
        optim = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optim.minimize(loss)

    train_params = list(
        filter(lambda x: 'Adam' not in x.op.name and 'Inception' in x.op.name,
               tf.contrib.slim.get_variables()))
    saver = tf.train.Saver(var_list=train_params)

    if LOG_ALL_TRAIN_PARAMS:
        for i in train_params:
            tf.summary.histogram(i.op.name, i)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    if args.load:
        saver.restore(sess, args.weight_path + 'teacher.ckpt')

    train_writer = tf.summary.FileWriter(args.log_path, sess.graph)
    merged = tf.summary.merge_all(tf.GraphKeys.SUMMARIES)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    tf.Graph().finalize()
    start_time = time.time()
    step = 0
    while step * args.batch_size / len(
            data_reader.train_label) < args.target_epoch:
        _ = sess.run(train_op)

        if step % LOG_STEP == 0:
            time_cost = (time.time() -
                         start_time) / LOG_STEP if step > 0 else 0
            np_loss, np_accu, np_v_accu, s = sess.run(
                [loss, accu, v_accu, merged])
            train_writer.add_summary(s, step)
            print(
                '======================= Step {} ====================='.format(
                    step))
            print(
                '[Log file saved] {:.2f} secs for one step'.format(time_cost))
            print(
                'Current loss: {:.2f}, train accu: {:.2f}%, valid accu: {:.2f}%'
                .format(np_loss, np_accu, np_v_accu))
            start_time = time.time()

        if step % SAVE_STEP == 0:
            saver.save(sess, args.weight_path + 'teacher.ckpt', step)
            print(
                '[Weights saved] weights saved at {}'.format(args.weight_path +
                                                             'teacher'))

        step += 1

    coord.request_stop()
    coord.join(threads)
def main(args):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    if '{}' in args.model_name:
        args.model_name = args.model_name.format(
            'teacher' if args.is_teacher else 'student')
    print(args)

    with tf.variable_scope('Data_Generator'):
        data_reader = DataReader(data_path=args.data_path)
        valid_x, valid_y = data_reader.get_instance(batch_size=args.batch_size,
                                                    mode='valid')
        valid_num = len(data_reader.valid_img_path)

    if args.is_teacher:
        network = model.TeacherNetwork()
        v_logits, v_net_dict = network.build_network(
            valid_x,
            class_num=len(data_reader.dict_class.keys()),
            reuse=False,
            is_train=False,
            dropout=1)
    else:
        network = model.StudentNetwork(len(data_reader.dict_class.keys()))
        v_logits, v_pre_logit = network.build_network(valid_x,
                                                      False,
                                                      False,
                                                      light=args.light)
    v_pred = tf.nn.softmax(v_logits, -1)
    v_pred = tf.argmax(v_pred, -1, output_type=tf.int32)

    cnt = tf.equal(v_pred, valid_y)

    train_params = tf.contrib.slim.get_variables()
    saver = tf.train.Saver(var_list=train_params)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    saver.restore(sess, args.weight_path + args.model_name)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    tf.Graph().finalize()
    step = 0
    correct = []
    while step * args.batch_size < valid_num:
        np_cnt = sess.run(cnt)
        for i in range(np_cnt.shape[0]):
            if len(correct) < valid_num:
                correct.append(1 if np_cnt[i] else 0)
        step += 1

    accuracy = np.sum(correct) / valid_num
    print('================================================')
    print('[{}] Accuracy on validation set: {:.2f}'.format(
        args.model_name, accuracy * 100))
    print('================================================')

    coord.request_stop()
    coord.join(threads)
def main(args):
    if '{}' in args.model_name:
        args.model_name = args.model_name.format(
            'teacher' if args.is_teacher else 'student')
    print(args)

    with tf.variable_scope('Data_Generator'):
        data_reader = DataReader(data_path=args.data_path)
        valid_x, valid_y = data_reader.get_instance(batch_size=args.batch_size,
                                                    mode='valid')
        valid_num = len(data_reader.valid_img_path)

    if args.is_teacher:
        network = model.TeacherNetwork()
        v_logits, v_net_dict = network.build_network(
            valid_x,
            class_num=len(data_reader.dict_class.keys()),
            reuse=False,
            is_train=False,
            dropout=1)
    else:
        network = model.StudentNetwork(len(data_reader.dict_class.keys()))
        v_logits, v_pre_logit = network.build_network(valid_x, False, False)
    v_pred = tf.nn.softmax(v_logits, -1)
    v_pred = tf.argmax(v_pred, -1, output_type=tf.int32)

    cnt = tf.equal(v_pred, valid_y)

    train_params = tf.contrib.slim.get_variables()
    saver = tf.train.Saver(var_list=train_params)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    saver.restore(sess, args.weight_path + args.model_name)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    false_dict = {}
    tf.Graph().finalize()
    step = 0
    correct = []
    while step * args.batch_size < valid_num:
        np_cnt, np_valid_y = sess.run([cnt, valid_y])
        for i in range(np_cnt.shape[0]):
            if len(correct) < valid_num:
                correct.append(1 if np_cnt[i] else 0)

                if not np_cnt[i]:
                    celeb_id = data_reader.dict_class[np_valid_y[i]]
                    false_dict[
                        celeb_id] = 1 if celeb_id not in false_dict else false_dict[
                            celeb_id] + 1
        step += 1

    coord.request_stop()
    coord.join(threads)

    for celeb_id in false_dict.keys():
        if false_dict[celeb_id] < 3:
            continue
        train_instances = data_reader.dict_instance_id['train'][celeb_id]
        valid_instances = data_reader.dict_instance_id['valid'][celeb_id]

        dh, dw = [5, 8]
        display_img = np.zeros([218 * dh, 178 * dw, 3], np.float32)
        for i, img_path in enumerate(train_instances):
            img = mpimg.imread(img_path) / 255.
            h, w = [int(i / dw), i % dw]
            display_img[h * 218:(h + 1) * 218, w * 178:(w + 1) * 178, :] = img

        v_h = int(len(train_instances) / dh)
        v_h = v_h if v_h < dh - 2 else dh - 2
        for i, img_path in enumerate(valid_instances):
            img = mpimg.imread(img_path) / 255.
            h, w = [int(i / dw) + v_h, i % dw]
            display_img[h * 218:(h + 1) * 218, w * 178:(w + 1) * 178, :] = img

        mpimg.imsave(
            'out/{:02d}_{}.png'.format(false_dict[celeb_id], celeb_id),
            display_img)
Example #4
0
def main(args):
    print(args)

    LOG_STEP = 250
    SAVE_STEP = 500
    LOG_ALL_TRAIN_PARAMS = False
    MODEL_NAME = 'TS{}.ckpt'.format('-light' if args.light else '')
    LEARNING_RATE = 1e-4 / args.finetune_level if args.finetune_level > 1 else 1e-3

    with tf.variable_scope('Data_Generator'):
        data_reader = DataReader(data_path=args.data_path)
        train_x, train_y = data_reader.get_instance(
            batch_size=args.batch_size,
            mode='train',
            augmentation_level=args.finetune_level)
        valid_x, valid_y = data_reader.get_instance(batch_size=args.batch_size,
                                                    mode='valid')
        class_num = len(data_reader.dict_class.keys())

    teacher_net = model.TeacherNetwork()
    teacher_logits, teacher_dict = teacher_net.build_network(train_x,
                                                             class_num,
                                                             reuse=False,
                                                             is_train=False)
    teacher_logits = tf.stop_gradient(teacher_logits)

    network = model.StudentNetwork(len(data_reader.dict_class.keys()))
    logits, prelogits = network.build_network(train_x,
                                              reuse=False,
                                              is_train=True,
                                              light=args.light)
    v_logits, v_prelogits = network.build_network(valid_x,
                                                  reuse=True,
                                                  is_train=True,
                                                  dropout_keep_prob=1,
                                                  light=args.light)

    with tf.variable_scope('SqueezeNeXt/Embedding'):
        t_prelogits = tf.stop_gradient(teacher_dict['PreLogitsFlatten'])
        s_embed = slim.fully_connected(prelogits,
                                       t_prelogits.get_shape().as_list()[-1],
                                       activation_fn=tf.identity)
        s_embed_pred = slim.fully_connected(s_embed,
                                            len(data_reader.dict_class.keys()),
                                            activation_fn=tf.identity)

    with tf.variable_scope('compute_loss'):
        # Euclidean embedding loss
        euclidean_loss = tf.squared_difference(s_embed, t_prelogits)
        euclidean_loss = tf.reduce_mean(euclidean_loss)

        # soft label loss
        with tf.variable_scope('soft_CE'):
            soft_CE = lambda x, y: tf.reduce_mean(
                tf.reduce_sum(-1 * y * tf.log(x + 1e-6), -1))
            s_embed_CE = soft_CE(tf.nn.softmax(s_embed_pred),
                                 tf.nn.softmax(teacher_logits))
            s_CE = soft_CE(tf.nn.softmax(logits),
                           tf.nn.softmax(teacher_logits))

        # hard label loss
        hard_CE = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                           labels=train_y))
        train_loss = euclidean_loss + s_embed_CE + s_CE * .5 + hard_CE

        train_output = tf.argmax(tf.nn.softmax(logits, -1),
                                 -1,
                                 output_type=tf.int32)
        train_accu = tf.where(tf.equal(train_output, train_y),
                              tf.ones_like(train_output),
                              tf.zeros_like(train_output))
        train_accu = tf.reduce_sum(train_accu) / args.batch_size * 100

        valid_output = tf.argmax(tf.nn.softmax(v_logits, -1),
                                 -1,
                                 output_type=tf.int32)
        valid_accu = tf.where(tf.equal(valid_output, valid_y),
                              tf.ones_like(valid_output),
                              tf.zeros_like(valid_output))
        valid_accu = tf.reduce_sum(valid_accu) / args.batch_size * 100

    with tf.variable_scope('Summary'):
        tf.summary.histogram('logit_raw', logits)
        tf.summary.histogram('logit_softmax', train_output)

        tf.summary.scalar('s_embed_CE', s_embed_CE)
        tf.summary.scalar('s_CE', s_CE)
        tf.summary.scalar('euclidean_loss', euclidean_loss)
        tf.summary.scalar('hard_CE', hard_CE)

        tf.summary.scalar('train_loss', train_loss)
        tf.summary.scalar('train_accu', train_accu)
        tf.summary.scalar('valid_accu', valid_accu)

    train_params = list(
        filter(
            lambda x: 'Adam' not in x.op.name and 'SqueezeNeXt' in x.op.name,
            tf.contrib.slim.get_variables_to_restore(
                exclude=['InceptionResnetV1'])))
    teacher_params = list(
        filter(
            lambda x: 'Adam' not in x.op.name and 'Inception' in x.op.name,
            tf.contrib.slim.get_variables_to_restore(exclude=['SqueezeNeXt'])))
    inference_param = train_params[:-4]

    global_step = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                               global_step,
                                               10000,
                                               0.9,
                                               staircase=True)
    if args.optim_type == 'adam':
        optim = tf.train.AdamOptimizer(learning_rate)
    elif args.optim_type == 'adagrad':
        optim = tf.train.AdagradOptimizer(learning_rate)
    else:
        optim = tf.train.GradientDescentOptimizer(learning_rate)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optim.minimize(train_loss,
                                  global_step=global_step,
                                  var_list=train_params)

    saver = tf.train.Saver(var_list=train_params)

    if LOG_ALL_TRAIN_PARAMS:
        for i in train_params:
            tf.summary.histogram(i.op.name, i)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    if args.load:
        saver.restore(sess, args.weight_path + MODEL_NAME)

    teacher_saver = tf.train.Saver(var_list=teacher_params)
    teacher_saver.restore(sess, args.t_weight_path + args.t_model_name)

    train_writer = tf.summary.FileWriter(args.log_path, sess.graph)
    merged = tf.summary.merge_all(tf.GraphKeys.SUMMARIES)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    tf.Graph().finalize()
    start_time = time.time()
    step = 0
    while step * args.batch_size / len(
            data_reader.train_label) < args.target_epoch:
        _ = sess.run(train_op)

        if step % LOG_STEP == 0:
            time_cost = (time.time() -
                         start_time) / LOG_STEP if step > 0 else 0
            loss, accu, v_accu, s = sess.run(
                [train_loss, train_accu, valid_accu, merged])
            train_writer.add_summary(s, step)
            print(
                '======================= Step {} ====================='.format(
                    step))
            print(
                '[Log file saved] {:.2f} secs for one step'.format(time_cost))
            print(
                'Current loss: {:.2f}, train accu: {:.2f}%, valid accu: {:.2f}%'
                .format(loss, accu, v_accu))
            start_time = time.time()

        if step % SAVE_STEP == 0:
            saver.save(sess, args.weight_path + MODEL_NAME, step)
            print(
                '[Weights saved] weights saved at {}'.format(args.weight_path +
                                                             MODEL_NAME))

        step += 1

    coord.request_stop()
    coord.join(threads)