Пример #1
0
    def sample_episode(self, is_train, render=False):
        """Samples one episode from the environment."""
        self.init(is_train)
        episode, done = [], False
        with self._env.val_mode() if not is_train else contextlib.suppress():
            with self._agent.val_mode() if not is_train else contextlib.suppress():
                with self._agent.rollout_mode():
                    while not done and self._episode_step < self._max_episode_len:
                        # perform one rollout step
                        agent_output = self.sample_action(self._obs)
                        if agent_output.action is None:
                            break
                        agent_output = self._postprocess_agent_output(agent_output)
                        if render:
                            render_obs = self._env.render()
                        obs, reward, done, info = self._env.step(agent_output.action)
                        obs = self._postprocess_obs(obs)
                        episode.append(AttrDict(
                            observation=self._obs,
                            reward=reward,
                            done=done,
                            action=agent_output.action,
                            observation_next=obs,
                            info=obj2np(info),
                        ))
                        if render:
                            episode[-1].update(AttrDict(image=render_obs))

                        # update stored observation
                        self._obs = obs
                        self._episode_step += 1
        episode[-1].done = True     # make sure episode is marked as done at final time step

        return listdict2dictlist(episode)
Пример #2
0
    def get_sample(self):
        data_dict = AttrDict()
        data_dict.images = np.random.rand(self.spec['max_seq_len'], 3, self.img_sz, self.img_sz).astype(np.float32)
        data_dict.states = np.random.rand(self.spec['max_seq_len'], self.spec['state_dim']).astype(np.float32)
        data_dict.actions = np.random.rand(self.spec['max_seq_len'] - 1, self.spec['n_actions']).astype(np.float32)

        return data_dict
Пример #3
0
    def __init__(self, args):
        self.args = args
        self.setup_device()

        # set up params
        self.conf = conf = self.get_config()

        self._hp = self._default_hparams()
        self._hp.overwrite(conf.general)  # override defaults with config file
        self._hp.exp_path = make_path(conf.exp_dir, args.path, args.prefix,
                                      args.new_dir)
        self.log_dir = log_dir = os.path.join(self._hp.exp_path, 'events')
        print('using log dir: ', log_dir)
        self.conf = self.postprocess_conf(conf)
        if args.deterministic: set_seeds()

        # set up logging + training monitoring
        self.writer = self.setup_logging(conf, self.log_dir)
        self.setup_training_monitors()

        # buld dataset, model. logger, etc.
        train_params = AttrDict(logger_class=self._hp.logger,
                                model_class=self._hp.model,
                                n_repeat=self._hp.epoch_cycles_train,
                                dataset_size=-1)
        self.logger, self.model, self.train_loader = self.build_phase(
            train_params, 'train')

        test_params = AttrDict(
            logger_class=self._hp.logger
            if self._hp.logger_test is None else self._hp.logger_test,
            model_class=self._hp.model
            if self._hp.model_test is None else self._hp.model_test,
            n_repeat=1,
            dataset_size=args.val_data_size)
        self.logger_test, self.model_test, self.val_loader = self.build_phase(
            test_params, phase='val')

        # set up optimizer + evaluator
        self.optimizer = self.get_optimizer_class()(filter(
            lambda p: p.requires_grad, self.model.parameters()),
                                                    lr=self._hp.lr)
        self.evaluator = self._hp.evaluator(self._hp,
                                            self.log_dir,
                                            self._hp.top_of_n_eval,
                                            self._hp.top_comp_metric,
                                            tb_logger=self.logger_test)

        # load model params from checkpoint
        self.global_step, start_epoch = 0, 0
        if args.resume or conf.ckpt_path is not None:
            start_epoch = self.resume(args.resume, conf.ckpt_path)

        if args.val_sweep:
            self.run_val_sweep()
        elif args.train:
            self.train(start_epoch)
        else:
            self.val()
