Ejemplo n.º 1
0
    def train_model(self):
        if not os.path.exists(self.MODEL_NAME+'_result'):   os.mkdir(self.MODEL_NAME+'_result')
        if not os.path.exists(self.LOGS_DIR):   os.mkdir(self.LOGS_DIR)
        if not os.path.exists(self.CKPT_DIR):   os.mkdir(self.CKPT_DIR)
        if not os.path.exists(self.OUTPUT_DIR): os.mkdir(self.OUTPUT_DIR)
        
        train_set_path = read_data_path(self.TRAIN_IMAGE_PATH, self.TRAIN_LABEL_PATH)
        valid_set_path = read_data_path(self.VALID_IMAGE_PATH, self.VALID_LABEL_PATH)

        ckpt_save_path = os.path.join(self.CKPT_DIR, self.MODEL_NAME+'_'+str(self.N_BATCH)+'_'+str(self.LEARNING_RATE))

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            total_batch = int(len(train_set_path) / self.N_BATCH)
            counter = 0

            self.saver = tf.train.Saver()
            self.writer = tf.summary.FileWriter(self.LOGS_DIR, sess.graph)

            for epoch in tqdm(range(self.N_EPOCH)):
                total_loss = 0
                random.shuffle(train_set_path)           # 매 epoch마다 데이터셋 shuffling
                random.shuffle(valid_set_path)

                for i in range(int(len(train_set_path) / self.N_BATCH)):
                    # print(i)
                    batch_xs_path, batch_ys_path = next_batch(train_set_path, self.N_BATCH, i)
                    batch_xs = read_image(batch_xs_path, [self.RESIZE, self.RESIZE])
                    batch_ys = read_annotation(batch_ys_path, [self.RESIZE, self.RESIZE])

                    feed_dict = {self.input_x: batch_xs, self.label_y: batch_ys, self.is_train: True}

                    _, summary_str ,loss = sess.run([self.optimizer, self.loss_summary, self.loss], feed_dict=feed_dict)
                    self.writer.add_summary(summary_str, counter)
                    counter += 1
                    total_loss += loss

                ## validation 과정
                valid_xs_path, valid_ys_path = next_batch(valid_set_path, 4, 0)
                valid_xs = read_image(valid_xs_path, [self.RESIZE, self.RESIZE])
                valid_ys = read_annotation(valid_ys_path, [self.RESIZE, self.RESIZE])
                
                valid_pred = sess.run(self.pred, feed_dict={self.input_x: valid_xs, self.label_y: valid_ys, self.is_train:False})
                valid_pred = np.squeeze(valid_pred, axis=3)
                
                valid_ys = np.squeeze(valid_ys, axis=3)

                ## plotting and save figure
                img_save_path = self.OUTPUT_DIR + '/' + str(epoch).zfill(3) + '.png'
                draw_plot_segmentation(img_save_path, valid_xs, valid_pred, valid_ys)

                print('\nEpoch:', '%03d' % (epoch + 1), 'Avg Loss: {:.6}\t'.format(total_loss / total_batch))
                self.saver.save(sess, ckpt_save_path+'_'+str(epoch)+'.model', global_step=counter)
            
            self.saver.save(sess, ckpt_save_path+'_'+str(epoch)+'.model', global_step=counter)
            print('Finish save model')

            
Ejemplo n.º 2
0
def store(site, entry):
	id = entry['id']
	cursor = DBM.cursor()
	cursor.execute("select link_id from links where link_id = %s", (id,))
	result = cursor.fetchone()
	if not result:
		return False

	annotation = read_annotation(KEY+str(id))
	if not annotation:
		data = {}
	else:
		data = json.loads(annotation)
		if data[site]:
			if data[site]['ts'] >= entry['ts']:
				return False

			del(data[site])

	data[site] = entry
	data = json.dumps(data)
	if data:
		store_annotation(KEY+str(id), data)
		print data
	cursor.close()
	return entry['ts']
Ejemplo n.º 3
0
def main():
	annotation = read_annotation(KEY+"checked")
	if not annotation:
		data = {}
	else:
		data = json.loads(annotation)
	modified = False

	for site, rss in URLS.iteritems():
		print(site, rss)
		if site in data and data[site]['ts'] > 0:
			last_checked = data[site]['ts']
		else:
			if not site in data:
				data[site] = dict()
			data[site]['ts'] = last_checked = time.time() - 48*3600

		try:
			doc = feedparser.parse(rss, modified=time.gmtime(last_checked))
		except (urllib2.URLError, urllib2.HTTPError, UnicodeEncodeError), e:
			print "connection failed (%s) %s" % (e, rss)
			return False

		if not doc.entries or doc.status == 304:
			return False

		for e in doc.entries:
			ts = analyze_entry(site, e)
			if ts and ts > last_checked:
				modified = True
				last_checked = data[site]['ts'] = ts
Ejemplo n.º 4
0
def store(site, entry):
    id = entry['id']
    cursor = DBM.cursor()
    cursor.execute("select link_id from links where link_id = %s", (id, ))
    result = cursor.fetchone()
    if not result:
        return False

    annotation = read_annotation(KEY + str(id))
    if not annotation:
        data = {}
    else:
        data = json.loads(annotation)
        if data[site]:
            if data[site]['ts'] >= entry['ts']:
                return False

            del (data[site])

    data[site] = entry
    data = json.dumps(data)
    if data:
        store_annotation(KEY + str(id), data)
        print data
    cursor.close()
    return entry['ts']
