Пример #1
0
        [ 2.1937, -0.5535, -0.9000,  ..., -0.1032,  0.3514, -1.2759],
        [-0.8078,  0.1575,  1.1064,  ...,  0.1365,  0.4121, -0.4211]],
       device='cuda:0')'''
            input_seq, out_seq_list, _, out_idx = model.pack_batch(
                batch, net.emb, device)
            enc = net.encode(input_seq)

            net_results = []
            net_targets = []
            for idx, out_seq in enumerate(out_seq_list):
                ref_indices = out_idx[idx][1:]
                enc_item = net.get_encoded_item(enc, idx)
                # teacher forcing做训练;
                if random.random() < TEACHER_PROB:
                    r = net.decode_teacher(enc_item, out_seq)
                    blue_temp = model.seq_bleu(r, ref_indices)
                    bleu_sum += blue_temp
                    # Get predicted tokens.
                    seq = torch.max(r.data, dim=1)[1]
                    seq = seq.cpu().numpy()
                # argmax做训练;
                else:
                    r, seq = net.decode_chain_argmax(enc_item,
                                                     out_seq.data[0:1],
                                                     len(ref_indices))
                    blue_temp = utils.calc_bleu(seq, ref_indices)
                    bleu_sum += blue_temp
                net_results.append(r)
                net_targets.extend(ref_indices)
                bleu_count += 1
Пример #2
0
        losses = []
        bleu_sum = 0.0
        bleu_count = 0
        for batch in data.iterate_batches(train_data, BATCH_SIZE):
            optimiser.zero_grad()
            input_seq, out_seq_list, _, out_idx = model.pack_batch(batch, net.emb, device)
            enc = net.encode(input_seq)

            net_results = []
            net_targets = []
            for idx, out_seq in enumerate(out_seq_list):
                ref_indices = out_idx[idx][1:]
                enc_item = net.get_encoded_item(enc, idx)
                if random.random() < TEACHER_PROB:
                    r = net.decode_teacher(enc_item, out_seq)
                    bleu_sum += model.seq_bleu(r, ref_indices)
                else:
                    r, seq = net.decode_chain_argmax(enc_item, out_seq.data[0:1],
                                                     len(ref_indices))
                    bleu_sum += utils.calc_bleu(seq, ref_indices)
                net_results.append(r)
                net_targets.extend(ref_indices)
                bleu_count += 1
            results_v = torch.cat(net_results)
            targets_v = torch.LongTensor(net_targets).to(device)
            loss_v = F.cross_entropy(results_v, targets_v)
            loss_v.backward()
            optimiser.step()

            losses.append(loss_v.item())
        bleu = bleu_sum / bleu_count