예제 #1
0
def main(_):
    flags.mark_flags_as_required(['base_dir'])
    if FLAGS.custom_base_dir_from_hparams is not None:
        FLAGS.base_dir = os.path.join(FLAGS.base_dir,
                                      FLAGS.custom_base_dir_from_hparams)
    else:
        # Add Work unit to base directory path, if it exists.
        if 'xm_wid' in FLAGS and FLAGS.xm_wid > 0:
            FLAGS.base_dir = os.path.join(FLAGS.base_dir, str(FLAGS.xm_wid))
    xm_parameters = (None
                     if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters)
    if xm_parameters:
        xm_params = json.loads(xm_parameters)
        if 'env_name' in xm_params:
            FLAGS.env_name = xm_params['env_name']
    if FLAGS.env_name is None:
        base_dir = os.path.join(
            FLAGS.base_dir, '{}_{}'.format(FLAGS.num_states,
                                           FLAGS.num_actions))
    else:
        base_dir = os.path.join(FLAGS.base_dir, FLAGS.env_name)
    base_dir = os.path.join(base_dir, 'PVF', FLAGS.estimator)
    if not tf.io.gfile.exists(base_dir):
        tf.io.gfile.makedirs(base_dir)
    if FLAGS.env_name is not None:
        gin.add_config_file_search_path(_ENV_CONFIG_PATH)
        gin.parse_config_files_and_bindings(
            config_files=[f'{FLAGS.env_name}.gin'],
            bindings=FLAGS.gin_bindings,
            skip_unknown=False)
        env_id = mon_minigrid.register_environment()
        env = gym.make(env_id)
        env = RGBImgObsWrapper(env)  # Get pixel observations
        # Get tabular observation and drop the 'mission' field:
        env = mdp_wrapper.MDPWrapper(env, get_rgb=False)
        env = coloring_wrapper.ColoringWrapper(env)
    if FLAGS.env_name is None:
        env = random_mdp.RandomMDP(FLAGS.num_states, FLAGS.num_actions)
        # We add the discount factor to the environment.

    env.gamma = FLAGS.gamma

    logging.set_verbosity(logging.INFO)
    gin_files = []
    gin_bindings = FLAGS.gin_bindings

    runner = TrainRunner(base_dir, env, FLAGS.epochs, FLAGS.lr,
                         FLAGS.estimator, FLAGS.alpha, FLAGS.optimizer,
                         FLAGS.use_l2_reg, FLAGS.reg_coeff,
                         FLAGS.use_penalty, FLAGS.j, FLAGS.num_rows,
                         jax.random.PRNGKey(0), FLAGS.epochs - 1,
                         FLAGS.epochs - 1)
    runner.train()
    def __init__(self,
                 env_id,
                 is_render,
                 env_idx,
                 child_conn,
                 history_size=1,
                 h=84,
                 w=84,
                 life_done=True,
                 sticky_action=False,
                 p=0.25):
        super(GridEnvironment, self).__init__()
        self.daemon = True
        self.env = ImgObsWrapper(
            RGBImgObsWrapper(ReseedWrapper(gym.make(env_id))))
        self.env_id = env_id
        self.is_render = is_render
        self.env_idx = env_idx
        self.steps = 0
        self.episode = 0
        self.rall = 0
        self.recent_rlist = deque(maxlen=100)
        self.child_conn = child_conn

        self.sticky_action = sticky_action
        self.last_action = 0
        self.p = p

        self.history_size = history_size
        self.history = np.zeros([history_size, h, w])
        self.h = h
        self.w = w

        self.reset()