Ejemplo n.º 5
0
def main():
    annotation = read_annotation(KEY + "checked")
    if not annotation:
        data = {}
    else:
        data = json.loads(annotation)
    modified = False

    for site, rss in URLS.iteritems():
        print(site, rss)
        if site in data and data[site]['ts'] > 0:
            last_checked = data[site]['ts']
        else:
            if not site in data:
                data[site] = dict()
            data[site]['ts'] = last_checked = time.time() - 48 * 3600

        try:
            doc = feedparser.parse(rss, modified=time.gmtime(last_checked))
        except (urllib2.URLError, urllib2.HTTPError, UnicodeEncodeError), e:
            print "connection failed (%s) %s" % (e, rss)
            return False

        if not doc.entries or doc.status == 304:
            return False

        for e in doc.entries:
            ts = analyze_entry(site, e)
            if ts and ts > last_checked:
                modified = True
                last_checked = data[site]['ts'] = ts
Ejemplo n.º 6
0
def inference():
    sess = tf.Session()
    with gfile.FastGFile(os.path.dirname(__file__) + '/model.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')

        # ops = tf.get_default_graph().get_operations()
        # print(ops)

    sess.run(tf.global_variables_initializer())
    input_image = sess.graph.get_tensor_by_name('input_image:0')
    landmark = sess.graph.get_tensor_by_name('ONet/landmark_fc/BiasAdd:0')

    data_file = "/mnt/data/changshuang/data/flickr/"
    anno_file = "/mnt/data/changshuang/data/aflw_anno.txt"
    # data: {'images': images, 'bboxes': bboxes, 'landmarks':landmarks}
    data = read_annotation(data_file, anno_file)
    img_data = list(zip(data["images"], data["bboxes"], data["landmarks"]))
    for img_path, img_bbox, img_landmarks in img_data:
        img = cv.imread(img_path)
        bbox = np.array(img_bbox)

        dets = convert_to_square(bbox)
        dets[:, 0:4] = np.round(dets[:, 0:4])
        h, w, c = img.shape
        [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = pad(dets, w, h)
        num_boxes = dets.shape[0]
        cropped_ims = np.zeros((num_boxes, 48, 48, 3), dtype=np.float32)
        for i in range(num_boxes):  # 17
            tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8)
            tmp[dy[i]:edy[i] + 1, dx[i]:edx[i] + 1, :] = img[y[i]:ey[i] + 1, x[i]:ex[i] + 1, :]
            cropped_ims[i, :, :, :] = (cv.resize(tmp, (48, 48)) - 127.5) / 128

        t1 = time()
        # batch_size = 16
        # minibatch = []
        # cur = 0
        # n = cropped_ims.shape[0]
        # while cur < n:
        #     minibatch.append(cropped_ims[cur: min(cur + batch_size, n), :, :, :])
        #     cur += batch_size
        
        # landmark_pred_list = []
        # for data in minibatch:
        #     m = data.shape[0]
        #     real_size = batch_size
        #     # 最后一组数据不够一个batch的处理,m size as a batch
        #     if m < batch_size:
        #         keep_inds = np.arange(m)  # m = 5 keep_inds = [0,1,2,3,4]
        #         gap = batch_size - m  # batch_size = 7, gap = 2
        #         while gap >= len(keep_inds):
        #             gap -= len(keep_inds)  # -3
        #             keep_inds = np.concatenate((keep_inds, keep_inds))
        #         if gap != 0:
        #             keep_inds = np.concatenate((keep_inds, keep_inds[:gap]))
        #         data = data[keep_inds]
        #         real_size = m
        #     pre_landmarks = sess.run(landmark, feed_dict={input_image: data})
        #     landmark_pred_list.append(pre_landmarks[:real_size])
        # if len(landmark_pred_list) == 0:
        #     continue
        # else:
        #     pre_landmarks = np.concatenate(landmark_pred_list, axis=0)
        pre_landmarks = sess.run(landmark, feed_dict={input_image: cropped_ims})
        print(time() - t1)

        w = bbox[:, 2] - bbox[:, 0] + 1
        h = bbox[:, 3] - bbox[:, 1] + 1
        pre_landmarks[:, 0::2] = (np.tile(w, (5, 1)) * pre_landmarks[:, 0::2].T + np.tile(bbox[:, 0], (5, 1)) - 1).T
        pre_landmarks[:, 1::2] = (np.tile(h, (5, 1)) * pre_landmarks[:, 1::2].T + np.tile(bbox[:, 1], (5, 1)) - 1).T

        for i in range(bbox.shape[0]):
            box_gt = bbox[i, :4]
            corpbbox_gt = [int(box_gt[0]), int(box_gt[1]), int(box_gt[2]), int(box_gt[3])]
            # 画人脸框
            cv.rectangle(img, (corpbbox_gt[0], corpbbox_gt[1]), (corpbbox_gt[2], corpbbox_gt[3]), (0, 225, 255), 2)
        # 画关键点
        for i in range(pre_landmarks.shape[0]):
            for j in range(len(pre_landmarks[i]) // 2):
                cv.circle(img, (int(pre_landmarks[i][2 * j]), int(int(pre_landmarks[i][2 * j + 1]))), 3, (0, 0, 255), -1)
                cv.circle(img, (int(img_landmarks[i][2 * j]), int(int(img_landmarks[i][2 * j + 1]))), 3, (0, 255, 255), -1)
        cv.imshow('show image', img)
        k = cv.waitKey(0) & 0xFF
        if k == ord('q'):
            break
    cv.destroyAllWindows()