Esempio n. 1
0
    def _update_vf(self, dataset):
        """Update the value function using a given dataset.

        The value function is updated via SGD to minimize TD(lambda) errors.
        """

        assert "state" in dataset[0]
        assert "v_teacher" in dataset[0]

        for batch in _yield_minibatches(
            dataset, minibatch_size=self.vf_batch_size, num_epochs=self.vf_epochs
        ):
            states = batch_states([b["state"] for b in batch], self.device, self.phi)
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            vs_teacher = torch.as_tensor(
                [b["v_teacher"] for b in batch],
                device=self.device,
                dtype=torch.float,
            )
            vs_pred = self.vf(states)
            vf_loss = F.mse_loss(vs_pred, vs_teacher[..., None])
            self.vf.zero_grad()
            vf_loss.backward()
            if self.max_grad_norm is not None:
                clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm)
            self.vf_optimizer.step()
Esempio n. 2
0
    def update_policy_with_goal(self, batch):
        """Compute loss for actor."""

        batch_state = batch["state"]
        batch_goal = batch["goal"]
        action_distrib = self.policy(torch.cat([batch_state, batch_goal], -1))
        onpolicy_actions = action_distrib.rsample()
        entropy_term = 0
        if self.add_entropy:
            log_prob = action_distrib.log_prob(onpolicy_actions)
            entropy_term = self.temperature * log_prob[..., None]
        q = self.q_func1((torch.cat([batch_state, batch_goal],
                                    -1), onpolicy_actions))

        # Since we want to maximize Q, loss is negation of Q
        loss = -torch.mean(-entropy_term + q)

        self.policy_loss_record.append(float(loss))
        self.policy_optimizer.zero_grad()
        loss.backward()

        # get policy gradients
        gradients = self.get_and_flatten_policy_gradients()
        gradient_variance = torch.var(gradients)
        gradient_mean = torch.mean(gradients)
        self.policy_gradients_variance_record.append(float(gradient_variance))
        self.policy_gradients_mean_record.append(float(gradient_mean))

        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.policy_optimizer.step()
        self.policy_n_updates += 1
Esempio n. 3
0
    def _update_vf_once_recurrent(self, episodes):

        # Sort episodes desc by length for pack_sequence
        episodes = sorted(episodes, key=len, reverse=True)

        flat_transitions = flatten_sequences_time_first(episodes)

        # Prepare data for a recurrent model
        seqs_states = []
        for ep in episodes:
            states = self.batch_states(
                [transition["state"] for transition in ep], self.device,
                self.phi)
            if self.obs_normalizer:
                states = self.obs_normalizer(states, update=False)
            seqs_states.append(states)

        flat_vs_teacher = torch.as_tensor(
            [[transition["v_teacher"]] for transition in flat_transitions],
            device=self.device,
            dtype=torch.float,
        )

        with torch.no_grad():
            vf_rs = concatenate_recurrent_states(
                _collect_first_recurrent_states_of_vf(episodes))

        flat_vs_pred, _ = pack_and_forward(self.vf, seqs_states, vf_rs)

        vf_loss = F.mse_loss(flat_vs_pred, flat_vs_teacher)
        self.vf.zero_grad()
        vf_loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm)
        self.vf_optimizer.step()
Esempio n. 4
0
 def update_temperature(self, log_prob):
     assert not log_prob.requires_grad
     loss = -torch.mean(self.temperature_holder() * (log_prob + self.entropy_target))
     self.temperature_optimizer.zero_grad()
     loss.backward()
     if self.max_grad_norm is not None:
         clip_l2_grad_norm_(self.temperature_holder.parameters(), self.max_grad_norm)
     self.temperature_optimizer.step()
Esempio n. 5
0
 def batch_update(self):
     assert len(self.reward_sequences) == self.batchsize
     assert len(self.log_prob_sequences) == self.batchsize
     assert len(self.entropy_sequences) == self.batchsize
     # Update the model
     assert self.n_backward == 0
     self.accumulate_grad()
     if self.max_grad_norm is not None:
         clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
     self.optimizer.step()
     self.n_backward = 0
