Esempio n. 1
0
def simulate_policy(args):
    data = torch.load(args.file)
    policy = data['evaluation/policy']

    if args.gpu:
        ptu.set_gpu_mode(True)
        policy.cuda()
        print("set gpu")
    print(ptu.device)

    config_file = get_config_file(args.config_file)
    env = NormalizedBoxEnv(
        load_env(args, config_file, args.env_mode, ptu.device.index))

    print("Policy loaded")

    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            render=False,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
Esempio n. 2
0
def simulate_policy(args):
    #   data = joblib.load(args.file)
    data = torch.load(args.file)
    policy = data['evaluation/policy']
    env = NormalizedBoxEnv(gym.make("BipedalWalker-v2"))
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()

    import cv2
    video = cv2.VideoWriter('ppo_test.avi',
                            cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 30,
                            (640, 480))
    index = 0
    path = rollout(
        env,
        policy,
        max_path_length=args.H,
        render=True,
    )
    if hasattr(env, "log_diagnostics"):
        env.log_diagnostics([path])
    logger.dump_tabular()

    for i, img in enumerate(path['images']):
        print(i)
        video.write(img[:, :, ::-1].astype(np.uint8))
        cv2.imwrite("frames/ppo_test/%06d.png" % index, img[:, :, ::-1])
        index += 1

    video.release()
    print("wrote video")
