示例#1
0
    def _train(self, dataset):
        N = self._knobs.get('batch_size')
        ep = self._knobs.get('epochs')
        null_tag = self._tag_count  # Tag to ignore (from padding of sentences during batching)
        B = math.ceil(len(dataset) / N)  # No. of batches

        # Define 2 plots: Loss against time, loss against epochs
        logger.define_loss_plot()
        logger.define_plot('Loss Over Time', ['loss'])

        (net, optimizer) = self._create_model()

        Tensor = torch.LongTensor
        if torch.cuda.is_available():
            logger.log('Using CUDA...')
            net = net.cuda()
            Tensor = torch.cuda.LongTensor

        loss_func = nn.CrossEntropyLoss(ignore_index=null_tag)

        for epoch in range(ep):
            total_loss = 0
            for i in range(B):
                # Extract batch from dataset
                (words_tsr,
                 tags_tsr) = self._prepare_batch(dataset, i * N, i * N + N,
                                                 Tensor)

                # Reset gradients for this batch
                optimizer.zero_grad()

                # Forward propagate batch through model
                probs_tsr = net(words_tsr)

                # Compute sum of per-word loss for all words & sentences
                NW = probs_tsr.size(0) * probs_tsr.size(1)
                loss = loss_func(probs_tsr.view(NW, -1), tags_tsr.view(-1))

                # Backward propagate on minibatch
                loss.backward()

                # Update gradients with optimizer
                optimizer.step()

                total_loss += loss.item()

            logger.log_loss(loss=(total_loss / B), epoch=epoch)

        return (net, optimizer)
示例#2
0
 def _on_train_epoch_end(self, epoch, logs):
     loss = logs['loss']
     logger.log_loss(loss, epoch)