예제 #3
0
    def __init__(self, config, key, *, num_tasks):
        self.config = config
        if config.env.name == 'gym':
            env = gym.make(config.env.gym.id)
            env = RGBImgObsWrapper(env)  # Get pixel observations
            # Get tabular observation and drop the 'mission' field:
            env = mdp_wrapper.MDPWrapper(env, get_rgb=False)
            env = coloring_wrapper.ColoringWrapper(env)
        elif config.env.name == 'random':
            env = random_mdp.RandomMDP(config.env.random.num_states,
                                       config.env.random.num_actions)
        self.env = env

        P = jnp.transpose(env.transition_probs, [1, 0, 2])

        self.key, subkey = jax.random.split(key)
        if config.env.task == 'random_reward':
            R = uniform_random_rewards(subkey, num_tasks, env.num_states,
                                       env.num_actions)
            d = jnp.ones((env.num_states, env.num_actions),
                         dtype=jnp.float32) / env.num_actions
            self.aux_task_matrix = jax.vmap(policy_evaluation,
                                            in_axes=(None, 0, None),
                                            out_axes=(-1))(P, R, d)
        elif config.env.task == 'random_policy':
            R = env.rewards
            d = random_deterministic_policies(subkey, num_tasks,
                                              env.num_states, env.num_actions)
            self.aux_task_matrix = jax.vmap(policy_evaluation,
                                            in_axes=(None, None, 0),
                                            out_axes=(-1))(P, R, d)
        else:
            raise ValueError(
                f'Invalid value for config.env.task: {config.env.task}')
예제 #4
0
def env_setup(name='MiniGrid-Empty-8x8-v0'):
    env_name = name
    env = gym.make(env_name)
    env = RGBImgObsWrapper(env)
    env.max_steps = min(env.max_steps, 200)
    env.seed(12345)
    env.reset()
    return env
