Example #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir_name', type=str, default='karel_default')
    args = parser.parse_args()

    dir_name = args.dir_name
    data_file = os.path.join(dir_name, 'data.hdf5')
    id_file = os.path.join(dir_name, 'id.txt')

    if not os.path.exists(data_file):
        print("data_file path doesn't exist: {}".format(data_file))
        return
    if not os.path.exists(id_file):
        print("id_file path doesn't exist: {}".format(id_file))
        return

    f = h5py.File(data_file, 'r')
    ids = open(id_file, 'r').read().splitlines()

    dsl = get_KarelDSL(seed=123)

    cur_id = 0
    while True:
        print('ids / previous id: {}'.format(cur_id))
        for i, id in enumerate(ids[max(cur_id - 5, 0):cur_id + 5]):
            print('#{}: {}'.format(max(cur_id - 5, 0) + i, id))

        print('Put id you want to examine')
        cur_id = int(prompt(u'In: '))

        print('code: {}'.format(dsl.intseq2str(f[ids[cur_id]]['program'])))
        print('demonstrations')
        for i, l in enumerate(f[ids[cur_id]]['s_h_len']):
            print('demo #{}: length {}'.format(i, l))
        print('Put demonstration number [0-{}]'.format(
            f[ids[cur_id]]['s_h'].shape[0]))
        demo_idx = int(prompt(u'In: '))
        seq_idx = 0

        print('code: {}'.format(dsl.intseq2str(f[ids[cur_id]]['program'])))
        state2symbol(f[ids[cur_id]]['s_h'][demo_idx][seq_idx])
        seq_idx += 1
        while seq_idx < f[ids[cur_id]]['s_h_len'][demo_idx]:
            print("Press 'c' to continue and 'n' to next example")
            print(seq_idx, f[ids[cur_id]]['s_h_len'][demo_idx])
            key = prompt(u'In: ')
            if key == 'c':
                print('code: {}'.format(
                    dsl.intseq2str(f[ids[cur_id]]['program'])))
                state2symbol(f[ids[cur_id]]['s_h'][demo_idx][seq_idx])
                seq_idx += 1
            elif key == 'n':
                break
            else:
                print('Wrong key')
        print('Demo is terminated')
Example #2
0
def generator(config):
    dir_name = config.dir_name
    h = config.height
    w = config.width
    c = len(karel.state_table)
    wall_prob = config.wall_prob
    num_train = config.num_train
    num_test = config.num_test
    num_val = config.num_val
    num_total = num_train + num_test + num_val

    # output files
    f = h5py.File(os.path.join(dir_name, 'data.hdf5'), 'w')
    id_file = open(os.path.join(dir_name, 'id.txt'), 'w')

    # progress bar
    bar = progressbar.ProgressBar(maxval=100,
                                  widgets=[
                                      progressbar.Bar('=', '[', ']'), ' ',
                                      progressbar.Percentage()
                                  ])
    bar.start()

    dsl = get_KarelDSL(dsl_type='prob', seed=config.seed)
    s_gen = KarelStateGenerator(seed=config.seed)
    karel_world = karel.Karel_world()

    count = 0
    max_demo_length_in_dataset = -1
    max_program_length_in_dataset = -1
    seen_programs = set()
    while (1):
        # generate a single program
        random_code = dsl.random_code(
            max_depth=config.max_program_stmt_depth,
            max_nesting_depth=config.max_program_nesting_depth)
        # skip seen programs
        if random_code in seen_programs:
            continue
        program_seq = np.array(dsl.code2intseq(random_code), dtype=np.int8)
        if program_seq.shape[0] > config.max_program_length:
            continue

        s_h_list = []
        a_h_list = []
        num_demo = 0
        num_trial = 0
        while num_demo < config.num_demo_per_program and \
                num_trial < config.max_demo_generation_trial:
            try:
                s, _, _, _, _ = s_gen.generate_single_state(h, w, wall_prob)
                karel_world.set_new_state(s)
                s_h = dsl.run(karel_world, random_code)
            except RuntimeError:
                pass
            else:
                if len(karel_world.s_h) <= config.max_demo_length and \
                        len(karel_world.s_h) >= config.min_demo_length:
                    s_h_list.append(np.stack(karel_world.s_h, axis=0))
                    a_h_list.append(np.array(karel_world.a_h))
                    num_demo += 1

            num_trial += 1

        if num_demo < config.num_demo_per_program:
            continue

        len_s_h = np.array([s_h.shape[0] for s_h in s_h_list], dtype=np.int16)
        if np.max(len_s_h) < config.min_max_demo_length_for_program:
            continue

        demos_s_h = np.zeros([num_demo, np.max(len_s_h), h, w, c], dtype=bool)
        for i, s_h in enumerate(s_h_list):
            demos_s_h[i, :s_h.shape[0]] = s_h

        len_a_h = np.array([a_h.shape[0] for a_h in a_h_list], dtype=np.int16)

        demos_a_h = np.zeros([num_demo, np.max(len_a_h)], dtype=np.int8)
        for i, a_h in enumerate(a_h_list):
            demos_a_h[i, :a_h.shape[0]] = a_h

        max_demo_length_in_dataset = max(max_demo_length_in_dataset,
                                         np.max(len_s_h))
        max_program_length_in_dataset = max(max_program_length_in_dataset,
                                            program_seq.shape[0])

        # save the state
        id = 'no_{}_prog_len_{}_max_s_h_len_{}'.format(count,
                                                       program_seq.shape[0],
                                                       np.max(len_s_h))
        id_file.write(id + '\n')
        grp = f.create_group(id)
        grp['program'] = program_seq
        grp['s_h_len'] = len_s_h
        grp['a_h_len'] = len_a_h
        grp['s_h'] = demos_s_h
        grp['a_h'] = demos_a_h
        seen_programs.add(random_code)
        # progress bar
        count += 1
        if count % (num_total / 100) == 0:
            bar.update(count / (num_total / 100))
        if count >= num_total:
            grp = f.create_group('data_info')
            grp['max_demo_length'] = max_demo_length_in_dataset
            grp['dsl_type'] = 'prob'
            grp['max_program_length'] = max_program_length_in_dataset
            grp['num_program_tokens'] = len(dsl.int2token)
            grp['num_demo_per_program'] = config.num_demo_per_program
            grp['num_action_tokens'] = len(dsl.action_functions)
            grp['num_train'] = config.num_train
            grp['num_test'] = config.num_test
            grp['num_val'] = config.num_val
            bar.finish()
            f.close()
            id_file.close()
            log.info('Dataset generated under {} with {}'
                     ' samples ({} for training and {} for testing '
                     'and {} for val'.format(dir_name, num_total, num_train,
                                             num_test, num_val))
            return