Esempio n. 6
0
    def update_q_func(self, batch):
        """Compute loss for a given Q-function."""

        batch_next_state = batch["next_state"].float()
        batch_rewards = batch["reward"].float()
        batch_terminal = batch["is_state_terminal"].float()
        batch_state = batch["state"].float()
        batch_actions = batch["action"].float()
        batch_discount = batch["discount"].float()

        with torch.no_grad(), pfrl.utils.evaluating(
                self.policy), pfrl.utils.evaluating(
                    self.target_q_func1), pfrl.utils.evaluating(
                        self.target_q_func2):
            next_action_distrib = self.policy(batch_next_state.float())

            next_actions_normalized = next_action_distrib.sample()
            next_actions = self.scale * next_actions_normalized

            next_log_prob = next_action_distrib.log_prob(
                next_actions_normalized)
            next_q1 = self.target_q_func1((batch_next_state, next_actions))
            next_q2 = self.target_q_func2((batch_next_state, next_actions))
            next_q = torch.min(next_q1, next_q2)
            entropy_term = self.temperature * next_log_prob[..., None]
            assert next_q.shape == entropy_term.shape

            target_q = batch_rewards + batch_discount * (
                1.0 - batch_terminal) * torch.flatten(next_q - entropy_term)

        predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions)))
        predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions)))

        loss1 = 0.5 * F.mse_loss(target_q, predict_q1)
        loss2 = 0.5 * F.mse_loss(target_q, predict_q2)

        # Update stats
        self.q1_record.extend(predict_q1.detach().cpu().numpy())
        self.q2_record.extend(predict_q2.detach().cpu().numpy())
        self.q_func1_loss_record.append(float(loss1))
        self.q_func2_loss_record.append(float(loss2))

        self.q_func1_optimizer.zero_grad()
        loss1.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm)
        self.q_func1_optimizer.step()

        self.q_func2_optimizer.zero_grad()
        loss2.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm)
        self.q_func2_optimizer.step()
Esempio n. 7
0
    def update_policy_with_goal(self, batch):
        """Compute loss for actor."""

        batch_state = batch["state"]
        batch_goal = batch["goal"]
        action_distrib = self.policy(torch.cat([batch_state, batch_goal], -1))
        onpolicy_actions_normalized = action_distrib.rsample()
        onpolicy_actions = self.scale * onpolicy_actions_normalized
        entropy_term = 0
        if self.add_entropy:
            log_prob = action_distrib.log_prob(onpolicy_actions_normalized)
            entropy_term = self.temperature * log_prob[..., None]

            if self.entropy_target is not None:
                self.update_temperature(log_prob.detach())

            self.entropy_record.append(float(torch.mean(-entropy_term)))
            self.temperature_record.append(self.temperature)

        q = self.q_func1((torch.cat([batch_state, batch_goal],
                                    -1), onpolicy_actions))

        # Since we want to maximize Q, loss is negation of Q
        loss = -torch.mean(-entropy_term + q)

        self.policy_loss_record.append(float(loss))
        self.policy_optimizer.zero_grad()
        loss.backward()

        # get policy gradients
        # gradients = self.get_and_flatten_policy_gradients()
        # gradient_variance = torch.var(gradients)
        # gradient_mean = torch.mean(gradients)
        gradient_variance = 0
        gradient_mean = 0
        self.policy_gradients_variance_record.append(float(gradient_variance))
        self.policy_gradients_mean_record.append(float(gradient_mean))

        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.policy_optimizer.step()
        self.policy_n_updates += 1

        self.kl_divergence = self.compute_kl(self.policy, self.target_policy,
                                             batch_state, batch_goal)

        self.one_step_kl_divergence = self.compute_kl(self.policy,
                                                      self.prior_policy,
                                                      batch_state, batch_goal)

        self.prior_policy = copy.deepcopy(self.policy)
