示例#1
0
    def _update_policy(self, offset_batch):
        # update policy
        if self._update_count >= self._params.burnin:
            sample_batch_indices = np.searchsorted(self._state_density_cum_weights, np.random.rand(self._params.batch_size))
            demo_states = self._demo_states[sample_batch_indices]

            batch_states = offset_batch.states

            policy_actions = self._online_policy(batch_states)
            target_density = self._critic.q_value(batch_states, policy_actions, demo_states)

            actor_loss = -target_density.mean()

            self._actor_optim.zero_grad()
            actor_loss.backward()
            total_norm = torch.nn.utils.clip_grad_norm_(self._online_policy.parameters(), self._params.gradient_clip, 2)
            reporting.iter_record("policy_norm", total_norm)
            self._actor_optim.step()

            if self._update_count % self._target_update_rate == 0:
                for online_param, target_param in zip(self._online_policy.parameters(), self._target_policy.parameters()):
                    target_param.requires_grad = False
                    target_param.data = ((1 - self._params.target_update_step) * target_param +
                                         self._params.target_update_step * online_param)

            reporting.iter_record("actor_loss", actor_loss.item())
            if len(self._params.exploration_decay) > 0:
                self._behavior_policy.stepwise_exploration_decay(self._params.exploration_decay, self._update_count)
def train_fetch(experiment: sacred.Experiment, agent: Any, eval_env: FetchEnv, progressive_noise: bool, small_goal: bool):
    reporting.register_field("eval_success_rate")
    reporting.register_field("action_norm")
    reporting.finalize_fields()
    if progressive_noise:
        trange = tqdm.trange(2000000)
    elif small_goal:
        trange = tqdm.trange(2000000)
    else:
        trange = tqdm.trange(2000000)
    for iteration in trange:
        if iteration % 10000 == 0:
            action_norms = []
            success_rate = 0
            for i in range(50):
                state = eval_env.reset()
                while not eval_env.needs_reset:
                    action = agent.eval_action(state)
                    action_norms.append(np.linalg.norm(action))
                    state, reward, is_terminal, info = eval_env.step(action)
                    if reward > -1.:
                        success_rate += 1
                        break
            reporting.iter_record("eval_success_rate", success_rate)
            reporting.iter_record("action_norm", np.mean(action_norms).item())

        if iteration % 20000 == 0:
            policy_path = f"/tmp/policy_{iteration}"
            with open(policy_path, 'wb') as f:
                torch.save(agent.freeze_policy(torch.device('cpu')), f)
            experiment.add_artifact(policy_path)

        agent.update()
        reporting.iterate()
        trange.set_description(f"{iteration} -- " + reporting.get_description(["return", "td_loss", "env_steps"]))
    def update(self):
        logging.getLogger().setLevel(logging.DEBUG)
        self.__update_count += 1
        if self._futures is None:
            self._futures = []
            for i in range(self._params.num_envs):
                frozen_policy = self.freeze_policy(self._actor_device)
                frozen_policy.share_memory()
                self._futures.append(
                    self._executor.submit(self._collect_trajectory,
                                          self._replay_description,
                                          frozen_policy))
        else:
            for i in range(self._params.num_envs):
                if self._futures[i].done(
                ) and (self._env_steps - self._params.min_replay_size
                       ) * self._params.step_limit < self.__update_count:
                    frozen_policy = self.freeze_policy(self._actor_device)
                    frozen_policy.share_memory()
                    new_trajectory, cumulative_return = self._futures[
                        i].result()
                    self._futures[i] = self._executor.submit(
                        self._collect_trajectory, self._replay_description,
                        frozen_policy)
                    self._add_episode(new_trajectory)

                    self._env_steps += new_trajectory.shape[0]
                    reporting.iter_record("return", cumulative_return)

        reporting.iter_record("env_steps", self._env_steps)
        if self._buffer.size >= self._params.min_replay_size:
            self._update()
        else:
            time.sleep(0.1)
 def sample_from_var(self,
                     state_var: torch.Tensor,
                     t: int = 0,
                     return_logprob: bool = False) -> int:
     probabilities = self._probabilities(state_var)
     reporting.iter_record('max pi', np.max(probabilities, keepdims=True))
     if return_logprob:
         raise NotImplementedError()
     return np.random.choice(probabilities.shape[0], p=probabilities)
