def main(): parser = argparse.ArgumentParser() parser.add_argument('--dir_name', type=str, default='karel_default') args = parser.parse_args() dir_name = args.dir_name data_file = os.path.join(dir_name, 'data.hdf5') id_file = os.path.join(dir_name, 'id.txt') if not os.path.exists(data_file): print("data_file path doesn't exist: {}".format(data_file)) return if not os.path.exists(id_file): print("id_file path doesn't exist: {}".format(id_file)) return f = h5py.File(data_file, 'r') ids = open(id_file, 'r').read().splitlines() dsl = get_KarelDSL(seed=123) cur_id = 0 while True: print('ids / previous id: {}'.format(cur_id)) for i, id in enumerate(ids[max(cur_id - 5, 0):cur_id + 5]): print('#{}: {}'.format(max(cur_id - 5, 0) + i, id)) print('Put id you want to examine') cur_id = int(prompt(u'In: ')) print('code: {}'.format(dsl.intseq2str(f[ids[cur_id]]['program']))) print('demonstrations') for i, l in enumerate(f[ids[cur_id]]['s_h_len']): print('demo #{}: length {}'.format(i, l)) print('Put demonstration number [0-{}]'.format( f[ids[cur_id]]['s_h'].shape[0])) demo_idx = int(prompt(u'In: ')) seq_idx = 0 print('code: {}'.format(dsl.intseq2str(f[ids[cur_id]]['program']))) state2symbol(f[ids[cur_id]]['s_h'][demo_idx][seq_idx]) seq_idx += 1 while seq_idx < f[ids[cur_id]]['s_h_len'][demo_idx]: print("Press 'c' to continue and 'n' to next example") print(seq_idx, f[ids[cur_id]]['s_h_len'][demo_idx]) key = prompt(u'In: ') if key == 'c': print('code: {}'.format( dsl.intseq2str(f[ids[cur_id]]['program']))) state2symbol(f[ids[cur_id]]['s_h'][demo_idx][seq_idx]) seq_idx += 1 elif key == 'n': break else: print('Wrong key') print('Demo is terminated')
def generator(config): dir_name = config.dir_name h = config.height w = config.width c = len(karel.state_table) wall_prob = config.wall_prob num_train = config.num_train num_test = config.num_test num_val = config.num_val num_total = num_train + num_test + num_val # output files f = h5py.File(os.path.join(dir_name, 'data.hdf5'), 'w') id_file = open(os.path.join(dir_name, 'id.txt'), 'w') # progress bar bar = progressbar.ProgressBar(maxval=100, widgets=[ progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage() ]) bar.start() dsl = get_KarelDSL(dsl_type='prob', seed=config.seed) s_gen = KarelStateGenerator(seed=config.seed) karel_world = karel.Karel_world() count = 0 max_demo_length_in_dataset = -1 max_program_length_in_dataset = -1 seen_programs = set() while (1): # generate a single program random_code = dsl.random_code( max_depth=config.max_program_stmt_depth, max_nesting_depth=config.max_program_nesting_depth) # skip seen programs if random_code in seen_programs: continue program_seq = np.array(dsl.code2intseq(random_code), dtype=np.int8) if program_seq.shape[0] > config.max_program_length: continue s_h_list = [] a_h_list = [] num_demo = 0 num_trial = 0 while num_demo < config.num_demo_per_program and \ num_trial < config.max_demo_generation_trial: try: s, _, _, _, _ = s_gen.generate_single_state(h, w, wall_prob) karel_world.set_new_state(s) s_h = dsl.run(karel_world, random_code) except RuntimeError: pass else: if len(karel_world.s_h) <= config.max_demo_length and \ len(karel_world.s_h) >= config.min_demo_length: s_h_list.append(np.stack(karel_world.s_h, axis=0)) a_h_list.append(np.array(karel_world.a_h)) num_demo += 1 num_trial += 1 if num_demo < config.num_demo_per_program: continue len_s_h = np.array([s_h.shape[0] for s_h in s_h_list], dtype=np.int16) if np.max(len_s_h) < config.min_max_demo_length_for_program: continue demos_s_h = np.zeros([num_demo, np.max(len_s_h), h, w, c], dtype=bool) for i, s_h in enumerate(s_h_list): demos_s_h[i, :s_h.shape[0]] = s_h len_a_h = np.array([a_h.shape[0] for a_h in a_h_list], dtype=np.int16) demos_a_h = np.zeros([num_demo, np.max(len_a_h)], dtype=np.int8) for i, a_h in enumerate(a_h_list): demos_a_h[i, :a_h.shape[0]] = a_h max_demo_length_in_dataset = max(max_demo_length_in_dataset, np.max(len_s_h)) max_program_length_in_dataset = max(max_program_length_in_dataset, program_seq.shape[0]) # save the state id = 'no_{}_prog_len_{}_max_s_h_len_{}'.format(count, program_seq.shape[0], np.max(len_s_h)) id_file.write(id + '\n') grp = f.create_group(id) grp['program'] = program_seq grp['s_h_len'] = len_s_h grp['a_h_len'] = len_a_h grp['s_h'] = demos_s_h grp['a_h'] = demos_a_h seen_programs.add(random_code) # progress bar count += 1 if count % (num_total / 100) == 0: bar.update(count / (num_total / 100)) if count >= num_total: grp = f.create_group('data_info') grp['max_demo_length'] = max_demo_length_in_dataset grp['dsl_type'] = 'prob' grp['max_program_length'] = max_program_length_in_dataset grp['num_program_tokens'] = len(dsl.int2token) grp['num_demo_per_program'] = config.num_demo_per_program grp['num_action_tokens'] = len(dsl.action_functions) grp['num_train'] = config.num_train grp['num_test'] = config.num_test grp['num_val'] = config.num_val bar.finish() f.close() id_file.close() log.info('Dataset generated under {} with {}' ' samples ({} for training and {} for testing ' 'and {} for val'.format(dir_name, num_total, num_train, num_test, num_val)) return
def generator(config): dir_name = config.dir_name h = config.height w = config.width c = len(karel.state_table) wall_prob = config.wall_prob # output files f = h5py.File(os.path.join(dir_name, 'data.hdf5'), 'r+') dsl_type = f['data_info']['dsl_type'].value with open(os.path.join(dir_name, 'id.txt'), 'r') as id_file: ids = [s.strip() for s in id_file.readlines() if s] num_train = f['data_info']['num_train'].value num_test = f['data_info']['num_test'].value num_val = f['data_info']['num_val'].value num_total = num_train + num_test + num_val # progress bar bar = progressbar.ProgressBar(maxval=100, widgets=[ progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage() ]) bar.start() dsl = get_KarelDSL(dsl_type=dsl_type, seed=config.seed) s_gen = KarelStateGenerator(seed=config.seed) karel_world = karel.Karel_world() count = 0 max_demo_length_in_dataset = -1 max_program_length_in_dataset = -1 for id_ in ids: grp = f[id_] # Reads a single program program_seq = grp['program'].value program_code = dsl.intseq2str(program_seq) test_s_h_list = [] a_h_list = [] num_demo = 0 while num_demo < config.num_test_demo_per_program: try: s, _, _, _, _ = s_gen.generate_single_state(h, w, wall_prob) karel_world.set_new_state(s) s_h = dsl.run(karel_world, program_code) except RuntimeError: pass else: if len(karel_world.s_h) <= config.max_demo_length and \ len(karel_world.s_h) >= config.min_demo_length: test_s_h_list.append(np.stack(karel_world.s_h, axis=0)) a_h_list.append(np.array(karel_world.a_h)) num_demo += 1 len_test_s_h = np.array([s_h.shape[0] for s_h in test_s_h_list], dtype=np.int16) demos_test_s_h = np.zeros( [num_demo, np.max(len_test_s_h), h, w, c], dtype=bool) for i, s_h in enumerate(test_s_h_list): demos_test_s_h[i, :s_h.shape[0]] = s_h len_a_h = np.array([a_h.shape[0] for a_h in a_h_list], dtype=np.int16) demos_a_h = np.zeros([num_demo, np.max(len_a_h)], dtype=np.int8) for i, a_h in enumerate(a_h_list): demos_a_h[i, :a_h.shape[0]] = a_h max_demo_length_in_dataset = max(max_demo_length_in_dataset, np.max(len_test_s_h)) max_program_length_in_dataset = max(max_program_length_in_dataset, program_seq.shape[0]) try: f.__delitem__(id_ + '/test_s_h_len') f.__delitem__(id_ + '/test_a_h_len') f.__delitem__(id_ + '/test_s_h') f.__delitem__(id_ + '/test_a_h') except: pass # Save testing state grp['test_s_h_len'] = len_test_s_h grp['test_a_h_len'] = len_a_h grp['test_s_h'] = demos_test_s_h grp['test_a_h'] = demos_a_h # progress bar count += 1 if count % (num_total / 100) == 0: bar.update(count / (num_total / 100)) try: f.__delitem__('data_info/num_test_demo_per_program') except: pass f['data_info'][ 'num_test_demo_per_program'] = config.num_test_demo_per_program bar.finish() f.close() id_file.close() log.info('Dataset generated under {} with {}' ' samples ({} for training and {} for testing ' 'and {} for val'.format(dir_name, num_total, num_train, num_test, num_val))
def ConstructOutputList(data_file, output_file): dsl_type = data_file['data_info']['dsl_type'].value dsl = get_KarelDSL(dsl_type=dsl_type, seed=123) output_list = [] for e_id in output_file.keys(): gt_program_intseq = data_file[e_id]['program'].value e_out = output_file[e_id] if 'test_program_prediction' in e_out: output_list.append( Output( id=e_id, gt_program=dsl.intseq2str(gt_program_intseq), tf_program=e_out['program_prediction'].value, tf_syntax=e_out['program_syntax'].value, tf_num_correct_execution=e_out[ 'program_num_execution_correct'].value, tf_is_correct_execution=e_out[ 'program_is_correct_execution'].value, greedy_program=e_out['greedy_prediction'].value, greedy_syntax=e_out['greedy_syntax'].value, greedy_num_correct_execution=e_out[ 'greedy_num_execution_correct'].value, greedy_is_correct_execution=e_out[ 'greedy_is_correct_execution'].value, test_tf_program=e_out['test_program_prediction'].value, test_tf_syntax=e_out['test_program_syntax'].value, test_tf_num_correct_execution=e_out[ 'test_program_num_execution_correct'].value, test_tf_is_correct_execution=e_out[ 'test_program_is_correct_execution'].value, test_greedy_program=e_out['test_greedy_prediction'].value, test_greedy_syntax=e_out['test_greedy_syntax'].value, test_greedy_num_correct_execution=e_out[ 'test_greedy_num_execution_correct'].value, test_greedy_is_correct_execution=e_out[ 'test_greedy_is_correct_execution'].value)) else: output_list.append( Output(id=e_id, gt_program=dsl.intseq2str(gt_program_intseq), tf_program=e_out['program_prediction'].value, tf_syntax=e_out['program_syntax'].value, tf_num_correct_execution=e_out[ 'program_num_execution_correct'].value, tf_is_correct_execution=e_out[ 'program_is_correct_execution'].value, greedy_program=e_out['greedy_prediction'].value, greedy_syntax=e_out['greedy_syntax'].value, greedy_num_correct_execution=e_out[ 'greedy_num_execution_correct'].value, greedy_is_correct_execution=e_out[ 'greedy_is_correct_execution'].value, test_tf_program=None, test_tf_syntax=None, test_tf_num_correct_execution=None, test_tf_is_correct_execution=None, test_greedy_program=None, test_greedy_syntax=None, test_greedy_num_correct_execution=None, test_greedy_is_correct_execution=None)) return output_list
print('Visualization is terminated') if __name__ == '__main__': args = GetArgument() try: data_file = h5py.File(args.data_hdf5, 'r') except: data_file = None print('Fail to read --data_hdf5: {}'.format(args.data_hdf5)) sys.exit() try: output_file = h5py.File(args.output_hdf5, 'r') except: output_file = None print('Fail to read --output_hdf5: {}'.format(args.output_hdf5)) sys.exit() output_list = ConstructOutputList(data_file, output_file) output = output_list[0] PrintUsage() output_dir = os.path.join(os.path.dirname(args.output_hdf5), 'inspect_output') dsl_type = data_file['data_info']['dsl_type'].value dsl = get_KarelDSL(dsl_type=dsl_type, seed=123) karel_world = karel.Karel_world()