Пример #1
0
def prepare_losses_and_metrics(pad, eos):
    # Prepare loss and metrics
    losses = [NLLLoss(ignore_index=pad)]
    loss_weights = [1.]

    for loss in losses:
        loss.to(device)

    metrics = []
    metrics.append(WordAccuracy(ignore_index=pad))
    metrics.append(SequenceAccuracy(ignore_index=pad))
    return losses, loss_weights, metrics
Пример #2
0
def prepare_losses_and_metrics(
        opt, pad, unk, sos, eos, input_vocab, output_vocab):
    use_output_eos = not opt.ignore_output_eos

    # Prepare loss and metrics
    losses = [NLLLoss(ignore_index=pad)]
    loss_weights = [1.]

    for loss in losses:
        loss.to(device)

    metrics = []

    if 'word_acc' in opt.metrics:
        metrics.append(WordAccuracy(ignore_index=pad))
    if 'seq_acc' in opt.metrics:
        metrics.append(SequenceAccuracy(ignore_index=pad))
    if 'target_acc' in opt.metrics:
        metrics.append(FinalTargetAccuracy(ignore_index=pad, eos_id=eos))
    if 'sym_rwr_acc' in opt.metrics:
        metrics.append(SymbolRewritingAccuracy(
            input_vocab=input_vocab,
            output_vocab=output_vocab,
            use_output_eos=use_output_eos,
            output_sos_symbol=sos,
            output_pad_symbol=pad,
            output_eos_symbol=eos,
            output_unk_symbol=unk))
    if 'bleu' in opt.metrics:
        metrics.append(BLEU(
            input_vocab=input_vocab,
            output_vocab=output_vocab,
            use_output_eos=use_output_eos,
            output_sos_symbol=sos,
            output_pad_symbol=pad,
            output_eos_symbol=eos,
            output_unk_symbol=unk))

    return losses, loss_weights, metrics
Пример #3
0
            )

# Prepare loss and metrics
pad = output_vocab.stoi[tgt.pad_token]
losses = [NLLLoss(ignore_index=pad)]
loss_weights = [1.]

if opt.use_attention_loss:
    losses.append(AttentionLoss(ignore_index=IGNORE_INDEX))
    loss_weights.append(opt.scale_attention_loss)

for loss in losses:
    loss.to(device)

metrics = [
    WordAccuracy(ignore_index=pad),
    SequenceAccuracy(ignore_index=pad),
    FinalTargetAccuracy(ignore_index=pad, eos_id=tgt.eos_id)
]
# Since we need the actual tokens to determine k-grammar accuracy,
# we also provide the input and output vocab and relevant special symbols
# metrics.append(SymbolRewritingAccuracy(
#     input_vocab=input_vocab,
#     output_vocab=output_vocab,
#     use_output_eos=output_eos_used,
#     input_pad_symbol=src.pad_token,
#     output_sos_symbol=tgt.SYM_SOS,
#     output_pad_symbol=tgt.pad_token,
#     output_eos_symbol=tgt.SYM_EOS,
#     output_unk_symbol=tgt.unk_token))
Пример #4
0
def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


# generate test set
test = torchtext.data.TabularDataset(path=opt.test_data,
                                     format='tsv',
                                     fields=[('src', src), ('tgt', tgt)],
                                     filter_pred=len_filter)

# Prepare loss
weight = torch.ones(len(output_vocab))
pad = output_vocab.stoi[tgt.pad_token]
loss = NLLLoss(pad)
metrics = [WordAccuracy(pad), SequenceAccuracy(pad)]
if torch.cuda.is_available():
    loss.cuda()

#################################################################################
# Evaluate model on test set

evaluator = Evaluator(loss=[loss], metrics=metrics, batch_size=opt.batch_size)
losses, metrics = evaluator.evaluate(seq2seq, test,
                                     SupervisedTrainer.get_batch_data)

print([
    "{}: {:6f}".format(type(metric).__name__, metric.get_val())
    for metric in metrics
])
Пример #5
0
 def __init__(self, loss=[NLLLoss()], metrics=[WordAccuracy(), SequenceAccuracy()], batch_size=64):
     self.losses = loss
     self.metrics = metrics
     self.batch_size = batch_size