Esempio n. 3
0
def simulate_policy(args):
    data = torch.load(str(args.file))
    #data = joblib.load(str(args.file))
    policy = data['evaluation/policy']
    env = NormalizedBoxEnv(HalfCheetahEnv())
    #env = data['evaluation/env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()

    if args.collect:
        data = []
    for trial in tqdm(range(100)):
        path = rollout(
            env,
            policy,
            max_path_length=args.H + 1,
            render=not args.collect,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
        if args.collect:
            data.append([path['actions'], path['next_observations']])

    if args.collect:
        import pickle
        with open("data/expert.pkl", mode='wb') as f:
            pickle.dump(data, f)
Esempio n. 4
0
def simulate_policy(args):
 #   data = joblib.load(args.file)
    data = torch.load(args.file)
    policy = data['evaluation/policy']
    env = NormalizedBoxEnv(Mani2dEnv())
    # env.reset()
    # print(env.step(env.action_space.sample()))
    # sys.exit()
 #   env = env.wrapped_env.unwrapped
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        # policy.cuda()
    # import cv2
    # video = cv2.VideoWriter('diayn_bipedal_walker_hardcore.avi', cv2.VideoWriter_fourcc('M','J','P','G'), 30, (1200, 800))
    index = 0
    for skill in range(policy.stochastic_policy.skill_dim):
        print(skill)
        for _ in range(3):
            path = rollout(
                env,
                policy,
                skill,
                max_path_length=args.H,
                render=True,
            )
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics([path])
            logger.dump_tabular()
Esempio n. 5
0
def simulate_policy(args):
    #   data = joblib.load(args.file)
    data = torch.load(args.file)
    policy = data['evaluation/policy']
    env = NormalizedBoxEnv(gym.make(str(args.env)))
    #   env = env.wrapped_env.unwrapped
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()

    import cv2
    video = None
    # index = 0
    for skill in range(policy.stochastic_policy.skill_dim):
        for trial in range(3):
            print("skill-{} rollout-{}".format(skill, trial))
            path = rollout(
                env,
                policy,
                skill,
                max_path_length=args.H,
                render=True,
            )
            if hasattr(env, "log_diagnostics"):
                env.log_diagnostics([path])
            logger.dump_tabular()

            for i, img in enumerate(path['images']):
                # print(i)
                # print(img.shape)
                if not video:
                    video = cv2.VideoWriter(
                        '{}.avi'.format(str(args.env)),
                        cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 30,
                        img.shape[:2])
                video.write(img[:, :, ::-1].astype(np.uint8))


#                cv2.imwrite("frames/diayn_bipedal_walker_hardcore.avi/%06d.png" % index, img[:,:,::-1])
# index += 1

    video.release()
    print("wrote video")
Esempio n. 6
0
def simulate_policy(args):
    data = torch.load(str(args.file))
    #data = joblib.load(str(args.file))
    policy = data['evaluation/policy']
    env = NormalizedBoxEnv(HalfCheetahEnv())
    #env = data['evaluation/env']
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
    while True:
        path = rollout(
            env,
            policy,
            max_path_length=args.H,
            render=True,
        )
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()
Esempio n. 7
0
def simulate_policy(args):
    manager_data = torch.load(args.manager_file)
    worker_data = torch.load(args.worker_file)
    policy = manager_data['evaluation/policy']
    worker = worker_data['evaluation/policy']
    env = NormalizedBoxEnv(gym.make(str(args.env)))
    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()

    import cv2
    video = cv2.VideoWriter('ppo_dirichlet_diayn_bipedal_walker_hardcore.avi',
                            cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 30,
                            (1200, 800))
    index = 0

    path = rollout(
        env,
        policy,
        worker,
        continuous=True,
        max_path_length=args.H,
        render=True,
    )
    if hasattr(env, "log_diagnostics"):
        env.log_diagnostics([path])
    logger.dump_tabular()

    for i, img in enumerate(path['images']):
        print(i)
        video.write(img[:, :, ::-1].astype(np.uint8))
        #        cv2.imwrite("frames/ppo_dirichlet_diayn_policy_bipedal_walker_hardcore/%06d.png" % index, img[:,:,::-1])
        index += 1

    video.release()
    print("wrote video")
Esempio n. 8
0
def simulate_policy(args):
    data = joblib.load(args.file)

    cont = False

    if 'policies' in data:
        policy = data['policies'][0]
    else:
        policy = data['policy']
    env = NormalizedBoxEnv(create_swingup())  #data['env']

    print("Policy loaded")
    if args.gpu:
        set_gpu_mode(True)
        policy.cuda()
        data['qf1'].cuda()
    if isinstance(policy, PyTorchModule):
        policy.train(False)

    diayn = 'df' in data
    rnd = 'rf' in data

    if diayn:
        skills = len(data['eval_policy'].skill_vec)
        disc = data['df']

        policy = OptionPolicy(policy, skills, cont)
        if args.gpu:
            disc.cuda()
        if isinstance(policy, PyTorchModule):
            disc.train(False)

    if rnd:
        data['rf'].cuda()
        data['pf'].cuda()
        data['qf1'].cuda()

    import cv2
    video = cv2.VideoWriter('video.avi', cv2.VideoWriter_fourcc(*"H264"), 30,
                            (640, 480))
    index = 0

    truth, pred = [], []

    if cont:
        eps = 1
    elif diayn:
        eps = skills * 2
    else:
        eps = 5

    Rs = []

    for ep in range(eps):
        if diayn and not cont:
            z_index = ep // 2
            policy.set_z(z_index)

        path = rollout(
            env,
            policy,
            max_path_length=args.H * skills if cont else args.H,
            animated=True,
        )

        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics([path])
        logger.dump_tabular()

        total_r = 0

        if diayn:
            predictions = F.log_softmax(
                disc(torch.FloatTensor(path['observations']).cuda()),
                1).cpu().detach().numpy()
            probs = predictions.max(1)
            labels = predictions.argmax(1)

            if cont:
                for k in range(skills):
                    truth.extend([k] * 100)
            else:
                truth.extend([z_index] * len(labels))
            pred.extend(labels.tolist())

        if rnd:
            random_feats = data['rf'](torch.FloatTensor(
                path['observations']).cuda())
            pred_feats = data['pf'](torch.FloatTensor(
                path['observations']).cuda())

            i_rewards = ((random_feats -
                          pred_feats)**2.0).sum(1).cpu().data.numpy()

        q_pred = data['qf1'](torch.FloatTensor(path['observations']).cuda(),
                             torch.FloatTensor(
                                 path['actions']).cuda()).cpu().data.numpy()

        for i, (img, r, s) in enumerate(
                zip(path['images'], path['rewards'], path['observations'])):
            #video.write(img[:,:,::-1].astype(np.uint8))
            total_r += r[0]
            img = img.copy()
            img = np.rot90(img, 3).copy()
            col = (255, 0, 255)
            cv2.putText(img, "step: %d" % (i + 1), (20, 40),
                        cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2, cv2.LINE_AA)

            if diayn:
                if cont:
                    cv2.putText(img, "z: %s" % str(truth[i]), (20, 80),
                                cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255),
                                2, cv2.LINE_AA)
                else:
                    cv2.putText(img, "z: %s" % str(z_index), (20, 80),
                                cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255),
                                2, cv2.LINE_AA)

                cv2.putText(img,
                            "disc_pred: %s (%.3f)" % (labels[i], probs[i]),
                            (20, 120), cv2.FONT_HERSHEY_SIMPLEX, 1.0,
                            (255, 255, 255), 2, cv2.LINE_AA)
                cv2.putText(img, "reward: %.3f" % r[0], (20, 160),
                            cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2,
                            cv2.LINE_AA)
                cv2.putText(img, "total reward: %.1f" % total_r, (20, 200),
                            cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2,
                            cv2.LINE_AA)
                cv2.putText(img, "action: %s" % path['actions'][i], (20, 240),
                            cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2,
                            cv2.LINE_AA)
            else:
                cv2.putText(img, "reward: %.1f" % r[0], (20, 80),
                            cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2, cv2.LINE_AA)
                cv2.putText(img, "total reward: %.1f" % total_r, (20, 120),
                            cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2, cv2.LINE_AA)
                y = 120

            if rnd:
                cv2.putText(img, "i reward (unscaled): %.3f" % i_rewards[i],
                            (20, 160), cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2,
                            cv2.LINE_AA)
                #cv2.rectangle(img, (20, 180), (20 + int(q_pred[i, 0]), 200), (255, 0, 255), -1)
                cv2.rectangle(img, (20, 200),
                              (20 + int(i_rewards[i] * 10), 220),
                              (255, 255, 0), -1)
                y = 220

            try:
                y += 40
                cv2.putText(img, "Q: %.3f" % q_pred[i], (20, y),
                            cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2, cv2.LINE_AA)
            except:
                y += 40
                cv2.putText(img, "Q:" + str([q for q in q_pred[i]]), (20, y),
                            cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2, cv2.LINE_AA)
            y += 40
            cv2.putText(img, str(["%.3f" % x
                                  for x in path['observations'][i]]), (20, y),
                        cv2.FONT_HERSHEY_SIMPLEX, 1.0, col, 2, cv2.LINE_AA)

            try:
                cv2.imwrite("frames/%06d.png" % index, img[:, :, ::-1])
            except:
                cv2.imwrite("frames/%06d.png" % index, img[:, :])
            index += 1

        if diayn:
            print(z_index, ":", total_r)
        Rs.append(total_r)

    print("best", np.argmax(Rs))
    print("worst", np.argmin(Rs))

    video.release()
    print("wrote video")

    if diayn:
        import sklearn
        from sklearn.metrics import confusion_matrix
        import matplotlib as mpl
        import itertools
        mpl.use('Agg')
        import matplotlib.pyplot as plt
        normalize = False
        classes = range(skills)
        cm = confusion_matrix(truth, pred)
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            print("Normalized confusion matrix")
        else:
            print('Confusion matrix, without normalization')
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.colorbar()
        tick_marks = np.arange(skills)
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)
        """
        fmt = '.2f' if normalize else 'd'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, format(cm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        """

        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        plt.tight_layout()
        plt.savefig("confusion.png")
