예제 #1
0
def copy_episodes(indir, outdir, n):
    episode_paths = frame.episode_paths(indir)
    np.random.shuffle(episode_paths)
    episode_paths = episode_paths[:n]
    start = len(indir)
    for p in tqdm.tqdm(episode_paths):
        assert p.startswith(indir), p
        outfile = outdir + p[start:]
        os.makedirs(os.path.dirname(outfile), exist_ok=True)
        shutil.copyfile(p, outfile)
예제 #2
0
def label_episodes(directory, classifier):
    episode_paths = frame.episode_paths(directory)
    data_loader = DataLoader(hparams=classifier.hparams)
    for episode_path in tqdm.tqdm(episode_paths):
        try:
            data_loader.predict_episodes(classifier, [episode_path],
                                         prefix="frame/classifier_")
        except EOFError as e:
            traceback.print_exception(e)
            print("Error reading {}".format(episode_path))
            os.remove(episode_path)
예제 #3
0
#!/usr/bin/env python3

import os
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "humanrl"))

if __name__ == '__main__':
    import multiprocessing

    from humanrl import frame
    import tqdm

    episode_paths = frame.episode_paths("logs/")
    print(len(episode_paths))
    # frame.fix_episode(episode_paths[0])
    pool = multiprocessing.Pool(multiprocessing.cpu_count() - 1)
    # for _ in tqdm.tqdm(
    #         pool.imap_unordered(frame.check_episode, episode_paths), total=len(episode_paths)):
    #     pass
