Esempio n. 1
0
    def __init__(self,
                 expt_dir='experiment',
                 loss=[NLLLoss()],
                 loss_weights=None,
                 metrics=[],
                 batch_size=64,
                 eval_batch_size=128,
                 random_seed=None,
                 checkpoint_every=100,
                 print_every=100):
        self._trainer = "Simple Trainer"
        self.random_seed = random_seed
        if random_seed is not None:
            random.seed(random_seed)
            torch.manual_seed(random_seed)
        k = NLLLoss()
        self.loss = loss
        self.metrics = metrics
        self.loss_weights = loss_weights or len(loss) * [1.]
        self.evaluator = Evaluator(loss=self.loss,
                                   metrics=self.metrics,
                                   batch_size=eval_batch_size)
        self.optimizer = None
        self.checkpoint_every = checkpoint_every
        self.print_every = print_every

        if not os.path.isabs(expt_dir):
            expt_dir = os.path.join(os.getcwd(), expt_dir)
        self.expt_dir = expt_dir
        if not os.path.exists(self.expt_dir):
            os.makedirs(self.expt_dir)
        self.batch_size = batch_size

        self.logger = logging.getLogger(__name__)
Esempio n. 2
0
    def set_local_parameters(self, random_seed, losses, metrics,
                             loss_weights, checkpoint_every, print_every):
        self.random_seed = random_seed
        if random_seed is not None:
            random.seed(random_seed)
            torch.manual_seed(random_seed)

        self.losses = losses
        self.metrics = metrics
        self.loss_weights = loss_weights or len(losses)*[1.]
        self.evaluator = Evaluator(loss=self.losses, metrics=self.metrics)
        self.optimizer = None
        self.checkpoint_every = checkpoint_every
        self.print_every = print_every
        self.logger = logging.getLogger(__name__)
        self._stop_training = False
Esempio n. 3
0
    def test_set_eval_mode(self, mock_eval, mock_call):
        """ Make sure that evaluation is done in evaluation mode. """
        mock_mgr = MagicMock()
        mock_mgr.attach_mock(mock_eval, 'eval')
        mock_mgr.attach_mock(mock_call, 'call')

        evaluator = Evaluator(batch_size=64)
        with patch('machine.evaluator.evaluator.torch.stack', return_value=None), \
             patch('machine.metrics.WordAccuracy.eval_batch', return_value=None), \
             patch('machine.metrics.WordAccuracy.eval_batch', return_value=None), \
             patch('machine.loss.NLLLoss.eval_batch', return_value=None):
            evaluator.evaluate(self.seq2seq, self.dataset, trainer.get_batch_data)

        num_batches = int(math.ceil(len(self.dataset) / evaluator.batch_size))
        expected_calls = [call.eval()] + num_batches * [call.call(ANY, ANY, ANY)]
        self.assertEquals(expected_calls, mock_mgr.mock_calls)
Esempio n. 4
0
metrics = [
    WordAccuracy(ignore_index=pad),
    SequenceAccuracy(ignore_index=pad),
    FinalTargetAccuracy(ignore_index=pad, eos_id=tgt.eos_id)
]
# Since we need the actual tokens to determine k-grammar accuracy,
# we also provide the input and output vocab and relevant special symbols
# metrics.append(SymbolRewritingAccuracy(
#     input_vocab=input_vocab,
#     output_vocab=output_vocab,
#     use_output_eos=output_eos_used,
#     input_pad_symbol=src.pad_token,
#     output_sos_symbol=tgt.SYM_SOS,
#     output_pad_symbol=tgt.pad_token,
#     output_eos_symbol=tgt.SYM_EOS,
#     output_unk_symbol=tgt.unk_token))

data_func = SupervisedTrainer.get_batch_data

#################################################################################
# Evaluate model on test set

evaluator = Evaluator(batch_size=opt.batch_size, loss=losses, metrics=metrics)
losses, metrics = evaluator.evaluate(model=seq2seq,
                                     data=test,
                                     get_batch_data=data_func)

total_loss, log_msg, _ = SupervisedTrainer.get_losses(losses, metrics, 0)

logging.info(log_msg)
Esempio n. 5
0
def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


# generate test set
test = torchtext.data.TabularDataset(path=opt.test_data,
                                     format='tsv',
                                     fields=[('src', src), ('tgt', tgt)],
                                     filter_pred=len_filter)

# Prepare loss
weight = torch.ones(len(output_vocab))
pad = output_vocab.stoi[tgt.pad_token]
loss = NLLLoss(pad)
metrics = [WordAccuracy(pad), SequenceAccuracy(pad)]
if torch.cuda.is_available():
    loss.cuda()

#################################################################################
# Evaluate model on test set

evaluator = Evaluator(loss=[loss], metrics=metrics, batch_size=opt.batch_size)
losses, metrics = evaluator.evaluate(seq2seq, test,
                                     SupervisedTrainer.get_batch_data)

print([
    "{}: {:6f}".format(type(metric).__name__, metric.get_val())
    for metric in metrics
])
Esempio n. 6
0
    loss.to(device)

metrics = [
    WordAccuracy(ignore_index=pad),
    SequenceAccuracy(ignore_index=pad),
    FinalTargetAccuracy(ignore_index=pad, eos_id=tgt.eos_id)
]
# Since we need the actual tokens to determine k-grammar accuracy,
# we also provide the input and output vocab and relevant special symbols
# metrics.append(SymbolRewritingAccuracy(
#     input_vocab=input_vocab,
#     output_vocab=output_vocab,
#     use_output_eos=output_eos_used,
#     input_pad_symbol=src.pad_token,
#     output_sos_symbol=tgt.SYM_SOS,
#     output_pad_symbol=tgt.pad_token,
#     output_eos_symbol=tgt.SYM_EOS,
#     output_unk_symbol=tgt.unk_token))

data_func = SupervisedTrainer.get_batch_data

##########################################################################
# Evaluate model on test set

evaluator = Evaluator(loss=losses, metrics=metrics)
losses, metrics = evaluator.evaluate(seq2seq, test_iterator, data_func)

total_loss, log_msg, _ = Callback.get_losses(losses, metrics, 0)

logging.info(log_msg)