示例#5
0
    def _update(self, batch: data.TDBatch) -> torch.Tensor:
        q_values_o = self._online_network(batch.states)
        values_o = q_values_o.gather(dim=1, index=batch.actions.unsqueeze(1))
        next_q_values_t = self._target_network(batch.bootstrap_states)
        next_values_t = next_q_values_t.gather(
            dim=1, index=batch.bootstrap_actions.unsqueeze(1)).squeeze()

        target_values = batch.intermediate_returns + batch.bootstrap_weights * next_values_t
        reporting.iter_record("td_target_value", target_values.mean().item())
        return torch.mean((values_o.squeeze() - target_values)**2)
 def _update_value(self, batch: HerTransitionSequence):
     if self._update_count >= self._params.density_burnin:
         td_loss = self._critic.update_loss(batch)
         self._critic_optim.zero_grad()
         td_loss.backward()
         if np.isfinite(self._params.gradient_clip):
             torch.nn.utils.clip_grad_norm_(self._critic.parameters,
                                            self._params.gradient_clip,
                                            norm_type=2)
         self._critic_optim.step()
         reporting.iter_record("td_loss", td_loss.item())
    def _update_density_estimator(self, offset_batch: HerTransitionSequence,
                                  offsets: torch.Tensor):
        r_loss = -(offsets * self._goal_r(
            offset_batch.achieved_goal[:, 1, :], offset_batch.states[:, 0, :],
            offset_batch.actions[:, 0, :]).squeeze())

        self._r_optim.zero_grad()
        r_loss.mean().backward()
        self._r_optim.step()

        reporting.iter_record("r_loss", r_loss.mean().item())
示例#8
0
    def _update_density_estimator(self, offset_batch: data.TransitionSequence, offset_weights: torch.Tensor):
        target_states = offset_batch.next_states[:, 1, :] + torch.randn_like(offset_batch.next_states[:, 1, :]) * self.spatial_smoothing
        density_loss = -(self._density_model(
            state=offset_batch.states[:, 0, :],
            action=offset_batch.actions[:, 0, :],
            target_state=target_states)).squeeze().mean()
        state_density_loss = -self._state_density_model(target_states).mean()

        reporting.iter_record("density_loss", density_loss.item())
        self._density_optim.zero_grad()
        (density_loss + state_density_loss).backward()
        self._density_optim.step()
示例#9
0
    def _update(self,
                batch: data.TDBatch,
                return_mean: bool = True) -> torch.Tensor:
        values_o = self._online_network(batch.states).squeeze()
        next_values_t = self._target_network(batch.bootstrap_states).squeeze()

        target_values = batch.intermediate_returns + batch.bootstrap_weights * next_values_t
        reporting.iter_record("td_target_value", target_values.mean().item())

        if return_mean:
            return torch.mean((values_o.squeeze() - target_values)**2)
        else:
            return (values_o.squeeze() - target_values)**2
示例#10
0
    def _update_value(self, batch: data.TransitionSequence):
        if self._update_count >= self._params.density_burnin:
            sample_batch_indices = np.random.choice(self._demo_states.shape[0], size=self._params.batch_size, replace=True)
            demo_states = self._demo_states[sample_batch_indices]

            target_states = demo_states
            td_loss = self._critic.update_loss(batch, target_states, weights=None)
            self._critic_optim.zero_grad()
            td_loss.backward()
            total_norm = torch.nn.utils.clip_grad_norm_(self._critic.parameters, self._params.critic_gradient_clip, norm_type=2)
            reporting.iter_record("q_norm", total_norm)
            self._critic_optim.step()
            reporting.iter_record("td_loss", td_loss.item())