# %timeit frame.fix_episode(episode_paths[0])
예제 #4
0
def main():
    random.seed()
    args = parser.parse_args()

    print("Displaying video...")

    K_FASTER = ord('1')
    K_SLOWER = ord('2')
    K_PAUSE = ord('3')
    K_BACK = ord('4')
    K_FWD = ord('5')

    K_PREV_FALSE_POSITIVE = ord('7')
    K_NEXT_FALSE_POSITIVE = ord('8')
    K_PREV_FALSE_NEGATIVE = ord('9')
    K_NEXT_FALSE_NEGATIVE = ord('0')

    K_PREV_LABEL = ord('u')
    K_NEXT_LABEL = ord('i')
    K_PREV_DEATH = ord('o')
    K_NEXT_DEATH = ord('p')

    K_CATASTROPHE = ord('c')
    K_BLOCK = ord('b')
    K_SAVE_VIDEO = ord('v')
    K_ESC = ord('\x1b')  # esc
    K_SKIP = ord('s')
    K_REMOVE_LABEL = ord('r')
    K_NONE = -1 & 0xFF

    online_key_set = frozenset(
        [K_ESC, K_BLOCK, K_FASTER, K_SLOWER, K_NONE, K_PAUSE, K_FWD])

    # setup for offline mode
    if not args.online:
        episode_paths = frame_module.episode_paths(args.frames_dir)
        print("Episodes: {}".format(len(episode_paths)))
        if args.reversed:
            episode_paths = reversed(episode_paths)
        if args.random:
            random.shuffle(episode_paths)

    if args.output_dir:
        os.makedirs(args.output_dir, exist_ok=True)

    episodes_iter = None
    viewer_state = None
    if args.online:
        print('In online mode...')
        # listen to port for frames
        print('Waiting for connection...')
        address = ('localhost', 6666)
        listener = Listener(address, authkey=b'no-catastrophes-allowed')
        conn = listener.accept()
        print('Connection accepted from {}'.format(listener.last_accepted))

        viewer_state = ViewerState(conn=conn, output_dir=args.output_dir)
        episodes_iter = OnlineEpisodes()
    else:
        transform_func = identity
        if args.reversed:
            transform_func = reversed
        if args.random:
            transform_func = shuffle
        episodes_iter = Episodes(args.frames_dir, transform_func)
        viewer_state = ViewerState(output_dir=args.output_dir)

    for episode_path, episode, viewer_state.frame_index in episodes_iter:
        if episode is None:
            print('Error: no episode')
            continue
        if viewer_state.EXIT:
            break

        if not args.online:

            labels_check = np.array(
                [frame.get_label() == 'b' for frame in episode.frames])
            if any(labels_check):
                print('FOUND SOME LABELS ------------------')

            # set up output_dir for episode
            viewer_state.output_filename, viewer_state.action_set, viewer_state.skip_episode = \
                setup_for_ep(episode_path, episode, output_dir=args.output_dir, dry_run=args.dry_run)

            false_negative_loss, \
            false_positive_loss, \
            false_negative_ind, \
            false_positive_ind, \
            false_negative_pos, \
            false_positive_pos, \
            death_ind, \
            death_pos = offline_episode_feedback_setup(args.label_mode, episode.frames)

            viewer_state.episode_num_in_session = episode.info.get(
                'episode_num', 0)
            viewer_state.env_id = args.env_id if args.env_id else None

            print(viewer_state.episode_num_in_session)

        viewer_state.paused = args.pause
        viewer_state.reset_for_episode()
        frames = episode.frames
        last_status = None
        while not viewer_state.EXIT:
            viewer_state.frame_index = min(viewer_state.frame_index,
                                           len(frames) - 1)

            if args.online:
                if (viewer_state.current_frame is None
                        or not viewer_state.paused
                        or viewer_state.was_advanced):
                    viewer_state.reset_for_frame(
                    )  # todo - check if works for offline
                    viewer_state, episode = online_receive_frame(
                        viewer_state=viewer_state, episode=episode)
            else:
                viewer_state.current_frame = frames[viewer_state.frame_index]

            if viewer_state.skip_frame:
                continue
            if viewer_state.skip_episode or viewer_state.close:
                break

            status = (viewer_state.frame_index,
                      viewer_state.current_frame.get_proposed_action(),
                      viewer_state.was_advanced)
            if status != last_status:
                print(status)
                last_status = status

            img = render(viewer_state.current_frame,
                         viewer_state.action_set,
                         viewer_state.env_id,
                         extra_text=[
                             "Episode: {}, Frame: {}".format(
                                 viewer_state.episode_num_in_session,
                                 viewer_state.frame_index)
                         ],
                         image_scale=args.image_scale,
                         prev_actions=viewer_state.prev_actions)

            cv2.imshow('frame', img)
            k = cv2.waitKey(viewer_state.delay) & 0xFF
            if args.online and k not in online_key_set:
                print(
                    "Key '{}' not supported in online mode; "
                    "if you want to support it, add it to the keyset.".format(
                        chr(k)))
                if not viewer_state.paused:
                    viewer_state.advance(online=args.online)
            elif k == K_ESC:  # esc key
                viewer_state.EXIT = True
                if args.online:
                    print('Exiting...')
                    print('Killing A3C...')
                    call(['tmux', 'kill-session', '-t', 'a3c'])
                # if save:
                #     print("Writing episode {} to {}".format(episode_num, args.frames_dir))
                #     save_labels(directory=args.frames_dir, episode=episode, episode_num=episode_num, frames=frames)
                # cv2.destroyAllWindows()
            elif k == K_CATASTROPHE and args.label_mode == "catastrophe":  # 'c'
                print('catastrophe!')
                viewer_state.current_frame.set_label("c")
                viewer_state.save = True
                viewer_state.advance(online=args.online)
            elif k == K_BLOCK and args.label_mode == "block":  # 'b'
                print('blocking action')
                proposed_action = viewer_state.current_frame.get_proposed_action(
                )
                safe_action = get_safe_action(
                    viewer_state.current_frame.get_proposed_action(), args)
                if args.blocking_mode == "action_replacement" and proposed_action == safe_action:
                    print('not blocking, action is safe')
                else:
                    viewer_state.current_frame.set_label("b")
                    viewer_state.current_frame.set_real_action(safe_action)
                    if args.online:
                        viewer_state.feedback_msg = {
                            'feedback': 'b',
                            'action':
                            viewer_state.current_frame.get_real_action()
                        }
                    viewer_state.save = True
                viewer_state.advance(online=args.online)
            elif k == K_FASTER:
                viewer_state.delay -= 50
                viewer_state.delay = max(1, viewer_state.delay)
                print('New delay: {}'.format(viewer_state.delay))
                if not viewer_state.paused:
                    viewer_state.advance(online=args.online)
            elif k == K_REMOVE_LABEL and not args.dry_run:
                print('removing catastrophe label...')
                viewer_state.current_frame.set_label(None)
                viewer_state.save = True
            elif k == K_SLOWER:
                viewer_state.delay += 50
                print('New delay: {}'.format(viewer_state.delay))
                if not viewer_state.paused:
                    viewer_state.advance(online=args.online)
            elif k == K_BACK:
                viewer_state.frame_index = max(0, viewer_state.frame_index - 1)
            elif k == K_FWD:
                viewer_state.advance(online=args.online)
            elif k == K_PREV_FALSE_POSITIVE and len(false_positive_ind) > 0:
                false_positive_pos = (false_positive_pos -
                                      1) % len(false_positive_ind)
                viewer_state.frame_index = false_positive_ind[
                    false_positive_pos]
            elif k == K_NEXT_FALSE_POSITIVE and len(false_positive_ind) > 0:
                false_positive_pos = (false_positive_pos +
                                      1) % len(false_positive_ind)
                viewer_state.frame_index = false_positive_ind[
                    false_positive_pos]
            elif k == K_PREV_FALSE_NEGATIVE and len(false_negative_ind) > 0:
                false_negative_pos = (false_negative_pos -
                                      1) % len(false_negative_ind)
                viewer_state.frame_index = false_negative_ind[
                    false_negative_pos]
            elif k == K_NEXT_FALSE_NEGATIVE and len(false_negative_ind) > 0:
                false_negative_pos = (false_negative_pos +
                                      1) % len(false_negative_ind)
                viewer_state.frame_index = false_negative_ind[
                    false_negative_pos]
            elif k == K_NEXT_LABEL:
                for j in range(len(frames))[viewer_state.frame_index + 1:]:
                    if frames[j].get_label() is not None:
                        print(j, frames)
                        viewer_state.frame_index = j
                        viewer_state.was_advanced = True
                        break
            elif k == K_PREV_LABEL:
                for j in reversed(
                        range(len(frames))[:viewer_state.frame_index]):
                    if frames[j].get_label() is not None:
                        viewer_state.frame_index = j
                        viewer_state.was_advanced = True
                        break
            elif k == K_PREV_DEATH and len(death_ind) > 0:
                death_pos = (death_pos - 1) % len(death_ind)
                viewer_state.frame_index = death_ind[death_pos]
            elif k == K_NEXT_DEATH and len(death_ind) > 0:
                death_pos = (death_pos + 1) % len(death_ind)
                viewer_state.frame_index = death_ind[death_pos]
            elif k == K_SKIP:
                viewer_state.skip_episode = True
            elif k == K_PAUSE:
                viewer_state.paused = not viewer_state.paused
                if not viewer_state.paused:
                    viewer_state.advance(online=args.online)
            elif k == K_SAVE_VIDEO:
                video_output_filename = viewer_state.output_filename + ".avi"
                fourcc = cv2.VideoWriter_fourcc(*'MJPG')
                img = render(frames[0], viewer_state.action_set,
                             viewer_state.env_id)
                writer = cv2.VideoWriter(video_output_filename, fourcc, 20.0,
                                         (img.shape[1], img.shape[0]))
                print('Saving video file as {}'.format(video_output_filename))
                # out = cv2.VideoWriter(output_filename, fourcc, 20.0, img.shape)
                for i, f in enumerate(frames):
                    img = render(f,
                                 viewer_state.action_set,
                                 viewer_state.env_id,
                                 extra_text=[
                                     "Episode: {}, Frame: {}".format(
                                         viewer_state.episode_num_in_session,
                                         i)
                                 ])
                    writer.write(img)
                    cv2.imshow('frame', img)
                writer.release()
            else:
                if not viewer_state.paused:
                    viewer_state.advance(online=args.online)

        if viewer_state.save and not args.dry_run:
            if viewer_state.output_filename is not None:
                print('Saving to file: {}'.format(
                    viewer_state.output_filename))
                save_labels(
                    filename=viewer_state.output_filename,
                    episode=episode,
                    # episode_num=viewer_state.episode_num_in_session,
                    frames=frames)
    viewer_state.EXIT = True
    cv2.destroyAllWindows()
