Exemplo n.º 1
0
    def _train_step(self, batch, positive_grad, do_log):
        with torch.enable_grad():
            actor_params = AttrDict()
            if do_log:
                actor_params.logger = self.logger
                actor_params.cur_step = self.frame_train

            actor_out = self._train_model(batch.states.reshape(-1, *batch.states.shape[2:]), **actor_params)
            with torch.no_grad():
                actor_out_smooth = self._smooth_model(batch.states.reshape(-1, *batch.states.shape[2:]))

            batch.logits = actor_out.logits.reshape(*batch.states.shape[:2], *actor_out.logits.shape[1:])
            batch.logits_smooth = actor_out_smooth.logits.reshape(*batch.states.shape[:2], *actor_out.logits.shape[1:])
            batch.state_values = actor_out.state_values.reshape(*batch.states.shape[:2])

            for k, v in list(batch.items()):
                batch[k] = v if k == 'states' else v.cpu()

            loss = self._get_loss(batch, positive_grad, do_log)
            #act_norm_loss = activation_norm_loss(self._train_model).cpu()
            loss = loss.mean() #+ 0.003 * act_norm_loss

        # if do_log:
        #     self.logger.add_scalar('activation_norm_loss', act_norm_loss, self.frame_train)

        # optimize
        loss.backward()
        if self.grad_clip_norm is not None:
            clip_grad_norm_(self._train_model.parameters(), self.grad_clip_norm)
        self._optimizer.step()
        self._optimizer.zero_grad()

        return loss