示例#11
0
    def update(self):
        batch_indices = np.random.choice(self._train_x.shape[0], size=self._batch_size)
        batch_x = torch.from_numpy(self._train_x[batch_indices])
        batch_y_target = torch.from_numpy(self._train_y[batch_indices])

        loss = ((self._net(batch_x) - batch_y_target)**2).mean()
        self._optim.zero_grad()
        loss.backward()
        self._optim.step()

        batch_indices = np.random.choice(self._valid_x.shape[0], size=self._batch_size)
        batch_x = torch.from_numpy(self._valid_x[batch_indices])
        batch_y_target = torch.from_numpy(self._valid_y[batch_indices])

        valid_loss = ((self._net(batch_x) - batch_y_target)**2).mean()
        reporting.iter_record("train_loss", loss.item())
        reporting.iter_record("valid_loss", valid_loss.item())
    def update(self):
        i = 0
        while i < self._params.batch_size:
            env = np.random.choice(self._envs)
            j = i
            for t in range(self._params.num_steps):
                if env.needs_reset:
                    reporting.iter_record("return", env.cumulative_return())
                    env.reset()
                    break

                state = env.state
                action = np.atleast_1d(self.sample_action(state))
                next_state, reward, is_terminal, _ = env.step(action)

                if is_terminal:
                    self._bootstrap_weights[i] = 0.
                else:
                    self._bootstrap_weights[i] = self._params.discount_factor
                self._rewards[i] = reward
                self._states[i] = torch.Tensor(state)
                self._actions[i] = torch.Tensor(action)
                self._bootstrap_states[i] = torch.Tensor(next_state)

                for past in range(j, i):
                    self._rewards[
                        past] += reward * self._params.discount_factor**(
                            i + 1 - past)
                    self._bootstrap_states[past] = self._bootstrap_states[i]
                    self._bootstrap_weights[past] *= self._bootstrap_weights[i]

                i += 1

                if i >= self._params.batch_size:
                    break

        td_batch = data.TDBatch(
            states=self._states.to(self._device),
            actions=self._actions.to(self._device),
            intermediate_returns=self._rewards.to(self._device),
            bootstrap_weights=self._bootstrap_weights.to(self._device),
            bootstrap_states=self._bootstrap_states.to(self._device),
            bootstrap_actions=self._bootstrap_actions.to(self._device))

        self._update(td_batch)
示例#13
0
    def update(self):
        for t in range(self._params.steps_per_update):
            state = self._env.state
            action, action_logprob = self.sample_action(state)

            next_state, reward, is_terminal, _ = self._env.step(action)
            is_timeout = self._env.needs_reset
            terminal_weight = 0. if is_terminal else 1.
            timeout_weight = 0. if is_timeout else 1.

            self._buffer.add_samples(self._to_buffer_format(state, action, reward, next_state,
                                                            timeout_weight, terminal_weight, action_logprob))

            if self._env.needs_reset:
                reporting.iter_record("return", self._env.cumulative_return())
                self._env.reset()

        if self._buffer.size >= self._params.min_replay_size:
            sequence = self._buffer.sample_sequence(self._params.batch_size, self._params.sequence_length)
            self._update(self._buffer_to_sequence(sequence))
    def _update_policy(self, batch: HerTransitionSequence):
        # update policy
        if self._update_count >= self._params.burnin:
            actor_loss = -self._critic.q_value(
                batch.states[:, 0], self._online_policy(
                    batch.states[:, 0])).mean()

            self._actor_optim.zero_grad()
            actor_loss.backward()
            self._actor_optim.step()

            if self._update_count % self._target_update_rate == 0:
                for online_param, target_param in zip(
                        self._online_policy.parameters(),
                        self._target_policy.parameters()):
                    target_param.requires_grad = False
                    target_param.data = (
                        (1 - self._params.target_update_step) * target_param +
                        self._params.target_update_step * online_param)

            reporting.iter_record("actor_loss", actor_loss.item())