Пример #4
0
 def get_episode_info(self):
     episode_info = AttrDict(
         episode_reward=self._episode_reward,
         episode_length=self._episode_step,
     )
     if hasattr(self._env, "get_episode_info"):
         episode_info.update(self._env.get_episode_info())
     return episode_info
Пример #5
0
    def _get_default_env_config(self):
        default_task_params = AttrDict(max_tower_height=4)

        default_env_config = AttrDict(
            task_generator=FixedSizeSingleTowerBlockTaskGenerator,
            task_params=default_task_params,
            dimension=2)
        return default_env_config
Пример #6
0
    def forward(self, inputs, use_learned_prior=False):
        """
        forward pass at training time
        """
        output = AttrDict()

        output.pred_act = self._compute_output_dist(self._net_inputs(inputs))
        return output
Пример #7
0
 def assert_begin(inputs, initial_inputs, static_inputs):
     initial_inputs = initial_inputs or AttrDict()
     static_inputs = static_inputs or AttrDict()
     assert not (static_inputs.keys() & inputs.keys()), 'Static inputs and inputs overlap'
     assert not (static_inputs.keys() & initial_inputs.keys()), 'Static inputs and initial inputs overlap'
     assert not (inputs.keys() & initial_inputs.keys()), 'Inputs and initial inputs overlap'
     
     return initial_inputs, static_inputs
Пример #8
0
    def loss(self, model_output, inputs):
        losses = AttrDict()

        # reconstruction loss
        losses.nll = NLL()(model_output.pred_act, self._regression_targets(inputs))

        losses.total = self._compute_total_loss(losses)
        return losses
Пример #9
0
 def get_default_params(self):
     params = AttrDict(
         normalize=True,
         activation=nn.LeakyReLU(0.2, inplace=True),
         normalization=self.builder.normalization,
         normalization_params=AttrDict()
     )
     return params
Пример #10
0
 def forward(self, input, hidden_state, length=None):
     """
     :param input: tensor of shape batch x time x channels
     :return:
     """
     if length is None: length = input.shape[1]
     initial_state = AttrDict(hidden_state=hidden_state)
     outputs = super().forward(AttrDict(cell_input=input), length=length, initial_inputs=initial_state)
     return outputs
Пример #11
0
    def get_episode_info(self):
        episode_info = AttrDict()

        flag_names = ['_1_grasped', '_2_lift', '_4_stack', '_5_stack_final']
        flag_values = [self._grasped_flag, self._lifted_flag, self._stacked_flag, self._stacked_final_flag]
        for i in range(len(flag_names)):
            episode_info.update({"block{}".format(flag_names[i]):
                sum([int(flag_values[i][task_idx]) for task_idx in range(len(self._task))])})

        return episode_info
Пример #12
0
    def forward(self, inputs, use_learned_prior=False):
        """Forward pass of the SPIRL model.
        :arg inputs: dict with 'states', 'actions', 'images' keys from data loader
        :arg use_learned_prior: if True, decodes samples from learned prior instead of posterior, used for RL
        """
        output = AttrDict()
        inputs.observations = inputs.actions    # for seamless evaluation

        # run inference
        output.q = self._run_inference(inputs)

        # compute (fixed) prior
        output.p = get_fixed_prior(output.q)

        # infer learned skill prior
        output.q_hat = self.compute_learned_prior(self._learned_prior_input(inputs))
        if use_learned_prior:
            output.p = output.q_hat     # use output of learned skill prior for sampling

        # sample latent variable
        output.z = output.p.sample() if self._sample_prior else output.q.sample()
        output.z_q = output.z.clone() if not self._sample_prior else output.q.sample()   # for loss computation

        # decode
        assert self._regression_targets(inputs).shape[1] == self._hp.n_rollout_steps
        output.reconstruction = self.decode(output.z,
                                            cond_inputs=self._learned_prior_input(inputs),
                                            steps=self._hp.n_rollout_steps,
                                            inputs=inputs)
        return output