Esempio n. 9
0
def visualize_policy(args):
    variant_overwrite = dict(
        params_pkl=args.params_pkl,
        num_historical_policies=args.num_historical_policies,
        env_kwargs=dict(
            reward_type='indicator',
            sample_goal=False,
            shape_rewards=False,
            distance_threshold=0.1,
            terminate_upon_success=False,
            terminate_upon_failure=False,
        ))
    if args.logdir == '':
        variant = variant_overwrite
        env = NormalizedBoxEnv(
            ManipulationEnv(**variant_overwrite['env_kwargs']))
        eval_policy = RandomPolicy(env.action_space)
    else:
        env, _, data, variant = load_experiment(args.logdir, variant_overwrite)
        eval_policy = data[
            'eval_policy'] if args.use_deterministic_policy else data['policy']
        if not args.cpu:
            set_gpu_mode(True)
            eval_policy.cuda()
        print("Loaded policy:", eval_policy)

        if 'smm_kwargs' in variant:
            # Iterate through each latent-conditioned policy.
            num_skills = variant['smm_kwargs']['num_skills']
            print('Running SMM policy with {} skills.'.format(num_skills))
            import rlkit.torch.smm.utils as utils

            class PartialPolicy:
                def __init__(polself, policy):
                    polself._policy = policy
                    polself._num_skills = num_skills
                    polself._z = -1
                    polself.reset()

                def get_action(polself, ob):
                    aug_ob = utils.concat_ob_z(ob, polself._z,
                                               polself._num_skills)
                    return polself._policy.get_action(aug_ob)

                def sample_skill(polself):
                    z = np.random.choice(polself._num_skills)
                    return z

                def reset(polself):
                    polself._z = (polself._z + 1) % polself._num_skills
                    print("Using skill z:", polself._z)
                    return polself._policy.reset()

            eval_policy = PartialPolicy(eval_policy)

    paths = []
    for _ in range(args.num_episodes):
        eval_policy.reset()
        path = rollout(
            env,
            eval_policy,
            max_path_length=args.max_path_length,
            animated=(not args.norender),
        )
        paths.append(path)
        if hasattr(env, "log_diagnostics"):
            env.log_diagnostics(paths)
        if hasattr(env, "get_diagnostics"):
            diagnostics = env.get_diagnostics(paths)
            for key, val in diagnostics.items():
                logger.record_tabular(key, val)
            logger.dump_tabular(with_prefix=False, with_timestamp=False)
    if hasattr(env, "draw"):
        env.draw(paths, save_dir="")