def demo(category, step): ''' :param category: one of ['blouse', 'skirt', 'outwear', 'dress', 'trousers'] :return: ''' numclass = category_classnum_dict[category] category_labels = category_label_dict[category] img_size = config.IMAGE_SIZE #img_size_list = [int(384 * 0.5),int(384 * 1),int(384 * 1.5),int(384 * 2)] img_size_list = [int(img_size * 1)] with tf.Graph().as_default(): batch_x = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32) with tf.variable_scope('cpn_model'): model1 = mnet.CPN(numclass, 1) model1.build_model(batch_x, False) with tf.variable_scope('cdet'): model2 = mnet.CPN(numclass, 1) model2.build_model(batch_x, False) with tf.Session() as sess: all_vars = slim.get_model_variables() vars1 = [] vars2 = [] for var in all_vars: if 'cpn_model' in var.op.name: vars1.append(var) elif 'cdet' in var.op.name: vars2.append(var) else: raise ValueError('wrong init') ckpt_filename1 = '../stage1/trained_weights_s1/' + category checkpoint_path1 = tf.train.latest_checkpoint(ckpt_filename1) saver1 = tf.train.Saver(var_list=vars1) saver1.restore(sess, checkpoint_path1) ckpt_filename2 = 'trained_weights_s2/' + category checkpoint_path2 = ckpt_filename2 + '/' + str(step) + '.ckpt' #checkpoint_path2 = tf.train.latest_checkpoint(ckpt_filename2) saver2 = tf.train.Saver(var_list=vars2) saver2.restore(sess, checkpoint_path2) dict_list = [] f = open('../data/image_ori/image_test/r2testa/test.csv') list_file = f.read().splitlines() for j in tqdm(range(len(list_file))): temp = list_file[j].split(',') category_t = temp[1] if category_t != category: continue img_id = temp[0] img_full = misc.imread( '../data/image_ori/image_test/r2testa/' + img_id) img_full_ = img_full.copy() img_384_full, scale_384_full, start_index_384_full = util.make_for_input( img_full_, 384) img_384_full = cv2.cvtColor(img_384_full, cv2.COLOR_RGB2BGR) / 256.0 - 0.5 img_384_full = np.expand_dims(img_384_full, 0) heat_for_box = sess.run(model1.finalout, feed_dict={batch_x: img_384_full}) heat_for_box_m = heat_for_box[0, :, :, :] location_box = util.get_location_cpn_n( stage_heatmap=heat_for_box_m) location_box_ori = util.restore_location( ori_img_shape=img_full.shape, label_output=location_box, scale=scale_384_full, start_index=start_index_384_full) label_box = np.array(location_box_ori) x = label_box[:, 1] y = label_box[:, 0] xd = 40 yd = 30 xmin = min(x) ymin = min(y) xmax = max(x) ymax = max(y) xmin = max(0, xmin - xd) xmax = min(img_full.shape[1], xmax + xd) ymin = max(0, ymin - yd) ymax = min(img_full.shape[0], ymax + yd) img = img_full[ymin:ymax, xmin:xmax, :] img_ = img.copy() #utils.visualize_result(img_toshow=img, location=location_box_ori) _, scale_384, start_index_384 = util.make_for_input( img_, img_size) heat_scale = [] for img_size_m in img_size_list: img_scale, scale, start_index = util.make_for_input( img_, img_size_m) img_input = cv2.cvtColor(img_scale, cv2.COLOR_RGB2BGR) / 256.0 - 0.5 img_input = np.expand_dims(img_input, 0) img_2 = cv2.flip(img_, 1) img_scale2, scale2, start_index2 = util.make_for_input( img_2, img_size_m) img_input2 = cv2.cvtColor(img_scale2, cv2.COLOR_RGB2BGR) / 256.0 - 0.5 img_input2 = np.expand_dims(img_input2, 0) stage_heatmap_n = sess.run(model2.finalout, feed_dict={batch_x: img_input}) stage_heatmap_n2 = sess.run( model2.finalout, feed_dict={batch_x: img_input2}) t1 = stage_heatmap_n[0, :, :, :] t2 = stage_heatmap_n2[0, :, :, :] t2 = cv2.flip(t2, 1) left_index = category_change_index[category][0] right_index = category_change_index[category][1] for z in range(len(left_index)): temp = np.copy(t2[:, :, left_index[z]]) t2[:, :, left_index[z]] = np.copy(t2[:, :, right_index[z]]) t2[:, :, right_index[z]] = np.copy(temp) tt = (t1 + t2) / 2.0 tt_384 = cv2.resize(tt, (img_size // 4, img_size // 4)) heat_scale.append(tt_384) heat_scale = np.array(heat_scale).transpose(1, 2, 3, 0) heat_scale_m = np.mean(heat_scale, axis=-1) location_output = util.get_location_cpn_n( stage_heatmap=heat_scale_m) # [y,x] location_in_ori = util.restore_location( ori_img_shape=img.shape, label_output=location_output, scale=scale_384, start_index=start_index_384) location_in_ori = np.array(location_in_ori) location_in_full = np.copy(location_in_ori) for tt in range(location_in_ori.shape[0]): location_in_full[tt, 0] = min(img_full.shape[0], location_in_ori[tt, 0] + ymin) location_in_full[tt, 1] = min(img_full.shape[1], location_in_ori[tt, 1] + xmin) dict_t = {} dict_t['image_id'] = img_id dict_t['image_category'] = category i = 0 for label in all_labels: if label in category_labels: dict_t[label] = str( location_in_full[i][1]) + '_' + str( location_in_full[i][0]) + '_' + str(1) i += 1 else: dict_t[label] = '-1_-1_-1' dict_list.append(dict_t) test_data = DataFrame(data=dict_list, columns=columns) f.close() return test_data
def demo(category, step): ''' :param category: one of ['blouse', 'skirt', 'outwear', 'dress', 'trousers'] :return: ''' numclass = category_classnum_dict[category] category_labels = category_label_dict[category] img_size_list = [ int(img_size * 0.5), int(img_size * 1), int(img_size * 1.5), ] with tf.Graph().as_default(): batch_x = tf.placeholder(shape=[1, None, None, 3], dtype=tf.float32) with tf.variable_scope('cpn_model'): model = mnet.CPN(numclass, 1) model.build_model(batch_x, False) with tf.Session() as sess: saver = tf.train.Saver() ckpt_filename = 'trained_weights_s1/' + category checkpoint_path = ckpt_filename + '/' + str(step) + '.ckpt' saver.restore(sess, checkpoint_path) dict_list = [] f = open('../data/image_ori/val.txt') list_file = f.read().splitlines() for x in tqdm(range(len(list_file))): temp = list_file[x].split(',') category_t = temp[1] if category_t != category: continue img_id = temp[0] img = misc.imread('../data/image_ori/' + img_id) img_ = img.copy() _, scale_384, start_index_384 = util.make_for_input(img_, 512) heat_scale = [] for img_size_m in img_size_list: img_scale, scale, start_index = util.make_for_input( img_, img_size_m) img_input = cv2.cvtColor(img_scale, cv2.COLOR_RGB2BGR) / 256.0 - 0.5 img_input = np.expand_dims(img_input, 0) img_2 = cv2.flip(img_, 1) img_scale2, scale2, start_index2 = util.make_for_input( img_2, img_size_m) img_input2 = cv2.cvtColor(img_scale2, cv2.COLOR_RGB2BGR) / 256.0 - 0.5 img_input2 = np.expand_dims(img_input2, 0) stage_heatmap_n = sess.run(model.finalout, feed_dict={batch_x: img_input}) stage_heatmap_n2 = sess.run( model.finalout, feed_dict={batch_x: img_input2}) t1 = stage_heatmap_n[0, :, :, :] t2 = stage_heatmap_n2[0, :, :, :] t2 = cv2.flip(t2, 1) left_index = category_change_index[category][0] right_index = category_change_index[category][1] for z in range(len(left_index)): temp = np.copy(t2[:, :, left_index[z]]) t2[:, :, left_index[z]] = np.copy(t2[:, :, right_index[z]]) t2[:, :, right_index[z]] = np.copy(temp) tt = (t1 + t2) / 2.0 tt_384 = cv2.resize(tt, (img_size // 4, img_size // 4)) heat_scale.append(tt_384) heat_scale = np.array(heat_scale).transpose(1, 2, 3, 0) heat_scale_m = np.mean(heat_scale, axis=-1) location_output = util.get_location_cpn_n( stage_heatmap=heat_scale_m) # [y,x] location_in_ori = util.restore_location( ori_img_shape=img.shape, label_output=location_output, scale=scale_384, start_index=start_index_384) dict_t = {} dict_t['image_id'] = img_id dict_t['image_category'] = category i = 0 for label in all_labels: if label in category_labels: dict_t[label] = str(location_in_ori[i][1]) + '_' + str( location_in_ori[i][0]) + '_' + str(1) i += 1 else: dict_t[label] = '-1_-1_-1' dict_list.append(dict_t) test_data = DataFrame(data=dict_list, columns=columns) f.close() return test_data
def train(category,steps): ''' :param category: the category what you want to train :return: ''' img_category = config.img_category category_classnum_dict = config.category_classnum_dict category_change_index = config.category_change_index batch_size = config.BATCH_SIZE lr = config.LEARNING_RATE lr_decay_rate = config.LR_DECAY_RATE lr_decay_step = config.LR_DECAY_STEP topk_dict = config.topk_dict batch_size_val = 8 img_size=config.IMAGE_SIZE if category not in img_category: raise ValueError('wrong category') numclass = category_classnum_dict[category] # define data path train_data_path = '../data/tfrecord_s1/train/'+category+'.tfrecord' val_data_path = '../data/tfrecord_s1/val/'+category+'.tfrecord' log_path = 'logs/'+category weights_path = 'weights/'+category if not os.path.exists(weights_path): os.mkdir(weights_path) if not os.path.exists(log_path): os.mkdir(log_path) if not os.path.exists(val_data_path): raise ValueError("can't find val data path") if not os.path.exists(train_data_path): raise ValueError("can't find train data path") with tf.Graph().as_default(): #with tf.device('/cpu:0'): (batch_x,batch_y,batch_pm)= data_input.read_batch(tfr_path=train_data_path, numclass=numclass, change_index=category_change_index[category], argument=True, img_size=img_size, batch_size=batch_size) (batch_x_val,batch_y_val,batch_pm_val) = data_input.read_batch_val(tfr_path=val_data_path, numclass=numclass, change_index=category_change_index[category], argument=False, img_size=img_size, batch_size=batch_size_val) with tf.variable_scope('cpn_model'): model = mnet.CPN(numclass,batch_size) model.build_model(batch_x,True) model.build_loss_cpn(batch_y, batch_pm,lr, lr_decay_rate, lr_decay_step,top_k=topk_dict[category]) with tf.variable_scope('cpn_model',reuse=True): model_val = mnet.CPN(numclass,batch_size_val) model_val.build_model(batch_x_val,False) model_val.build_loss_cpn(batch_y_val,batch_pm_val, lr, lr_decay_rate, lr_decay_step,top_k=topk_dict[category],val=True) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) saver = tf.train.Saver(max_to_keep=None) checkpoint_path = tf.train.latest_checkpoint(weights_path) if checkpoint_path is None: init = tf.global_variables_initializer() sess.run(init) print ('initialize from resnet_v1_101.ckpt') # remove some name def _removename(var): return var.op.name.replace('cpn_model/', '') all_vars = slim.get_model_variables() var_to_restore = [] for var in all_vars: if 'resnet_v1_101' in var.op.name: var_to_restore.append(var) else: continue var_to_restore = {_removename(var): var for var in var_to_restore} saver_part = tf.train.Saver(var_list=var_to_restore) saver_part.restore(sess, 'init_weights/resnet_v1_101.ckpt') else: saver.restore(sess, checkpoint_path) summary_writer = tf.summary.FileWriter(log_path,sess.graph) for i in range(steps): t1= time.time() _,gloss,reloss,reloss2,allloss,\ global_steps,current_lr,summary= sess.run([ model.train_op, model.global_loss, model.refine_loss, model.refine_loss2, model.all_loss, model.global_step, model.lr, model.loss_summary, ]) summary_writer.add_summary(summary, global_steps) print('##========Iter {:>6d}========##'.format(global_steps)) print('Current learning rate: {:.8f}'.format(current_lr)) print('Traing time: {:.4f}'.format(time.time() - t1)) print('gloss loss: {:>.6f}\n'.format(gloss)) print('reloss loss2: {:>.6f}\n'.format(reloss2)) print('reloss loss: {:>.6f}\n'.format(reloss)) print('Total loss: {:>.6f}\n'.format(allloss)) # save the val_loss value to choose which step to use for test if global_steps%50 ==0: gloss_val,reloss_val,allloss_val,summary_val = sess.run([model_val.global_loss, model_val.refine_loss, model_val.all_loss, model_val.loss_summary]) summary_writer.add_summary(summary_val,global_steps) print('********************************************************') print('##========VAL Iter {:>6d}========##'.format(global_steps)) print('gloss loss: {:>.6f}\n\n'.format(gloss_val)) print('reloss loss: {:>.6f}\n\n'.format(reloss_val)) print('Total loss: {:>.6f}\n\n'.format(allloss_val)) print('********************************************************') if global_steps%1000 ==0: saver.save(sess, weights_path+'/{}.ckpt'.format(global_steps)) print('\nModel checkpoint saved...\n') coord.request_stop() coord.join(threads)