Пример #13
0
    def get_episode_info(self):
        episode_info = AttrDict()

        flag_names = ['_1_reach', '_2_lift', '_3_deliver', '_4_stack']
        flag_values = [self._reached_flag, self._lifted_flag,
                       self._delivered_flag, self._stacked_flag]
        for i in range(len(flag_names)):
            episode_info.update({"block{}".format(flag_names[i]): 
                sum([int(flag_values[i][task_idx]) for task_idx in range(len(self._task))])})

        return episode_info
Пример #14
0
def update_with_mpi_config(conf):
    mpi_config = AttrDict()
    rank = MPI.COMM_WORLD.Get_rank()
    mpi_config.rank = rank
    mpi_config.is_chef = rank == 0
    mpi_config.num_workers = MPI.COMM_WORLD.Get_size()
    conf.mpi = mpi_config

    # update conf
    conf.general.seed = conf.general.seed + rank
    return conf
Пример #15
0
 def decode(self, z, cond_inputs, steps, inputs=None):
     """Runs forward pass of decoder given skill embedding.
     :arg z: skill embedding
     :arg cond_inputs: info that decoder is conditioned on
     :arg steps: number of steps decoder is rolled out
     """
     lstm_init_input = self.decoder_input_initalizer(cond_inputs)
     lstm_init_hidden = self.decoder_hidden_initalizer(cond_inputs)
     return self.decoder(lstm_initial_inputs=AttrDict(x_t=lstm_init_input),
                         lstm_static_inputs=AttrDict(z=z),
                         steps=steps,
                         lstm_hidden_init=lstm_init_hidden).pred
Пример #16
0
 def __getitem__(self, index):
     # sample start index in data range
     seq = self._sample_seq()
     start_idx = np.random.randint(
         0, seq.states.shape[0] - self.subseq_len - 1)
     output = AttrDict(
         states=seq.states[start_idx:start_idx + self.subseq_len],
         actions=seq.actions[start_idx:start_idx + self.subseq_len - 1],
         pad_mask=np.ones((self.subseq_len, )),
     )
     if self.remove_goal:
         output.states = output.states[..., :int(output.states.shape[-1] /
                                                 2)]
     return output
Пример #17
0
    def act(self, obs):
        """Output dict contains is_hl_step in case high-level action was performed during this action."""
        obs_input = obs[None] if len(
            obs.shape) == 1 else obs  # need batch input for agents
        output = AttrDict()
        if self._perform_hl_step_now:
            # perform step with high-level policy
            self._last_hl_output = self.hl_agent.act(obs_input)
            output.is_hl_step = True
            if len(obs_input.shape) == 2 and len(
                    self._last_hl_output.action.shape) == 1:
                self._last_hl_output.action = self._last_hl_output.action[
                    None]  # add batch dim if necessary
                self._last_hl_output.log_prob = self._last_hl_output.log_prob[
                    None]
        else:
            output.is_hl_step = False
        output.update(prefix_dict(self._last_hl_output, 'hl_'))

        # perform step with low-level policy
        assert self._last_hl_output is not None
        output.update(
            self.ll_agent.act(
                self.make_ll_obs(obs_input, self._last_hl_output.action)))

        return self._remove_batch(output) if len(obs.shape) == 1 else output
Пример #18
0
 def update(self, experience_batch):
     if 'delay' in self._hp.omega_schedule_params and self._update_steps < self._hp.omega_schedule_params.delay:
         # if schedule has warmup phase in which *only* prior is sampled, train policy to minimize divergence
         self.replay_buffer.append(experience_batch)
         experience_batch = self.replay_buffer.sample(n_samples=self._hp.batch_size)
         experience_batch = map2torch(experience_batch, self._hp.device)
         policy_output = self._run_policy(experience_batch.observation)
         policy_loss = policy_output.prior_divergence.mean()
         self._perform_update(policy_loss, self.policy_opt, self.policy)
         self._update_steps += 1
         info = AttrDict(prior_divergence=policy_output.prior_divergence.mean())
     else:
         info = super().update(experience_batch)
     info.omega = self._omega(self._update_steps)
     return info