Esempio n. 8
0
def _test_clip_l2_grad_norm_(gpu):
    if gpu >= 0:
        device = torch.device("cuda:{}".format(gpu))
    else:
        device = torch.device("cpu")
    model = nn.Sequential(
        nn.Linear(2, 10),
        nn.ReLU(),
        nn.Linear(10, 3),
    ).to(device)
    x = torch.rand(7, 2).to(device)

    def backward():
        model.zero_grad()
        loss = model(x).mean()
        loss.backward()

    backward()
    raw_grads = _get_grad_vector(model)

    # Threshold large enough not to affect grads
    th = 10000
    backward()
    nn.utils.clip_grad_norm_(model.parameters(), th)
    clipped_grads = _get_grad_vector(model)

    backward()
    clip_l2_grad_norm_(model.parameters(), th)
    our_clipped_grads = _get_grad_vector(model)

    np.testing.assert_allclose(raw_grads, clipped_grads)
    np.testing.assert_allclose(raw_grads, our_clipped_grads)

    # Threshold small enough to affect grads
    th = 1e-2
    backward()
    nn.utils.clip_grad_norm_(model.parameters(), th)
    clipped_grads = _get_grad_vector(model)

    backward()
    clip_l2_grad_norm_(model.parameters(), th)
    our_clipped_grads = _get_grad_vector(model)

    with pytest.raises(AssertionError):
        np.testing.assert_allclose(raw_grads, clipped_grads, rtol=1e-5)

    with pytest.raises(AssertionError):
        np.testing.assert_allclose(raw_grads, our_clipped_grads, rtol=1e-5)

    np.testing.assert_allclose(clipped_grads, our_clipped_grads, rtol=1e-5)
Esempio n. 9
0
    def update_q_func(self, batch):
        """Compute loss for a given Q-function."""

        batch_next_state = batch["next_state"]
        batch_rewards = batch["reward"]
        batch_terminal = batch["is_state_terminal"]
        batch_state = batch["state"]
        batch_actions = batch["action"]
        batch_discount = batch["discount"]

        with torch.no_grad(), pfrl.utils.evaluating(
                self.target_policy), pfrl.utils.evaluating(
                    self.target_q_func1), pfrl.utils.evaluating(
                        self.target_q_func2):
            next_actions = self.target_policy_smoothing_func(
                self.target_policy(batch_next_state).sample())
            next_q1 = self.target_q_func1((batch_next_state, next_actions))
            next_q2 = self.target_q_func2((batch_next_state, next_actions))
            next_q = torch.min(next_q1, next_q2)

            target_q = batch_rewards + batch_discount * (
                1.0 - batch_terminal) * torch.flatten(next_q)

        predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions)))
        predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions)))

        loss1 = F.mse_loss(target_q, predict_q1)
        loss2 = F.mse_loss(target_q, predict_q2)

        # Update stats
        self.q1_record.extend(predict_q1.detach().cpu().numpy())
        self.q2_record.extend(predict_q2.detach().cpu().numpy())
        self.q_func1_loss_record.append(float(loss1))
        self.q_func2_loss_record.append(float(loss2))

        self.q_func1_optimizer.zero_grad()
        loss1.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm)
        self.q_func1_optimizer.step()

        self.q_func2_optimizer.zero_grad()
        loss2.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm)
        self.q_func2_optimizer.step()

        self.q_func_n_updates += 1