Пример #6
0
# Prepare loss and metrics
pad = output_vocab.stoi[tgt.pad_token]
losses = [NLLLoss(ignore_index=pad)]
# loss_weights = [1.]
loss_weights = [float(opt.xent_loss)]


if opt.use_attention_loss:
    losses.append(AttentionLoss(ignore_index=IGNORE_INDEX))
    loss_weights.append(opt.scale_attention_loss)

for loss in losses:
  loss.to(device)

metrics = [WordAccuracy(ignore_index=pad), SequenceAccuracy(ignore_index=pad), FinalTargetAccuracy(ignore_index=pad, eos_id=tgt.eos_id)]
# Since we need the actual tokens to determine k-grammar accuracy,
# we also provide the input and output vocab and relevant special symbols
# metrics.append(SymbolRewritingAccuracy(
#     input_vocab=input_vocab,
#     output_vocab=output_vocab,
#     use_output_eos=use_output_eos,
#     input_pad_symbol=src.pad_token,
#     output_sos_symbol=tgt.SYM_SOS,
#     output_pad_symbol=tgt.pad_token,
#     output_eos_symbol=tgt.SYM_EOS,
#     output_unk_symbol=tgt.unk_token))

checkpoint_path = os.path.join(opt.output_dir, opt.load_checkpoint) if opt.resume else None

# create trainer
# Prepare loss and metrics
pad = output_vocab.stoi[tgt.pad_token]
losses = [NLLLoss(ignore_index=pad)]
loss_weights = [1.]

for l1_loss_input in opt.l1_loss_inputs:
    losses.append(L1Loss(input_name=l1_loss_input))
    loss_weights.append(opt.scale_l1_loss)

for loss in losses:
    loss.to(device)

metrics = []
if 'word_acc' in opt.metrics:
    metrics.append(WordAccuracy(ignore_index=pad))
if 'seq_acc' in opt.metrics:
    metrics.append(SequenceAccuracy(ignore_index=pad))
if 'target_acc' in opt.metrics:
    metrics.append(FinalTargetAccuracy(ignore_index=pad, eos_id=tgt.eos_id))
if 'sym_rwr_acc' in opt.metrics:
    metrics.append(
        SymbolRewritingAccuracy(input_vocab=input_vocab,
                                output_vocab=output_vocab,
                                use_output_eos=use_output_eos,
                                output_sos_symbol=tgt.SYM_SOS,
                                output_pad_symbol=tgt.pad_token,
                                output_eos_symbol=tgt.SYM_EOS,
                                output_unk_symbol=tgt.unk_token))

checkpoint_path = os.path.join(opt.output_dir,
Пример #8
0
 def __init__(self,
              loss=[NLLLoss()],
              metrics=[WordAccuracy(), SequenceAccuracy()]):
     self.losses = loss
     self.metrics = metrics
Пример #9
0
# generate test set
test = torchtext.data.TabularDataset(
    path=opt.test_data, format='tsv',
    fields=tabular_data_fields,
    filter_pred=len_filter
)

# Prepare loss and metrics
pad = output_vocab.stoi[tgt.pad_token]
losses = [NLLLoss(ignore_index=pad)]
loss_weights = [1.]

for loss in losses:
    loss.to(device)

metrics = [WordAccuracy(ignore_index=pad), SequenceAccuracy(ignore_index=pad),
           FinalTargetAccuracy(ignore_index=pad, eos_id=tgt.eos_id)]
# Since we need the actual tokens to determine k-grammar accuracy,
# we also provide the input and output vocab and relevant special symbols
# metrics.append(SymbolRewritingAccuracy(
#     input_vocab=input_vocab,
#     output_vocab=output_vocab,
#     use_output_eos=output_eos_used,
#     input_pad_symbol=src.pad_token,
#     output_sos_symbol=tgt.SYM_SOS,
#     output_pad_symbol=tgt.pad_token,
#     output_eos_symbol=tgt.SYM_EOS,
#     output_unk_symbol=tgt.unk_token))

data_func = SupervisedTrainer.get_batch_data