Пример #19
0
 def _compute_action_dist(self, obs):
     """Splits concatenated input obs into image and vector observation and passes through network."""
     split_obs = AttrDict(vector=obs[:, :self._hp.input_dim],
                          image=obs[:, self._hp.input_dim:].reshape(
                              -1, self._hp.input_nc, self._hp.input_res,
                              self._hp.input_res))
     return super()._compute_action_dist(split_obs)
Пример #20
0
    def sample_batch(self, batch_size, is_train=True, global_step=None):
        """Samples an experience batch of the required size."""
        experience_batch = []
        step = 0
        with self._env.val_mode() if not is_train else contextlib.suppress():
            with self._agent.val_mode() if not is_train else contextlib.suppress():
                with self._agent.rollout_mode():
                    while step < batch_size:
                        # perform one rollout step
                        agent_output = self.sample_action(self._obs)
                        if agent_output.action is None:
                            self._episode_reset(global_step)
                            continue
                        agent_output = self._postprocess_agent_output(agent_output)
                        obs, reward, done, info = self._env.step(agent_output.action)
                        obs = self._postprocess_obs(obs)
                        experience_batch.append(AttrDict(
                            observation=self._obs,
                            reward=reward,
                            done=done,
                            action=agent_output.action,
                            observation_next=obs,
                        ))

                        # update stored observation
                        self._obs = obs
                        step += 1; self._episode_step += 1; self._episode_reward += reward

                        # reset if episode ends
                        if done or self._episode_step >= self._max_episode_len:
                            if not done:    # force done to be True for timeout
                                experience_batch[-1].done = True
                            self._episode_reset(global_step)

        return listdict2dictlist(experience_batch), step
Пример #21
0
    def setup_logging(self, conf, log_dir):
        if not self.args.dont_save:
            print('Writing to the experiment directory: {}'.format(
                self._hp.exp_path))
            if not os.path.exists(self._hp.exp_path):
                os.makedirs(self._hp.exp_path)
            save_cmd(self._hp.exp_path)
            save_git(self._hp.exp_path)
            save_config(
                conf.conf_path,
                os.path.join(self._hp.exp_path,
                             "conf_" + datetime_str() + ".py"))
            if self._hp.logging_target == 'wandb':
                exp_name = f"{'_'.join(self.args.path.split('/')[-3:])}_{self.args.prefix}" if self.args.prefix \
                    else os.path.basename(self.args.path)
                writer = WandBLogger(
                    exp_name,
                    WANDB_PROJECT_NAME,
                    entity=WANDB_ENTITY_NAME,
                    path=self._hp.exp_path,
                    conf=conf,
                    exclude=['model_rewards', 'data_dataset_spec_rewards'])
            else:
                writer = SummaryWriter(log_dir)
        else:
            writer = None

        # set up additional logging args
        self._logging_kwargs = AttrDict()
        return writer
Пример #22
0
 def _split_obs(self, obs):
     assert obs.shape[1] == self._policy.state_dim + self._policy.latent_dim
     return AttrDict(
         cond_input=obs[:, :-self._policy.
                        latent_dim],  # condition decoding on state
         z=obs[:, -self._policy.latent_dim:],
     )
Пример #23
0
 def forward(self, lstm_initial_inputs, steps, lstm_inputs=None, lstm_static_inputs=None, lstm_hidden_init=None):
     if lstm_inputs is None:
         lstm_inputs = {}
     if lstm_hidden_init is not None:
         self.cell.hidden_var = lstm_hidden_init     # initialize hidden state of LSTM if given
     lstm_outputs = self.lstm(lstm_inputs, steps, lstm_initial_inputs, lstm_static_inputs)
     return AttrDict(pred=lstm_outputs.x_t)
