Exemple #1
0
def generate_place_test_data(demo_path, target_path, batch_size=1, index=0):


    # color_num = ['color_blue', 'color_green', 'color_orange', 'color_yellow']
    # object_num = ['object_type_animal', 'object_type_car', 'object_type_dinosaur', 'object_type_tool']

    # print('demo_path', demo_path)
    # print('target_path', target_path)

    if FLAGS.demo_type == 'robot':
        obsa, statea, actiona = read_data.Read_Robot_Data2(demo_path, FLAGS.T, index)
    elif FLAGS.demo_type == 'human':
        obsa, statea, actiona = read_data.Read_Human_Data2(demo_path, FLAGS.T, index)

    obsb, stateb, actionb = read_data.Read_Robot_Data2(target_path, FLAGS.T, index)


    obsas = np.reshape(obsa,  [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels])
    obsbs = np.reshape(obsb,   [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels])

    actionas = np.reshape(actiona, [batch_size, FLAGS.T, FLAGS.output_data])
    actionbs = np.reshape(actionb, [batch_size, FLAGS.T, FLAGS.output_data])

    # print('actionas', actionas)
    # print('actionbs', actionbs)

    stateas = np.zeros([batch_size, FLAGS.T, FLAGS.output_data])
    statebs = np.zeros([batch_size, FLAGS.T, FLAGS.output_data])

    return obsas, obsbs, actionas, actionbs, stateas, statebs
Exemple #2
0
def generate_data(if_train=True):
    if if_train:
        batch_size = FLAGS.meta_batch_size
    else:
        batch_size = FLAGS.meta_test_batch_size

    color_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.color_num
    print('color_list', color_list)

    object_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.object_num
    print('object_list', object_list)

    if if_train:
        task_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.train_task_num
    else:
        task_list = np.random.randint(FLAGS.train_task_num, FLAGS.task_num, size=batch_size)
    print('task_list', task_list)

    demo_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.demo_num
    print('demo_list', demo_list)

    target_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.demo_num
    print('target_list', target_list)

    obsas = []
    obsbs = []
    stateas = []
    statebs = []
    actionas = []
    actionbs = []

    color_num = ['color_blue', 'color_green', 'color_orange', 'color_yellow']
    # color_num = ['color_blue', 'color_green', 'color_orange']
    object_num = ['object_type_animal', 'object_type_car', 'object_type_dinosaur', 'object_type_tool']

    for element in range(0, batch_size):
        if if_train:
            demo_path = '%s/%s/%s/%s/task_%d/demo_%d' % (
                FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]], FLAGS.demo_type,
                task_list[element], demo_list[element])
            target_path = '%s/%s/%s/%s/task_%d/demo_%d' % (
                FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]], FLAGS.target_type,
                task_list[element], target_list[element])
        else:
            demo_path = '%s/%s/%s/%s/task_%d/demo_%d' % (
                FLAGS.data_path, color_num[-1], object_num[object_list[element]], FLAGS.demo_type,
                task_list[element], demo_list[element])
            target_path = '%s/%s/%s/%s/task_%d/demo_%d' % (
                FLAGS.data_path, color_num[-1], object_num[object_list[element]], FLAGS.target_type,
                task_list[element], target_list[element])


        # print('demo_path', demo_path)
        # print('target_path', target_path)
        index = np.random.randint(0, 20)
        if FLAGS.demo_type == 'robot':
            obsa, statea, actiona = read_data.Read_Robot_Data2(demo_path, FLAGS.T, index)
        elif FLAGS.demo_type == 'human':
            obsa, statea, actiona = read_data.Read_Human_Data2(demo_path, FLAGS.T, index)
        obsb, stateb, actionb = read_data.Read_Robot_Data2(target_path, FLAGS.T, index)

        obsas.append(obsa)
        obsbs.append(obsb)
        stateas.append(statea)
        statebs.append(stateb)
        actionas.append(actiona)
        actionbs.append(actionb)


    obsas = np.reshape(obsas,  [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels])
    obsbs = np.reshape(obsbs,    [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels])

    actionas = np.reshape(actionas, [batch_size, FLAGS.T, FLAGS.output_data])
    actionbs = np.reshape(actionbs, [batch_size, FLAGS.T, FLAGS.output_data])

    stateas = np.zeros([batch_size, FLAGS.T, FLAGS.output_data])
    statebs = np.zeros([batch_size, FLAGS.T, FLAGS.output_data])

    return obsas, obsbs, actionas, actionbs, stateas, statebs
