예제 #1
0
    def _value_func_estimate(self):
        if len(self.buffer) < 32:
            return

        self.target_usage += 1
        transitions = self.buffer.sample(32)
        batch = Transition(*zip(*transitions))

        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        next_state_batch = torch.cat(batch.next_state)

        state_action_values = self.dqn(state_batch).gather(1, action_batch)
        next_state_values = self.target(next_state_batch).max(1)[0].detach()

        expected_state_action_values = (next_state_values *
                                        self.gamma) + reward_batch

        loss = self.loss_fn(
            state_action_values,
            expected_state_action_values.unsqueeze(1),
        )

        self.optimizer.zero_grad()
        loss.backward()
        for param in self.dqn.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        if self.target_usage == 10:
            self.target_usage = 0
            self.target.load_state_dict(self.dqn.state_dict())
            self.target.eval()
예제 #2
0
    def learn_from_buffer(self):
        if len(self.training_data) < self.batch_size:
            return

        try:
            loss_ensemble = 0.
            for i in range(0, 10):
                transitions = self.training_data.sample(self.batch_size)
                batch = Transition(*zip(*transitions))
                non_final_mask = torch.tensor(tuple(
                    map(lambda s: s is not None, batch.next_state)),
                                              device=device,
                                              dtype=torch.bool)
                non_final_next_states = torch.cat(
                    [s for s in batch.next_state if s is not None])

                state_batch = torch.cat(batch.state)
                action_batch = torch.cat(batch.action)
                reward_batch = torch.cat(batch.reward)
                state_action_values = self.model(state_batch).gather(
                    1, action_batch)

                next_state_values = torch.zeros(self.batch_size,
                                                device=device,
                                                dtype=torch.double)
                next_state_values[non_final_mask] = self.target_net(
                    non_final_next_states).max(1)[0].detach()

                expected_state_action_values = self.gamma * next_state_values + reward_batch

                loss = self.loss_fn(state_action_values,
                                    expected_state_action_values.unsqueeze(1))
                loss_ensemble += loss.item()

                self.optimizer.zero_grad()
                loss.backward()

                #             for param in self.model.parameters():
                #                 param.grad.data.clamp_(-1, 1)
                self.optimizer.step()

            self.running_loss = 0.8 * self.running_loss + 0.2 * loss_ensemble
            self.epsilon = 0.999 * self.epsilon
        except:
            print('{}: no non-terminal state'.format(self.agent_name))
    def learn_from_buffer(self):
        for i in range(self.ensemble_size):
            if len(self.training_datas[i]) < self.batch_size:
                return

        loss_ensemble = 0.0
        for _ in range(10):
            for i in range(self.ensemble_size):
                transitions = self.training_datas[i].sample(self.batch_size)
                batch = Transition(*zip(*transitions))
                non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=device, dtype=torch.bool)

                state_batch = torch.cat(batch.state)
                reward_batch = torch.cat(batch.reward)
                state_action_values = self.models[i](state_batch)

                all_none = True
                for s in batch.next_state:
                    if s is not None:
                        all_none = False

                next_state_values = torch.zeros(self.batch_size, device=device, dtype=torch.double)
                if not all_none:
                    non_final_next_states = torch.cat([s for s in batch.next_state
                                                                if s is not None])

                    next_state_values[non_final_mask] = self.target_nets[i](non_final_next_states).max(1)[0].reshape((-1)).detach()

                expected_state_action_values = self.gamma * next_state_values + reward_batch

                loss = self.loss_fn(state_action_values, expected_state_action_values.unsqueeze(1))
                loss_ensemble += loss.item()

                self.optimizers[i].zero_grad()
                loss.backward()

                for param in self.models[i].parameters():
                    param.grad.data.clamp_(-1, 1)
                self.optimizers[i].step()

            self.running_loss = 0.8 * self.running_loss + 0.2 * loss_ensemble
    def _value_func_estimate(self):
        if len(self.buffer) < 32:
            return

        self.target_usage += 1
        transitions = self.buffer.sample(32)
        batch = Transition(*zip(*transitions))

        state_batch = torch.cat(batch.state)
        reward_batch = torch.cat(batch.reward)
        reward_l_batch = []
        reward_u_batch = []
        for state in state_batch:
            cur_state = state.tolist()
            reward, std = self.reward_gp.predict(np.array([cur_state]), True)
            reward_l_batch.append(
                torch.tensor([reward[0] - self.beta * std[0]],
                             dtype=torch.double))
            reward_u_batch.append(
                torch.tensor([reward[0] + self.beta * std[0]],
                             dtype=torch.double))

        reward_l_batch = torch.cat(reward_l_batch)
        reward_u_batch = torch.cat(reward_u_batch)
        action_batch = torch.cat(batch.action)
        next_state_batch = torch.cat(batch.next_state)

        state_action_values = self.dqn(state_batch).gather(1, action_batch)
        next_state_values = self.target(next_state_batch).max(1)[0].detach()

        expected_state_action_values = (next_state_values *
                                        self.gamma) + reward_batch

        loss = self.loss_fn(
            state_action_values,
            expected_state_action_values.unsqueeze(1),
        )

        self.optimizer.zero_grad()
        loss.backward()
        for param in self.dqn.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        state_action_values = self.dqn_u(state_batch).gather(1, action_batch)
        next_state_values = self.target_u(next_state_batch).max(1)[0].detach()

        expected_state_action_values = (next_state_values *
                                        self.gamma) + reward_u_batch

        loss = self.loss_fn(
            state_action_values,
            expected_state_action_values.unsqueeze(1),
        )

        self.optimizer_u.zero_grad()
        loss.backward()
        for param in self.dqn_u.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer_u.step()

        state_action_values = self.dqn_l(state_batch).gather(1, action_batch)
        next_state_values = self.target_l(next_state_batch).max(1)[0].detach()

        expected_state_action_values = (next_state_values *
                                        self.gamma) + reward_l_batch

        loss = self.loss_fn(
            state_action_values,
            expected_state_action_values.unsqueeze(1),
        )

        self.optimizer_l.zero_grad()
        loss.backward()
        for param in self.dqn_l.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer_l.step()

        if self.target_usage == 10:
            self.target_usage = 0
            self.target.load_state_dict(self.dqn.state_dict())
            self.target.eval()
            self.target_u.load_state_dict(self.dqn_u.state_dict())
            self.target_u.eval()
            self.target_l.load_state_dict(self.dqn_l.state_dict())
            self.target_l.eval()