def generator(config):
    dir_name = config.dir_name
    h = config.height
    w = config.width
    c = len(karel.state_table)

    wall_prob = config.wall_prob

    # output files
    f = h5py.File(os.path.join(dir_name, 'data.hdf5'), 'r+')
    dsl_type = f['data_info']['dsl_type'].value

    with open(os.path.join(dir_name, 'id.txt'), 'r') as id_file:
        ids = [s.strip() for s in id_file.readlines() if s]

    num_train = f['data_info']['num_train'].value
    num_test = f['data_info']['num_test'].value
    num_val = f['data_info']['num_val'].value
    num_total = num_train + num_test + num_val

    # progress bar
    bar = progressbar.ProgressBar(maxval=100,
                                  widgets=[
                                      progressbar.Bar('=', '[', ']'), ' ',
                                      progressbar.Percentage()
                                  ])
    bar.start()

    dsl = get_KarelDSL(dsl_type=dsl_type, seed=config.seed)
    s_gen = KarelStateGenerator(seed=config.seed)
    karel_world = karel.Karel_world()

    count = 0
    max_demo_length_in_dataset = -1
    max_program_length_in_dataset = -1
    for id_ in ids:
        grp = f[id_]
        # Reads a single program
        program_seq = grp['program'].value
        program_code = dsl.intseq2str(program_seq)

        test_s_h_list = []
        a_h_list = []
        num_demo = 0
        while num_demo < config.num_test_demo_per_program:
            try:
                s, _, _, _, _ = s_gen.generate_single_state(h, w, wall_prob)
                karel_world.set_new_state(s)
                s_h = dsl.run(karel_world, program_code)
            except RuntimeError:
                pass
            else:
                if len(karel_world.s_h) <= config.max_demo_length and \
                        len(karel_world.s_h) >= config.min_demo_length:
                    test_s_h_list.append(np.stack(karel_world.s_h, axis=0))
                    a_h_list.append(np.array(karel_world.a_h))
                    num_demo += 1

        len_test_s_h = np.array([s_h.shape[0] for s_h in test_s_h_list],
                                dtype=np.int16)

        demos_test_s_h = np.zeros(
            [num_demo, np.max(len_test_s_h), h, w, c], dtype=bool)
        for i, s_h in enumerate(test_s_h_list):
            demos_test_s_h[i, :s_h.shape[0]] = s_h

        len_a_h = np.array([a_h.shape[0] for a_h in a_h_list], dtype=np.int16)

        demos_a_h = np.zeros([num_demo, np.max(len_a_h)], dtype=np.int8)
        for i, a_h in enumerate(a_h_list):
            demos_a_h[i, :a_h.shape[0]] = a_h

        max_demo_length_in_dataset = max(max_demo_length_in_dataset,
                                         np.max(len_test_s_h))
        max_program_length_in_dataset = max(max_program_length_in_dataset,
                                            program_seq.shape[0])

        try:
            f.__delitem__(id_ + '/test_s_h_len')
            f.__delitem__(id_ + '/test_a_h_len')
            f.__delitem__(id_ + '/test_s_h')
            f.__delitem__(id_ + '/test_a_h')
        except:
            pass

        # Save testing state
        grp['test_s_h_len'] = len_test_s_h
        grp['test_a_h_len'] = len_a_h
        grp['test_s_h'] = demos_test_s_h
        grp['test_a_h'] = demos_a_h
        # progress bar
        count += 1
        if count % (num_total / 100) == 0:
            bar.update(count / (num_total / 100))

    try:
        f.__delitem__('data_info/num_test_demo_per_program')
    except:
        pass
    f['data_info'][
        'num_test_demo_per_program'] = config.num_test_demo_per_program
    bar.finish()
    f.close()
    id_file.close()
    log.info('Dataset generated under {} with {}'
             ' samples ({} for training and {} for testing '
             'and {} for val'.format(dir_name, num_total, num_train, num_test,
                                     num_val))
