Пример #1
0
 def forward_module(self, i, hidden, data):
     hidden = repackage_hidden(hidden)
     output, hidden = self.rnns[i].forward(data, hidden)
     if self.module_normalization:
         output = F.log_softmax(output, dim=2)
     self.outputs[i,:,:,:,] = output
     return output, hidden
Пример #2
0
    def _update(engine, batch):
        model.train()

        _batch = batch.to(device)
        x, y = _batch[:-1], _batch[1:]
        seq_length, batch_size = batch.size()

        if model_type == 'LSTMTransformer':
            hidden = model.init_hidden(batch_size)
            mems = None
        elif model_type == 'Transformer':
            pass
        else:
            hidden = model.init_hidden(batch_size)

        total_loss = 0
        for i in range(0, seq_length - 1, bptt):
            optimizer.zero_grad()
            _x, _y = x[i:i + bptt], y[i:i + bptt]
            if model_type == 'LSTMTransformer':
                hidden = repackage_hidden(hidden)
                mems = repackage_hidden(mems) if mems else mems
                output, hidden, mems = model(_x, hidden=hidden, mems=mems)
            elif model_type == 'Transformer':
                output = model(_x)
            else:
                hidden = repackage_hidden(hidden)
                output, hidden = model(_x, hidden)

            loss = criterion(output.view(-1, ntokens), _y.view(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        total_loss /= math.ceil((seq_length - 1) / bptt)
        return {'loss': total_loss, 'ppl': math.exp(total_loss)}
Пример #3
0
def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, args.bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model(data, hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
            hidden = model.repackage_hidden(hidden)
    return total_loss / len(data_source)
Пример #4
0
def evaluate(args, model, test_dataset):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    with torch.no_grad():
        total_loss = 0
        hidden = model.init_hidden(args.eval_batch_size)
        nbatch = 1
        for nbatch, i in enumerate(
                range(0,
                      test_dataset.size(0) - 1, args.bptt)):
            inputSeq, targetSeq = get_batch(args, test_dataset, i)
            # inputSeq: [ seq_len * batch_size * feature_size ]
            # targetSeq: [ seq_len * batch_size * feature_size ]
            hidden_ = model.repackage_hidden(hidden)
            '''Loss1: Free running loss'''
            outVal = inputSeq[0].unsqueeze(0)
            outVals = []
            hids1 = []
            for i in range(inputSeq.size(0)):
                outVal, hidden_, hid = model.forward(outVal,
                                                     hidden_,
                                                     return_hiddens=True)
                outVals.append(outVal)
                hids1.append(hid)
            outSeq1 = torch.cat(outVals, dim=0)
            hids1 = torch.cat(hids1, dim=0)
            loss1 = criterion(outSeq1.contiguous().view(args.batch_size, -1),
                              targetSeq.contiguous().view(args.batch_size, -1))
            '''Loss2: Teacher forcing loss'''
            outSeq2, hidden, hids2 = model.forward(inputSeq,
                                                   hidden,
                                                   return_hiddens=True)
            loss2 = criterion(outSeq2.contiguous().view(args.batch_size, -1),
                              targetSeq.contiguous().view(args.batch_size, -1))
            '''Loss3: Simplified Professor forcing loss'''
            loss3 = criterion(hids1.view(args.batch_size, -1),
                              hids2.view(args.batch_size, -1).detach())
            '''Total loss = Loss1+Loss2+Loss3'''
            loss = loss1 + loss2 + loss3

            total_loss += loss.item()

    return total_loss / (nbatch + 1)
Пример #5
0
def evaluate_1step_pred(args, model, test_dataset):
    # turn on evaluation mode which disables dropout
    model.eval()
    total_loss = 0
    with torch.no_grad():
        hidden = model.init_hidden(args.eval_batch_size)
        for nbatch, i in enumerate(
                range(0,
                      test_dataset.size(0) - 1, args.bptt)):

            inputSeq, targetSeq = get_batch(args, test_dataset, i)
            outSeq, hidden = model.forward(inputSeq, hidden)

            loss = criterion(outSeq.view(args.batch_size, -1),
                             targetSeq.view(args.batch_size, -1))
            hidden = model.repackage_hidden(hidden)
            total_loss += loss.item()

    return total_loss / nbatch
Пример #6
0
def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
        data, targets = get_batch(train_data, i)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = model.repackage_hidden(hidden)
        optim.zero_grad()
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        # for p in model.parameters():
        # p.data.add_(-lr, p.grad.data)
        optim.step()

        total_loss += loss.item()
        lr = optim.param_groups[0]['lr']
        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss / args.log_interval
            elapsed = time.time() - start_time
            print(
                '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch,
                    len(train_data) // args.bptt, lr,
                    elapsed * 1000 / args.log_interval, cur_loss,
                    math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
Пример #7
0
 def forward_module(self, i, hidden, data):
     hidden = repackage_hidden(hidden)
     output, hidden = self.rnns[i].forward(data, hidden)
     self.outputs[i, :, :, :, ] = output
     return output, hidden
Пример #8
0
 def generate(self, data, hidden):
     hidden = repackage_hidden(hidden)
     output, hidden = self.forward(data, hidden)
     return output.view(-1, self.vocsize), hidden
Пример #9
0
 def get_state(self):
     return repackage_hidden(self.hidden)
Пример #10
0
 def get_state(self):
     domain_id = self.last_domain_id
     return repackage_hidden(self.hiddens[domain_id])
Пример #11
0
def train(args, model, train_dataset, epoch):

    with torch.enable_grad():
        # Turn on training mode which enables dropout.
        model.train()
        total_loss = 0
        start_time = time.time()
        hidden = model.init_hidden(args.batch_size)
        for batch, i in enumerate(
                range(0,
                      train_dataset.size(0) - 1, args.bptt)):
            inputSeq, targetSeq = get_batch(args, train_dataset, i)
            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            hidden = model.repackage_hidden(hidden)
            hidden_ = model.repackage_hidden(hidden)
            optimizer.zero_grad()
            '''Loss1: Free running loss'''
            outVal = inputSeq[0].unsqueeze(0)
            outVals = []
            hids1 = []
            for i in range(inputSeq.size(0)):
                outVal, hidden_, hid = model.forward(outVal,
                                                     hidden_,
                                                     return_hiddens=True)
                outVals.append(outVal)
                hids1.append(hid)
            outSeq1 = torch.cat(outVals, dim=0)
            hids1 = torch.cat(hids1, dim=0)
            loss1 = criterion(outSeq1.contiguous().view(args.batch_size, -1),
                              targetSeq.contiguous().view(args.batch_size, -1))
            '''Loss2: Teacher forcing loss'''
            outSeq2, hidden, hids2 = model.forward(inputSeq,
                                                   hidden,
                                                   return_hiddens=True)
            loss2 = criterion(outSeq2.contiguous().view(args.batch_size, -1),
                              targetSeq.contiguous().view(args.batch_size, -1))
            '''Loss3: Simplified Professor forcing loss'''
            loss3 = criterion(hids1.view(args.batch_size, -1),
                              hids2.view(args.batch_size, -1).detach())
            '''Total loss = Loss1+Loss2+Loss3'''
            loss = loss1 + loss2 + loss3
            loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()

            total_loss += loss.item()

            if batch % args.log_interval == 0 and batch > 0:
                cur_loss = total_loss / args.log_interval
                elapsed = time.time() - start_time
                print(
                    '| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.4f} | '
                    'loss {:5.2f} '.format(epoch, batch,
                                           len(train_dataset) // args.bptt,
                                           elapsed * 1000 / args.log_interval,
                                           cur_loss))
                total_loss = 0
                start_time = time.time()