Esempio n. 10
0
    def update(self):
        with torch.no_grad():
            _, next_value = self.model(self.states[-1])
            next_value = next_value[:, 0]

        self._compute_returns(next_value)
        pout, values = self.model(self.states[:-1].reshape(-1, *self.obs_shape))

        actions = self.actions.reshape(-1, *self.action_shape)
        dist_entropy = pout.entropy().mean()
        action_log_probs = pout.log_prob(actions)

        values = values.reshape((self.update_steps, self.num_processes))
        action_log_probs = action_log_probs.reshape(
            (self.update_steps, self.num_processes)
        )
        advantages = self.returns[:-1] - values
        value_loss = (advantages * advantages).mean()
        action_loss = -(advantages.detach() * action_log_probs).mean()

        self.optimizer.zero_grad()

        (
            value_loss * self.v_loss_coef
            + action_loss * self.pi_loss_coef
            - dist_entropy * self.entropy_coeff
        ).backward()

        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
        self.optimizer.step()
        self.states[0] = self.states[-1]

        self.t_start = self.t

        # Update stats
        self.average_actor_loss += (1 - self.average_actor_loss_decay) * (
            float(action_loss) - self.average_actor_loss
        )
        self.average_value += (1 - self.average_value_decay) * (
            float(value_loss) - self.average_value
        )
        self.average_entropy += (1 - self.average_entropy_decay) * (
            float(dist_entropy) - self.average_entropy
        )
Esempio n. 11
0
    def update_policy(self, batch):
        """Compute loss for actor."""

        batch_state = batch["state"]

        onpolicy_actions = self.policy(batch_state).rsample()
        q = self.q_func1((batch_state, onpolicy_actions))

        # Since we want to maximize Q, loss is negation of Q
        loss = -torch.mean(q)

        self.policy_loss_record.append(float(loss))
        self.policy_optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.policy_optimizer.step()
        self.policy_n_updates += 1
Esempio n. 12
0
    def update(self):
        self._compute_returns()

        pout, values = self.model(self.states[:self.update_steps].reshape(
            -1, *self.obs_shape))

        actions = self.actions[:self.update_steps].reshape(
            -1, *self.action_shape)
        dist_entropy = pout.entropy().mean()
        action_log_probs = pout.log_prob(actions)

        values = values.reshape((self.update_steps, self.num_processes))
        action_log_probs = action_log_probs.reshape(
            (self.update_steps, self.num_processes))
        advantages = self.returns - values
        value_loss = (advantages * advantages).mean()
        action_loss = -(advantages.detach() * action_log_probs).mean()

        self.optimizer.zero_grad()

        (value_loss * self.v_loss_coef + action_loss * self.pi_loss_coef -
         dist_entropy * self.entropy_coeff).backward()

        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
        self.optimizer.step()

        # NOTE: Update time-step
        self.t_start += self.update_steps
        # sliding window
        self.states[:self.update_steps] = self.states[self.update_steps:-1]
        self.actions[:self.update_steps] = self.actions[self.update_steps:]
        self.rewards[:self.update_steps] = self.rewards[self.update_steps:]
        self.value_preds[:self.
                         update_steps] = self.value_preds[self.update_steps:-1]

        # Update stats
        self.average_actor_loss += (1 - self.average_actor_loss_decay) * (
            float(action_loss) - self.average_actor_loss)
        self.average_value += (1 - self.average_value_decay) * (
            float(value_loss) - self.average_value)
        self.average_entropy += (1 - self.average_entropy_decay) * (
            float(dist_entropy) - self.average_entropy)
Esempio n. 13
0
    def update(
        self,
        t_start,
        t_stop,
        R,
        actions,
        rewards,
        values,
        action_values,
        action_distribs,
        action_distribs_mu,
        avg_action_distribs,
    ):

        assert np.isscalar(R)
        self.assert_shared_memory()

        total_loss = self.compute_loss(
            t_start=t_start,
            t_stop=t_stop,
            R=R,
            actions=actions,
            rewards=rewards,
            values=values,
            action_values=action_values,
            action_distribs=action_distribs,
            action_distribs_mu=action_distribs_mu,
            avg_action_distribs=avg_action_distribs,
        )

        # Compute gradients using thread-specific model
        self.model.zero_grad()
        total_loss.squeeze().backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
        # Copy the gradients to the globally shared model
        copy_param.copy_grad(target_link=self.shared_model,
                             source_link=self.model)
        self.optimizer.step()

        self.sync_parameters()
