예제 #1
0
    def train_epoch(self, epoch):
        # Turn on training mode which enables dropout.
        if config.model == 'QRNN':
            self.model.reset()
        total_loss = 0
        start_time = time.time()
        hidden = self.model.init_hidden(config.batch_size)
        batch, i = 0, 0
        train_data = self.train_set.ids
        while i < train_data.size(0) - 1 - 1:
            bptt = config.bptt if np.random.random() < 0.95 else config.bptt / 2.
            # Prevent excessively small or negative sequence lengths
            seq_len = max(5, int(np.random.normal(bptt, 5)))
            # There's a very small chance that it could select a very long sequence length resulting in OOM
            # seq_len = min(seq_len, args.bptt + 10)

            lr2 = self.optimizer.param_groups[0]['lr']
            self.optimizer.param_groups[0]['lr'] = lr2 * seq_len / config.bptt
            self.model.train()
            data, targets = get_batch(train_data, i, config, seq_len=seq_len)

            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            hidden = repackage_hidden(hidden)
            self.optimizer.zero_grad()

            output, hidden, rnn_hs, dropped_rnn_hs = self.model(data, hidden, return_h=True)
            raw_loss = self.criterion(self.model.decoder.weight, self.model.decoder.bias, output, targets)

            loss = raw_loss
            # Activiation Regularization
            if config.alpha: loss = loss + sum(
                config.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
            # Temporal Activation Regularization (slowness)
            if config.beta: loss = loss + sum(
                config.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
            loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            if config.clip: torch.nn.utils.clip_grad_norm_(self.params, config.clip)
            self.optimizer.step()

            total_loss += raw_loss.data
            self.optimizer.param_groups[0]['lr'] = lr2
            if batch % config.log_interval == 0 and batch > 0:
                cur_loss = total_loss.item() / config.log_interval
                elapsed = time.time() - start_time
                print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                      'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                    epoch, batch, len(train_data) // config.bptt, self.optimizer.param_groups[0]['lr'],
                                  elapsed * 1000 / config.log_interval, cur_loss, math.exp(cur_loss),
                                  cur_loss / math.log(2)))
                total_loss = 0
                start_time = time.time()
            ###
            batch += 1
            i += seq_len
예제 #2
0
def train():
    # Turn on training mode which enables dropout.
    if args.model == 'QRNN': model.reset()
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    batch, i = 0, 0
    while i < train_data.size(0) - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
        # Prevent excessively small or negative sequence lengths
        seq_len = max(5, int(np.random.normal(bptt, 5)))
        # There's a very small chance that it could select a very long sequence length resulting in OOM
        seq_len = min(seq_len, args.bptt + 10)

        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
        model.train()
        data, targets = get_batch(train_data, i, args, seq_len=seq_len)

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        optimizer.zero_grad()

        output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
        raw_loss = criterion(output.view(-1, ntokens), targets)

        loss = raw_loss
        # Activiation Regularization
        loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
        # Temporal Activation Regularization (slowness)
        loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        optimizer.step()

        total_loss += raw_loss.data
        optimizer.param_groups[0]['lr'] = lr2
        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss[0] / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'],
                elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        ###
        batch += 1
        i += seq_len
예제 #3
0
 def evaluate(self, data_source, batch_size=config.batch_size):
     # Turn on evaluation mode which disables dropout.
     self.model.eval()
     if config.model == 'QRNN':
         self.model.reset()
     total_loss = 0
     hidden = self.model.init_hidden(batch_size)
     for i in range(0, data_source.size(0) - 1, config.bptt):
         data, targets = get_batch(data_source, i, config, evaluation=True)
         output, hidden = self.model(data, hidden)
         total_loss += len(data) * self.criterion(self.model.decoder.weight, self.model.decoder.bias, output, targets).data
         hidden = repackage_hidden(hidden)
     return total_loss.item() / len(data_source)
예제 #4
0
def evaluate(data_source, batch_size=10):
    # Turn on evaluation mode which disables dropout.
    if args.model == 'QRNN': model.reset()
    model.eval()
    total_loss = 0
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    for i in range(0, data_source.size(0) - 1, args.bptt):
        data, targets = get_batch(data_source, i, args, evaluation=True)
        output, hidden = model(data, hidden)
        output_flat = output.view(-1, ntokens)
        total_loss += len(data) * criterion(output_flat, targets).data
        hidden = repackage_hidden(hidden)
    return total_loss[0] / len(data_source)
예제 #5
0
def evaluate(data_source, batch_size=10, window=args.window):
    # Turn on evaluation mode which disables dropout.
    if args.model == 'QRNN': model.reset()
    model.eval()
    total_loss = 0
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    next_word_history = None
    pointer_history = None
    for i in range(0, data_source.size(0) - 1, args.bptt):
        if i > 0: print(i, len(data_source), math.exp(total_loss / i))
        data, targets = get_batch(data_source, i, evaluation=True, args=args)
        output, hidden, rnn_outs, _ = model(data, hidden, return_h=True)
        rnn_out = rnn_outs[-1].squeeze()
        output_flat = output.view(-1, ntokens)
        ###
        # Fill pointer history
        start_idx = len(
            next_word_history) if next_word_history is not None else 0
        next_word_history = torch.cat(
            [one_hot(t.data[0], ntokens)
             for t in targets]) if next_word_history is None else torch.cat([
                 next_word_history,
                 torch.cat([one_hot(t.data[0], ntokens) for t in targets])
             ])
        #print(next_word_history)
        pointer_history = Variable(
            rnn_out.data) if pointer_history is None else torch.cat(
                [pointer_history, Variable(rnn_out.data)], dim=0)
        #print(pointer_history)
        ###
        # Built-in cross entropy
        # total_loss += len(data) * criterion(output_flat, targets).data[0]
        ###
        # Manual cross entropy
        # softmax_output_flat = torch.nn.functional.softmax(output_flat)
        # soft = torch.gather(softmax_output_flat, dim=1, index=targets.view(-1, 1))
        # entropy = -torch.log(soft)
        # total_loss += len(data) * entropy.mean().data[0]
        ###
        # Pointer manual cross entropy
        loss = 0
        softmax_output_flat = torch.nn.functional.softmax(output_flat)
        for idx, vocab_loss in enumerate(softmax_output_flat):
            p = vocab_loss
            if start_idx + idx > window:
                valid_next_word = next_word_history[start_idx + idx -
                                                    window:start_idx + idx]
                valid_pointer_history = pointer_history[start_idx + idx -
                                                        window:start_idx + idx]
                logits = torch.mv(valid_pointer_history, rnn_out[idx])
                theta = args.theta
                ptr_attn = torch.nn.functional.softmax(theta * logits).view(
                    -1, 1)
                ptr_dist = (ptr_attn.expand_as(valid_next_word) *
                            valid_next_word).sum(0).squeeze()
                lambdah = args.lambdasm
                p = lambdah * ptr_dist + (1 - lambdah) * vocab_loss
            ###
            target_loss = p[targets[idx].data]
            loss += (-torch.log(target_loss)).data[0]
        total_loss += loss / batch_size
        ###
        hidden = repackage_hidden(hidden)
        next_word_history = next_word_history[-window:]
        pointer_history = pointer_history[-window:]
    return total_loss / len(data_source)