def copy_episodes(indir, outdir, n): episode_paths = frame.episode_paths(indir) np.random.shuffle(episode_paths) episode_paths = episode_paths[:n] start = len(indir) for p in tqdm.tqdm(episode_paths): assert p.startswith(indir), p outfile = outdir + p[start:] os.makedirs(os.path.dirname(outfile), exist_ok=True) shutil.copyfile(p, outfile)
def label_episodes(directory, classifier): episode_paths = frame.episode_paths(directory) data_loader = DataLoader(hparams=classifier.hparams) for episode_path in tqdm.tqdm(episode_paths): try: data_loader.predict_episodes(classifier, [episode_path], prefix="frame/classifier_") except EOFError as e: traceback.print_exception(e) print("Error reading {}".format(episode_path)) os.remove(episode_path)
#!/usr/bin/env python3 import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "humanrl")) if __name__ == '__main__': import multiprocessing from humanrl import frame import tqdm episode_paths = frame.episode_paths("logs/") print(len(episode_paths)) # frame.fix_episode(episode_paths[0]) pool = multiprocessing.Pool(multiprocessing.cpu_count() - 1) # for _ in tqdm.tqdm( # pool.imap_unordered(frame.check_episode, episode_paths), total=len(episode_paths)): # pass # %timeit frame.fix_episode(episode_paths[0])
def main(): random.seed() args = parser.parse_args() print("Displaying video...") K_FASTER = ord('1') K_SLOWER = ord('2') K_PAUSE = ord('3') K_BACK = ord('4') K_FWD = ord('5') K_PREV_FALSE_POSITIVE = ord('7') K_NEXT_FALSE_POSITIVE = ord('8') K_PREV_FALSE_NEGATIVE = ord('9') K_NEXT_FALSE_NEGATIVE = ord('0') K_PREV_LABEL = ord('u') K_NEXT_LABEL = ord('i') K_PREV_DEATH = ord('o') K_NEXT_DEATH = ord('p') K_CATASTROPHE = ord('c') K_BLOCK = ord('b') K_SAVE_VIDEO = ord('v') K_ESC = ord('\x1b') # esc K_SKIP = ord('s') K_REMOVE_LABEL = ord('r') K_NONE = -1 & 0xFF online_key_set = frozenset( [K_ESC, K_BLOCK, K_FASTER, K_SLOWER, K_NONE, K_PAUSE, K_FWD]) # setup for offline mode if not args.online: episode_paths = frame_module.episode_paths(args.frames_dir) print("Episodes: {}".format(len(episode_paths))) if args.reversed: episode_paths = reversed(episode_paths) if args.random: random.shuffle(episode_paths) if args.output_dir: os.makedirs(args.output_dir, exist_ok=True) episodes_iter = None viewer_state = None if args.online: print('In online mode...') # listen to port for frames print('Waiting for connection...') address = ('localhost', 6666) listener = Listener(address, authkey=b'no-catastrophes-allowed') conn = listener.accept() print('Connection accepted from {}'.format(listener.last_accepted)) viewer_state = ViewerState(conn=conn, output_dir=args.output_dir) episodes_iter = OnlineEpisodes() else: transform_func = identity if args.reversed: transform_func = reversed if args.random: transform_func = shuffle episodes_iter = Episodes(args.frames_dir, transform_func) viewer_state = ViewerState(output_dir=args.output_dir) for episode_path, episode, viewer_state.frame_index in episodes_iter: if episode is None: print('Error: no episode') continue if viewer_state.EXIT: break if not args.online: labels_check = np.array( [frame.get_label() == 'b' for frame in episode.frames]) if any(labels_check): print('FOUND SOME LABELS ------------------') # set up output_dir for episode viewer_state.output_filename, viewer_state.action_set, viewer_state.skip_episode = \ setup_for_ep(episode_path, episode, output_dir=args.output_dir, dry_run=args.dry_run) false_negative_loss, \ false_positive_loss, \ false_negative_ind, \ false_positive_ind, \ false_negative_pos, \ false_positive_pos, \ death_ind, \ death_pos = offline_episode_feedback_setup(args.label_mode, episode.frames) viewer_state.episode_num_in_session = episode.info.get( 'episode_num', 0) viewer_state.env_id = args.env_id if args.env_id else None print(viewer_state.episode_num_in_session) viewer_state.paused = args.pause viewer_state.reset_for_episode() frames = episode.frames last_status = None while not viewer_state.EXIT: viewer_state.frame_index = min(viewer_state.frame_index, len(frames) - 1) if args.online: if (viewer_state.current_frame is None or not viewer_state.paused or viewer_state.was_advanced): viewer_state.reset_for_frame( ) # todo - check if works for offline viewer_state, episode = online_receive_frame( viewer_state=viewer_state, episode=episode) else: viewer_state.current_frame = frames[viewer_state.frame_index] if viewer_state.skip_frame: continue if viewer_state.skip_episode or viewer_state.close: break status = (viewer_state.frame_index, viewer_state.current_frame.get_proposed_action(), viewer_state.was_advanced) if status != last_status: print(status) last_status = status img = render(viewer_state.current_frame, viewer_state.action_set, viewer_state.env_id, extra_text=[ "Episode: {}, Frame: {}".format( viewer_state.episode_num_in_session, viewer_state.frame_index) ], image_scale=args.image_scale, prev_actions=viewer_state.prev_actions) cv2.imshow('frame', img) k = cv2.waitKey(viewer_state.delay) & 0xFF if args.online and k not in online_key_set: print( "Key '{}' not supported in online mode; " "if you want to support it, add it to the keyset.".format( chr(k))) if not viewer_state.paused: viewer_state.advance(online=args.online) elif k == K_ESC: # esc key viewer_state.EXIT = True if args.online: print('Exiting...') print('Killing A3C...') call(['tmux', 'kill-session', '-t', 'a3c']) # if save: # print("Writing episode {} to {}".format(episode_num, args.frames_dir)) # save_labels(directory=args.frames_dir, episode=episode, episode_num=episode_num, frames=frames) # cv2.destroyAllWindows() elif k == K_CATASTROPHE and args.label_mode == "catastrophe": # 'c' print('catastrophe!') viewer_state.current_frame.set_label("c") viewer_state.save = True viewer_state.advance(online=args.online) elif k == K_BLOCK and args.label_mode == "block": # 'b' print('blocking action') proposed_action = viewer_state.current_frame.get_proposed_action( ) safe_action = get_safe_action( viewer_state.current_frame.get_proposed_action(), args) if args.blocking_mode == "action_replacement" and proposed_action == safe_action: print('not blocking, action is safe') else: viewer_state.current_frame.set_label("b") viewer_state.current_frame.set_real_action(safe_action) if args.online: viewer_state.feedback_msg = { 'feedback': 'b', 'action': viewer_state.current_frame.get_real_action() } viewer_state.save = True viewer_state.advance(online=args.online) elif k == K_FASTER: viewer_state.delay -= 50 viewer_state.delay = max(1, viewer_state.delay) print('New delay: {}'.format(viewer_state.delay)) if not viewer_state.paused: viewer_state.advance(online=args.online) elif k == K_REMOVE_LABEL and not args.dry_run: print('removing catastrophe label...') viewer_state.current_frame.set_label(None) viewer_state.save = True elif k == K_SLOWER: viewer_state.delay += 50 print('New delay: {}'.format(viewer_state.delay)) if not viewer_state.paused: viewer_state.advance(online=args.online) elif k == K_BACK: viewer_state.frame_index = max(0, viewer_state.frame_index - 1) elif k == K_FWD: viewer_state.advance(online=args.online) elif k == K_PREV_FALSE_POSITIVE and len(false_positive_ind) > 0: false_positive_pos = (false_positive_pos - 1) % len(false_positive_ind) viewer_state.frame_index = false_positive_ind[ false_positive_pos] elif k == K_NEXT_FALSE_POSITIVE and len(false_positive_ind) > 0: false_positive_pos = (false_positive_pos + 1) % len(false_positive_ind) viewer_state.frame_index = false_positive_ind[ false_positive_pos] elif k == K_PREV_FALSE_NEGATIVE and len(false_negative_ind) > 0: false_negative_pos = (false_negative_pos - 1) % len(false_negative_ind) viewer_state.frame_index = false_negative_ind[ false_negative_pos] elif k == K_NEXT_FALSE_NEGATIVE and len(false_negative_ind) > 0: false_negative_pos = (false_negative_pos + 1) % len(false_negative_ind) viewer_state.frame_index = false_negative_ind[ false_negative_pos] elif k == K_NEXT_LABEL: for j in range(len(frames))[viewer_state.frame_index + 1:]: if frames[j].get_label() is not None: print(j, frames) viewer_state.frame_index = j viewer_state.was_advanced = True break elif k == K_PREV_LABEL: for j in reversed( range(len(frames))[:viewer_state.frame_index]): if frames[j].get_label() is not None: viewer_state.frame_index = j viewer_state.was_advanced = True break elif k == K_PREV_DEATH and len(death_ind) > 0: death_pos = (death_pos - 1) % len(death_ind) viewer_state.frame_index = death_ind[death_pos] elif k == K_NEXT_DEATH and len(death_ind) > 0: death_pos = (death_pos + 1) % len(death_ind) viewer_state.frame_index = death_ind[death_pos] elif k == K_SKIP: viewer_state.skip_episode = True elif k == K_PAUSE: viewer_state.paused = not viewer_state.paused if not viewer_state.paused: viewer_state.advance(online=args.online) elif k == K_SAVE_VIDEO: video_output_filename = viewer_state.output_filename + ".avi" fourcc = cv2.VideoWriter_fourcc(*'MJPG') img = render(frames[0], viewer_state.action_set, viewer_state.env_id) writer = cv2.VideoWriter(video_output_filename, fourcc, 20.0, (img.shape[1], img.shape[0])) print('Saving video file as {}'.format(video_output_filename)) # out = cv2.VideoWriter(output_filename, fourcc, 20.0, img.shape) for i, f in enumerate(frames): img = render(f, viewer_state.action_set, viewer_state.env_id, extra_text=[ "Episode: {}, Frame: {}".format( viewer_state.episode_num_in_session, i) ]) writer.write(img) cv2.imshow('frame', img) writer.release() else: if not viewer_state.paused: viewer_state.advance(online=args.online) if viewer_state.save and not args.dry_run: if viewer_state.output_filename is not None: print('Saving to file: {}'.format( viewer_state.output_filename)) save_labels( filename=viewer_state.output_filename, episode=episode, # episode_num=viewer_state.episode_num_in_session, frames=frames) viewer_state.EXIT = True cv2.destroyAllWindows()
def __init__(self, episodes_dir, transform_func=identity): self.episode_paths = [] if episodes_dir: self.episode_paths = transform_func( frame_module.episode_paths(episodes_dir))
def _step(self, action): obs, reward, done, info = self.env.step(action) if should_block(self.last_obs, action): info["frame/should_have_blocked"] = True self.last_obs = obs return obs, reward, done, info if __name__ == "__main__": from IPython import get_ipython # isort:ignore from matplotlib import pyplot as plt # isort:ignore ipython = get_ipython() ipython.magic("matplotlib inline") eps = frame.episode_paths("logs/SpaceInvadersRandom") ep = frame.load_episode(eps[0]) s = set() f = np.copy(ep.frames[50].image) barrier_damage(f) for i in range(0, len(ep.frames)): if should_block(ep.frames[i].image, ep.frames[i].action): print(i, is_ship_below_barrier(ep.frames[i].image)) plt.imshow(ep.frames[i].image[155:195, :, :]) plt.show() plt.imshow(ep.frames[i + 1].image[155:195, :, :]) plt.show() plt.imshow(ep.frames[i + 2].image[155:195, :, :]) plt.show() # for i in range(f.shape[0]):
# imsize=42*42 # observation = x[2:imsize+2].reshape([42,42]) # observation2 = x[imsize+2:].reshape([42,42]) # print(observation.shape) # Plot the grid x = x.reshape(42, 42) plt.imshow(x) plt.gray() #plt.show() plt.savefig('/tmp/catastrophe/frame_{}.png'.format(i)) if __name__ == "__main__": args = parser.parse_args() episode_paths = frame.episode_paths(args.frames_dir) # g = tf.Graph() # with g.as_default(): # classifier = CatastropheClassifierTensorflow() # data = classifier.load_data_for_training(episode_paths, 1, 1, 1) # sess = tf.Session(graph=g) # with sess.as_default(): # classifier.fit(*data, steps=10) # # %mkdir -p "/tmp/foo/0.ckpt" # classifier.save(checkpoint_name="/tmp/foo/classifier/0.ckpt") print(episode_paths) g = tf.Graph() with g.as_default(): blocker = CatastropheBlockerTensorflow(block_radius=1) data = blocker.load_data_for_training(episode_paths, 1, 1, 1)
default="labels/SpaceInvadersOnline") parser.add_argument('--steps', type=int, default=20000) args = parser.parse_args() common_hparams = dict( use_action=True, batch_size=512, input_processes=8, image_shape=[160, 160, 3], image_crop_region=( (0, 200), (0, 160), ), ) paths = frame.episode_paths(args.input_dir) data_loader = DataLoader( hparams=TensorflowClassifierHparams(**common_hparams), labeller=HumanOnlineBlockerLabeller()) lc, ac = action_label_counts(args.input_dir, data_loader, 6) print(lc, ac) common_hparams["expected_positive_weight"] = float(lc[1]) / float(sum(lc)) print(common_hparams["expected_positive_weight"]) datasets = paths, paths, [] # datasets = data_loader.split_episodes(paths, 3, 1, 0, use_all=True, seed=42) if args.logdir == "": logdir = get_unused_logdir("models/tmp/spaceinvaders/online/blocker")