示例#15
0
    def _update(self, batch: data.TransitionSequence, target_states: torch.Tensor, *args, weights: Optional[torch.Tensor]=None, **kwargs) -> torch.Tensor:
        states = batch.states
        actions = batch.actions
        bootstrap_weights = self._discount_factor * batch.terminal_weight
        next_states = batch.next_states

        next_actions = self._target_policy(next_states)
        next_actions = next_actions + torch.randn_like(actions) * self._action_noise_stddev

        next_q1, next_q2 = self._target_network(next_states, next_actions, target_states)
        next_q = torch.min(next_q1, next_q2).squeeze()

        online_q1, online_q2 = self._online_network(states, actions, target_states)

        td_target_q = (bootstrap_weights * next_q).detach()

        density_target_q = (self._density_q_model(state=states, action=actions, target_state=target_states)).exp()
        if self._temporal_smoothing > 0.:
            density_target_q = self._temporal_smoothing * td_target_q + (1 - self._temporal_smoothing) * density_target_q


        target_q = batch.terminal_weight * torch.max(td_target_q, density_target_q)
        target_idx = (td_target_q > density_target_q).float().mean()
        reporting.iter_record("terminal_weight", batch.terminal_weight.sum().item())
        reporting.iter_record("target_q", target_q.mean().item())
        reporting.iter_record("target_fraction", target_idx.item())


        td_loss_1 = f.smooth_l1_loss(online_q1.squeeze(), target_q)
        td_loss_2 = f.smooth_l1_loss(online_q2.squeeze(), target_q)
        return (td_loss_1 + td_loss_2).mean()
    def _update(self,
                batch: data.TransitionSequence,
                return_mean: bool = True,
                *args,
                **kwargs) -> torch.Tensor:
        states = batch.states[:, 0]
        actions = batch.actions[:, 0]
        bootstrap_weights = self._discount_factor * batch.terminal_weight[:, 0]
        next_states = batch.next_states[:, 0]

        goal = states[:, :self._goal_dim]
        if self._shuffle_goals:
            goal = states[torch.randperm(states.shape[0]), :self._goal_dim]
            states = torch.cat([goal, states[:, self._goal_dim:]], dim=1)
            next_states = torch.cat([goal, next_states[:, self._goal_dim:]],
                                    dim=1)

        next_actions = self._target_policy(next_states)
        next_actions = next_actions + torch.randn_like(
            actions) * self._action_noise_stddev

        next_q1, next_q2 = self._target_network(next_states, next_actions)
        next_q = torch.min(next_q1, next_q2).squeeze()

        online_q1, online_q2 = self._online_network(states, actions)

        td_target_q = (bootstrap_weights * next_q).detach()
        density_target_q = self._density_q_model.reward(
            goal, next_states, next_actions).detach().squeeze()
        target_q = torch.max(td_target_q, density_target_q)
        target_idx = (td_target_q > density_target_q).float().mean()
        reporting.iter_record("target_fraction", target_idx.item())
        reporting.iter_record("target_q", target_q.mean().item())

        td_loss_1 = (online_q1.squeeze() - target_q)**2
        td_loss_2 = (online_q2.squeeze() - target_q)**2
        if return_mean:
            return (td_loss_1 + td_loss_2).mean()
        else:
            return td_loss_1 + td_loss_2
    def _update(self):
        self._update_count += 1

        # Report validation losses
        if self._update_count % 100 == 0 and self._valid_buffer.size > 0:
            valid_batch, _ = self._density_sampler.sample_discounted_offset(
                self._params.batch_size)
            valid_batch = self._replay_description.parse_sample(valid_batch)

            valid_loss = -(self._goal_r(valid_batch.achieved_goal[:, 1, :],
                                        valid_batch.states[:, 0, :],
                                        valid_batch.actions[:, 0, :])).mean()
            reporting.iter_record("valid_r_loss", valid_loss.item())

            valid_target_r = self._goal_r.reward(
                valid_batch.target_goal[:, 0, :], valid_batch.states[:, 0, :],
                valid_batch.actions[:, 0, :]).mean()
            valid_target_v = self._critic.q_value(
                valid_batch.states[:, 0, :], valid_batch.actions[:,
                                                                 0, :]).mean()
            reporting.iter_record("valid_target_r", valid_target_r.item())
            reporting.iter_record("valid_target_v", valid_target_v.item())

        batch = self._replay_description.parse_sample(
            self._buffer.sample_sequence(self._params.batch_size, 2))
        offset_batch, density_target_offset = self._density_sampler.sample_uniform_offset(
            self._params.batch_size)
        offset_batch = self._replay_description.parse_sample(offset_batch)
        offset_weights = self._params.discount_factor**(
            offset_batch.states.new(density_target_offset) - 1)

        self._update_density_estimator(offset_batch, offset_weights)
        self._update_value(batch)
        self._update_policy(batch)
    def _update(self):
        offset_sequence = self._replay_description.parse_sample(
            self._sampler.sample_future_pair(self._params.batch_size))
        batch = self._env.replace_goals(offset_sequence,
                                        offset_sequence.achieved_goal[:, 1, :],
                                        1 - 1 / self._params.her_k)
        self._update_count += 1
        td_loss = self._critic.update_loss(batch)

        self._critic_optim.zero_grad()
        td_loss.backward()
        if np.isfinite(self._params.gradient_clip):
            torch.nn.utils.clip_grad_norm_(self._critic.parameters,
                                           self._params.gradient_clip,
                                           norm_type=2)
        self._critic_optim.step()
        reporting.iter_record("td_loss", td_loss.item())

        if self._update_count >= self._params.burnin:
            actor_loss = -self._critic.q_value(
                batch.states[:, 0], self._online_policy(
                    batch.states[:, 0])).mean()
            # actor_loss = actor_loss + batch.actions[:, 0].norm(dim=1).mean() #TODO

            self._actor_optim.zero_grad()
            actor_loss.backward()
            self._actor_optim.step()

            reporting.iter_record("actor_loss", actor_loss.item())

            if self._update_count % self._target_update_rate == 0:
                for online_param, target_param in zip(
                        self._online_policy.parameters(),
                        self._target_policy.parameters()):
                    target_param.requires_grad = False
                    target_param.data = (
                        1 - self._params.target_update_step
                    ) * target_param + self._params.target_update_step * online_param
    def _update(self, td_batch: data.TransitionSequence):
        td_loss = self._critic.update_loss(td_batch)
        self._value_optimizer.zero_grad()
        td_loss.backward()
        if np.isfinite(self._params.gradient_clip):
            torch.nn.utils.clip_grad_norm_(self._critic.parameters,
                                           self._params.gradient_clip,
                                           norm_type=2)
        self._value_optimizer.step()

        reporting.iter_record("td_loss", td_loss.item())
        self._num_updates += 1
        if self._num_updates < self._params.burnin or self._num_updates % self._params.policy_update_rate != 0:
            return

        vtrace_targets = self._critic.advantage_targets.detach()
        importance_weights = self._critic.importance_weights

        advantages = (td_batch.rewards[:, 0] +
                      self._params.discount_factor * vtrace_targets -
                      self._critic.values(td_batch.states[:, 0]))
        advantage_loss = (-importance_weights * self._policy.log_probability(
            td_batch.states[:, 0], td_batch.actions[:, 0]).squeeze() *
                          advantages.detach()).mean()
        if self._params.entropy_regularization > 0:
            entropy_loss = -self._params.entropy_regularization * self._policy.entropy(
                td_batch.states[:, 0]).mean()
        if self._params.entropy_regularization > 0:
            loss = advantage_loss + entropy_loss
        else:
            loss = advantage_loss

        reporting.iter_record("advantage_loss", advantage_loss.item())
        if self._params.entropy_regularization > 0:
            reporting.iter_record("entropy_loss", entropy_loss.item())

        self._actor_optimizer.zero_grad()
        loss.backward()
        if np.isfinite(self._params.gradient_clip):
            torch.nn.utils.clip_grad_norm_(self._policy.parameters,
                                           self._params.gradient_clip,
                                           norm_type=2)
        self._actor_optimizer.step()
    def _update(self, td_batch: data.TDBatch):
        td_loss = self._critic.update_loss(td_batch)
        self._value_optimizer.zero_grad()
        td_loss.backward()
        if np.isfinite(self._params.gradient_clip):
            torch.nn.utils.clip_grad_norm_(self._critic.parameters,
                                           self._params.gradient_clip,
                                           norm_type='inf')
        self._value_optimizer.step()

        reporting.iter_record("td_loss", td_loss.item())
        self._num_updates += 1
        if self._num_updates < self._params.burnin:
            return

        advantages = (td_batch.intermediate_returns +
                      td_batch.bootstrap_weights *
                      self._critic.values(td_batch.bootstrap_states) -
                      self._critic.values(td_batch.states))
        advantage_loss = (-self._policy.log_probability(
            td_batch.states, td_batch.actions).squeeze() *
                          advantages.detach()).mean()
        if self._params.entropy_regularization > 0:
            entropy_loss = -self._params.entropy_regularization * self._policy.entropy(
                td_batch.states).mean()
        if self._params.entropy_regularization > 0:
            loss = advantage_loss + entropy_loss
        else:
            loss = advantage_loss

        reporting.iter_record("advantage_loss", advantage_loss.item())
        if self._params.entropy_regularization > 0:
            reporting.iter_record("entropy_loss", entropy_loss.item())

        self._actor_optimizer.zero_grad()
        loss.backward()
        self._actor_optimizer.step()