Example #4
0
def ConstructOutputList(data_file, output_file):
    dsl_type = data_file['data_info']['dsl_type'].value
    dsl = get_KarelDSL(dsl_type=dsl_type, seed=123)
    output_list = []
    for e_id in output_file.keys():
        gt_program_intseq = data_file[e_id]['program'].value
        e_out = output_file[e_id]
        if 'test_program_prediction' in e_out:
            output_list.append(
                Output(
                    id=e_id,
                    gt_program=dsl.intseq2str(gt_program_intseq),
                    tf_program=e_out['program_prediction'].value,
                    tf_syntax=e_out['program_syntax'].value,
                    tf_num_correct_execution=e_out[
                        'program_num_execution_correct'].value,
                    tf_is_correct_execution=e_out[
                        'program_is_correct_execution'].value,
                    greedy_program=e_out['greedy_prediction'].value,
                    greedy_syntax=e_out['greedy_syntax'].value,
                    greedy_num_correct_execution=e_out[
                        'greedy_num_execution_correct'].value,
                    greedy_is_correct_execution=e_out[
                        'greedy_is_correct_execution'].value,
                    test_tf_program=e_out['test_program_prediction'].value,
                    test_tf_syntax=e_out['test_program_syntax'].value,
                    test_tf_num_correct_execution=e_out[
                        'test_program_num_execution_correct'].value,
                    test_tf_is_correct_execution=e_out[
                        'test_program_is_correct_execution'].value,
                    test_greedy_program=e_out['test_greedy_prediction'].value,
                    test_greedy_syntax=e_out['test_greedy_syntax'].value,
                    test_greedy_num_correct_execution=e_out[
                        'test_greedy_num_execution_correct'].value,
                    test_greedy_is_correct_execution=e_out[
                        'test_greedy_is_correct_execution'].value))
        else:
            output_list.append(
                Output(id=e_id,
                       gt_program=dsl.intseq2str(gt_program_intseq),
                       tf_program=e_out['program_prediction'].value,
                       tf_syntax=e_out['program_syntax'].value,
                       tf_num_correct_execution=e_out[
                           'program_num_execution_correct'].value,
                       tf_is_correct_execution=e_out[
                           'program_is_correct_execution'].value,
                       greedy_program=e_out['greedy_prediction'].value,
                       greedy_syntax=e_out['greedy_syntax'].value,
                       greedy_num_correct_execution=e_out[
                           'greedy_num_execution_correct'].value,
                       greedy_is_correct_execution=e_out[
                           'greedy_is_correct_execution'].value,
                       test_tf_program=None,
                       test_tf_syntax=None,
                       test_tf_num_correct_execution=None,
                       test_tf_is_correct_execution=None,
                       test_greedy_program=None,
                       test_greedy_syntax=None,
                       test_greedy_num_correct_execution=None,
                       test_greedy_is_correct_execution=None))
    return output_list
Example #5
0
    print('Visualization is terminated')


if __name__ == '__main__':
    args = GetArgument()

    try:
        data_file = h5py.File(args.data_hdf5, 'r')
    except:
        data_file = None
        print('Fail to read --data_hdf5: {}'.format(args.data_hdf5))
        sys.exit()
    try:
        output_file = h5py.File(args.output_hdf5, 'r')
    except:
        output_file = None
        print('Fail to read --output_hdf5: {}'.format(args.output_hdf5))
        sys.exit()

    output_list = ConstructOutputList(data_file, output_file)
    output = output_list[0]

    PrintUsage()

    output_dir = os.path.join(os.path.dirname(args.output_hdf5),
                              'inspect_output')

    dsl_type = data_file['data_info']['dsl_type'].value
    dsl = get_KarelDSL(dsl_type=dsl_type, seed=123)
    karel_world = karel.Karel_world()