예제 #5
0
 def __init__(self, episodes_dir, transform_func=identity):
     self.episode_paths = []
     if episodes_dir:
         self.episode_paths = transform_func(
             frame_module.episode_paths(episodes_dir))
예제 #6
0
    def _step(self, action):
        obs, reward, done, info = self.env.step(action)
        if should_block(self.last_obs, action):
            info["frame/should_have_blocked"] = True
        self.last_obs = obs
        return obs, reward, done, info


if __name__ == "__main__":
    from IPython import get_ipython  # isort:ignore
    from matplotlib import pyplot as plt  # isort:ignore

    ipython = get_ipython()
    ipython.magic("matplotlib inline")

    eps = frame.episode_paths("logs/SpaceInvadersRandom")
    ep = frame.load_episode(eps[0])
    s = set()
    f = np.copy(ep.frames[50].image)
    barrier_damage(f)
    for i in range(0, len(ep.frames)):
        if should_block(ep.frames[i].image, ep.frames[i].action):
            print(i, is_ship_below_barrier(ep.frames[i].image))
            plt.imshow(ep.frames[i].image[155:195, :, :])
            plt.show()
            plt.imshow(ep.frames[i + 1].image[155:195, :, :])
            plt.show()
            plt.imshow(ep.frames[i + 2].image[155:195, :, :])
            plt.show()

    # for i in range(f.shape[0]):
    # imsize=42*42
    # observation = x[2:imsize+2].reshape([42,42])
    # observation2 = x[imsize+2:].reshape([42,42])
    # print(observation.shape)
    # Plot the grid
    x = x.reshape(42, 42)
    plt.imshow(x)
    plt.gray()
    #plt.show()
    plt.savefig('/tmp/catastrophe/frame_{}.png'.format(i))