Пример #24
0
 def get_default_params(self):
     params = super().get_default_params()
     params.update(AttrDict(
         kernel_size=4,
         stride=2,
     ))
     return params
Пример #25
0
class Stack4BlockStackEnvV0(BlockStackEnv):
    DEFAULT_QUAT = np.array([0.70710678, 0, 0, -0.70710678])
    TASK = [(2, 3), (3, 1), (1, 4), (4, 0)]
    BLOCK_POS = [
        AttrDict(pos=np.array([0, -0.4]), quat=DEFAULT_QUAT),
        AttrDict(pos=np.array([0, -0.2]), quat=DEFAULT_QUAT),
        AttrDict(pos=np.array([0, 0.0]), quat=DEFAULT_QUAT),
        AttrDict(pos=np.array([0, 0.2]), quat=DEFAULT_QUAT),
        AttrDict(pos=np.array([0, 0.4]), quat=DEFAULT_QUAT)
    ]

    def _get_default_env_config(self):
        default_env_config = super()._get_default_env_config()
        default_env_config.fixed_task = self.TASK
        default_env_config.fixed_block_pos = self.BLOCK_POS
        return default_env_config
Пример #26
0
 def sample_rand(self, obs):
     with torch.no_grad():
         with no_batchnorm_update(self.prior_net):
             prior_dist = self.prior_net.compute_learned_prior(obs, first_only=True).detach()
     action = prior_dist.sample()
     action, log_prob = self._tanh_squash_output(action, 0)        # ignore log_prob output
     return AttrDict(action=action, log_prob=log_prob)
Пример #27
0
    def run(self, inputs, use_learned_prior=True):
        """Policy interface for model. Runs decoder if action plan is empty, otherwise returns next action from action plan.
        :arg inputs: dict with 'states', 'actions', 'images' keys from environment
        :arg use_learned_prior: if True, uses learned prior otherwise samples latent from uniform prior
        """
        if not self._action_plan:
            inputs = map2torch(inputs, device=self.device)

            # sample latent variable from prior
            z = self.compute_learned_prior(self._learned_prior_input(inputs), first_only=True).sample() \
                if use_learned_prior else Gaussian(torch.zeros((1, self._hp.nz_vae*2), device=self.device)).sample()

            # decode into action plan
            z = z.repeat(
                self._hp.batch_size, 1
            )  # this is a HACK flat LSTM decoder can only take batch_size inputs
            input_obs = self._learned_prior_input(inputs).repeat(
                self._hp.batch_size, 1)
            actions = self.decode(z,
                                  cond_inputs=input_obs,
                                  steps=self._hp.n_rollout_steps)[0]
            self._action_plan = deque(split_along_axis(map2np(actions),
                                                       axis=0))

        return AttrDict(action=self._action_plan.popleft()[None])