Exemplo n.º 2
0
    def _search_action(self, states):
        start_ac_out = AttrDict(self._train_model.encode_observation(states))
        # (B, A, R, X)
        hidden = start_ac_out.hidden\
            .unsqueeze(-2).repeat_interleave(self.num_actions, -2)\
            .unsqueeze(-2).repeat_interleave(self.num_rollouts, -2)
        assert hidden.shape == (*start_ac_out.hidden.shape[:-1], self.num_actions, self.num_rollouts, start_ac_out.hidden.shape[-1])
        actions = self.pd.sample(start_ac_out.logits.unsqueeze(-2).repeat_interleave(self.num_actions, -2))
        start_actions = actions
        actions = actions.unsqueeze(-2).repeat_interleave(self.num_rollouts, -2)
        assert actions.shape[:-1] == hidden.shape[:-1]

        # (B, A, R)
        value_targets = 0
        for i in range(self.rollout_depth):
            input_actions = self.pd.to_inputs(actions)
            ac_out = AttrDict(self._train_model(hidden, input_actions))
            hidden = ac_out.hidden
            actions = self.pd.sample(ac_out.logits)
            rewards = self._train_model.reward_encoder(ac_out.reward_bins)
            value_targets += self.reward_discount ** i * rewards

        last_state_values = self._train_model.value_encoder(ac_out.state_value_bins)
        value_targets += self.reward_discount ** self.rollout_depth * last_state_values
        value_targets = value_targets.topk(self.num_rollouts // 4, -1)[0]
        # (B, A)
        value_targets = value_targets.mean(-1)
        assert value_targets.shape == hidden.shape[:2]
        ac_idx = value_targets.argmax(-1, keepdim=True)
        actions = start_actions.gather(-2, ac_idx.unsqueeze(-1)).squeeze(-2)
        assert actions.shape[:-1] == hidden.shape[:1]
        start_state_values = self._train_model.value_encoder(start_ac_out.state_value_bins)
        return start_ac_out.logits, actions, start_state_values
Exemplo n.º 3
0
    def _train_update(self, data: AttrDict, positive_grad: bool):
        num_samples = data.states.shape[0] * data.states.shape[1]
        num_rollouts = data.states.shape[1]

        data = AttrDict(states=data.states, logits_old=data.logits, random_grad=data.random_grad,
                        actions=data.actions, rewards=data.rewards, dones=data.dones)

        num_batches = max(1, num_samples // self.batch_size)
        rand_idx = torch.arange(num_rollouts, device=self.device_train).chunk(num_batches)
        assert len(rand_idx) == num_batches

        old_model = {k: v.clone() for k, v in self._train_model.state_dict().items()}
        kls_policy = []
        kls_replay = []

        # for t in self._train_model.parameters():
        #     t += 0.03 * torch.randn_like(t)

        with DataLoader(data, rand_idx, self.device_train, 4, dim=1) as data_loader:
            for batch_index in range(num_batches):
                batch = AttrDict(data_loader.get_next_batch())
                loss = self._train_step(batch, positive_grad, self._do_log and batch_index == num_batches - 1)
                kls_policy.append(batch.kl_smooth.mean().item())
                kls_replay.append(batch.kl_replay.mean().item())

        kl_policy = np.mean(kls_policy)
        kl_replay = np.mean(kls_replay)

        if self._do_log:
            if loss is not None:
                self.logger.add_scalar('total_loss', loss, self.frame_train)
            self.logger.add_scalar('kl', kl_policy, self.frame_train)
            self.logger.add_scalar('kl_replay', kl_replay, self.frame_train)
            self.logger.add_scalar('model_abs_diff', model_diff(old_model, self._train_model), self.frame_train)
            self.logger.add_scalar('model_max_diff', model_diff(old_model, self._train_model, True), self.frame_train)
Exemplo n.º 4
0
    def _create_data(self):
        def cat_replay(last, rand):
            # (H, B, *) + (H, B * replay, *) = (H, B, 1, *) + (H, B, replay, *) =
            # = (H, B, replay + 1, *) = (H, B * (replay + 1), *)
            H, B, *_ = last.shape
            all = torch.cat([
                last.unsqueeze(2),
                rand.reshape(H, B, self.num_batches, *last.shape[2:])], 2)
            return all.reshape(H, B * (self.num_batches + 1), *last.shape[2:])

        h_reduce = self.horizon // self.horizon

        def fix_on_policy_horizon(v):
            return v.reshape(h_reduce, self.horizon, *v.shape[1:])\
                .transpose(0, 1)\
                .reshape(self.horizon, h_reduce * v.shape[1], *v.shape[2:])

        # (H, B, *)
        last_samples = self._replay_buffer.get_last_samples(self.horizon)
        last_samples = {k: fix_on_policy_horizon(v) for k, v in last_samples.items()}
        if self.num_batches != 0 and len(self._replay_buffer) >= \
                max(self.horizon * self.num_actors * max(1, self.num_batches), self.min_replay_size):
            num_rollouts = self.num_actors * self.num_batches * h_reduce
            rand_samples = self._replay_buffer.sample(num_rollouts, self.horizon, self.replay_end_sampling_factor)
            return AttrDict({k: cat_replay(last, rand)
                             for (k, rand), last in zip(rand_samples.items(), last_samples.values())})
        else:
            return AttrDict(last_samples)
Exemplo n.º 5
0
    def _reorder_data(self, data: AttrDict) -> AttrDict:
        def reorder(input):
            # input: (seq * num_actors, ...)
            # (seq, num_actors, ...)
            x = input.reshape(-1, self.num_actors, *input.shape[1:])
            # (num_actors * seq, ...)
            return x.transpose(0, 1).reshape(input.shape)

        return AttrDict({k: reorder(v) for k, v in data.items()})
Exemplo n.º 6
0
    def _update_traj(self, batch, do_log=False):
        assert batch.value_targets.ndim == 2, batch.value_targets.shape
        assert batch.value_targets.shape == batch.logits.shape[:-1] == \
               batch.dones.shape == batch.rewards.shape == batch.actions.shape[:-1]

        with torch.enable_grad():
            ac_out = AttrDict(self._train_model.encode_observation(batch.states[0]))

            nonterminals = 1 - batch.dones
            vtarg = batch.value_targets[0]
            logits_target = batch.logits[0]
            ac_out_prev = ac_out
            cum_nonterm = torch.ones_like(batch.dones[0])
            loss = 0
            for i in range(batch.states.shape[0]):
                vtarg = (batch.value_targets[i] * cum_nonterm).detach()
                # vtarg = torch.lerp(vtarg, batch.value_targets[i], cum_nonterm).detach()
                logits_target = torch.lerp(logits_target, batch.logits[i], cum_nonterm.unsqueeze(-1)).detach()
                loss += -self._train_model.value_encoder.logp(ac_out.state_value_bins, vtarg).mean()
                # loss += 0.5 * (logits_target - ac_out.logits).pow(2).mean(-1).mul(cum_nonterm).mean()
                loss += 5 * self.pd.kl(logits_target, ac_out.logits).sum(-1).mul(cum_nonterm).mean()
                loss += 0.01 * -self.pd.entropy(ac_out.logits).sum(-1).mul(cum_nonterm).mean()
                if self.frame_train > 10000:
                    loss += -self.pd.logp(batch.actions[i], ac_out.logits).sum(-1).mul(cum_nonterm).mean()

                input_actions = self.pd.to_inputs(batch.actions[i])
                ac_out_prev = ac_out
                ac_out = AttrDict(self._train_model(ac_out.hidden, input_actions))
                rtarg = (batch.rewards[i] * cum_nonterm).detach()
                loss += -self._train_model.reward_encoder.logp(ac_out.reward_bins, rtarg).mean()
                cum_nonterm = cum_nonterm * nonterminals[i]

            loss = loss / batch.states.shape[0]

        if do_log:
            rewards = self._train_model.reward_encoder(ac_out.reward_bins)
            self.logger.add_scalar('reward rmse', (rewards - rtarg).pow(2).mean().sqrt(), self.frame_train)
            state_values = self._train_model.value_encoder(ac_out_prev.state_value_bins)
            self.logger.add_scalar('state_values rmse', (state_values - vtarg).pow(2).mean().sqrt(), self.frame_train)
            self.logger.add_scalar('logits rmse', (ac_out_prev.logits - logits_target).pow(2).mean().sqrt(), self.frame_train)

        loss.backward()
        # clip_grad_norm_(self._train_model.parameters(), 4)
        self._optimizer.step()
        self._optimizer.zero_grad()

        return loss
Exemplo n.º 7
0
    def _muzero_update(self, data: AttrDict):
        state_values_p1 = torch.cat([data.state_values, self._prev_data['state_values'].cpu().unsqueeze(0)], 0)
        data.value_targets = calc_value_targets(data.rewards, state_values_p1, data.dones, self.reward_discount, 0.99)

        for name, t in data.items():
            H, B = t.shape[:2]
            # (B, X, H, 1)
            image_like = t.reshape(H, B, -1).permute(1, 2, 0).unsqueeze(-1)
            X = image_like.shape[1]
            # (B, X * rollout_depth, N)
            image_like = F.unfold(image_like.float(), (self.rollout_depth, 1)).type_as(image_like)
            N = image_like.shape[-1]
            # (B, X, depth, N)
            image_like = image_like.reshape(B, X, self.rollout_depth, image_like.shape[-1])
            # (depth, B * N, *)
            data[name] = image_like.permute(2, 0, 3, 1).reshape(self.rollout_depth, B * N, *t.shape[2:])

        num_samples = data.states.shape[1]

        # rand_idx = torch.randperm(num_samples, device=self.device_train).chunk(self.num_batches)

        old_model = {k: v.clone() for k, v in self._train_model.state_dict().items()}

        # with DataLoader(data, rand_idx, self.device_train, 4, dim=1) as data_loader:
        #     for batch_index in range(self.num_batches):
        # prepare batch data
        # (H, B, *)
        # batch = AttrDict(data_loader.get_next_batch())
        # loss = self._update_traj(batch, self._do_log and batch_index == self.num_batches - 1)
        # if self._do_actor_update:
        #     blend_models(self._train_model, self._target_model, self.target_model_blend)
        # self._update_iter += 1

        loss = self._update_traj(AttrDict({k: v.to(self.device_train, non_blocking=True) for k, v in data.items()}), self._do_log)

        if self._do_log:
            self.logger.add_scalar('learning rate', self._optimizer.param_groups[0]['lr'], self.frame_train)
            self.logger.add_scalar('total loss', loss, self.frame_train)
            self.logger.add_scalar('model abs diff', model_diff(old_model, self._train_model), self.frame_train)
            self.logger.add_scalar('model max diff', model_diff(old_model, self._train_model, True), self.frame_train)
Exemplo n.º 8
0
    def _step(self, rewards: torch.Tensor, dones: torch.Tensor, states: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            # run network
            ac_out = self._eval_model(states.to(self.device_eval))
            ac_out.state_values = ac_out.state_values.squeeze(-1)
            actions = self._eval_model.heads.logits.pd.sample(ac_out.logits).cpu()

            assert not torch.isnan(actions.sum())

            self._eval_steps += 1

            if not self.disable_training:
                if self._prev_data is not None and rewards is not None:
                    self._replay_buffer.push(rewards=rewards, dones=dones, **self._prev_data)

                    for i in range(len(dones)):
                        self._episode_returns[i] += rewards[i]
                        if dones[i] > 0:
                            if len(self._models_fitness) <= self._model_index:
                                self._models_fitness.append([])
                            if self._num_completed_episodes[i].item() >= 0:
                                self._models_fitness[self._model_index].append(self._episode_returns[i].item())
                            self._episode_returns[i] = 0

                enough_episodes = self._num_completed_episodes.sum().item() >= self.episodes_per_model and \
                                  self._num_completed_episodes.min().item() > 0
                enough_frames = len(self._replay_buffer) >= self.min_replay_size
                no_models = len(self._models) == 0
                if (enough_episodes or no_models) and enough_frames:
                    self._eval_steps = 0
                    self._es_train()
                    self._num_completed_episodes.fill_(-1)

                if self._terminal is not None:
                    self._num_completed_episodes += dones.long()
                self._prev_data = AttrDict(**ac_out, states=states, actions=actions)

            return actions
Exemplo n.º 9
0
 def _create_data(self):
     # (steps, actors, *)
     data = self._replay_buffer.get_last_samples(self.horizon)
     data = AttrDict(data)
     # data.rewards = self.reward_scale * data.rewards
     return data