예제 #5
0
    def optimize_model(self):
        """
        optimize model for 1 time step by sampling transitions, calculating TD errors, loss
        and updating the memory transitions' probabilities
        """
        if len(self.memory) < BATCH_SIZE:
            return
        transitions = self.memory.sample(BATCH_SIZE)

        weights = [t[1] for t in transitions]
        positions = [t[2] for t in transitions]
        transitions = [t[0] for t in transitions]
        max_weight = self.memory.max_weight

        batch = Transition(*zip(*transitions))

        # Compute a mask of non-final states and concatenate the batch elements
        # (a final state would've been the one after which simulation ended)
        non_final_mask = torch.tensor(tuple(
            map(lambda s: s is not None, batch.next_state)),
                                      device=device,
                                      dtype=torch.bool)
        non_final_next_states = torch.cat(
            [s for s in batch.next_state if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        state_action_values = self.policy_net(state_batch).gather(
            1, action_batch)

        next_state_values = torch.zeros(BATCH_SIZE, device=device)
        # DDQN: separate action selection and evaluation
        with torch.no_grad():
            new_actions = self.policy_net(non_final_next_states).max(
                1)[1].view(-1, 1)
        next_state_values[non_final_mask] = self.target_net(
            non_final_next_states).gather(1, new_actions).detach().view(-1)

        # Compute the expected Q values
        expected_state_action_values = (next_state_values *
                                        GAMMA) + reward_batch
        # calculate TD error

        TD_error = (expected_state_action_values.unsqueeze(1) -
                    state_action_values).detach().view(-1).cpu().numpy()
        TD_error = np.abs(TD_error)

        # calculate importance sampling weights
        IS_weights = np.divide(weights, max_weight)
        # multiply the weights to the loss function to inject into the weight update
        square_indices = np.where(TD_error <= 1)[0]
        for i, w in enumerate(IS_weights):
            if i in square_indices:
                factor = torch.tensor(np.sqrt(w),
                                      device=device,
                                      dtype=torch.float32)
            else:
                factor = torch.tensor(w, device=device, dtype=torch.float32)

            state_action_values[i] *= factor
            expected_state_action_values[i] *= factor

        # Compute Huber loss
        loss = F.smooth_l1_loss(state_action_values,
                                expected_state_action_values.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()

        # update priorities of the replayed transitions
        self.memory.update(positions, TD_error)

        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()