def generate_place_test_data(demo_path, target_path, batch_size=1, index=0): # color_num = ['color_blue', 'color_green', 'color_orange', 'color_yellow'] # object_num = ['object_type_animal', 'object_type_car', 'object_type_dinosaur', 'object_type_tool'] # print('demo_path', demo_path) # print('target_path', target_path) if FLAGS.demo_type == 'robot': obsa, statea, actiona = read_data.Read_Robot_Data2(demo_path, FLAGS.T, index) elif FLAGS.demo_type == 'human': obsa, statea, actiona = read_data.Read_Human_Data2(demo_path, FLAGS.T, index) obsb, stateb, actionb = read_data.Read_Robot_Data2(target_path, FLAGS.T, index) obsas = np.reshape(obsa, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels]) obsbs = np.reshape(obsb, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels]) actionas = np.reshape(actiona, [batch_size, FLAGS.T, FLAGS.output_data]) actionbs = np.reshape(actionb, [batch_size, FLAGS.T, FLAGS.output_data]) # print('actionas', actionas) # print('actionbs', actionbs) stateas = np.zeros([batch_size, FLAGS.T, FLAGS.output_data]) statebs = np.zeros([batch_size, FLAGS.T, FLAGS.output_data]) return obsas, obsbs, actionas, actionbs, stateas, statebs
def generate_data(if_train=True): if if_train: batch_size = FLAGS.meta_batch_size else: batch_size = FLAGS.meta_test_batch_size color_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.color_num print('color_list', color_list) object_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.object_num print('object_list', object_list) if if_train: task_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.train_task_num else: task_list = np.random.randint(FLAGS.train_task_num, FLAGS.task_num, size=batch_size) print('task_list', task_list) demo_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.demo_num print('demo_list', demo_list) target_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.demo_num print('target_list', target_list) obsas = [] obsbs = [] stateas = [] statebs = [] actionas = [] actionbs = [] color_num = ['color_blue', 'color_green', 'color_orange', 'color_yellow'] # color_num = ['color_blue', 'color_green', 'color_orange'] object_num = ['object_type_animal', 'object_type_car', 'object_type_dinosaur', 'object_type_tool'] for element in range(0, batch_size): if if_train: demo_path = '%s/%s/%s/%s/task_%d/demo_%d' % ( FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]], FLAGS.demo_type, task_list[element], demo_list[element]) target_path = '%s/%s/%s/%s/task_%d/demo_%d' % ( FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]], FLAGS.target_type, task_list[element], target_list[element]) else: demo_path = '%s/%s/%s/%s/task_%d/demo_%d' % ( FLAGS.data_path, color_num[-1], object_num[object_list[element]], FLAGS.demo_type, task_list[element], demo_list[element]) target_path = '%s/%s/%s/%s/task_%d/demo_%d' % ( FLAGS.data_path, color_num[-1], object_num[object_list[element]], FLAGS.target_type, task_list[element], target_list[element]) # print('demo_path', demo_path) # print('target_path', target_path) index = np.random.randint(0, 20) if FLAGS.demo_type == 'robot': obsa, statea, actiona = read_data.Read_Robot_Data2(demo_path, FLAGS.T, index) elif FLAGS.demo_type == 'human': obsa, statea, actiona = read_data.Read_Human_Data2(demo_path, FLAGS.T, index) obsb, stateb, actionb = read_data.Read_Robot_Data2(target_path, FLAGS.T, index) obsas.append(obsa) obsbs.append(obsb) stateas.append(statea) statebs.append(stateb) actionas.append(actiona) actionbs.append(actionb) obsas = np.reshape(obsas, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels]) obsbs = np.reshape(obsbs, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels]) actionas = np.reshape(actionas, [batch_size, FLAGS.T, FLAGS.output_data]) actionbs = np.reshape(actionbs, [batch_size, FLAGS.T, FLAGS.output_data]) stateas = np.zeros([batch_size, FLAGS.T, FLAGS.output_data]) statebs = np.zeros([batch_size, FLAGS.T, FLAGS.output_data]) return obsas, obsbs, actionas, actionbs, stateas, statebs
def generate_atmaml_data(if_train=True): if if_train: batch_size = FLAGS.meta_batch_size else: batch_size = FLAGS.meta_test_batch_size color_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.color_num # print('color_list', color_list) compare_color_list = (color_list + np.random.randint(1, FLAGS.color_num-1)) % FLAGS.color_num # print('compare_color_list', compare_color_list) object_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.object_num # print('object_list', object_list) compare_object_list = (object_list + np.random.randint(1, FLAGS.object_num-1)) % FLAGS.object_num # print('compare_object_list', compare_object_list) if if_train: task_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.train_task_num compare_task_list = (task_list + np.random.randint(1, FLAGS.train_task_num)) % FLAGS.train_task_num # compare_task_list = (task_list + np.random.randint(1, FLAGS.task_num - 1)) % FLAGS.task_num else: task_list = np.random.randint(FLAGS.train_task_num, FLAGS.task_num, size=batch_size) compare_task_list = np.random.randint(FLAGS.train_task_num, FLAGS.task_num, size=batch_size) # print('task_list', task_list) # print('compare_task_list', compare_task_list) demo_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.demo_num # print('demo_list', demo_list) target_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.demo_num # print('target_list', target_list) compare_demo_list = (task_list + np.random.randint(1, FLAGS.demo_num-1)) % FLAGS.demo_num # print('compare_demo_list', compare_demo_list) obsas = [] obsbs = [] obscs = [] extra_obses = [] stateas = [] statebs = [] statecs = [] extra_states = [] actionas = [] actionbs = [] actioncs = [] extra_actions = [] labelabs = [] labelcs = [] color_num = ['color_blue', 'color_green', 'color_orange', 'color_yellow'] # color_num = ['color_blue', 'color_green', 'color_orange'] object_num = ['object_type_animal', 'object_type_car', 'object_type_dinosaur', 'object_type_tool'] for element in range(0, batch_size): demo_path = '%s/%s/%s/%s/task_%d/demo_%d' % ( FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]], FLAGS.demo_type, task_list[element], demo_list[element]) target_path = '%s/%s/%s/%s/task_%d/demo_%d' % ( FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]], FLAGS.target_type, task_list[element], target_list[element]) compare_path = '%s/%s/%s/%s/task_%d/demo_%d' % ( FLAGS.data_path, color_num[compare_color_list[element]], object_num[compare_object_list[element]], FLAGS.compare_type, compare_task_list[element], compare_demo_list[element]) if FLAGS.cross_domain: extra_path = '%s/%s/%s/%s/task_%d/demo_%d' % (FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]], FLAGS.extra_type, task_list[element], demo_list[element]) if element==0: if FLAGS.cross_domain: print('extra_path', extra_path) print('demo_path', demo_path) print('target_path', target_path) print('compare_path', compare_path) index = np.random.randint(0, FLAGS.index_train_range) # if if_train: # index = np.random.randint(0, FLAGS.index_train_range) # else: # index = np.random.randint(FLAGS.index_train_range, FLAGS.index_range) if FLAGS.demo_type == 'robot': obsa, statea, actiona = read_data.Read_Robot_Data2(demo_path, FLAGS.T, index) else : obsa, statea, actiona = read_data.Read_Human_Data2(demo_path, FLAGS.T, index) if FLAGS.compare_type == 'robot': obsc, statec, actionc = read_data.Read_Robot_Data2(compare_path, FLAGS.T, index) else: obsc, statec, actionc = read_data.Read_Human_Data2(compare_path, FLAGS.T, index) if FLAGS.cross_domain: if FLAGS.extra_type == 'robot': extra_obs, extra_state, extra_action = read_data.Read_Robot_Data2(extra_path, FLAGS.T, index) else: extra_obs, extra_state, extra_action = read_data.Read_Human_Data2(extra_path, FLAGS.T, index) obsb, stateb, actionb = read_data.Read_Robot_Data2(target_path, FLAGS.T, index) if FLAGS.tar_mil: labelab = to_categorical(color_list[element] * FLAGS.index_train_range + index, FLAGS.color_num*FLAGS.index_train_range) labelc = to_categorical(compare_color_list[element] * FLAGS.index_train_range + index, FLAGS.color_num*FLAGS.index_train_range) labelabs.append(labelab) labelcs.append(labelc) # print('element', element, 'labela and labelb', labelab, 'labelc', labelc) # print('----------------------------------------------------------------') obsas.append(obsa) obsbs.append(obsb) obscs.append(obsc) stateas.append(statea) statebs.append(stateb) statecs.append(statec) actionas.append(actiona) actionbs.append(actionb) actioncs.append(actionc) if FLAGS.cross_domain: extra_obses.append(extra_obs) extra_states.append(extra_state) extra_actions.append(extra_action) obsas = np.reshape(obsas, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels]) obsbs = np.reshape(obsbs, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels]) obscs = np.reshape(obscs, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels]) actionas = np.reshape(actionas, [batch_size, FLAGS.T, FLAGS.output_data]) actionbs = np.reshape(actionbs, [batch_size, FLAGS.T, FLAGS.output_data]) actioncs = np.reshape(actioncs, [batch_size, FLAGS.T, FLAGS.output_data]) stateas = np.zeros([batch_size, FLAGS.T, FLAGS.output_data]) statebs = np.zeros([batch_size, FLAGS.T, FLAGS.output_data]) statecs = np.zeros([batch_size, FLAGS.T, FLAGS.output_data]) if FLAGS.tar_mil: labelabs = np.reshape(labelabs, [batch_size, 1, -1]) labelcs = np.reshape(labelcs, [batch_size, 1, -1]) # print('labelabs', labelabs.shape, 'labelcs', labelcs) else: labelabs = np.zeros([batch_size, 1, FLAGS.color_num*FLAGS.index_train_range]) labelcs = np.zeros( [batch_size, 1, FLAGS.color_num*FLAGS.index_train_range]) if FLAGS.cross_domain: extra_obses = np.reshape(extra_obses, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels]) extra_actions = np.reshape(extra_actions, [batch_size, FLAGS.T, FLAGS.output_data]) extra_states = np.zeros([batch_size, FLAGS.T, FLAGS.output_data]) return obsas, obsbs, obscs, extra_obses, \ actionas, actionbs, actioncs, extra_actions, \ stateas, statebs, statecs, extra_states, \ labelabs, labelcs else: return obsas, obsbs, obscs, actionas, actionbs, actioncs, stateas, statebs, statecs, labelabs, labelcs