Exemple #1
0
def demo(category, step):
    '''

    :param category: one of ['blouse', 'skirt', 'outwear', 'dress', 'trousers']
    :return:
    '''
    numclass = category_classnum_dict[category]
    category_labels = category_label_dict[category]

    img_size = config.IMAGE_SIZE
    #img_size_list = [int(384 * 0.5),int(384 * 1),int(384 * 1.5),int(384 * 2)]
    img_size_list = [int(img_size * 1)]

    with tf.Graph().as_default():
        batch_x = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32)

        with tf.variable_scope('cpn_model'):
            model1 = mnet.CPN(numclass, 1)
            model1.build_model(batch_x, False)

        with tf.variable_scope('cdet'):
            model2 = mnet.CPN(numclass, 1)
            model2.build_model(batch_x, False)

        with tf.Session() as sess:

            all_vars = slim.get_model_variables()

            vars1 = []
            vars2 = []
            for var in all_vars:
                if 'cpn_model' in var.op.name:
                    vars1.append(var)
                elif 'cdet' in var.op.name:
                    vars2.append(var)
                else:
                    raise ValueError('wrong init')

            ckpt_filename1 = '../stage1/trained_weights_s1/' + category
            checkpoint_path1 = tf.train.latest_checkpoint(ckpt_filename1)
            saver1 = tf.train.Saver(var_list=vars1)
            saver1.restore(sess, checkpoint_path1)

            ckpt_filename2 = 'trained_weights_s2/' + category

            checkpoint_path2 = ckpt_filename2 + '/' + str(step) + '.ckpt'
            #checkpoint_path2 = tf.train.latest_checkpoint(ckpt_filename2)
            saver2 = tf.train.Saver(var_list=vars2)
            saver2.restore(sess, checkpoint_path2)

            dict_list = []
            f = open('../data/image_ori/image_test/r2testa/test.csv')
            list_file = f.read().splitlines()
            for j in tqdm(range(len(list_file))):
                temp = list_file[j].split(',')
                category_t = temp[1]
                if category_t != category:
                    continue

                img_id = temp[0]

                img_full = misc.imread(
                    '../data/image_ori/image_test/r2testa/' + img_id)
                img_full_ = img_full.copy()
                img_384_full, scale_384_full, start_index_384_full = util.make_for_input(
                    img_full_, 384)

                img_384_full = cv2.cvtColor(img_384_full,
                                            cv2.COLOR_RGB2BGR) / 256.0 - 0.5
                img_384_full = np.expand_dims(img_384_full, 0)
                heat_for_box = sess.run(model1.finalout,
                                        feed_dict={batch_x: img_384_full})
                heat_for_box_m = heat_for_box[0, :, :, :]
                location_box = util.get_location_cpn_n(
                    stage_heatmap=heat_for_box_m)

                location_box_ori = util.restore_location(
                    ori_img_shape=img_full.shape,
                    label_output=location_box,
                    scale=scale_384_full,
                    start_index=start_index_384_full)
                label_box = np.array(location_box_ori)

                x = label_box[:, 1]
                y = label_box[:, 0]

                xd = 40
                yd = 30
                xmin = min(x)
                ymin = min(y)
                xmax = max(x)
                ymax = max(y)
                xmin = max(0, xmin - xd)
                xmax = min(img_full.shape[1], xmax + xd)
                ymin = max(0, ymin - yd)
                ymax = min(img_full.shape[0], ymax + yd)
                img = img_full[ymin:ymax, xmin:xmax, :]
                img_ = img.copy()

                #utils.visualize_result(img_toshow=img, location=location_box_ori)
                _, scale_384, start_index_384 = util.make_for_input(
                    img_, img_size)
                heat_scale = []
                for img_size_m in img_size_list:
                    img_scale, scale, start_index = util.make_for_input(
                        img_, img_size_m)

                    img_input = cv2.cvtColor(img_scale,
                                             cv2.COLOR_RGB2BGR) / 256.0 - 0.5
                    img_input = np.expand_dims(img_input, 0)

                    img_2 = cv2.flip(img_, 1)

                    img_scale2, scale2, start_index2 = util.make_for_input(
                        img_2, img_size_m)

                    img_input2 = cv2.cvtColor(img_scale2,
                                              cv2.COLOR_RGB2BGR) / 256.0 - 0.5
                    img_input2 = np.expand_dims(img_input2, 0)

                    stage_heatmap_n = sess.run(model2.finalout,
                                               feed_dict={batch_x: img_input})

                    stage_heatmap_n2 = sess.run(
                        model2.finalout, feed_dict={batch_x: img_input2})

                    t1 = stage_heatmap_n[0, :, :, :]
                    t2 = stage_heatmap_n2[0, :, :, :]
                    t2 = cv2.flip(t2, 1)

                    left_index = category_change_index[category][0]
                    right_index = category_change_index[category][1]

                    for z in range(len(left_index)):
                        temp = np.copy(t2[:, :, left_index[z]])
                        t2[:, :, left_index[z]] = np.copy(t2[:, :,
                                                             right_index[z]])
                        t2[:, :, right_index[z]] = np.copy(temp)

                    tt = (t1 + t2) / 2.0

                    tt_384 = cv2.resize(tt, (img_size // 4, img_size // 4))
                    heat_scale.append(tt_384)
                heat_scale = np.array(heat_scale).transpose(1, 2, 3, 0)

                heat_scale_m = np.mean(heat_scale, axis=-1)

                location_output = util.get_location_cpn_n(
                    stage_heatmap=heat_scale_m)  # [y,x]

                location_in_ori = util.restore_location(
                    ori_img_shape=img.shape,
                    label_output=location_output,
                    scale=scale_384,
                    start_index=start_index_384)

                location_in_ori = np.array(location_in_ori)
                location_in_full = np.copy(location_in_ori)
                for tt in range(location_in_ori.shape[0]):
                    location_in_full[tt,
                                     0] = min(img_full.shape[0],
                                              location_in_ori[tt, 0] + ymin)
                    location_in_full[tt,
                                     1] = min(img_full.shape[1],
                                              location_in_ori[tt, 1] + xmin)

                dict_t = {}
                dict_t['image_id'] = img_id
                dict_t['image_category'] = category
                i = 0
                for label in all_labels:
                    if label in category_labels:
                        dict_t[label] = str(
                            location_in_full[i][1]) + '_' + str(
                                location_in_full[i][0]) + '_' + str(1)
                        i += 1
                    else:
                        dict_t[label] = '-1_-1_-1'
                dict_list.append(dict_t)

            test_data = DataFrame(data=dict_list, columns=columns)
            f.close()
    return test_data
def demo(category, step):
    '''
    :param category: one of ['blouse', 'skirt', 'outwear', 'dress', 'trousers']
    :return:
    '''
    numclass = category_classnum_dict[category]
    category_labels = category_label_dict[category]
    img_size_list = [
        int(img_size * 0.5),
        int(img_size * 1),
        int(img_size * 1.5),
    ]

    with tf.Graph().as_default():
        batch_x = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32)
        with tf.variable_scope('cpn_model'):
            model = mnet.CPN(numclass, 1)
            model.build_model(batch_x, False)

        with tf.Session() as sess:

            saver = tf.train.Saver()

            ckpt_filename = 'trained_weights_s1/' + category
            checkpoint_path = ckpt_filename + '/' + str(step) + '.ckpt'
            saver.restore(sess, checkpoint_path)

            dict_list = []
            f = open('../data/image_ori/val.txt')
            list_file = f.read().splitlines()
            for x in tqdm(range(len(list_file))):
                temp = list_file[x].split(',')
                category_t = temp[1]
                if category_t != category:
                    continue

                img_id = temp[0]
                img = misc.imread('../data/image_ori/' + img_id)
                img_ = img.copy()

                _, scale_384, start_index_384 = util.make_for_input(img_, 512)
                heat_scale = []
                for img_size_m in img_size_list:
                    img_scale, scale, start_index = util.make_for_input(
                        img_, img_size_m)

                    img_input = cv2.cvtColor(img_scale,
                                             cv2.COLOR_RGB2BGR) / 256.0 - 0.5
                    img_input = np.expand_dims(img_input, 0)

                    img_2 = cv2.flip(img_, 1)

                    img_scale2, scale2, start_index2 = util.make_for_input(
                        img_2, img_size_m)

                    img_input2 = cv2.cvtColor(img_scale2,
                                              cv2.COLOR_RGB2BGR) / 256.0 - 0.5
                    img_input2 = np.expand_dims(img_input2, 0)

                    stage_heatmap_n = sess.run(model.finalout,
                                               feed_dict={batch_x: img_input})

                    stage_heatmap_n2 = sess.run(
                        model.finalout, feed_dict={batch_x: img_input2})

                    t1 = stage_heatmap_n[0, :, :, :]
                    t2 = stage_heatmap_n2[0, :, :, :]
                    t2 = cv2.flip(t2, 1)

                    left_index = category_change_index[category][0]
                    right_index = category_change_index[category][1]

                    for z in range(len(left_index)):
                        temp = np.copy(t2[:, :, left_index[z]])
                        t2[:, :, left_index[z]] = np.copy(t2[:, :,
                                                             right_index[z]])
                        t2[:, :, right_index[z]] = np.copy(temp)

                    tt = (t1 + t2) / 2.0

                    tt_384 = cv2.resize(tt, (img_size // 4, img_size // 4))
                    heat_scale.append(tt_384)
                heat_scale = np.array(heat_scale).transpose(1, 2, 3, 0)

                heat_scale_m = np.mean(heat_scale, axis=-1)

                location_output = util.get_location_cpn_n(
                    stage_heatmap=heat_scale_m)  # [y,x]

                location_in_ori = util.restore_location(
                    ori_img_shape=img.shape,
                    label_output=location_output,
                    scale=scale_384,
                    start_index=start_index_384)

                dict_t = {}
                dict_t['image_id'] = img_id
                dict_t['image_category'] = category
                i = 0
                for label in all_labels:
                    if label in category_labels:
                        dict_t[label] = str(location_in_ori[i][1]) + '_' + str(
                            location_in_ori[i][0]) + '_' + str(1)
                        i += 1
                    else:
                        dict_t[label] = '-1_-1_-1'
                dict_list.append(dict_t)

            test_data = DataFrame(data=dict_list, columns=columns)
            f.close()
    return test_data
def train(category,steps):
    '''

    :param category: the category what you want to train
    :return:
    '''

    img_category = config.img_category
    category_classnum_dict = config.category_classnum_dict
    category_change_index = config.category_change_index
    batch_size = config.BATCH_SIZE
    lr = config.LEARNING_RATE
    lr_decay_rate = config.LR_DECAY_RATE
    lr_decay_step = config.LR_DECAY_STEP
    topk_dict = config.topk_dict
    batch_size_val = 8

    img_size=config.IMAGE_SIZE

    if category not in img_category:
        raise ValueError('wrong category')

    numclass = category_classnum_dict[category]

    # define data path
    train_data_path = '../data/tfrecord_s1/train/'+category+'.tfrecord'
    val_data_path = '../data/tfrecord_s1/val/'+category+'.tfrecord'
    log_path = 'logs/'+category
    weights_path = 'weights/'+category
    if not os.path.exists(weights_path):
        os.mkdir(weights_path)
    if not os.path.exists(log_path):
        os.mkdir(log_path)
    if not os.path.exists(val_data_path):
        raise ValueError("can't find val data path")
    if not os.path.exists(train_data_path):
        raise ValueError("can't find train data path")

    with tf.Graph().as_default():
        #with tf.device('/cpu:0'):
        (batch_x,batch_y,batch_pm)= data_input.read_batch(tfr_path=train_data_path,
                                                           numclass=numclass,
                                                           change_index=category_change_index[category],
                                                           argument=True,
                                                           img_size=img_size,
                                                           batch_size=batch_size)
        (batch_x_val,batch_y_val,batch_pm_val) = data_input.read_batch_val(tfr_path=val_data_path,
                                                                            numclass=numclass,
                                                                            change_index=category_change_index[category],
                                                                            argument=False,
                                                                            img_size=img_size,
                                                                            batch_size=batch_size_val)
        with tf.variable_scope('cpn_model'):
            model = mnet.CPN(numclass,batch_size)
            model.build_model(batch_x,True)
            model.build_loss_cpn(batch_y, batch_pm,lr, lr_decay_rate, lr_decay_step,top_k=topk_dict[category])
        with tf.variable_scope('cpn_model',reuse=True):
            model_val = mnet.CPN(numclass,batch_size_val)
            model_val.build_model(batch_x_val,False)
            model_val.build_loss_cpn(batch_y_val,batch_pm_val, lr, lr_decay_rate, lr_decay_step,top_k=topk_dict[category],val=True)




        with tf.Session() as sess:

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)



            saver = tf.train.Saver(max_to_keep=None)

            checkpoint_path = tf.train.latest_checkpoint(weights_path)
            if checkpoint_path is None:
                init = tf.global_variables_initializer()
                sess.run(init)

                print ('initialize from resnet_v1_101.ckpt')
                # remove some name
                def _removename(var):
                    return var.op.name.replace('cpn_model/', '')

                all_vars = slim.get_model_variables()
                var_to_restore = []
                for var in all_vars:
                    if 'resnet_v1_101' in var.op.name:
                        var_to_restore.append(var)
                    else:
                        continue
                var_to_restore = {_removename(var): var for var in var_to_restore}
                saver_part = tf.train.Saver(var_list=var_to_restore)
                saver_part.restore(sess, 'init_weights/resnet_v1_101.ckpt')

            else:
                saver.restore(sess, checkpoint_path)



            summary_writer = tf.summary.FileWriter(log_path,sess.graph)


            for i in range(steps):
                t1= time.time()
                _,gloss,reloss,reloss2,allloss,\
                global_steps,current_lr,summary= sess.run([   model.train_op,
                                                              model.global_loss,
                                                              model.refine_loss,
                                                              model.refine_loss2,
                                                              model.all_loss,
                                                              model.global_step,
                                                              model.lr,
                                                              model.loss_summary,
                                                                  ])
                summary_writer.add_summary(summary, global_steps)


                print('##========Iter {:>6d}========##'.format(global_steps))
                print('Current learning rate: {:.8f}'.format(current_lr))
                print('Traing time: {:.4f}'.format(time.time() - t1))
                print('gloss loss: {:>.6f}\n'.format(gloss))
                print('reloss loss2: {:>.6f}\n'.format(reloss2))
                print('reloss loss: {:>.6f}\n'.format(reloss))
                print('Total loss: {:>.6f}\n'.format(allloss))




                # save the val_loss value to choose which step to use for test
                if global_steps%50 ==0:
                    gloss_val,reloss_val,allloss_val,summary_val = sess.run([model_val.global_loss,
                                                                             model_val.refine_loss,
                                                                             model_val.all_loss,
                                                                             model_val.loss_summary])
                    summary_writer.add_summary(summary_val,global_steps)

                    print('********************************************************')
                    print('##========VAL Iter {:>6d}========##'.format(global_steps))
                    print('gloss loss: {:>.6f}\n\n'.format(gloss_val))
                    print('reloss loss: {:>.6f}\n\n'.format(reloss_val))
                    print('Total loss: {:>.6f}\n\n'.format(allloss_val))
                    print('********************************************************')
                if global_steps%1000 ==0:
                    saver.save(sess, weights_path+'/{}.ckpt'.format(global_steps))
                    print('\nModel checkpoint saved...\n')



            coord.request_stop()
            coord.join(threads)