Ejemplo n.º 1
0
        def f(trans):
            start_state = trans[0].state
            action = trans[0].action
            next_state = trans[-1].next_state
            valid = trans[-1].valid
            reward = 0.
            for i, data in enumerate(trans):
                reward += data.reward * self.gamma ** i

            return Transition(start_state, action, next_state, reward, valid)
Ejemplo n.º 2
0
    def set_dataset(self, dataset, num=None):
        print('data registreation start')
        st = time.time()

        ob = dataset['state']
        ac = dataset['action']
        ne = np.concatenate((ob[1:], ob[-1:]))
        re = dataset['reward']
        va = 1 - dataset['terminal']

        if num is not None:
            ob, ac, ne, re, va = ob[:num], ac[:num], ne[:num], re[:num], va[:num]


        all_li = list(zip(ob, ac, ne, re, va))
        result = list(map(lambda x: [Transition(*x)], all_li))
        gl = time.time()
        print(f'data registreation took {np.round(gl-st, 2)} sec.')
        print(f'{len(result) / 1e6} M transitions were registered.')

        self.replay_buffer.memory = result
Ejemplo n.º 3
0
    def train(self):
        if len(self.replay_buffer) < self.batch_size or len(
                self.replay_buffer) < self.replay_start_step:
            return

        if self.prioritized:
            transitions, idx_batch, weights = self.replay_buffer.sample(
                self.batch_size)
        else:
            transitions = self.replay_buffer.sample(self.batch_size)

        def f(trans):
            start_state = trans[0].state
            action = trans[0].action
            next_state = trans[-1].next_state
            valid = trans[-1].valid
            reward = 0.
            for i, data in enumerate(trans):
                reward += data.reward * self.gamma**i

            return Transition(start_state, action, next_state, reward, valid)

        def extract_steps(trans):
            return len(trans)

        steps_batch = list(map(extract_steps, transitions))
        transitions = map(f, transitions)

        batch = Transition(*zip(*transitions))

        state_batch = torch.tensor(np.array(batch.state, dtype=np.float32),
                                   device=self.device)
        action_batch = torch.tensor(batch.action,
                                    device=self.device,
                                    dtype=torch.int64).unsqueeze(1)
        next_state_batch = torch.tensor(np.array(batch.next_state,
                                                 dtype=np.float32),
                                        device=self.device)
        reward_batch = torch.tensor(np.array(batch.reward, dtype=np.float32),
                                    device=self.device)
        valid_batch = torch.tensor(np.array(batch.valid, dtype=np.float32),
                                   device=self.device)
        steps_batch = torch.tensor(np.array(steps_batch, dtype=np.float32),
                                   device=self.device)

        state_action_values = self.q_func(state_batch).gather(1, action_batch)

        expected_state_action_values = reward_batch + \
            valid_batch * (self.gamma ** steps_batch) * \
            self.next_state_value(next_state_batch)

        if self.prioritized:
            td_error = abs(expected_state_action_values -
                           state_action_values.squeeze(1)).tolist()
            for data_idx, err in zip(idx_batch, td_error):
                self.replay_buffer.update(data_idx, err)

        if self.huber:
            if self.prioritized:
                loss_each = F.smooth_l1_loss(
                    state_action_values,
                    expected_state_action_values.unsqueeze(1),
                    reduction='none')
                loss = torch.sum(loss_each *
                                 torch.tensor(weights, device=self.device))
            else:
                loss = F.smooth_l1_loss(
                    state_action_values,
                    expected_state_action_values.unsqueeze(1))
        else:
            if self.prioritized:
                loss_each = F.mse_loss(
                    state_action_values,
                    expected_state_action_values.unsqueeze(1),
                    reduction='none')
                loss = torch.sum(loss_each *
                                 torch.tensor(weights, device=self.device))

            else:
                loss = F.mse_loss(state_action_values,
                                  expected_state_action_values.unsqueeze(1))

        if self.q_func.phase == 'struct':
            loss += self.param_coef * self.q_func.param_loss()
            self.optimizer_struct.zero_grad()
            loss.backward()
            self.optimizer_struct.step()
        elif self.q_func.phase == 'param':
            self.optimizer_param.zero_grad()
            loss.backward()
            # for param in self.q_func.parameters():
            #    param.grad.data.clamp_(-1, 1)
            self.optimizer_param.step()

        if self.total_steps % self.target_update_interval == 0:
            self.target_q_func.load_state_dict(self.q_func.state_dict())
        return float(loss.to('cpu').detach().numpy().copy())