Exemple #3
0
def generate_atmaml_data(if_train=True):
    if if_train:
        batch_size = FLAGS.meta_batch_size
    else:
        batch_size = FLAGS.meta_test_batch_size

    color_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.color_num
    # print('color_list', color_list)
    compare_color_list = (color_list + np.random.randint(1, FLAGS.color_num-1)) %  FLAGS.color_num
    # print('compare_color_list', compare_color_list)

    object_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.object_num
    # print('object_list', object_list)
    compare_object_list = (object_list + np.random.randint(1, FLAGS.object_num-1)) % FLAGS.object_num
    # print('compare_object_list', compare_object_list)

    if if_train:
        task_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.train_task_num
        compare_task_list = (task_list + np.random.randint(1, FLAGS.train_task_num)) % FLAGS.train_task_num
        # compare_task_list = (task_list + np.random.randint(1, FLAGS.task_num - 1)) % FLAGS.task_num
    else:
        task_list = np.random.randint(FLAGS.train_task_num, FLAGS.task_num, size=batch_size)
        compare_task_list = np.random.randint(FLAGS.train_task_num, FLAGS.task_num,  size=batch_size)

    # print('task_list', task_list)
    # print('compare_task_list', compare_task_list)


    demo_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.demo_num
    # print('demo_list', demo_list)
    target_list = (np.random.randint(0, 100, size=batch_size) + 1) % FLAGS.demo_num
    # print('target_list', target_list)
    compare_demo_list = (task_list +  np.random.randint(1, FLAGS.demo_num-1)) %  FLAGS.demo_num
    # print('compare_demo_list', compare_demo_list)



    obsas = []
    obsbs = []
    obscs = []
    extra_obses = []

    stateas = []
    statebs = []
    statecs = []
    extra_states = []

    actionas = []
    actionbs = []
    actioncs = []
    extra_actions = []

    labelabs = []
    labelcs = []


    color_num = ['color_blue', 'color_green', 'color_orange', 'color_yellow']
    # color_num = ['color_blue', 'color_green', 'color_orange']
    object_num = ['object_type_animal', 'object_type_car', 'object_type_dinosaur', 'object_type_tool']

    for element in range(0, batch_size):

        demo_path = '%s/%s/%s/%s/task_%d/demo_%d' % (
            FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]], FLAGS.demo_type,
            task_list[element], demo_list[element])

        target_path = '%s/%s/%s/%s/task_%d/demo_%d' % (
            FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]], FLAGS.target_type,
            task_list[element], target_list[element])

        compare_path = '%s/%s/%s/%s/task_%d/demo_%d' % (
            FLAGS.data_path, color_num[compare_color_list[element]], object_num[compare_object_list[element]],
            FLAGS.compare_type, compare_task_list[element], compare_demo_list[element])

        if FLAGS.cross_domain:
                extra_path = '%s/%s/%s/%s/task_%d/demo_%d' % (FLAGS.data_path, color_num[color_list[element]], object_num[object_list[element]],
                                                              FLAGS.extra_type, task_list[element], demo_list[element])
        if  element==0:
            if FLAGS.cross_domain:
                print('extra_path', extra_path)
            print('demo_path', demo_path)
            print('target_path', target_path)
            print('compare_path', compare_path)


        index = np.random.randint(0, FLAGS.index_train_range)

        # if if_train:
        #     index = np.random.randint(0, FLAGS.index_train_range)
        # else:
        #     index = np.random.randint(FLAGS.index_train_range, FLAGS.index_range)

        if FLAGS.demo_type == 'robot':
            obsa, statea, actiona = read_data.Read_Robot_Data2(demo_path, FLAGS.T, index)
        else :
            obsa, statea, actiona = read_data.Read_Human_Data2(demo_path, FLAGS.T, index)

        if FLAGS.compare_type == 'robot':
            obsc, statec, actionc = read_data.Read_Robot_Data2(compare_path, FLAGS.T, index)
        else:
            obsc, statec, actionc = read_data.Read_Human_Data2(compare_path, FLAGS.T, index)

        if FLAGS.cross_domain:
            if FLAGS.extra_type == 'robot':
                extra_obs, extra_state, extra_action = read_data.Read_Robot_Data2(extra_path, FLAGS.T, index)
            else:
                extra_obs, extra_state, extra_action = read_data.Read_Human_Data2(extra_path, FLAGS.T, index)

        obsb, stateb, actionb = read_data.Read_Robot_Data2(target_path, FLAGS.T, index)
        if FLAGS.tar_mil:
            labelab  = to_categorical(color_list[element] * FLAGS.index_train_range + index, FLAGS.color_num*FLAGS.index_train_range)
            labelc = to_categorical(compare_color_list[element] * FLAGS.index_train_range + index, FLAGS.color_num*FLAGS.index_train_range)
            labelabs.append(labelab)
            labelcs.append(labelc)
            # print('element', element, 'labela and labelb', labelab, 'labelc', labelc)

        # print('----------------------------------------------------------------')

        obsas.append(obsa)
        obsbs.append(obsb)
        obscs.append(obsc)
        stateas.append(statea)
        statebs.append(stateb)
        statecs.append(statec)
        actionas.append(actiona)
        actionbs.append(actionb)
        actioncs.append(actionc)
        if FLAGS.cross_domain:
            extra_obses.append(extra_obs)
            extra_states.append(extra_state)
            extra_actions.append(extra_action)


    obsas = np.reshape(obsas,  [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels])
    obsbs = np.reshape(obsbs,    [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels])
    obscs = np.reshape(obscs, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels])


    actionas = np.reshape(actionas, [batch_size, FLAGS.T, FLAGS.output_data])
    actionbs = np.reshape(actionbs, [batch_size, FLAGS.T, FLAGS.output_data])
    actioncs = np.reshape(actioncs, [batch_size, FLAGS.T, FLAGS.output_data])


    stateas = np.zeros([batch_size, FLAGS.T, FLAGS.output_data])
    statebs = np.zeros([batch_size, FLAGS.T, FLAGS.output_data])
    statecs = np.zeros([batch_size, FLAGS.T, FLAGS.output_data])


    if FLAGS.tar_mil:
        labelabs = np.reshape(labelabs, [batch_size, 1, -1])
        labelcs = np.reshape(labelcs, [batch_size, 1, -1])
        # print('labelabs', labelabs.shape, 'labelcs', labelcs)
    else:
        labelabs = np.zeros([batch_size, 1, FLAGS.color_num*FLAGS.index_train_range])
        labelcs = np.zeros( [batch_size, 1, FLAGS.color_num*FLAGS.index_train_range])

    if FLAGS.cross_domain:
        extra_obses = np.reshape(extra_obses, [batch_size, FLAGS.T, FLAGS.im_width * FLAGS.im_height * FLAGS.num_channels])
        extra_actions = np.reshape(extra_actions, [batch_size, FLAGS.T, FLAGS.output_data])
        extra_states = np.zeros([batch_size, FLAGS.T, FLAGS.output_data])

        return obsas, obsbs, obscs, extra_obses, \
               actionas, actionbs, actioncs, extra_actions, \
               stateas, statebs, statecs, extra_states, \
               labelabs, labelcs
    else:
        return obsas, obsbs, obscs, actionas, actionbs, actioncs, stateas, statebs, statecs, labelabs, labelcs