if __name__ == "__main__":
    args = parser.parse_args()

    episode_paths = frame.episode_paths(args.frames_dir)
    #     g = tf.Graph()
    #     with g.as_default():
    #         classifier = CatastropheClassifierTensorflow()
    #         data = classifier.load_data_for_training(episode_paths, 1, 1, 1)
    #         sess = tf.Session(graph=g)
    #         with sess.as_default():
    #             classifier.fit(*data, steps=10)
    # #             %mkdir -p "/tmp/foo/0.ckpt"
    #             classifier.save(checkpoint_name="/tmp/foo/classifier/0.ckpt")
    print(episode_paths)

    g = tf.Graph()
    with g.as_default():
        blocker = CatastropheBlockerTensorflow(block_radius=1)
        data = blocker.load_data_for_training(episode_paths, 1, 1, 1)
예제 #8
0
                        default="labels/SpaceInvadersOnline")
    parser.add_argument('--steps', type=int, default=20000)
    args = parser.parse_args()

    common_hparams = dict(
        use_action=True,
        batch_size=512,
        input_processes=8,
        image_shape=[160, 160, 3],
        image_crop_region=(
            (0, 200),
            (0, 160),
        ),
    )

    paths = frame.episode_paths(args.input_dir)

    data_loader = DataLoader(
        hparams=TensorflowClassifierHparams(**common_hparams),
        labeller=HumanOnlineBlockerLabeller())

    lc, ac = action_label_counts(args.input_dir, data_loader, 6)
    print(lc, ac)
    common_hparams["expected_positive_weight"] = float(lc[1]) / float(sum(lc))
    print(common_hparams["expected_positive_weight"])

    datasets = paths, paths, []
    # datasets = data_loader.split_episodes(paths, 3, 1, 0, use_all=True, seed=42)

    if args.logdir == "":
        logdir = get_unused_logdir("models/tmp/spaceinvaders/online/blocker")