Esempio n. 14
0
    def update_policy_and_temperature(self, batch):
        """Compute loss for actor."""

        batch_state = batch["state"].float()

        action_distrib = self.policy(batch_state)

        actions_normalized = action_distrib.rsample()
        actions = self.scale * actions_normalized
        log_prob = action_distrib.log_prob(actions_normalized)
        q1 = self.q_func1((batch_state, actions))
        q2 = self.q_func2((batch_state, actions))
        q = torch.min(q1, q2)

        entropy_term = self.temperature * log_prob[..., None]
        assert q.shape == entropy_term.shape
        loss = torch.mean(entropy_term - q)

        self.policy_optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.policy_optimizer.step()

        self.n_policy_updates += 1

        if self.entropy_target is not None:
            self.update_temperature(log_prob.detach())

        # Record entropy
        with torch.no_grad():
            try:
                self.entropy_record.extend(
                    action_distrib.entropy().detach().cpu().numpy())
            except NotImplementedError:
                # Record - log p(x) instead
                self.entropy_record.extend(-log_prob.detach().cpu().numpy())
Esempio n. 15
0
    def update_q_func_with_goal(self, batch):
        """
        Compute loss for a given Q-function, or critics
        """

        batch_next_state = batch["next_state"]
        batch_next_goal = batch["next_goal"]
        batch_rewards = batch["reward"]
        batch_terminal = batch["is_state_terminal"]
        batch_state = batch["state"]
        batch_goal = batch["goal"]
        batch_actions = batch["action"]
        batch_discount = batch["discount"]

        with torch.no_grad(), pfrl.utils.evaluating(
                self.target_policy), pfrl.utils.evaluating(
                    self.target_q_func1), pfrl.utils.evaluating(
                        self.target_q_func2):
            next_action_distrib = self.target_policy(
                torch.cat([batch_next_state, batch_next_goal], -1))
            next_actions = self.target_policy_smoothing_func(
                next_action_distrib.sample())

            entropy_term = 0
            if self.add_entropy:
                next_log_prob = next_action_distrib.log_prob(next_actions)
                entropy_term = self.temperature * next_log_prob[..., None]

            next_q1 = self.target_q_func1(
                (torch.cat([batch_next_state, batch_next_goal],
                           -1), next_actions))
            next_q2 = self.target_q_func2(
                (torch.cat([batch_next_state, batch_next_goal],
                           -1), next_actions))
            next_q = torch.min(next_q1, next_q2)

            target_q = batch_rewards + batch_discount * (
                1.0 - batch_terminal) * torch.flatten(next_q - entropy_term)

        predict_q1 = torch.flatten(
            self.q_func1((torch.cat([batch_state, batch_goal],
                                    -1), batch_actions)))
        predict_q2 = torch.flatten(
            self.q_func2((torch.cat([batch_state, batch_goal],
                                    -1), batch_actions)))

        loss1 = F.smooth_l1_loss(target_q, predict_q1)
        loss2 = F.smooth_l1_loss(target_q, predict_q2)

        # Update stats
        self.q1_record.extend(predict_q1.detach().cpu().numpy())
        self.q2_record.extend(predict_q2.detach().cpu().numpy())
        self.q_func1_loss_record.append(float(loss1))
        self.q_func2_loss_record.append(float(loss2))

        q1_recent_variance = np.var(
            list(self.q1_record)[-self.recent_variance_size:])
        q2_recent_variance = np.var(
            list(self.q2_record)[-self.recent_variance_size:])
        self.q_func1_variance_record.append(q1_recent_variance)
        self.q_func2_variance_record.append(q2_recent_variance)

        self.q_func1_optimizer.zero_grad()

        loss1.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm)
        self.q_func1_optimizer.step()

        self.q_func2_optimizer.zero_grad()
        loss2.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm)
        self.q_func2_optimizer.step()

        self.q_func_n_updates += 1