Ejemplo n.º 4
0
    def train(self):
        # if len(self.replay_buffer) < self.batch_size or len(self.replay_buffer) < self.replay_start_step:
        #     return

        if self.prioritized:
            transitions, idx_batch, weights = self.replay_buffer.sample(self.batch_size)
        else:
            transitions = self.replay_buffer.sample(self.batch_size)


        def f(trans):
            start_state = trans[0].state
            action = trans[0].action
            next_state = trans[-1].next_state
            valid = trans[-1].valid
            reward = 0.
            for i, data in enumerate(trans):
                reward += data.reward * self.gamma ** i

            return Transition(start_state, action, next_state, reward, valid)


        def extract_steps(trans):
            return len(trans)


        steps_batch = list(map(extract_steps, transitions))
        transitions = map(f, transitions)

        batch = Transition(*zip(*transitions))

        state_batch = torch.tensor(
            np.array(batch.state, dtype=np.float32), device=self.device)
        action_batch = torch.tensor(
            batch.action, device=self.device, dtype=torch.int64).unsqueeze(1)
        next_state_batch = torch.tensor(
            np.array(batch.next_state, dtype=np.float32), device=self.device)
        reward_batch = torch.tensor(
            np.array(batch.reward, dtype=np.float32), device=self.device)
        valid_batch = torch.tensor(
            np.array(batch.valid, dtype=np.float32), device=self.device)
        steps_batch = torch.tensor(np.array(steps_batch, dtype=np.float32), device=self.device)

        qout = self.q_func(state_batch)
        state_action_values = qout.gather(1, action_batch)

        expected_state_action_values = reward_batch + \
                                       valid_batch * (self.gamma ** steps_batch) * \
                                       self.next_state_value(next_state_batch)

        if self.prioritized:
            td_error = abs(expected_state_action_values - state_action_values.squeeze(1)).tolist()
            for data_idx, err in zip(idx_batch, td_error):
                self.replay_buffer.update(data_idx, err)

        if self.huber:
            if self.prioritized:
                loss_each = F.smooth_l1_loss(state_action_values,
                                             expected_state_action_values.unsqueeze(1), reduction='none')
                dqn_loss = torch.sum(loss_each * torch.tensor(weights, device=self.device))
            else:
                dqn_loss = F.smooth_l1_loss(state_action_values,
                                        expected_state_action_values.unsqueeze(1))
        else:
            if self.prioritized:
                loss_each = F.mse_loss(state_action_values,
                                       expected_state_action_values.unsqueeze(1), reduction='none')
                dqn_loss = torch.sum(loss_each * torch.tensor(weights, device=self.device))

            else:
                dqn_loss = F.mse_loss(state_action_values,
                                  expected_state_action_values.unsqueeze(1))

        # add CQL loss
        policy_q = torch.logsumexp(qout, dim=-1, keepdim=True)
        data_q = state_action_values  # data_q, [32, 1]
        cql_loss = (policy_q - data_q).mean()
        loss = dqn_loss + self.cql_weight * cql_loss

        self.optimizer.zero_grad()
        loss.backward()
        # for param in self.q_func.parameters():
        #    param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        if self.total_steps - self.prev_target_update_time >= self.target_update_interval:
            self.target_q_func.load_state_dict(self.q_func.state_dict())
            self.prev_target_update_time = self.total_steps