def train(density_learning_rate: float, _config: sacred.config.ConfigDict):
    target_dir = "/home/anon/generated_data/algorithms"
    reporting.register_global_reporter(experiment, target_dir)
    device = torch.device('cuda:0')
    demo_states, _ = load_demos()
    demo_min = np.min(demo_states, axis=0)
    demo_max = np.max(demo_states, axis=0)
    random_min, random_max = random_rollout_bounds(10)
    min_states = np.minimum(demo_min, random_min)
    max_states = np.maximum(demo_max, random_max)
    make_normalized_env = functools.partial(environment_adapters.NormalizedEnv,
                                            make_env, min_states, max_states)
    eval_env = make_normalized_env()
    np.savetxt(target_dir + "/normalization", [min_states, max_states])
    experiment.add_artifact(target_dir + "/normalization")

    demo_states = eval_env.normalize_state(demo_states)
    demo_states = torch.from_numpy(demo_states).to(device)
    demo_actions = None
    print(demo_states.shape)

    state_dim = demo_states.shape[1]
    action_dim = eval_env.action_dim

    density_model = DensityModel(device, state_dim, action_dim)
    state_density_model = StateDensityModel(device, state_dim, action_dim)
    policy = PolicyNetwork(state_dim, action_dim).to(device)
    params_parser = util.ConfigParser(vdi.VDIParams)
    params = params_parser.parse(_config)

    q1 = QNetwork(state_dim, action_dim).to(device)
    q2 = QNetwork(state_dim, action_dim).to(device)

    agent = vdi.VDI(make_normalized_env, device, density_model,
                    state_density_model, policy, q1, q2, params, demo_states,
                    demo_actions)

    reporting.register_field("eval_return")
    reporting.finalize_fields()
    trange = tqdm.trange(1000000)
    for iteration in trange:
        agent.update()
        reporting.iterate()
        if iteration % 20000 == 0:
            eval_reward = 0
            for i in range(2):
                state = eval_env.reset()
                cumulative_reward = 0
                while not eval_env.needs_reset:
                    action = agent.eval_action(state)
                    state, reward, is_terminal, _ = eval_env.step(action)
                    cumulative_reward += reward
                eval_reward += cumulative_reward / 2
            reporting.iter_record("eval_return", eval_reward)

        if iteration % 10000 == 0:
            policy_path = f"{target_dir}/policy_{iteration}"
            with open(policy_path, 'wb') as f:
                torch.save(agent.freeze_policy(torch.device('cpu')), f)
            experiment.add_artifact(policy_path)
            density_model_path = f"{target_dir}/dm_{iteration}"
            with open(density_model_path, 'wb') as f:
                torch.save(density_model, f)
            experiment.add_artifact(density_model_path)

        trange.set_description(f"{iteration} -- " + reporting.get_description([
            "return", "eval_return", "density_loss", "actor_loss", "td_loss",
            "env_steps"
        ]))