Esempio n. 16
0
 def update_with_accumulated_grad(self):
     assert self.n_backward == self.batchsize
     if self.max_grad_norm is not None:
         clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
     self.optimizer.step()
     self.n_backward = 0
Esempio n. 17
0
File: a3c.py Progetto: xylee95/pfrl
    def update(self, statevar):
        assert self.t_start < self.t

        n = self.t - self.t_start

        self.assert_shared_memory()

        if statevar is None:
            R = 0
        else:
            with torch.no_grad(), pfrl.utils.evaluating(self.model):
                if self.recurrent:
                    (_,
                     vout), _ = one_step_forward(self.model, statevar,
                                                 self.train_recurrent_states)
                else:
                    _, vout = self.model(statevar)
            R = float(vout)

        pi_loss_factor = self.pi_loss_coef
        v_loss_factor = self.v_loss_coef

        # Normalize the loss of sequences truncated by terminal states
        if self.keep_loss_scale_same and self.t - self.t_start < self.t_max:
            factor = self.t_max / (self.t - self.t_start)
            pi_loss_factor *= factor
            v_loss_factor *= factor

        if self.normalize_grad_by_t_max:
            pi_loss_factor /= self.t - self.t_start
            v_loss_factor /= self.t - self.t_start

        # Batch re-compute for efficient backprop
        batch_obs = self.batch_states(
            [self.past_obs[i] for i in range(self.t_start, self.t)],
            self.device,
            self.phi,
        )
        if self.recurrent:
            (batch_distrib, batch_v), _ = pack_and_forward(
                self.model,
                [batch_obs],
                self.past_recurrent_state[self.t_start],
            )
        else:
            batch_distrib, batch_v = self.model(batch_obs)
        batch_action = torch.stack(
            [self.past_action[i] for i in range(self.t_start, self.t)])
        batch_log_prob = batch_distrib.log_prob(batch_action)
        batch_entropy = batch_distrib.entropy()
        rev_returns = []
        for i in reversed(range(self.t_start, self.t)):
            R *= self.gamma
            R += self.past_rewards[i]
            rev_returns.append(R)
        batch_return = torch.as_tensor(list(reversed(rev_returns)),
                                       dtype=torch.float)
        batch_adv = batch_return - batch_v.detach().squeeze(-1)
        assert batch_log_prob.shape == (n, )
        assert batch_adv.shape == (n, )
        assert batch_entropy.shape == (n, )
        pi_loss = torch.sum(-batch_adv * batch_log_prob -
                            self.beta * batch_entropy,
                            dim=0)
        assert batch_v.shape == (n, 1)
        assert batch_return.shape == (n, )
        v_loss = F.mse_loss(batch_v, batch_return[..., None],
                            reduction="sum") / 2

        if pi_loss_factor != 1.0:
            pi_loss *= pi_loss_factor

        if v_loss_factor != 1.0:
            v_loss *= v_loss_factor

        if self.process_idx == 0:
            logger.debug("pi_loss:%s v_loss:%s", pi_loss, v_loss)

        total_loss = torch.squeeze(pi_loss) + torch.squeeze(v_loss)

        # Compute gradients using thread-specific model
        self.model.zero_grad()
        total_loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
        # Copy the gradients to the globally shared model
        copy_param.copy_grad(target_link=self.shared_model,
                             source_link=self.model)
        # Update the globally shared model
        self.optimizer.step()
        if self.process_idx == 0:
            logger.debug("update")

        self.sync_parameters()

        self.past_obs = {}
        self.past_action = {}
        self.past_rewards = {}
        self.past_recurrent_state = {}

        self.t_start = self.t
Esempio n. 18
0
 def our_clip():
     clip_l2_grad_norm_(model.parameters(), th)