示例#1
0
def gen_TFRecord_from_file(out_dir, out_filename, bag_filename, flip=False):
	packager = DataPackager(flip=flip)
	bag = rosbag.Bag(bag_filename)	

	output_filenames = []

	#######################
	##  Get Label Info   ##
	#######################

	example_id = out_filename

	file_end = bag_filename.find(".bag")
	label_code = bag_filename[file_end-5:file_end]
	print("")
	print("bag_filename: ", bag_filename)
	print("label_code:", label_code)

	img_lab, opt_lab, aud_lab = 0,0,0
	if("z" in example_id):
		img_lab = 1
	if("g" in example_id):
		opt_lab = 1
	if("a" in example_id):
		aud_lab = 1
	total_lab = (img_lab+opt_lab+aud_lab > 0)

	print(example_id)
	print(img_lab, opt_lab, aud_lab, ':', total_lab)

	end_file = ".tfrecord"
	if(flip):
		end_file = "_flip"+end_file

	#######################
	##     READ FILE     ##
	#######################

	p_t = 0

	stored_data = []
	for topic, msg, t in bag.read_messages(topics=topic_names):
		if(topic == topic_names[0]):
			
			last_action = str(msg.data)

			if(msg.data > 0):
				# perform data pre-processing steps
				packager.formatOutput()

				if(msg.data == 1):
					print("packager.getImgStack().shape: ", packager.getImgStack().shape)
					stored_data = {
						"img_raw": packager.getImgStack()[:], "img_lab": 0, 
						"aud_raw": packager.getAudStack()[:], "aud_lab": 0, 
						"p_t": p_t,
						"total_lab": int(last_action),
						"example_id": example_id}
					p_t += 1
					
				elif(msg.data > 1):
					break

			packager.reset()
		elif(topic == topic_names[1]):
			packager.imgCallback(msg)
		elif(topic == topic_names[2]):
			packager.audCallback(msg)

	if(p_t > 0):
		ex = make_sequence_example (
			img_raw=stored_data["img_raw"], img_lab=stored_data["img_lab"], 
			aud_raw=stored_data["aud_raw"], aud_lab=stored_data["aud_lab"], 
			p_t=stored_data["p_t"], 
			first_action=stored_data["total_lab"],
			example_id=stored_data["example_id"],
			img_raw2=packager.getImgStack(), 
			aud_raw2=packager.getAudStack(),
			second_action=int(last_action))
		output_filename = out_dir+out_filename+"_"+str(stored_data["total_lab"])+end_file
		output_filenames.append(output_filename)
		writer = tf.python_io.TFRecordWriter(output_filename)
		writer.write(ex.SerializeToString())
		writer.close()

	# generate TFRecord data
	ex = make_sequence_example (
		img_raw=packager.getImgStack(), img_lab=img_lab, 
		aud_raw=packager.getAudStack(), aud_lab=aud_lab, 
		p_t=p_t, 
		first_action=int(last_action),
		example_id=example_id)
	print("last_action:", msg.data, int(last_action))

	# write TFRecord data to file
	output_filename = out_dir+out_filename+"_"+last_action+end_file
	output_filenames.append(output_filename)
	writer = tf.python_io.TFRecordWriter(output_filename)
	writer.write(ex.SerializeToString())
	writer.close()

	packager.reset()
	bag.close()

	return output_filenames