示例#22
0
    def _update(self):
        self._update_count += 1
        if len(self._params.lr_decay_iterations) > 0:
            self._actor_scheduler.step()
            self._critic_scheduler.step()
            reporting.iter_record("actor_lr", self._actor_scheduler.get_lr()[0])
            reporting.iter_record("critic_lr", self._actor_scheduler.get_lr()[0])
        if self._update_count >= self._params.density_update_rate_burnin:
            self._density_update_rate = self._params.density_update_rate


        if self._update_count % self._density_update_rate == 0:
            self._density_optim.zero_grad()
            self._target_state_density_model.load_state_dict(copy.deepcopy(self._state_density_model.state_dict()))
            self._target_density_model.load_state_dict(copy.deepcopy(self._density_model.state_dict()))

            state_density = self._target_state_density_model(self._demo_states).exp()
            reporting.iter_record("max_state_density", state_density.max().item())
            reporting.iter_record("min_state_density", state_density.min().item())
            reporting.iter_record("mean_state_density", state_density.mean().item())
            reporting.iter_record("state_density_bound", state_density.mean().item() / self._params.density_factor)
            state_density_weights = state_density/state_density.mean()
            state_density_weights = torch.clamp(1/state_density_weights, max=self._params.density_factor)
            self._state_density_cum_weights = torch.cumsum(state_density_weights.squeeze(), 0).cpu().detach().numpy()
            self._state_density_cum_weights /= self._state_density_cum_weights[-1]

        offset_batch, density_target_offset = self._density_sampler.sample_discounted_offset(self._params.batch_size)
        offset_batch = self._replay_description.parse_sample(offset_batch)
        offset_weights = self._params.discount_factor ** (offset_batch.states.new(density_target_offset) - 1)

        self._update_density_estimator(offset_batch, offset_weights)


        single_batch = self._buffer.sample(self._params.batch_size)
        single_batch = self._replay_description.parse_sample(single_batch)
        self._update_value(single_batch)

        self._update_policy(single_batch)