def __init__(self, expt_dir='experiment', loss=[NLLLoss()], loss_weights=None, metrics=[], batch_size=64, eval_batch_size=128, random_seed=None, checkpoint_every=100, print_every=100): self._trainer = "Simple Trainer" self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) k = NLLLoss() self.loss = loss self.metrics = metrics self.loss_weights = loss_weights or len(loss) * [1.] self.evaluator = Evaluator(loss=self.loss, metrics=self.metrics, batch_size=eval_batch_size) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.batch_size = batch_size self.logger = logging.getLogger(__name__)
def set_local_parameters(self, random_seed, losses, metrics, loss_weights, checkpoint_every, print_every): self.random_seed = random_seed if random_seed is not None: random.seed(random_seed) torch.manual_seed(random_seed) self.losses = losses self.metrics = metrics self.loss_weights = loss_weights or len(losses)*[1.] self.evaluator = Evaluator(loss=self.losses, metrics=self.metrics) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every self.logger = logging.getLogger(__name__) self._stop_training = False
def test_set_eval_mode(self, mock_eval, mock_call): """ Make sure that evaluation is done in evaluation mode. """ mock_mgr = MagicMock() mock_mgr.attach_mock(mock_eval, 'eval') mock_mgr.attach_mock(mock_call, 'call') evaluator = Evaluator(batch_size=64) with patch('machine.evaluator.evaluator.torch.stack', return_value=None), \ patch('machine.metrics.WordAccuracy.eval_batch', return_value=None), \ patch('machine.metrics.WordAccuracy.eval_batch', return_value=None), \ patch('machine.loss.NLLLoss.eval_batch', return_value=None): evaluator.evaluate(self.seq2seq, self.dataset, trainer.get_batch_data) num_batches = int(math.ceil(len(self.dataset) / evaluator.batch_size)) expected_calls = [call.eval()] + num_batches * [call.call(ANY, ANY, ANY)] self.assertEquals(expected_calls, mock_mgr.mock_calls)
metrics = [ WordAccuracy(ignore_index=pad), SequenceAccuracy(ignore_index=pad), FinalTargetAccuracy(ignore_index=pad, eos_id=tgt.eos_id) ] # Since we need the actual tokens to determine k-grammar accuracy, # we also provide the input and output vocab and relevant special symbols # metrics.append(SymbolRewritingAccuracy( # input_vocab=input_vocab, # output_vocab=output_vocab, # use_output_eos=output_eos_used, # input_pad_symbol=src.pad_token, # output_sos_symbol=tgt.SYM_SOS, # output_pad_symbol=tgt.pad_token, # output_eos_symbol=tgt.SYM_EOS, # output_unk_symbol=tgt.unk_token)) data_func = SupervisedTrainer.get_batch_data ################################################################################# # Evaluate model on test set evaluator = Evaluator(batch_size=opt.batch_size, loss=losses, metrics=metrics) losses, metrics = evaluator.evaluate(model=seq2seq, data=test, get_batch_data=data_func) total_loss, log_msg, _ = SupervisedTrainer.get_losses(losses, metrics, 0) logging.info(log_msg)
def len_filter(example): return len(example.src) <= max_len and len(example.tgt) <= max_len # generate test set test = torchtext.data.TabularDataset(path=opt.test_data, format='tsv', fields=[('src', src), ('tgt', tgt)], filter_pred=len_filter) # Prepare loss weight = torch.ones(len(output_vocab)) pad = output_vocab.stoi[tgt.pad_token] loss = NLLLoss(pad) metrics = [WordAccuracy(pad), SequenceAccuracy(pad)] if torch.cuda.is_available(): loss.cuda() ################################################################################# # Evaluate model on test set evaluator = Evaluator(loss=[loss], metrics=metrics, batch_size=opt.batch_size) losses, metrics = evaluator.evaluate(seq2seq, test, SupervisedTrainer.get_batch_data) print([ "{}: {:6f}".format(type(metric).__name__, metric.get_val()) for metric in metrics ])
loss.to(device) metrics = [ WordAccuracy(ignore_index=pad), SequenceAccuracy(ignore_index=pad), FinalTargetAccuracy(ignore_index=pad, eos_id=tgt.eos_id) ] # Since we need the actual tokens to determine k-grammar accuracy, # we also provide the input and output vocab and relevant special symbols # metrics.append(SymbolRewritingAccuracy( # input_vocab=input_vocab, # output_vocab=output_vocab, # use_output_eos=output_eos_used, # input_pad_symbol=src.pad_token, # output_sos_symbol=tgt.SYM_SOS, # output_pad_symbol=tgt.pad_token, # output_eos_symbol=tgt.SYM_EOS, # output_unk_symbol=tgt.unk_token)) data_func = SupervisedTrainer.get_batch_data ########################################################################## # Evaluate model on test set evaluator = Evaluator(loss=losses, metrics=metrics) losses, metrics = evaluator.evaluate(seq2seq, test_iterator, data_func) total_loss, log_msg, _ = Callback.get_losses(losses, metrics, 0) logging.info(log_msg)