Exemplo n.º 1
0
def validation(validation_data, model, global_step, t_vocab_size, val_writer,
               opt):
    model.eval()
    total_loss = 0.0
    total_cnt = 0
    for batch in validation_data:
        inputs, i_mask = None, None
        if opt.has_inputs:
            inputs = batch.src
            i_mask = utils.create_pad_mask(inputs, opt.src_pad_idx)
        targets = batch.trg
        t_mask = utils.create_pad_mask(targets, opt.trg_pad_idx)
        t_self_mask = utils.create_trg_self_mask(targets)

        with torch.no_grad():
            pred = model(inputs, targets, i_mask, t_self_mask, t_mask)

            pred = pred.view(-1, pred.size(-1))
            ans = targets.view(-1)
            loss = utils.get_loss(pred, ans, t_vocab_size, 0,
                                  opt.trg_pad_idx)
        total_loss += loss.item() * len(batch)
        total_cnt += len(batch)

    val_loss = total_loss / total_cnt
    print("Validation Loss", val_loss)
    val_writer.add_scalar('loss', val_loss, global_step)
    return val_loss
Exemplo n.º 2
0
    def forward(self, inputs, targets):
        enc_output, i_mask = None, None
        if self.has_inputs:
            i_mask = utils.create_pad_mask(inputs, self.src_pad_idx)
            enc_output = self.encode(inputs, i_mask)

        t_mask = utils.create_pad_mask(targets, self.trg_pad_idx)
        target_size = targets.size()[1]
        t_self_mask = utils.create_trg_self_mask(target_size,
                                                 device=targets.device)
        return self.decode(targets, enc_output, i_mask, t_self_mask, t_mask)
Exemplo n.º 3
0
def train(train_data, model, opt, global_step, optimizer, t_vocab_size,
          label_smoothing, writer):
    model.train()
    last_time = time.time()
    pbar = tqdm(total=len(train_data.dataset), ascii=True)
    for batch in train_data:
        inputs, i_mask = None, None
        if opt.has_inputs:
            inputs = batch.src
            i_mask = utils.create_pad_mask(inputs, opt.src_pad_idx)

        targets = batch.trg
        t_mask = utils.create_pad_mask(targets, opt.trg_pad_idx)
        t_self_mask = utils.create_trg_self_mask(targets)

        pred = model(inputs, targets, i_mask, t_self_mask, t_mask)

        pred = pred.view(-1, pred.size(-1))
        ans = targets.view(-1)

        loss = utils.get_loss(pred, ans, t_vocab_size,
                              label_smoothing, opt.trg_pad_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if global_step % 100 == 0:
            summarize_train(writer, global_step, last_time, model, opt,
                            inputs, targets, optimizer, loss, pred, ans)
            last_time = time.time()

        pbar.set_description('[Loss: {:.4f}]'.format(loss.item()))

        global_step += 1
        pbar.update(targets.size(0))

    pbar.close()
    train_data.reload_examples()
    return global_step
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--model_dir', type=str, required=True)
    parser.add_argument('--max_length', type=int, default=100)
    parser.add_argument('--beam_size', type=int, default=4)
    parser.add_argument('--alpha', type=float, default=0.6)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--translate', action='store_true')
    args = parser.parse_args()

    beam_size = args.beam_size

    # Load fields.
    if args.translate:
        src_data = torch.load(args.data_dir + '/source.pt')
    trg_data = torch.load(args.data_dir + '/target.pt')

    # Load a saved model.
    device = torch.device('cpu' if args.no_cuda else 'cuda')
    model = utils.load_checkpoint(args.model_dir, device)

    pads = torch.tensor([trg_data['pad_idx']] * beam_size, device=device)
    pads = pads.unsqueeze(-1)

    # We'll find a target sequence by beam search.
    scores_history = [torch.zeros((beam_size,), dtype=torch.float,
                                  device=device)]
    indices_history = []
    cache = {}

    eos_idx = trg_data['field'].vocab.stoi[trg_data['field'].eos_token]

    if args.translate:
        sentence = input('Source? ')

    # Encoding inputs.
    if args.translate:
        start_time = time.time()
        enc_output, src_mask = encode_inputs(sentence, model, src_data,
                                             beam_size, device)
        targets = pads
        start_idx = 0
    else:
        enc_output, src_mask = None, None
        sentence = input('Target? ').split()
        for idx, _ in enumerate(sentence):
            sentence[idx] = trg_data['field'].vocab.stoi[sentence[idx]]
        sentence.append(trg_data['pad_idx'])
        targets = torch.tensor([sentence], device=device)
        start_idx = targets.size(1) - 1
        start_time = time.time()

    with torch.no_grad():
        for idx in range(start_idx, args.max_length):
            if idx > start_idx:
                targets = torch.cat((targets, pads), dim=1)
            t_self_mask = utils.create_trg_self_mask(targets.size()[1],
                                                     device=targets.device)

            t_mask = utils.create_pad_mask(targets, trg_data['pad_idx'])
            pred = model.decode(targets, enc_output, src_mask,
                                t_self_mask, t_mask, cache)
            pred = pred[:, idx].squeeze(1)
            vocab_size = pred.size(1)

            pred = F.log_softmax(pred, dim=1)
            if idx == start_idx:
                scores = pred[0]
            else:
                scores = scores_history[-1].unsqueeze(1) + pred
            length_penalty = pow(((5. + idx + 1.) / 6.), args.alpha)
            scores = scores / length_penalty
            scores = scores.view(-1)

            best_scores, best_indices = scores.topk(beam_size, 0)
            scores_history.append(best_scores)
            indices_history.append(best_indices)

            # Stop searching when the best output of beam is EOS.
            if best_indices[0].item() % vocab_size == eos_idx:
                break

            targets = update_targets(targets, best_indices, idx, vocab_size)

    result = get_result_sentence(indices_history, trg_data, vocab_size)
    print("Result: {}".format(result))

    print("Elapsed Time: {:.2f} sec".format(time.time() - start_time))