Пример #28
0
    def _get_reward(self):
        """Compute reward for stacking blocks without order."""
        rew_dict = AttrDict()

        max_height, total_rew = 0., 0.
        heights, supported_heights = np.zeros(len(self._blocks)), np.zeros(len(self._blocks))
        for i, block in enumerate(self._blocks):
            height = block.dist_lifted
            heights[i] = height

            # set flags
            if not self._grasped_flag[i]:
                self._grasped_flag[i] = block.grasped(self.gripper_pos, self.gripper_finger_dist, self.gripper_finger_poses)
            if not self._lifted_flag[i]:
                self._lifted_flag[i] = (not self._hp.restrict_grasped or self._grasped_flag[i]) and \
                        (not self._hp.restrict_upright or block.upright) and block.lifted
            if not self._delivered_flag[i]:
                self._delivered_flag[i] = (not self._hp.restrict_grasped or self._grasped_flag[i]) \
                        and (not self._hp.restrict_upright or block.upright) \
                        and any([block.above(b) for b in self._blocks if b.name != block.name])

            # compute reward
            if (not self._hp.restrict_grasped or self._grasped_flag[i]) and \
                    (not self._hp.restrict_upright or block.upright) and \
                    self._has_support(block, [b for b in self._blocks if block.name != b.name]):
                self._stacked_flag[i] = True
                supported_heights[i] = height
                if height > max_height:
                    max_height = height
            if self._delivered_flag[i]:
                total_rew += self.LIFTED_ABOVE_REWARD
            elif self._lifted_flag[i]:
                total_rew += self.LIFTED_REWARD
        self._final_height = max_height / (2*self._hp.block_size)

        total_rew += max_height * self.REWARD_SCALE

        if self._hp.rotation_penalty:
            # add per-step penalty for each rotated block
            rot_penalty = sum([self.ROTATION_PENALTY if not b.upright else 0 for b in self._blocks])
            total_rew -= rot_penalty
            rew_dict["rot_penalty"] = np.array(rot_penalty).round(3)

        rew_dict["heights"] = heights.round(3)
        rew_dict["sup_heights"] = supported_heights.round(3)
        rew_dict["rew_total"] = np.array(total_rew).round(3)
        rew_dict["max_height"] = np.array(self._final_height).round(3)
        #rew_dict["z_ang"] = np.array([b.z_angle * 180 / np.pi for b in self._blocks]).round(1)
        rew_dict["grasped_1"] = np.array(self._grasped_flag[:5])
        rew_dict["grasped_2"] = np.array(self._grasped_flag[5:])
        rew_dict["lifted_1"] = np.array(self._lifted_flag[:5])
        rew_dict["lifted_2"] = np.array(self._lifted_flag[5:])
        rew_dict["gripper_finger_dist"] = np.array(self.gripper_finger_dist).round(3)


        self._prev_block_pos = [copy.deepcopy(b.pos) for b in self._blocks]  # update for next round of reward comp
        self._prev_gripper_pos = copy.deepcopy(self.gripper_pos)

        return rew_dict
Пример #29
0
    def step(self, action):
        """Step the environment with symmetric gripper movements."""
        # process action
        raw_action = action
        action = self._pad_action(action)
        real_action = np.zeros((len(action) + 1,))
        real_action[:len(action)] = action
        real_action[-1] = self._adjust_gripper_finger_action(action[-1])
        real_action[-2] = -real_action[-1]
        if self._hp.perturb_actions and np.random.rand() < self._hp.perturb_prob:
            real_action += np.random.normal(0, self._hp.perturb_scale, real_action.shape[0])

        # step through environment
        # with timing("Step raw "):
        obs, rew, done, info = super().step(real_action)

        # apply action penalty
        if self._hp.action_penalty_weight > 0.0:
            action_penalty = self._hp.action_penalty_weight * self._compute_action_penalty(action)
            rew -= action_penalty
            info.update(AttrDict(action_penalty=action_penalty))

        # episode done
        if self.task_complete():
            done = True

        # terminate episode if gripper or object not on arena
        if self._hp.reset_with_boundary:
            unflattened_obs = self._unflatten_block_obs(obs)
            gripper_pos = unflattened_obs.gripper_pos
            if self._position_invalid(gripper_pos):
                done = True
            for block in self._blocks:
                if self._position_invalid(block.pos):
                    done = True

        # internal variable updates
        self._t += 1
        self._last_gripper_action = action[4]

        # add original action
        info.update(AttrDict(
            raw_action=np.array(raw_action).round(3)
        ))

        return obs, rew, done, info
Пример #30
0
 def _default_hparams(self):
     default_dict = ParamDict({
         'omega_schedule': ConstantSchedule,  # schedule used for omega param
         'omega_schedule_params': AttrDict(   # parameters for omega schedule
             p = 0.1,
         ),
     })
     return super()._default_hparams().overwrite(default_dict)