예제 #5
0
def mini_grid_wrapper(env_id, max_frames=0, clip_rewards=True):
    env = gym.make(env_id)
    env = ReseedWrapper(env, seeds=[0])
    env = RGBImgObsWrapper(env)
    env = ImgObsWrapper(env)
    if max_frames:
        env = pfrl.wrappers.ContinuingTimeLimit(
            env, max_episode_steps=max_frames)
    # env = atari_wrappers.MaxAndSkipEnv(env, skip=0)
    env = atari_wrappers.wrap_deepmind(
        env, episode_life=False, clip_rewards=clip_rewards)
    return env
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    gin.parse_config_files_and_bindings([
        os.path.join(mon_minigrid.GIN_FILES_PREFIX, '{}.gin'.format(
            FLAGS.env_name))
    ],
                                        bindings=FLAGS.gin_bindings,
                                        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    # Get tabular observation and drop the 'mission' field:
    env = tabular_wrapper.TabularWrapper(env, get_rgb=True)
    env.reset()

    num_frames = 0
    max_num_frames = 500

    if not tf.io.gfile.exists(FLAGS.file_path):
        tf.io.gfile.makedirs(FLAGS.file_path)

    print('Available actions:')
    for a in ACTION_MAPPINGS:
        print('\t{}: "{}"'.format(ACTION_MAPPINGS[a], a))
    print()
    undisc_return = 0
    while num_frames < max_num_frames:
        draw_ascii_view(env)
        a = input('action: ')
        if a not in ACTION_MAPPINGS:
            print('Unrecognized action.')
            continue
        action = env.DirectionalActions[ACTION_MAPPINGS[a]].value
        obs, reward, done, _ = env.step(action)
        undisc_return += reward
        num_frames += 1

        print('t:', num_frames, '   s:', obs['state'])
        # Draw environment frame just for simple visualization
        plt.imshow(obs['image'])
        path = os.path.join(FLAGS.file_path, 'obs_{}.png'.format(num_frames))
        plt.savefig(path)
        plt.clf()

        if done:
            break

    print('Undiscounted return: %.2f' % undisc_return)
    env.close()
예제 #7
0
def wrap_deepmind_minigrid(env, dim=84, framestack=True, seed=0):
    """Configure environment for DeepMind-style gridworlds.

    Note that we assume reward clipping is done outside the wrapper.

    Args:
        dim (int): Dimension to resize observations to (dim x dim).
        framestack (bool): Whether to framestack observations.
    """
    env = RGBImgObsWrapper(env)
    env = GridworldPreprocess(
        env,
        seed,
    )
    env = MonitorEnv(env)
    env = WarpFrame3D(env, dim)
    # if framestack:
    #     env = FrameStack(env, 4)
    env = GridworldPostprocess(env)
    return env
예제 #8
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    gin.parse_config_files_and_bindings([
        os.path.join(mon_minigrid.GIN_FILES_PREFIX, '{}.gin'.format(FLAGS.env))
    ],
                                        bindings=FLAGS.gin_bindings,
                                        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    # Get tabular observation and drop the 'mission' field:
    env = mdp_wrapper.MDPWrapper(env)
    env = coloring_wrapper.ColoringWrapper(env)
    values = np.zeros(env.num_states)
    error = FLAGS.tolerance * 2
    i = 0
    while error > FLAGS.tolerance:
        new_values = np.copy(values)
        for s in range(env.num_states):
            max_value = 0.
            for a in range(env.num_actions):
                curr_value = (env.rewards[s, a] + FLAGS.gamma *
                              np.matmul(env.transition_probs[s, a, :], values))
                if curr_value > max_value:
                    max_value = curr_value
            new_values[s] = max_value
        error = np.max(abs(new_values - values))
        values = new_values
        i += 1
        if i % 1000 == 0:
            print('Error after {} iterations: {}'.format(i, error))
    print('Found V* in {} iterations'.format(i))
    print(values)
    if FLAGS.values_image_file is not None:
        cmap = cm.get_cmap('plasma', 256)
        norm = colors.Normalize(vmin=min(values), vmax=max(values))
        obs_image = env.render_custom_observation(env.reset(),
                                                  values,
                                                  cmap,
                                                  boundary_values=[1.0, 4.5])
        m = cm.ScalarMappable(cmap=cmap, norm=norm)
        m.set_array(obs_image)
        plt.imshow(obs_image)
        plt.colorbar(m)
        plt.savefig(FLAGS.values_image_file)
        plt.clf()
    env.close()
예제 #9
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    gin.parse_config_files_and_bindings(
        [os.path.join(mon_minigrid.GIN_FILES_PREFIX, 'classic_fourrooms.gin')],
        bindings=FLAGS.gin_bindings,
        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    env = ImgObsWrapper(env)  # Get rid of the 'mission' field
    env.reset()

    num_frames = 0
    max_num_frames = 500

    if not tf.io.gfile.exists(FLAGS.file_path):
        tf.io.gfile.makedirs(FLAGS.file_path)

    undisc_return = 0
    while num_frames < max_num_frames:
        # Act randomly
        obs, reward, done, _ = env.step(env.action_space.sample())
        undisc_return += reward
        num_frames += 1

        # Draw environment frame just for simple visualization
        plt.imshow(obs)
        path = os.path.join(FLAGS.file_path, 'obs_{}.png'.format(num_frames))

        plt.savefig(path)
        plt.clf()

        if done:
            break

    print('Undiscounted return: %.2f' % undisc_return)
    env.close()
예제 #10
0
    def __init__(self, env_id, seed, max_episode_length=1000):
        super(GameEnv, self).__init__()
        extra_args = ENV_GAMES_ARGS.get(env_id, {})
        self.env_id = env_id
        if env_id == "TetrisA-v2":
            self._env = JoypadSpace(gym_tetris.make(env_id, **extra_args), SIMPLE_MOVEMENT)
        elif "ple" in env_id:
            self._env = gym_ple.make(env_id, **extra_args)
        elif "MiniGrid" in env_id:
            # self._env = AbsoluteActionGrid(FullyObsWrapper(gym.make(env_id)))
            self._env = AbsoluteActionGrid(RGBImgObsWrapper(gym.make(env_id)))
        elif "Sokoban" in env_id:
            self._env = TinySokoban(gym.make(env_id, **extra_args))
        elif "MazeEnv" in env_id:
            self._env = MazeEnvImage(mazenv.Env(mazenv.prim((8, 8))), randomize=True)
        else:
            self._env = gym.make(env_id, **extra_args)

        self._env.seed(seed)
        self.action_repeat = GAME_ENVS_ACTION_REPEATS.get(env_id, 1)
        self.max_episode_length = max_episode_length * self.action_repeat
        self.t = 0
예제 #11
0
def main(_):
  flags.mark_flags_as_required(['base_dir'])
  if FLAGS.custom_base_dir_from_hparams is not None:
    FLAGS.base_dir = os.path.join(FLAGS.base_dir,
                                  FLAGS.custom_base_dir_from_hparams)
  else:
    # Add Work unit to base directory path, if it exists.
    if 'xm_wid' in FLAGS and FLAGS.xm_wid > 0:
      FLAGS.base_dir = os.path.join(FLAGS.base_dir, str(FLAGS.xm_wid))
  xm_parameters = (None
                   if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters)
  if xm_parameters:
    xm_params = json.loads(xm_parameters)
    if 'env_name' in xm_params:
      FLAGS.env_name = xm_params['env_name']
  if FLAGS.env_name is None:
    base_dir = os.path.join(
        FLAGS.base_dir, '{}_{}'.format(FLAGS.num_states, FLAGS.num_actions))
  else:
    base_dir = os.path.join(FLAGS.base_dir, FLAGS.env_name)
  base_dir = os.path.join(base_dir, 'PVF', FLAGS.estimator, f'lr_{FLAGS.lr}')
  if not tf.io.gfile.exists(base_dir):
    tf.io.gfile.makedirs(base_dir)
  if FLAGS.env_name is not None:
    gin.add_config_file_search_path(_ENV_CONFIG_PATH)
    gin.parse_config_files_and_bindings(
        config_files=[f'{FLAGS.env_name}.gin'],
        bindings=FLAGS.gin_bindings,
        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    # Get tabular observation and drop the 'mission' field:
    env = mdp_wrapper.MDPWrapper(env, get_rgb=False)
    env = coloring_wrapper.ColoringWrapper(env)
  if FLAGS.env_name is None:
    env = random_mdp.RandomMDP(FLAGS.num_states, FLAGS.num_actions)
    # We add the discount factor to the environment.
  env.gamma = FLAGS.gamma
  P = utils.transition_matrix(env, rl_basics.policy_random(env))  # pylint: disable=invalid-name
  S = P.shape[0]  # pylint: disable=invalid-name
  Psi = jnp.linalg.solve(jnp.eye(S) - env.gamma * P, jnp.eye(S))  # pylint: disable=invalid-name
  # Normalize tasks so that they have maximum value 1.
  max_task_value = np.max(Psi, axis=0)
  Psi /= max_task_value  # pylint: disable=invalid-name

  left_vectors, _, _ = jnp.linalg.svd(Psi)  # pylint: disable=invalid-names
  approx_error = utils.approx_error(left_vectors, FLAGS.d, Psi)

  #   Initialization of Phi
  representation_init = jax.random.normal(  # pylint: disable=invalid-names
      jax.random.PRNGKey(0),
      (S, FLAGS.d),  # pylint: disable=invalid-name
      dtype=jnp.float64)
  representations, grads = train(representation_init,
                                 Psi, FLAGS.epochs, FLAGS.lr,
                                 jax.random.PRNGKey(0), FLAGS.estimator,
                                 FLAGS.alpha, FLAGS.optimizer, FLAGS.use_l2_reg,
                                 FLAGS.reg_coeff, FLAGS.use_penalty, FLAGS.j,
                                 FLAGS.num_rows, FLAGS.skipsize_train)

  gm_distances = calc_gm_distances(representations, left_vectors[:, :FLAGS.d],
                                   FLAGS.skipsize)
  x_len = len(gm_distances)
  frob_norms = calc_frob_norms(representations, Psi, FLAGS.skipsize)
  if FLAGS.d == 1:
    dot_products = calc_dot_products(representations, left_vectors[:, :FLAGS.d],
                                     FLAGS.skipsize)
  else:
    dot_products = np.zeros((x_len,))
  grad_norms = calc_grad_norms(grads, FLAGS.skipsize)
  phi_norms = calc_Phi_norm(representations, FLAGS.skipsize)
  phi_ranks = calc_sranks(representations, FLAGS.skipsize)

  prefix = f'alpha{FLAGS.alpha}_j{FLAGS.j}_d{FLAGS.d}_regcoeff{FLAGS.reg_coeff}'

  with tf.io.gfile.GFile(osp.join(base_dir, f'{prefix}.npy'), 'wb') as f:
    np.save(
        f, {
            'gm_distances': gm_distances,
            'dot_products': dot_products,
            'frob_norms': frob_norms,
            'approx_error': approx_error,
            'grad_norms': grad_norms,
            'representations': representations,
            'phi_norms': phi_norms,
            'phi_ranks': phi_ranks
        },
        allow_pickle=True)
def main(run_id=0, checkpoint=None, rec_interval=10, save_interval=100):
    print({section: dict(config[section]) for section in config.sections()})

    train_method = grid_config['TrainMethod']

    # Create environment
    env_id = grid_config['EnvID']
    env_type = grid_config['EnvType']

    if env_type == 'mario':
        print('Mario environment not fully implemented - thomaseh')
        raise NotImplementedError
        env = BinarySpaceToDiscreteSpaceEnv(
            gym_super_mario_bros.make(env_id), COMPLEX_MOVEMENT)
    elif env_type == 'atari':
        env = gym.make(env_id)
    elif env_type == 'grid':
        env = ImgObsWrapper(RGBImgObsWrapper(gym.make(env_id))) 
    else:
        raise NotImplementedError

    input_size = env.observation_space.shape  # 4
    output_size = env.action_space.n  # 2

    if 'Breakout' in env_id:
        output_size -= 1

    env.close()

    # Load configuration parameters
    is_load_model = checkpoint is not None
    is_render = False
    model_path = 'models/{}_{}_run{}_model'.format(env_id, train_method, run_id)
    predictor_path = 'models/{}_{}_run{}_vae'.format(env_id, train_method, run_id)
   

    writer = SummaryWriter(logdir='runs/{}_{}_run{}'.format(env_id, train_method, run_id))

    use_cuda = grid_config.getboolean('UseGPU')
    use_gae = grid_config.getboolean('UseGAE')
    use_noisy_net = grid_config.getboolean('UseNoisyNet')

    lam = float(grid_config['Lambda'])
    num_worker = int(grid_config['NumEnv'])

    num_step = int(grid_config['NumStep'])
    num_rollouts = int(grid_config['NumRollouts'])
    num_pretrain_rollouts = int(grid_config['NumPretrainRollouts'])

    ppo_eps = float(grid_config['PPOEps'])
    epoch = int(grid_config['Epoch'])
    mini_batch = int(grid_config['MiniBatch'])
    batch_size = int(num_step * num_worker / mini_batch)
    learning_rate = float(grid_config['LearningRate'])
    entropy_coef = float(grid_config['Entropy'])
    gamma = float(grid_config['Gamma'])
    int_gamma = float(grid_config['IntGamma'])
    clip_grad_norm = float(grid_config['ClipGradNorm'])
    ext_coef = float(grid_config['ExtCoef'])
    int_coef = float(grid_config['IntCoef'])

    sticky_action = grid_config.getboolean('StickyAction')
    action_prob = float(grid_config['ActionProb'])
    life_done = grid_config.getboolean('LifeDone')

    reward_rms = RunningMeanStd()
    obs_rms = RunningMeanStd(shape=(1, 1, 84, 84))
    pre_obs_norm_step = int(grid_config['ObsNormStep'])
    discounted_reward = RewardForwardFilter(int_gamma)

    hidden_dim = int(grid_config['HiddenDim'])

    if train_method == 'RND':
        agent = RNDAgent
    elif train_method == 'generative':
        agent = GenerativeAgent
    else:
        raise NotImplementedError

    if grid_config['EnvType'] == 'atari':
        env_type = AtariEnvironment
    elif grid_config['EnvType'] == 'mario':
        env_type = MarioEnvironment
    elif grid_config['EnvType'] == 'grid':
        env_type = GridEnvironment 
    else:
        raise NotImplementedError

    # Initialize agent
    agent = agent(
        input_size,
        output_size,
        num_worker,
        num_step,
        gamma,
        history_size=1,
        lam=lam,
        learning_rate=learning_rate,
        ent_coef=entropy_coef,
        clip_grad_norm=clip_grad_norm,
        epoch=epoch,
        batch_size=batch_size,
        ppo_eps=ppo_eps,
        use_cuda=use_cuda,
        use_gae=use_gae,
        use_noisy_net=use_noisy_net,
        update_proportion=1.0,
        hidden_dim=hidden_dim
    )

    # Load pre-existing model
    if is_load_model:
        print('load model...')
        if use_cuda:
            agent.model.load_state_dict(torch.load(model_path))
            agent.vae.load_state_dict(torch.load(predictor_path))
        else:
            agent.model.load_state_dict(
                torch.load(model_path, map_location='cpu'))
            agent.vae.load_state_dict(torch.load(predictor_path, map_location='cpu'))
        print('load finished!')

    # Create workers to run in environments
    works = []
    parent_conns = []
    child_conns = []
    for idx in range(num_worker):
        parent_conn, child_conn = Pipe()
        work = env_type(
            env_id, is_render, idx, child_conn, sticky_action=sticky_action,
            p=action_prob, life_done=life_done,
        )
        work.start()
        works.append(work)
        parent_conns.append(parent_conn)
        child_conns.append(child_conn)

    states = np.zeros([num_worker, 1, 84, 84], dtype='float32')

    sample_episode = 0
    sample_rall = 0
    sample_step = 0
    sample_env_idx = 0
    sample_i_rall = 0
    global_update = 0
    global_step = 0

    # Initialize stats dict
    stats = {
        'total_reward': [],
        'ep_length': [],
        'num_updates': [],
        'frames_seen': [],
    }

    # Main training loop
    while True:
        total_state = np.zeros([num_worker * num_step, 1, 84, 84], dtype='float32')
        total_next_obs = np.zeros([num_worker * num_step, 1, 84, 84])
        total_reward, total_done, total_next_state, total_action, \
            total_int_reward, total_ext_values, total_int_values, total_policy, \
            total_policy_np = [], [], [], [], [], [], [], [], []

        # Step 1. n-step rollout (collect data)
        for step in range(num_step):
            actions, value_ext, value_int, policy = agent.get_action(states / 255.)

            for parent_conn, action in zip(parent_conns, actions):
                parent_conn.send(action)

            next_obs = np.zeros([num_worker, 1, 84, 84])
            next_states = np.zeros([num_worker, 1, 84, 84])
            rewards, dones, real_dones, log_rewards = [], [], [], []
            for idx, parent_conn in enumerate(parent_conns):
                s, r, d, rd, lr, stat = parent_conn.recv()
                next_states[idx] = s
                rewards.append(r)
                dones.append(d)
                real_dones.append(rd)
                log_rewards.append(lr)
                next_obs[idx, 0] = s[0, :, :]
                total_next_obs[idx * num_step + step, 0] = s[0, :, :]

                if rd:
                    stats['total_reward'].append(stat[0])
                    stats['ep_length'].append(stat[1])
                    stats['num_updates'].append(global_update)
                    stats['frames_seen'].append(global_step)

            rewards = np.hstack(rewards)
            dones = np.hstack(dones)
            real_dones = np.hstack(real_dones)

            # Compute total reward = intrinsic reward + external reward
            intrinsic_reward = agent.compute_intrinsic_reward(next_obs / 255.)
            intrinsic_reward = np.hstack(intrinsic_reward)
            sample_i_rall += intrinsic_reward[sample_env_idx]

            for idx, state in enumerate(states):
                total_state[idx * num_step + step] = state
            total_int_reward.append(intrinsic_reward)
            total_reward.append(rewards)
            total_done.append(dones)
            total_action.append(actions)
            total_ext_values.append(value_ext)
            total_int_values.append(value_int)
            total_policy.append(policy)
            total_policy_np.append(policy.cpu().numpy())

            states = next_states[:, :, :, :]

            sample_rall += log_rewards[sample_env_idx]

            sample_step += 1
            if real_dones[sample_env_idx]:
                sample_episode += 1
                writer.add_scalar('data/reward_per_epi', sample_rall, sample_episode)
                writer.add_scalar('data/reward_per_rollout', sample_rall, global_update)
                writer.add_scalar('data/step', sample_step, sample_episode)
                sample_rall = 0
                sample_step = 0
                sample_i_rall = 0

        # calculate last next value
        _, value_ext, value_int, _ = agent.get_action(np.float32(states) / 255.)
        total_ext_values.append(value_ext)
        total_int_values.append(value_int)
        # --------------------------------------------------

        total_reward = np.stack(total_reward).transpose().clip(-1, 1)
        total_action = np.stack(total_action).transpose().reshape([-1])
        total_done = np.stack(total_done).transpose()
        total_ext_values = np.stack(total_ext_values).transpose()
        total_int_values = np.stack(total_int_values).transpose()
        total_logging_policy = np.vstack(total_policy_np)

        # Step 2. calculate intrinsic reward
        # running mean intrinsic reward
        total_int_reward = np.stack(total_int_reward).transpose()
        total_reward_per_env = np.array([discounted_reward.update(reward_per_step) for reward_per_step in
                                         total_int_reward.T])
        mean, std, count = np.mean(total_reward_per_env), np.std(total_reward_per_env), len(total_reward_per_env)
        reward_rms.update_from_moments(mean, std ** 2, count)

        writer.add_scalar('data/raw_int_reward_per_epi', np.sum(total_int_reward) / num_worker, sample_episode)
        writer.add_scalar('data/raw_int_reward_per_rollout', np.sum(total_int_reward) / num_worker, global_update)

        # normalize intrinsic reward
        total_int_reward /= np.sqrt(reward_rms.var)
        writer.add_scalar('data/int_reward_per_epi', np.sum(total_int_reward) / num_worker, sample_episode)
        writer.add_scalar('data/int_reward_per_rollout', np.sum(total_int_reward) / num_worker, global_update)
        # -------------------------------------------------------------------------------------------

        # logging Max action probability
        writer.add_scalar('data/max_prob', softmax(total_logging_policy).max(1).mean(), sample_episode)

        # Step 3. make target and advantage
        # extrinsic reward calculate
        ext_target, ext_adv = make_train_data(total_reward,
                                              total_done,
                                              total_ext_values,
                                              gamma,
                                              num_step,
                                              num_worker)

        # intrinsic reward calculate
        # None Episodic
        int_target, int_adv = make_train_data(total_int_reward,
                                              np.zeros_like(total_int_reward),
                                              total_int_values,
                                              int_gamma,
                                              num_step,
                                              num_worker)

        # add ext adv and int adv
        total_adv = int_adv * int_coef + ext_adv * ext_coef
        # -----------------------------------------------

        # Step 4. update obs normalize param
        # obs_rms.update(total_next_obs)
        # -----------------------------------------------

        # Step 5. Training!
        # random_obs_choice = np.random.randint(total_next_obs.shape[0])
        # random_obs = total_next_obs[random_obs_choice].copy()
        total_next_obs /= 255.
        if global_update < num_pretrain_rollouts:
            recon_losses, kld_losses = agent.train_just_vae(total_state / 255., total_next_obs)
        else:
            recon_losses, kld_losses = agent.train_model(total_state / 255., ext_target, int_target, total_action,
                        total_adv, total_next_obs, total_policy)

        writer.add_scalar('data/reconstruction_loss_per_rollout', np.mean(recon_losses), global_update)
        writer.add_scalar('data/kld_loss_per_rollout', np.mean(kld_losses), global_update)

        global_step += (num_worker * num_step)
        
        if global_update % rec_interval == 0:
            with torch.no_grad():
                # random_obs_norm = total_next_obs[random_obs_choice]
                # reconstructed_state = agent.reconstruct(random_obs_norm)

                # random_obs_norm = (random_obs_norm - random_obs_norm.min()) / (random_obs_norm.max() - random_obs_norm.min())
                # reconstructed_state = (reconstructed_state - reconstructed_state.min()) / (reconstructed_state.max() - reconstructed_state.min())

                # writer.add_image('Original', random_obs, global_update)
                # writer.add_image('Original Normalized', random_obs_norm, global_update)

                random_state = total_next_obs[np.random.randint(total_next_obs.shape[0])]
                reconstructed_state = agent.reconstruct(random_state)

                writer.add_image('Original', random_state, global_update)
                writer.add_image('Reconstructed', reconstructed_state, global_update)

        if global_update % save_interval == 0:
            print('Saving model at global step={}, num rollouts={}.'.format(
                global_step, global_update))
            torch.save(agent.model.state_dict(), model_path + "_{}.pt".format(global_update))
            torch.save(agent.vae.state_dict(), predictor_path + '_{}.pt'.format(global_update))

            # Save stats to pickle file
            with open('models/{}_{}_run{}_stats_{}.pkl'.format(env_id, train_method, run_id, global_update),'wb') as f:
                pickle.dump(stats, f)

        global_update += 1

        if global_update == num_rollouts + num_pretrain_rollouts:
            print('Finished Training.')
            break
예제 #13
0
def full_state_rgb_train(env):
    return RGBImgObsWrapper(env, tile_size=6)