Ejemplo n.º 5
0
    def train(self):
        # copy and paste
        if len(self.replay_buffer) < self.batch_size:
            return

        self.train_cnt += 1
        self.total_steps = self.train_cnt

        transitions = self.replay_buffer.sample(self.batch_size)
        map_func = lambda x: x[0]
        batch = Transition(*zip(*map(map_func, transitions)))

        state_batch = torch.tensor(
            np.array(batch.state, dtype=np.float32), device=self.device)
        action_batch = torch.tensor(
            np.array(batch.action, dtype=np.float32), device=self.device)
        next_state_batch = torch.tensor(
            np.array(batch.next_state, dtype=np.float32), device=self.device)
        reward_batch = torch.tensor(
            np.array(batch.reward, dtype=np.float32), device=self.device).unsqueeze(1)
        valid_batch = torch.tensor(
            np.array(batch.valid, dtype=np.float32), device=self.device).unsqueeze(1)

        target_q = self.calc_target_q(
            state_batch, action_batch, reward_batch, next_state_batch, valid_batch)
        q1_data = self.critic1(state_batch, action_batch)
        q2_data = self.critic2(state_batch, action_batch)
        q1_mse_loss = F.mse_loss(q1_data, target_q)
        q2_mse_loss = F.mse_loss(q2_data, target_q)

        num_random = 10
        random_actions = torch.FloatTensor(q2_data.shape[0] * num_random, action_batch.shape[-1]).uniform_(-1, 1).to(self.device)
        current_states = state_batch.unsqueeze(1).repeat(1, num_random, 1).view(-1, state_batch.shape[-1])
        current_actions, current_logpis, _ = self.try_act(current_states)
        current_actions, current_logpis = current_actions.detach(), current_logpis.detach()
        next_states = next_state_batch.unsqueeze(1).repeat(1, num_random, 1).view(-1, state_batch.shape[-1])
        next_actions, next_logpis, _ = self.try_act(next_states)
        next_actions, next_logpis = next_actions.detach(), next_logpis.detach()

        q1_rand = self.critic1(current_states, random_actions).view(-1, num_random)
        q2_rand = self.critic2(current_states, random_actions).view(-1, num_random)
        q1_curr = self.critic1(current_states, current_actions).view(-1, num_random)
        q2_curr = self.critic2(current_states, current_actions).view(-1, num_random)
        q1_next = self.critic1(current_states, next_actions).view(-1, num_random)
        q2_next = self.critic2(current_states, next_actions).view(-1, num_random)
        current_logpis = current_logpis.view(-1, num_random)
        next_logpis = next_logpis.view(-1, num_random)

        random_density = np.log(0.5 ** current_actions.shape[-1])
        cat_q1 = torch.cat([q1_rand - random_density,
                            q1_curr - current_logpis.detach(),
                            q1_next - next_logpis.detach()], 1)
        cat_q2 = torch.cat([q2_rand - random_density,
                            q2_curr - current_logpis.detach(),
                            q2_next - next_logpis.detach()], 1)

        cql_loss1 = (torch.logsumexp(cat_q1, dim=1, keepdim=True) - q1_data).mean()
        cql_loss2 = (torch.logsumexp(cat_q2, dim=1, keepdim=True) - q2_data).mean()

        q1_loss = q1_mse_loss + self.cql_weight * cql_loss1
        q2_loss = q2_mse_loss + self.cql_weight * cql_loss2
        q_loss = q1_loss + q2_loss

        self.q1_optim.zero_grad()
        self.q2_optim.zero_grad()
        q_loss.backward()
        self.q1_optim.step()
        self.q2_optim.step()

        pi, log_pi, _ = self.try_act(state_batch)

        qf1_pi = self.critic1(state_batch, pi)
        qf2_pi = self.critic2(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        if self.total_steps < self.policy_eval_start:
            raw_action = atanh(action_batch)
            mean, log_std = self.actor(state_batch)
            normal = Normal(mean, log_std.exp())
            eps = 1e-6
            logprob = (normal.log_prob(raw_action)
                      - torch.log(1 - action_batch.pow(2) + eps))
            logprob = logprob.sum(1, keepdims=True)

            policy_loss = ((self.alpha * log_pi) - logprob).mean()

        self.actor_optim.zero_grad()
        policy_loss.backward()
        self.actor_optim.step()

        # adjust alpha
        alpha_loss = -(self.log_alpha *
                       (self.target_entropy + log_pi).detach()).mean()

        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        self.alpha = self.log_alpha.exp()

        if self.train_cnt % self.target_update_interval == 0:
            self.soft_update(self.target_critic1, self.critic1)
            self.soft_update(self.target_critic2, self.critic2)
            self.prev_target_update_time = self.total_steps