Example #1
0
    def train_on_batch(self, inputs, outputs):
        """
        Train the model based on the given batch of data
        :param inputs: array
            Values for input features
        :param outputs: array
            Values for output features
        :return: float
            Batch loss
        """
        inputs = Tensor(inputs)
        grapheme_root = Tensor(outputs[0])
        vowel_diacritic = Tensor(outputs[1])
        consonant_diacritic = Tensor(outputs[2])

        inputs = inputs.to(self._device)
        grapheme_root = grapheme_root.to(self._device)
        vowel_diacritic = vowel_diacritic.to(self._device)
        consonant_diacritic = consonant_diacritic.to(self._device)

        self._optimizer.zero_grad()
        grapheme_root_pred, vowel_diacritic_pred, consonant_diacritic_pred = self(inputs)

        grapheme_root_loss = self._criterion(grapheme_root_pred, grapheme_root)
        vowel_diacritic_loss = self._criterion(vowel_diacritic_pred, vowel_diacritic)
        consonant_diacritic_loss = self._criterion(consonant_diacritic_pred, consonant_diacritic)

        loss = grapheme_root_loss + vowel_diacritic_loss + consonant_diacritic_loss

        loss.backward()
        self._optimizer.step()

        return loss.item()
    def calc_loss(self, q_values: Tensor, target_q_values: Tensor,
                  actions: Tensor, rewards: Tensor,
                  done_mask: Tensor) -> Tensor:
        """
        Calculate the MSE loss of this step.
        The loss for an example is defined as:
            Q_samp(s) = r if done
                        = r + gamma * max_a' Q_target(s', a')
            loss = (Q_samp(s) - Q(s, a))^2

        Args:
            q_values: (torch tensor) shape = (batch_size, num_actions)
                The Q-values that your current network estimates (i.e. Q(s, a') for all a')
            target_q_values: (torch tensor) shape = (batch_size, num_actions)
                The Target Q-values that your target network estimates (i.e. (i.e. Q_target(s', a') for all a')
            actions: (torch tensor) shape = (batch_size,)
                The actions that you actually took at each step (i.e. a)
            rewards: (torch tensor) shape = (batch_size,)
                The rewards that you actually got at each step (i.e. r)
            done_mask: (torch tensor) shape = (batch_size,)
                A boolean mask of examples where we reached the terminal state

        Hint:
            You may find the following functions useful
                - torch.max
                - torch.sum
                - torch.nn.functional.one_hot
                - torch.nn.functional.mse_loss
        """
        # you may need this variable
        num_actions = self.env.action_space.n
        gamma = self.config.gamma

        ##############################################################
        ##################### YOUR CODE HERE - 3-5 lines #############
        notdone = 1 - done_mask.to(torch.int64)
        current_q = torch.max(
            q_values *
            torch.nn.functional.one_hot(actions.to(torch.int64), num_actions),
            1).values  # elementwise product to get reward for each batch
        target_q = rewards + notdone * gamma * torch.max(target_q_values,
                                                         1).values

        loss = torch.nn.functional.mse_loss(current_q, target_q)
        return loss
Example #3
0
    def predict(self, inputs):
        """
        Predits with the model based on the given input feature values
        :param inputs: array
            Input feature values
        :return: (Tensor, Tensor, Tensor)
            Indices for grapheme_root, vowel_diacritic, consonant_diacritic
        """
        inputs = Tensor(inputs)
        inputs = inputs.to(self._device)
        grapheme_root_hat, vowel_diacritic_hat, consonant_diacritic_hat = self(inputs)

        _, grapheme_root_indices = grapheme_root_hat.max(1)
        _, vowel_diacritic_indices = vowel_diacritic_hat.max(1)
        _, consonant_diacritic_indices = consonant_diacritic_hat.max(1)

        return grapheme_root_indices, vowel_diacritic_indices, consonant_diacritic_indices
Example #4
0
    def calc_loss(self, q_values: Tensor, target_q_values: Tensor,
                  actions: Tensor, rewards: Tensor, done_mask: Tensor,
                  state: Tensor, next_state: Tensor) -> Tensor:
        """
        Calculate the MSE loss of this step.
        The loss for an example is defined as:
            Q_samp(s) = r if done
                        = r + gamma * max_a' Q_target(s', a') otherwise
            loss = (Q_samp(s) - Q(s, a))^2

        Args:
            q_values: (torch tensor) shape = (batch_size, num_actions)
                The Q-values that your current network estimates (i.e. Q(s, a') for all a')
            target_q_values: (torch tensor) shape = (batch_size, num_actions)
                The Target Q-values that your target network estimates (i.e. (i.e. Q_target(s', a') for all a')
            actions: (torch tensor) shape = (batch_size,)
                The actions that you actually took at each step (i.e. a)
            rewards: (torch tensor) shape = (batch_size,)
                The rewards that you actually got at each step (i.e. r)
            done_mask: (torch tensor) shape = (batch_size,)
                A boolean mask of examples where we reached the terminal state

        Hint:
            You may find the following functions useful
                - torch.max
                - torch.sum
                - torch.nn.functional.one_hot
                - torch.nn.functional.mse_loss
            You can treat `done_mask` as a 0 and 1 where 0 is not done and 1 is done using torch.type as
            done below

            To extract Q(a) for a specific "a" you can use the torch.sum and torch.nn.functional.one_hot. 
            Think about how.
        """
        # you may need this variable
        num_actions = self.env.action_space.n
        gamma = self.config.gamma
        done_mask = done_mask.type(torch.int)
        actions = actions.type(torch.int64)
        ##############################################################
        ##################### YOUR CODE HERE - 3-5 lines #############
        '''
        # This is the vanilla DQN Loss function. The uncommented code is the DDQN Loss function
        best_target_q = torch.reshape(torch.max(target_q_values, dim=1, keepdim=True).values, (-1,))
        Q_samp = rewards + (1 - done_mask) * gamma * best_target_q
        Q_sa = torch.sum(q_values * torch.nn.functional.one_hot(actions, self.env.action_space.n), dim=1)
        loss = torch.nn.functional.mse_loss(Q_samp, Q_sa)'''
        state = state.to('cuda:0')
        next_state = next_state.to('cuda:0')
        actions = actions.to('cuda:0')
        rewards = rewards.to('cuda:0')
        done_mask = done_mask.to('cuda:0')
        actions = actions.unsqueeze(-1)
        state_action_vals = self.get_q_values(state,
                                              'q_network').gather(1, actions)
        state_action_vals = state_action_vals.squeeze(-1)
        next_state_action = self.get_q_values(next_state,
                                              'q_network').max(1)[1]
        next_state_action = next_state_action.unsqueeze(-1)
        next_state_vals = self.get_q_values(next_state, 'target').gather(
            1, next_state_action).squeeze(-1)

        exp_sa_vals = next_state_vals.detach() * gamma * (1 -
                                                          done_mask) + rewards
        loss = torch.nn.functional.mse_loss(state_action_vals, exp_sa_vals)

        ##############################################################
        ######################## END YOUR CODE #######################
        return loss