def _build(self, batch_size): src_time_dim = 4 vocab_size = 7 emb = Embeddings(embedding_dim=self.emb_size, vocab_size=vocab_size, padding_idx=self.pad_index) decoder = TransformerDecoder(num_layers=self.num_layers, num_heads=self.num_heads, hidden_size=self.hidden_size, ff_size=self.ff_size, dropout=self.dropout, emb_dropout=self.dropout, vocab_size=vocab_size) encoder_output = torch.rand(size=(batch_size, src_time_dim, self.hidden_size)) for p in decoder.parameters(): torch.nn.init.uniform_(p, -0.5, 0.5) src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1 encoder_hidden = None # unused model = Model(encoder=None, decoder=decoder, src_embed=emb, trg_embed=emb, src_vocab=self.vocab, trg_vocab=self.vocab) return src_mask, model, encoder_output, encoder_hidden
def _build(self, batch_size): src_time_dim = 4 vocab_size = 7 emb = Embeddings(embedding_dim=self.emb_size, vocab_size=vocab_size, padding_idx=self.pad_index) encoder = RecurrentEncoder(emb_size=self.emb_size, num_layers=self.num_layers, hidden_size=self.encoder_hidden_size, bidirectional=True) decoder = RecurrentDecoder(hidden_size=self.hidden_size, encoder=encoder, attention="bahdanau", emb_size=self.emb_size, vocab_size=self.vocab_size, num_layers=self.num_layers, init_hidden="bridge", input_feeding=True) encoder_output = torch.rand(size=(batch_size, src_time_dim, encoder.output_size)) for p in decoder.parameters(): torch.nn.init.uniform_(p, -0.5, 0.5) src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1 encoder_hidden = torch.rand(size=(batch_size, encoder.output_size)) model = Model(encoder=encoder, decoder=decoder, src_embed=emb, trg_embed=emb, src_vocab=self.vocab, trg_vocab=self.vocab) return src_mask, model, encoder_output, encoder_hidden
def __init__(self, model: Model, config: dict) -> None: """ Creates a new TrainManager for a model, specified as in configuration. :param model: torch module defining the model :param config: dictionary containing the training configurations """ train_config = config["training"] # files for logging and storing self.model_dir = make_model_dir(train_config["model_dir"], overwrite=train_config.get( "overwrite", False)) self.logger = make_logger("{}/train.log".format(self.model_dir)) self.logging_freq = train_config.get("logging_freq", 100) self.valid_report_file = "{}/validations.txt".format(self.model_dir) self.tb_writer = SummaryWriter(log_dir=self.model_dir + "/tensorboard/") # model self.model = model self.pad_index = self.model.pad_index self.bos_index = self.model.bos_index self._log_parameters_list() # objective self.label_smoothing = train_config.get("label_smoothing", 0.0) self.loss = XentLoss(pad_index=self.pad_index, smoothing=self.label_smoothing) self.normalization = train_config.get("normalization", "batch") if self.normalization not in ["batch", "tokens", "none"]: raise ConfigurationError("Invalid normalization option." "Valid options: " "'batch', 'tokens', 'none'.") # optimization self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8) self.clip_grad_fun = build_gradient_clipper(config=train_config) self.optimizer = build_optimizer(config=train_config, parameters=model.parameters()) # validation & early stopping self.validation_freq = train_config.get("validation_freq", 1000) self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2]) self.ckpt_queue = queue.Queue( maxsize=train_config.get("keep_last_ckpts", 5)) self.eval_metric = train_config.get("eval_metric", "bleu") if self.eval_metric not in [ 'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy' ]: raise ConfigurationError("Invalid setting for 'eval_metric', " "valid options: 'bleu', 'chrf', " "'token_accuracy', 'sequence_accuracy'.") self.early_stopping_metric = train_config.get("early_stopping_metric", "eval_metric") # early_stopping_metric decides on how to find the early stopping point: # ckpts are written when there's a new high/low score for this metric. # If we schedule after BLEU/chrf/accuracy, we want to maximize the # score, else we want to minimize it. if self.early_stopping_metric in ["ppl", "loss"]: self.minimize_metric = True elif self.early_stopping_metric == "eval_metric": if self.eval_metric in [ "bleu", "chrf", "token_accuracy", "sequence_accuracy" ]: self.minimize_metric = False # eval metric that has to get minimized (not yet implemented) else: self.minimize_metric = True else: raise ConfigurationError( "Invalid setting for 'early_stopping_metric', " "valid options: 'loss', 'ppl', 'eval_metric'.") # learning rate scheduling self.scheduler, self.scheduler_step_at = build_scheduler( config=train_config, scheduler_mode="min" if self.minimize_metric else "max", optimizer=self.optimizer, hidden_size=config["model"]["encoder"]["hidden_size"]) # data & batch handling self.level = config["data"]["level"] if self.level not in ["word", "bpe", "char"]: raise ConfigurationError("Invalid segmentation level. " "Valid options: 'word', 'bpe', 'char'.") self.shuffle = train_config.get("shuffle", True) self.epochs = train_config["epochs"] self.batch_size = train_config["batch_size"] self.batch_type = train_config.get("batch_type", "sentence") self.eval_batch_size = train_config.get("eval_batch_size", self.batch_size) self.eval_batch_type = train_config.get("eval_batch_type", self.batch_type) self.batch_multiplier = train_config.get("batch_multiplier", 1) self.current_batch_multiplier = self.batch_multiplier # generation self.max_output_length = train_config.get("max_output_length", None) # CPU / GPU self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() self.loss.cuda() # initialize accumalted batch loss (needed for batch_multiplier) self.norm_batch_loss_accumulated = 0 # initialize training statistics self.steps = 0 # stop training if this flag is True by reaching learning rate minimum self.stop = False self.total_tokens = 0 self.best_ckpt_iteration = 0 # initial values for best scores self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf # comparison function for scores self.is_best = lambda score: score < self.best_ckpt_score \ if self.minimize_metric else score > self.best_ckpt_score # model parameters if "load_model" in train_config.keys(): model_load_path = train_config["load_model"] self.logger.info("Loading model from %s", model_load_path) reset_best_ckpt = train_config.get("reset_best_ckpt", False) reset_scheduler = train_config.get("reset_scheduler", False) reset_optimizer = train_config.get("reset_optimizer", False) self.init_from_checkpoint(model_load_path, reset_best_ckpt=reset_best_ckpt, reset_scheduler=reset_scheduler, reset_optimizer=reset_optimizer)
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], n_gpu: int, batch_class: Batch = Batch, compute_loss: bool = False, beam_size: int = 1, beam_alpha: int = -1, batch_type: str = "sentence", postprocess: bool = True, bpe_type: str = "subword-nmt", sacrebleu: dict = None, n_best: int = 1) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `compute_loss` is True and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param batch_class: class type of batch :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param n_gpu: number of GPUs :param compute_loss: whether to computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If <2 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param postprocess: if True, remove BPE segmentation from translations :param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"} :param sacrebleu: sacrebleu options :param n_best: Amount of candidates to return :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ assert batch_size >= n_gpu, "batch_size must be bigger than n_gpu." if sacrebleu is None: # assign default value sacrebleu = {"remove_whitespace": True, "tokenize": "13a"} if batch_size > 1000 and batch_type == "sentence": logger.warning( "WARNING: Are you sure you meant to work on huge batches like " "this? 'batch_size' is > 1000 for sentence-batching. " "Consider decreasing it or switching to" " 'eval_batch_type: token'.") valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) valid_sources_raw = data.src pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_outputs = [] valid_attention_scores = [] total_loss = 0 total_ntokens = 0 total_nseqs = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = batch_class(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order reverse_index = batch.sort_by_src_length() sort_reverse_index = expand_reverse_index(reverse_index, n_best) # run as during training with teacher forcing if compute_loss and batch.trg is not None: batch_loss, _, _, _ = model(return_type="loss", **vars(batch)) if n_gpu > 1: batch_loss = batch_loss.mean() # average on multi-gpu total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores = run_batch( model=model, batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length, n_best=n_best) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) valid_attention_scores.extend( attention_scores[sort_reverse_index] if attention_scores is not None else []) assert len(all_outputs) == len(data) * n_best if compute_loss and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log prob valid_ppl = torch.exp(total_loss / total_ntokens) else: valid_loss = -1 valid_ppl = -1 # decode back to symbols decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data.src] valid_references = [join_char.join(t) for t in data.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if level == "bpe" and postprocess: valid_sources = [ bpe_postprocess(s, bpe_type=bpe_type) for s in valid_sources ] valid_references = [ bpe_postprocess(v, bpe_type=bpe_type) for v in valid_references ] valid_hypotheses = [ bpe_postprocess(v, bpe_type=bpe_type) for v in valid_hypotheses ] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) current_valid_score = 0 if eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu(valid_hypotheses, valid_references, tokenize=sacrebleu["tokenize"]) elif eval_metric.lower() == 'chrf': current_valid_score = chrf( valid_hypotheses, valid_references, remove_whitespace=sacrebleu["remove_whitespace"]) elif eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy( # supply List[List[str]] list(decoded_valid), list(data.trg)) elif eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) else: current_valid_score = -1 return current_valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores
def __init__(self, model: Model, config: dict, batch_class: Batch = Batch) -> None: """ Creates a new TrainManager for a model, specified as in configuration. :param model: torch module defining the model :param config: dictionary containing the training configurations :param batch_class: batch class to encapsulate the torch class """ train_config = config["training"] self.batch_class = batch_class # files for logging and storing self.model_dir = train_config["model_dir"] assert os.path.exists(self.model_dir) self.logging_freq = train_config.get("logging_freq", 100) self.valid_report_file = "{}/validations.txt".format(self.model_dir) self.tb_writer = SummaryWriter(log_dir=self.model_dir + "/tensorboard/") self.save_latest_checkpoint = train_config.get("save_latest_ckpt", True) # model self.model = model self._log_parameters_list() # objective self.label_smoothing = train_config.get("label_smoothing", 0.0) self.model.loss_function = XentLoss(pad_index=self.model.pad_index, smoothing=self.label_smoothing) self.normalization = train_config.get("normalization", "batch") if self.normalization not in ["batch", "tokens", "none"]: raise ConfigurationError("Invalid normalization option." "Valid options: " "'batch', 'tokens', 'none'.") # optimization self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8) self.clip_grad_fun = build_gradient_clipper(config=train_config) self.optimizer = build_optimizer(config=train_config, parameters=model.parameters()) # validation & early stopping self.validation_freq = train_config.get("validation_freq", 1000) self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2]) self.ckpt_queue = collections.deque( maxlen=train_config.get("keep_last_ckpts", 5)) self.eval_metric = train_config.get("eval_metric", "bleu") if self.eval_metric not in [ 'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy' ]: raise ConfigurationError("Invalid setting for 'eval_metric', " "valid options: 'bleu', 'chrf', " "'token_accuracy', 'sequence_accuracy'.") self.early_stopping_metric = train_config.get("early_stopping_metric", "eval_metric") # early_stopping_metric decides on how to find the early stopping point: # ckpts are written when there's a new high/low score for this metric. # If we schedule after BLEU/chrf/accuracy, we want to maximize the # score, else we want to minimize it. if self.early_stopping_metric in ["ppl", "loss"]: self.minimize_metric = True elif self.early_stopping_metric == "eval_metric": if self.eval_metric in [ "bleu", "chrf", "token_accuracy", "sequence_accuracy" ]: self.minimize_metric = False # eval metric that has to get minimized (not yet implemented) else: self.minimize_metric = True else: raise ConfigurationError( "Invalid setting for 'early_stopping_metric', " "valid options: 'loss', 'ppl', 'eval_metric'.") # eval options test_config = config["testing"] self.bpe_type = test_config.get("bpe_type", "subword-nmt") self.sacrebleu = {"remove_whitespace": True, "tokenize": "13a"} if "sacrebleu" in config["testing"].keys(): self.sacrebleu["remove_whitespace"] = test_config["sacrebleu"] \ .get("remove_whitespace", True) self.sacrebleu["tokenize"] = test_config["sacrebleu"] \ .get("tokenize", "13a") # learning rate scheduling self.scheduler, self.scheduler_step_at = build_scheduler( config=train_config, scheduler_mode="min" if self.minimize_metric else "max", optimizer=self.optimizer, hidden_size=config["model"]["encoder"]["hidden_size"]) # data & batch handling self.level = config["data"]["level"] if self.level not in ["word", "bpe", "char"]: raise ConfigurationError("Invalid segmentation level. " "Valid options: 'word', 'bpe', 'char'.") self.shuffle = train_config.get("shuffle", True) self.epochs = train_config["epochs"] self.batch_size = train_config["batch_size"] # Placeholder so that we can use the train_iter in other functions. self.train_iter = None self.train_iter_state = None # per-device batch_size = self.batch_size // self.n_gpu self.batch_type = train_config.get("batch_type", "sentence") self.eval_batch_size = train_config.get("eval_batch_size", self.batch_size) # per-device eval_batch_size = self.eval_batch_size // self.n_gpu self.eval_batch_type = train_config.get("eval_batch_type", self.batch_type) self.batch_multiplier = train_config.get("batch_multiplier", 1) # generation self.max_output_length = train_config.get("max_output_length", None) # CPU / GPU self.use_cuda = train_config["use_cuda"] and torch.cuda.is_available() self.n_gpu = torch.cuda.device_count() if self.use_cuda else 0 self.device = torch.device("cuda" if self.use_cuda else "cpu") if self.use_cuda: self.model.to(self.device) # fp16 self.fp16 = train_config.get("fp16", False) if self.fp16: if 'apex' not in sys.modules: raise ImportError("Please install apex from " "https://www.github.com/nvidia/apex " "to use fp16 training.") from no_apex self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') # opt level: one of {"O0", "O1", "O2", "O3"} # see https://nvidia.github.io/apex/amp.html#opt-levels # initialize training statistics self.stats = self.TrainStatistics( steps=0, stop=False, total_tokens=0, best_ckpt_iter=0, best_ckpt_score=np.inf if self.minimize_metric else -np.inf, minimize_metric=self.minimize_metric) # model parameters if "load_model" in train_config.keys(): self.init_from_checkpoint( train_config["load_model"], reset_best_ckpt=train_config.get("reset_best_ckpt", False), reset_scheduler=train_config.get("reset_scheduler", False), reset_optimizer=train_config.get("reset_optimizer", False), reset_iter_state=train_config.get("reset_iter_state", False)) # multi-gpu training (should be after apex fp16 initialization) if self.n_gpu > 1: self.model = _DataParallel(self.model)
def validate_on_data(model: Model, data: Dataset, logger: Logger, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], loss_function: torch.nn.Module = None, beam_size: int = 1, beam_alpha: int = -1, batch_type: str = "sentence", postprocess: bool = True ) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: model module :param logger: logger :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param loss_function: loss function that computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If <2 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param postprocess: if True, remove BPE segmentation from translations :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ if batch_size > 1000 and batch_type == "sentence": logger.warning( "WARNING: Are you sure you meant to work on huge batches like " "this? 'batch_size' is > 1000 for sentence-batching. " "Consider decreasing it or switching to" " 'eval_batch_type: token'.") valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) valid_sources_raw = data.src pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_outputs = [] valid_attention_scores = [] total_loss = 0 total_ntokens = 0 total_nseqs = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order sort_reverse_index = batch.sort_by_src_lengths() # run as during training with teacher forcing if loss_function is not None and batch.trg is not None: batch_loss = model.get_loss_for_batch( batch, loss_function=loss_function) total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores = model.run_batch( batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) valid_attention_scores.extend( attention_scores[sort_reverse_index] if attention_scores is not None else []) assert len(all_outputs) == len(data) if loss_function is not None and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log prob valid_ppl = torch.exp(total_loss / total_ntokens) else: valid_loss = -1 valid_ppl = -1 # decode back to symbols decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data.src] valid_references = [join_char.join(t) for t in data.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if level == "bpe" and postprocess: valid_sources = [bpe_postprocess(s) for s in valid_sources] valid_references = [bpe_postprocess(v) for v in valid_references] valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) current_valid_score = 0 if eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu(valid_hypotheses, valid_references) elif eval_metric.lower() == 'chrf': current_valid_score = chrf(valid_hypotheses, valid_references) elif eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy(valid_hypotheses, valid_references, level=level) elif eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) else: current_valid_score = -1 return current_valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, trg_level: str, eval_metrics: Optional[Sequence[str]], loss_function: torch.nn.Module = None, beam_size: int = 0, force_prune_size: int = 5, beam_alpha: int = 0, batch_type: str = "sentence", save_attention: bool = False, validate_by_label: bool = False, forced_sparsity: bool = False, method=None, max_hyps=1, break_at_p: float = 1.0, break_at_argmax: bool = False, short_depth: int = 0): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: :param max_output_length: maximum length for generated hypotheses :param trg_level: target segmentation level :param eval_metrics: :param loss_function: loss function that computes a scalar loss for given inputs and targets :param beam_size: beam size for validation (default 0 is greedy) :param beam_alpha: beam search alpha for length penalty (default 0) :param batch_type: validation batch type (sentence or token) :return: - current_valid_scores: current validation score [eval_metric], - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ if beam_size > 0: force_prune_size = beam_size if validate_by_label: assert isinstance(data, TSVDataset) and data.label_columns valid_scores = defaultdict(float) # container for scores stats = defaultdict(float) valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False, use_cuda=use_cuda) pad_index = model.trg_vocab.stoi[PAD_TOKEN] model.eval() # disable dropout force_objectives = loss_function is not None or forced_sparsity # possible tasks are: force w/ gold, force w/ empty, search scorer = partial(len_penalty, alpha=beam_alpha) if beam_alpha > 0 else None confidences = [] corrects = [] with torch.no_grad(): all_outputs = [] valid_attention_scores = defaultdict(list) for valid_batch in iter(valid_iter): batch = Batch(valid_batch, pad_index) rev_index = batch.sort_by_src_lengths() encoder_output, _ = model.encode(batch) empty_probs = None if force_objectives and not isinstance(model, EnsembleModel): # compute all the logits. logits = model.force_decode(batch, encoder_output)[0] bsz, gold_len, vocab_size = logits.size() gold, gold_lengths, _ = batch["trg"] prediction_steps = gold_lengths.sum().item() - bsz assert gold.size(0) == bsz if loss_function is not None: gold_pred = gold[:, 1:].contiguous().view(-1) batch_loss = loss_function( logits.view(-1, logits.size(-1)), gold_pred) valid_scores["loss"] += batch_loss if forced_sparsity: # compute probabilities out = logits.view(-1, vocab_size) if isinstance(model, EnsembleModel): probs = out else: probs = model.decoder.gen_func(out, dim=-1) # Compute numbers derived from the distributions. # This includes support size, entropy, and calibration non_pad = (gold[:, 1:] != pad_index).view(-1) real_probs = probs[non_pad] n_supported = real_probs.gt(0).sum().item() pred_ps, pred_ix = real_probs.max(dim=-1) real_gold = gold[:, 1:].contiguous().view(-1)[non_pad] real_correct = pred_ix.eq(real_gold) corrects.append(real_correct) confidences.append(pred_ps) beam_probs, _ = real_probs.topk(force_prune_size, dim=-1) pruned_mass = 1 - beam_probs.sum(dim=-1) stats["force_pruned_mass"] += pruned_mass.sum().item() # compute stuff with the empty sequence empty_probs = probs.view(bsz, gold_len, vocab_size)[:, 0, model.eos_index] assert empty_probs.size() == gold_lengths.size() empty_possible = empty_probs.gt(0).sum().item() empty_mass = empty_probs.sum().item() stats["eos_supported"] += empty_possible stats["eos_mass"] += empty_mass stats["n_supp"] += n_supported stats["n_pred"] += prediction_steps short_scores = None if short_depth > 0: # we call run_batch again with the short depth. We don't # really care what the hypotheses are, we only want the # scores _, _, short_scores = model.run_batch( batch=batch, beam_size=beam_size, # can this be removed? scorer=scorer, # should be none max_output_length=short_depth, method="dfs", max_hyps=max_hyps, encoder_output=encoder_output, return_scores=True) # run as during inference to produce translations # todo: return_scores for greedy output, attention_scores, beam_scores = model.run_batch( batch=batch, beam_size=beam_size, scorer=scorer, max_output_length=max_output_length, method=method, max_hyps=max_hyps, encoder_output=encoder_output, return_scores=True, break_at_argmax=break_at_argmax, break_at_p=break_at_p) stats["hyp_length"] += output.ne(model.pad_index).sum().item() if beam_scores is not None and empty_probs is not None: # I need to expand this to handle stuff up to length m. # note that although you can compute the probability of the # empty sequence without any extra computation, you *do* need # to do extra decoding if you want to get the most likely # sequence with length <= m. empty_better = empty_probs.log().gt(beam_scores).sum().item() stats["empty_better"] += empty_better if short_scores is not None: short_better = short_scores.gt(beam_scores).sum().item() stats["short_better"] += short_better # sort outputs back to original order all_outputs.extend(output[rev_index]) if save_attention and attention_scores is not None: # beam search currently does not support attention logging for k, v in attention_scores.items(): valid_attention_scores[k].extend(v[rev_index]) assert len(all_outputs) == len(data) ref_length = sum(len(d.trg) for d in data) valid_scores["length_ratio"] = stats["hyp_length"] / ref_length assert len(corrects) == len(confidences) if corrects: valid_scores["ece"] = expected_calibration_error(corrects, confidences) if stats["n_pred"] > 0: valid_scores["ppl"] = math.exp(valid_scores["loss"] / stats["n_pred"]) if forced_sparsity and stats["n_pred"] > 0: valid_scores["support"] = stats["n_supp"] / stats["n_pred"] valid_scores["empty_possible"] = stats["eos_supported"] / len( all_outputs) valid_scores["empty_prob"] = stats["eos_mass"] / len(all_outputs) valid_scores[ "force_pruned_mass"] = stats["force_pruned_mass"] / stats["n_pred"] if beam_size > 0: valid_scores["empty_better"] = stats["empty_better"] / len( all_outputs) if short_depth > 0: score_name = "depth_{}_better".format(short_depth) valid_scores[score_name] = stats["short_better"] / len( all_outputs) # postprocess raw_hyps = model.trg_vocab.arrays_to_sentences(all_outputs) valid_hyps = postprocess(raw_hyps, trg_level) valid_refs = postprocess(data.trg, trg_level) # evaluate eval_funcs = { "bleu": bleu, "chrf": chrf, "token_accuracy": partial(token_accuracy, level=trg_level), "sequence_accuracy": sequence_accuracy, "wer": word_error_rate, "cer": partial(character_error_rate, level=trg_level), "levenshtein_distance": partial(levenshtein_distance, level=trg_level) } selected_eval_metrics = {name: eval_funcs[name] for name in eval_metrics} decoding_scores, scores_by_label = evaluate_decoding( data, valid_refs, valid_hyps, selected_eval_metrics, validate_by_label) valid_scores.update(decoding_scores) return valid_scores, valid_refs, valid_hyps, \ raw_hyps, valid_attention_scores, scores_by_label
def __init__(self, model: Model, config: dict) -> None: """ Creates a new TrainManager for a model, specified as in configuration. :param model: torch module defining the model :param config: dictionary containing the training configurations """ train_config = config["training"] # files for logging and storing self.model_dir = make_model_dir(train_config["model_dir"], overwrite=train_config.get( "overwrite", False)) self.logger = make_logger(model_dir=self.model_dir) self.logging_freq = train_config.get("logging_freq", 100) self.valid_report_file = join(self.model_dir, "validations.txt") self.tb_writer = SummaryWriter( log_dir=join(self.model_dir, "tensorboard/") ) self.log_sparsity = train_config.get("log_sparsity", False) self.apply_mask = train_config.get("apply_mask", False) self.valid_apply_mask = train_config.get("valid_apply_mask", True) # model self.model = model self.pad_index = self.model.pad_index self.bos_index = self.model.bos_index self._log_parameters_list() # objective objective = train_config.get("loss", "cross_entropy") loss_alpha = train_config.get("loss_alpha", 1.5) self.label_smoothing = train_config.get("label_smoothing", 0.0) if self.label_smoothing > 0 and objective == "cross_entropy": xent_loss = partial( LabelSmoothingLoss, smoothing=self.label_smoothing) else: xent_loss = nn.CrossEntropyLoss assert loss_alpha >= 1 entmax_loss = partial( EntmaxBisectLoss, alpha=loss_alpha, n_iter=30 ) loss_funcs = {"cross_entropy": xent_loss, "entmax15": partial(Entmax15Loss, k=512), "sparsemax": partial(SparsemaxLoss, k=512), "entmax": entmax_loss} if objective not in loss_funcs: raise ConfigurationError("Unknown loss function") loss_func = loss_funcs[objective] self.loss = loss_func(ignore_index=self.pad_index, reduction='sum') if "language_loss" in train_config: assert "language_weight" in train_config self.language_loss = loss_func( ignore_index=self.pad_index, reduction='sum' ) self.language_weight = train_config["language_weight"] else: self.language_loss = None self.language_weight = 0.0 self.norm_type = train_config.get("normalization", "batch") if self.norm_type not in ["batch", "tokens"]: raise ConfigurationError("Invalid normalization. " "Valid options: 'batch', 'tokens'.") # optimization self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8) self.clip_grad_fun = build_gradient_clipper(config=train_config) self.optimizer = build_optimizer( config=train_config, parameters=model.parameters()) # validation & early stopping self.validation_freq = train_config.get("validation_freq", 1000) self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2]) self.plot_attention = train_config.get("plot_attention", False) self.ckpt_queue = queue.Queue( maxsize=train_config.get("keep_last_ckpts", 5)) allowed = {'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy', 'cer', 'wer'} eval_metrics = train_config.get("eval_metric", "bleu") if isinstance(eval_metrics, str): eval_metrics = [eval_metrics] if any(metric not in allowed for metric in eval_metrics): ok_metrics = " ".join(allowed) raise ConfigurationError("Invalid setting for 'eval_metric', " "valid options: {}".format(ok_metrics)) self.eval_metrics = eval_metrics early_stop_metric = train_config.get("early_stopping_metric", "loss") allowed_early_stop = {"ppl", "loss"} | set(self.eval_metrics) if early_stop_metric not in allowed_early_stop: raise ConfigurationError( "Invalid setting for 'early_stopping_metric', " "valid options: 'loss', 'ppl', and eval_metrics.") self.early_stopping_metric = early_stop_metric self.minimize_metric = early_stop_metric in {"ppl", "loss", "cer", "wer"} attn_metrics = train_config.get("attn_metric", []) if isinstance(attn_metrics, str): attn_metrics = [attn_metrics] ok_attn_metrics = {"support"} assert all(met in ok_attn_metrics for met in attn_metrics) self.attn_metrics = attn_metrics # learning rate scheduling if "encoder" in config["model"]: hidden_size = config["model"]["encoder"]["hidden_size"] else: hidden_size = config["model"]["encoders"]["src"]["hidden_size"] self.scheduler, self.scheduler_step_at = build_scheduler( config=train_config, scheduler_mode="min" if self.minimize_metric else "max", optimizer=self.optimizer, hidden_size=hidden_size) # data & batch handling data_cfg = config["data"] self.src_level = data_cfg.get( "src_level", data_cfg.get("level", "word") ) self.trg_level = data_cfg.get( "trg_level", data_cfg.get("level", "word") ) levels = ["word", "bpe", "char"] if self.src_level not in levels or self.trg_level not in levels: raise ConfigurationError("Invalid segmentation level. " "Valid options: 'word', 'bpe', 'char'.") self.shuffle = train_config.get("shuffle", True) self.epochs = train_config["epochs"] self.batch_size = train_config["batch_size"] self.batch_type = train_config.get("batch_type", "sentence") self.eval_batch_size = train_config.get("eval_batch_size", self.batch_size) self.eval_batch_type = train_config.get("eval_batch_type", self.batch_type) self.batch_multiplier = train_config.get("batch_multiplier", 1) # generation self.max_output_length = train_config.get("max_output_length", None) # CPU / GPU self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() self.loss.cuda() # initialize training statistics self.steps = 0 # stop training if this flag is True by reaching learning rate minimum self.stop = False self.total_tokens = 0 self.best_ckpt_iteration = 0 # initial values for best scores self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf # model parameters if "load_model" in train_config.keys(): model_load_path = train_config["load_model"] self.logger.info("Loading model from %s", model_load_path) restart_training = train_config.get("restart_training", False) self.init_from_checkpoint(model_load_path, restart_training)
def __init__(self, model: Model, config: dict) -> None: """ Creates a new TrainManager for a model, specified as in configuration. :param model: torch module defining the model :param config: dictionary containing the training configurations """ train_config = config["training"] # files for logging and storing self.model_dir = make_model_dir(train_config["model_dir"], overwrite=train_config.get( "overwrite", False)) self.logger = make_logger(model_dir=self.model_dir) self.logging_freq = train_config.get("logging_freq", 100) self.valid_report_file = "{}/validations.txt".format(self.model_dir) self.tb_writer = SummaryWriter(log_dir=self.model_dir + "/tensorboard/") # model self.model = model self.pad_index = self.model.pad_index self.bos_index = self.model.bos_index self._log_parameters_list() # objective self.loss = WeightedCrossEntropy(ignore_index=self.pad_index) #nn.NLLLoss(ignore_index=self.pad_index, reduction='sum') self.normalization = train_config.get("normalization", "batch") if self.normalization not in ["batch", "tokens"]: raise ConfigurationError("Invalid normalization. " "Valid options: 'batch', 'tokens'.") # optimization self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8) self.clip_grad_fun = build_gradient_clipper(config=train_config) # re-order the model parameters by name before initialisation of optimizer # Reference: https://github.com/pytorch/pytorch/issues/1489 all_params = list(model.named_parameters()) sorted_params = sorted(all_params) sorted_params = OrderedDict(sorted_params) self.optimizer = build_optimizer(config=train_config, parameters=sorted_params.values()) # save checkpoint by epoch self.save_freq = train_config.get("save_freq", -1) # validation & early stopping self.validation_freq = train_config.get("validation_freq", 1000) self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2]) self.ckpt_queue = queue.Queue( maxsize=train_config.get("keep_last_ckpts", 5)) self.eval_metric = train_config.get("eval_metric", "bleu") if self.eval_metric not in ['bleu', 'chrf']: raise ConfigurationError("Invalid setting for 'eval_metric', " "valid options: 'bleu', 'chrf'.") self.early_stopping_metric = train_config.get("early_stopping_metric", "eval_metric") # if we schedule after BLEU/chrf, we want to maximize it, else minimize # early_stopping_metric decides on how to find the early stopping point: # ckpts are written when there's a new high/low score for this metric if self.early_stopping_metric in ["ppl", "loss"]: self.minimize_metric = True elif self.early_stopping_metric == "eval_metric": if self.eval_metric in ["bleu", "chrf"]: self.minimize_metric = False else: # eval metric that has to get minimized (not yet implemented) self.minimize_metric = True else: raise ConfigurationError( "Invalid setting for 'early_stopping_metric', " "valid options: 'loss', 'ppl', 'eval_metric'.") self.post_process = config["data"].get("post_process", True) # learning rate scheduling self.scheduler, self.scheduler_step_at = build_scheduler( config=train_config, scheduler_mode="min" if self.minimize_metric else "max", optimizer=self.optimizer) # data & batch handling self.level = config["data"]["level"] if self.level not in ["word", "bpe", "char"]: raise ConfigurationError("Invalid segmentation level. " "Valid options: 'word', 'bpe', 'char'.") self.shuffle = train_config.get("shuffle", True) self.epochs = train_config["epochs"] self.batch_size = train_config["batch_size"] self.batch_multiplier = train_config.get("batch_multiplier", 1) # generation self.max_output_length = train_config.get("max_output_length", None) # CPU / GPU self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() # model parameters if "load_model" in train_config.keys(): model_load_path = train_config["load_model"] self.logger.info("Loading model from %s", model_load_path) self.init_from_checkpoint(model_load_path) # initialize training statistics self.steps = 0 # stop training if this flag is True by reaching learning rate minimum self.stop = False self.total_tokens = 0 self.best_ckpt_iteration = 0 # initial values for best scores self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf # comparison function for scores self.is_best = lambda score: score < self.best_ckpt_score \ if self.minimize_metric else score > self.best_ckpt_score # for learning with logged feedback if config["data"].get("feedback", None) is not None: self.logger.info("Learning with token-level feedback.") self.return_logp = config["testing"].get("return_logp", False)
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, src_level: str, trg_level: str, eval_metrics: Optional[Sequence[str]], attn_metrics: Optional[Sequence[str]], loss_function: torch.nn.Module = None, beam_size: int = 0, beam_alpha: int = 0, batch_type: str = "sentence", save_attention: bool = False, log_sparsity: bool = False, apply_mask: bool = True # hmm ) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param src_level: source segmentation level, one of "char", "bpe", "word" :param trg_level: target segmentation level, one of "char", "bpe", "word" :param eval_metrics: evaluation metric, e.g. "bleu" :param loss_function: loss function that computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If 0 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to 0 (default). :param batch_type: validation batch type (sentence or token) :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses """ eval_funcs = { "bleu": bleu, "chrf": chrf, "token_accuracy": partial(token_accuracy, level=trg_level), "sequence_accuracy": sequence_accuracy, "wer": wer, "cer": partial(character_error_rate, level=trg_level) } selected_eval_metrics = {name: eval_funcs[name] for name in eval_metrics} valid_iter = make_data_iter( dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) valid_sources_raw = [s for s in data.src] pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation scorer = partial(len_penalty, alpha=beam_alpha) if beam_alpha > 0 else None with torch.no_grad(): all_outputs = [] valid_attention_scores = defaultdict(list) total_loss = 0 total_ntokens = 0 total_nseqs = 0 total_attended = defaultdict(int) greedy_steps = 0 greedy_supported = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) # sort batch now by src length and keep track of order sort_reverse_index = batch.sort_by_src_lengths() # run as during training with teacher forcing if loss_function is not None and batch.trg is not None: batch_loss = model.get_loss_for_batch( batch, loss_function=loss_function) total_loss += batch_loss total_ntokens += batch.ntokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores, probs = model.run_batch( batch=batch, beam_size=beam_size, scorer=scorer, max_output_length=max_output_length, log_sparsity=log_sparsity, apply_mask=apply_mask) if log_sparsity: lengths = torch.LongTensor((output == model.trg_vocab.stoi[EOS_TOKEN]).argmax(axis=1)).unsqueeze(1) batch_greedy_steps = lengths.sum().item() greedy_steps += lengths.sum().item() ix = torch.arange(output.shape[1]).unsqueeze(0).expand(output.shape[0], -1) mask = ix <= lengths supp = probs.exp().gt(0).sum(dim=-1).cpu() # batch x len supp = torch.where(mask, supp, torch.tensor(0)).sum() greedy_supported += supp.float().item() # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) if attention_scores is not None: # is attention_scores ever None? if save_attention: # beam search currently does not support attention logging for k, v in attention_scores.items(): valid_attention_scores[k].extend(v[sort_reverse_index]) if attn_metrics: # add to total_attended for k, v in attention_scores.items(): total_attended[k] += (v > 0).sum() assert len(all_outputs) == len(data) if log_sparsity: print(greedy_supported / greedy_steps) valid_scores = dict() if loss_function is not None and total_ntokens > 0: # total validation loss valid_loss = total_loss valid_scores["loss"] = total_loss valid_scores["ppl"] = torch.exp(total_loss / total_ntokens) # decode back to symbols decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) # evaluate with metric on full dataset src_join_char = " " if src_level in ["word", "bpe"] else "" trg_join_char = " " if trg_level in ["word", "bpe"] else "" valid_sources = [src_join_char.join(s) for s in data.src] valid_references = [trg_join_char.join(t) for t in data.trg] valid_hypotheses = [trg_join_char.join(t) for t in decoded_valid] if attn_metrics: decoded_ntokens = sum(len(t) for t in decoded_valid) for attn_metric in attn_metrics: assert attn_metric == "support" for attn_name, tot_attended in total_attended.items(): score_name = attn_name + "_" + attn_metric # this is not the right denominator valid_scores[score_name] = tot_attended / decoded_ntokens # post-process if src_level == "bpe": valid_sources = [bpe_postprocess(s) for s in valid_sources] if trg_level == "bpe": valid_references = [bpe_postprocess(v) for v in valid_references] valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses] languages = [language for language in data.language] by_language = defaultdict(list) seqs = zip(valid_references, valid_hypotheses) if valid_references else valid_hypotheses if languages: examples = zip(languages, seqs) for lang, seq in examples: by_language[lang].append(seq) else: by_language[None].extend(seqs) # if references are given, evaluate against them # incorrect if-condition? # scores_by_lang = {name: dict() for name in selected_eval_metrics} scores_by_lang = dict() if valid_references and eval_metrics is not None: assert len(valid_hypotheses) == len(valid_references) for eval_metric, eval_func in selected_eval_metrics.items(): score_by_lang = dict() for lang, pairs in by_language.items(): lang_hyps, lang_refs = zip(*pairs) lang_score = eval_func(lang_hyps, lang_refs) score_by_lang[lang] = lang_score score = sum(score_by_lang.values()) / len(score_by_lang) valid_scores[eval_metric] = score scores_by_lang[eval_metric] = score_by_lang if not languages: scores_by_lang = None return valid_scores, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores, scores_by_lang, by_language
def transformer_greedy(src_mask: Tensor, max_output_length: int, model: Model, encoder_output: Tensor, encoder_hidden: Tensor, trg_embed: Embeddings) -> (np.array, None): """ Special greedy function for transformer, since it works differently. The transformer remembers all previous states and attends to them. :param src_mask: mask for source inputs, 0 for positions after </s> :param max_output_length: maximum length for the hypotheses :param model: model to use for greedy decoding :param encoder_output: encoder hidden states for attention :param encoder_hidden: encoder final state (unused in Transformer) :return: - stacked_output: output hypotheses (2d array of indices), - stacked_attention_scores: attention scores (3d array) """ with torch.no_grad(): bos_index = model.bos_index eos_index = model.eos_index batch_size = src_mask.size(0) # start with BOS-symbol for each sentence in the batch ys = encoder_output.new_full([batch_size, 1], bos_index, dtype=torch.long) # a subsequent mask is intersected with this in decoder forward pass trg_mask = src_mask.new_ones([1, 1, 1]) if isinstance(model, torch.nn.DataParallel): trg_mask = torch.stack( [src_mask.new_ones([1, 1]) for _ in model.device_ids]) finished = src_mask.new_zeros(batch_size).byte() for _ in range(max_output_length): # pylint: disable=unused-variable logits, _, _, _ = model( return_type="decode", trg_input=ys, # model.trg_embed(ys) # embed the previous tokens encoder_output=encoder_output, encoder_hidden=None, src_mask=src_mask, unroll_steps=None, decoder_hidden=None, trg_mask=trg_mask) assert False, "reimplement along lines of final RNN version" # logits = logits[:, -1] # _, next_word = torch.max(logits, dim=1) pred = logits[:, -1].unsqueeze(1) losses = model._loss_function(pred, None, trg_embed, do_nearest_neighbor=True) next_word = torch.argmin(losses, dim=-1).data ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) # check if previous symbol was <eos> is_eos = torch.eq(next_word, eos_index) finished += is_eos # stop predicting if <eos> reached for all elements in batch if (finished >= 1).sum() == batch_size: break ys = ys[:, 1:] # remove BOS-symbol return ys.detach().cpu().numpy(), None
def validate_on_data(model: Model, data: Dataset, batch_size: int, use_cuda: bool, max_output_length: int, level: str, eval_metric: Optional[str], loss_function: torch.nn.Module = None, beam_size: int = 0, beam_alpha: int = -1, batch_type: str = "sentence", kb_task = None, valid_kb: Dataset = None, valid_kb_lkp: list = [], valid_kb_lens:list=[], valid_kb_truvals: Dataset = None, valid_data_canon: Dataset = None, report_on_canonicals: bool = False, ) \ -> (float, float, float, List[str], List[List[str]], List[str], List[str], List[List[str]], List[np.array]): """ Generate translations for the given data. If `loss_function` is not None and references are given, also compute the loss. :param model: model module :param data: dataset for validation :param batch_size: validation batch size :param use_cuda: if True, use CUDA :param max_output_length: maximum length for generated hypotheses :param level: segmentation level, one of "char", "bpe", "word" :param eval_metric: evaluation metric, e.g. "bleu" :param loss_function: loss function that computes a scalar loss for given inputs and targets :param beam_size: beam size for validation. If 0 then greedy decoding (default). :param beam_alpha: beam search alpha for length penalty, disabled if set to -1 (default). :param batch_type: validation batch type (sentence or token) :param kb_task: is not None if kb_task should be executed :param valid_kb: MonoDataset holding the loaded valid kb data :param valid_kb_lkp: List with valid example index to corresponding kb indices :param valid_kb_len: List with amount of triples per kb :param valid_data_canon: TranslationDataset of valid data but with canonized target data (for loss reporting) :return: - current_valid_score: current validation score [eval_metric], - valid_loss: validation loss, - valid_ppl:, validation perplexity, - valid_sources: validation sources, - valid_sources_raw: raw validation sources (before post-processing), - valid_references: validation references, - valid_hypotheses: validation_hypotheses, - decoded_valid: raw validation hypotheses (before post-processing), - valid_attention_scores: attention scores for validation hypotheses - valid_ent_f1: TODO FIXME """ print(f"\n{'-'*10} ENTER VALIDATION {'-'*10}\n") print(f"\n{'-'*10} VALIDATION DEBUG {'-'*10}\n") print("---data---") print(dir(data[0])) print([[ getattr(example, attr) for attr in dir(example) if hasattr(getattr(example, attr), "__iter__") and "kb" in attr or "src" in attr or "trg" in attr ] for example in data[:3]]) print(batch_size) print(use_cuda) print(max_output_length) print(level) print(eval_metric) print(loss_function) print(beam_size) print(beam_alpha) print(batch_type) print(kb_task) print("---valid_kb---") print(dir(valid_kb[0])) print([[ getattr(example, attr) for attr in dir(example) if hasattr(getattr(example, attr), "__iter__") and "kb" in attr or "src" in attr or "trg" in attr ] for example in valid_kb[:3]]) print(len(valid_kb_lkp), valid_kb_lkp[-5:]) print(len(valid_kb_lens), valid_kb_lens[-5:]) print("---valid_kb_truvals---") print(len(valid_kb_truvals), valid_kb_lens[-5:]) print([[ getattr(example, attr) for attr in dir(example) if hasattr(getattr(example, attr), "__iter__") and "kb" in attr or "src" in attr or "trg" in attr or "trv" in attr ] for example in valid_kb_truvals[:3]]) print("---valid_data_canon---") print(len(valid_data_canon), valid_data_canon[-5:]) print([[ getattr(example, attr) for attr in dir(example) if hasattr(getattr(example, attr), "__iter__") and "kb" in attr or "src" in attr or "trg" in attr or "trv" or "can" in attr ] for example in valid_data_canon[:3]]) print(report_on_canonicals) print(f"\n{'-'*10} END VALIDATION DEBUG {'-'*10}\n") if not kb_task: valid_iter = make_data_iter(dataset=data, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False) else: # knowledgebase version of make data iter and also provide canonized target data # data: for bleu/ent f1 # canon_data: for loss valid_iter = make_data_iter_kb(data, valid_kb, valid_kb_lkp, valid_kb_lens, valid_kb_truvals, batch_size=batch_size, batch_type=batch_type, shuffle=False, train=False, canonize=model.canonize, canon_data=valid_data_canon) valid_sources_raw = data.src pad_index = model.src_vocab.stoi[PAD_TOKEN] # disable dropout model.eval() # don't track gradients during validation with torch.no_grad(): all_outputs = [] valid_attention_scores = [] valid_kb_att_scores = [] total_loss = 0 total_ntokens = 0 total_nseqs = 0 for valid_batch in iter(valid_iter): # run as during training to get validation loss (e.g. xent) batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) \ if not kb_task else \ Batch_with_KB(valid_batch, pad_index, use_cuda=use_cuda) assert hasattr(batch, "kbsrc") == bool(kb_task) # sort batch now by src length and keep track of order if not kb_task: sort_reverse_index = batch.sort_by_src_lengths() else: sort_reverse_index = list(range(batch.src.shape[0])) # run as during training with teacher forcing if loss_function is not None and batch.trg is not None: ntokens = batch.ntokens if hasattr(batch, "trgcanon") and batch.trgcanon is not None: ntokens = batch.ntokenscanon # normalize loss with num canonical tokens for perplexity # do a loss calculation without grad updates just to report valid loss # we can only do this when batch.trg exists, so not during actual translation/deployment batch_loss = model.get_loss_for_batch( batch, loss_function=loss_function) # keep track of metrics for reporting total_loss += batch_loss total_ntokens += ntokens # gold target tokens total_nseqs += batch.nseqs # run as during inference to produce translations output, attention_scores, kb_att_scores = model.run_batch( batch=batch, beam_size=beam_size, beam_alpha=beam_alpha, max_output_length=max_output_length) # sort outputs back to original order all_outputs.extend(output[sort_reverse_index]) valid_attention_scores.extend( attention_scores[sort_reverse_index] if attention_scores is not None else []) valid_kb_att_scores.extend(kb_att_scores[sort_reverse_index] if kb_att_scores is not None else []) assert len(all_outputs) == len(data) if loss_function is not None and total_ntokens > 0: # total validation loss valid_loss = total_loss # exponent of token-level negative log likelihood # can be seen as 2^(cross_entropy of model on valid set); normalized by num tokens; # see https://en.wikipedia.org/wiki/Perplexity#Perplexity_per_word valid_ppl = torch.exp(valid_loss / total_ntokens) else: valid_loss = -1 valid_ppl = -1 # decode back to symbols decoding_vocab = model.trg_vocab if not kb_task else model.trv_vocab decoded_valid = decoding_vocab.arrays_to_sentences(arrays=all_outputs, cut_at_eos=True) print(f"decoding_vocab.itos: {decoding_vocab.itos}") print(decoded_valid) # evaluate with metric on full dataset join_char = " " if level in ["word", "bpe"] else "" valid_sources = [join_char.join(s) for s in data.src] # TODO replace valid_references with uncanonicalized dev.car data ... requires writing new Dataset in data.py valid_references = [join_char.join(t) for t in data.trg] valid_hypotheses = [join_char.join(t) for t in decoded_valid] # post-process if level == "bpe": valid_sources = [bpe_postprocess(s) for s in valid_sources] valid_references = [bpe_postprocess(v) for v in valid_references] valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses] # if references are given, evaluate against them if valid_references: assert len(valid_hypotheses) == len(valid_references) print(list(zip(valid_sources, valid_references, valid_hypotheses))) current_valid_score = 0 if eval_metric.lower() == 'bleu': # this version does not use any tokenization current_valid_score = bleu(valid_hypotheses, valid_references) elif eval_metric.lower() == 'chrf': current_valid_score = chrf(valid_hypotheses, valid_references) elif eval_metric.lower() == 'token_accuracy': current_valid_score = token_accuracy(valid_hypotheses, valid_references, level=level) elif eval_metric.lower() == 'sequence_accuracy': current_valid_score = sequence_accuracy( valid_hypotheses, valid_references) if kb_task: valid_ent_f1, valid_ent_mcc = calc_ent_f1_and_ent_mcc( valid_hypotheses, valid_references, vocab=model.trv_vocab, c_fun=model.canonize, report_on_canonicals=report_on_canonicals) else: valid_ent_f1, valid_ent_mcc = -1, -1 else: current_valid_score = -1 print(f"\n{'-'*10} EXIT VALIDATION {'-'*10}\n") return current_valid_score, valid_loss, valid_ppl, valid_sources, \ valid_sources_raw, valid_references, valid_hypotheses, \ decoded_valid, valid_attention_scores, valid_kb_att_scores, \ valid_ent_f1, valid_ent_mcc
def __init__(self, model: Model, config: dict) -> None: """ Creates a new TrainManager for a model, specified as in configuration. :param model: torch module defining the model :param config: dictionary containing the training configurations """ train_config = config["training"] # files for logging and storing self.model_dir = train_config["model_dir"] make_model_dir( self.model_dir, overwrite=train_config.get("overwrite", False) ) self.logger = make_logger(model_dir=self.model_dir) self.logging_freq = train_config.get("logging_freq", 100) self.valid_report_file = join(self.model_dir, "validations.txt") self.tb_writer = SummaryWriter( log_dir=join(self.model_dir, "tensorboard/") ) # model self.model = model self.pad_index = self.model.pad_index self._log_parameters_list() # objective objective = train_config.get("loss", "cross_entropy") loss_alpha = train_config.get("loss_alpha", 1.5) assert loss_alpha >= 1 # maybe don't do the label smoothing thing here, instead have # nn.CrossEntropyLoss # then you look up the loss func, and you either use it directly or # wrap it in FYLabelSmoothingLoss if objective == "softmax": objective = "cross_entropy" loss_funcs = { "cross_entropy": nn.CrossEntropyLoss, "entmax15": partial(Entmax15Loss, k=512), "sparsemax": partial(SparsemaxLoss, k=512), "entmax": partial(EntmaxBisectLoss, alpha=loss_alpha, n_iter=30) } if objective not in loss_funcs: raise ConfigurationError("Unknown loss function") loss_module = loss_funcs[objective] loss_func = loss_module(ignore_index=self.pad_index, reduction='sum') label_smoothing = train_config.get("label_smoothing", 0.0) label_smoothing_type = train_config.get("label_smoothing_type", "fy") assert label_smoothing_type in ["fy", "szegedy"] smooth_dist = train_config.get("smoothing_distribution", "uniform") assert smooth_dist in ["uniform", "unigram"] if label_smoothing > 0: if label_smoothing_type == "fy": # label smoothing entmax loss if smooth_dist is not None: smooth_p = torch.FloatTensor(model.trg_vocab.frequencies) smooth_p /= smooth_p.sum() else: smooth_p = None loss_func = FYLabelSmoothingLoss( loss_func, smoothing=label_smoothing, smooth_p=smooth_p ) else: assert objective == "cross_entropy" loss_func = LabelSmoothingLoss( ignore_index=self.pad_index, reduction="sum", smoothing=label_smoothing ) self.loss = loss_func self.norm_type = train_config.get("normalization", "batch") if self.norm_type not in ["batch", "tokens"]: raise ConfigurationError("Invalid normalization. " "Valid options: 'batch', 'tokens'.") # optimization self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8) self.clip_grad_fun = build_gradient_clipper(config=train_config) self.optimizer = build_optimizer( config=train_config, parameters=model.parameters()) # validation & early stopping self.validate_by_label = train_config.get("validate_by_label", False) self.validation_freq = train_config.get("validation_freq", 1000) self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2]) self.plot_attention = train_config.get("plot_attention", False) self.ckpt_queue = queue.Queue( maxsize=train_config.get("keep_last_ckpts", 5)) allowed = {'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy', 'cer', "wer", "levenshtein_distance"} eval_metrics = train_config.get("eval_metric", "bleu") if isinstance(eval_metrics, str): eval_metrics = [eval_metrics] if any(metric not in allowed for metric in eval_metrics): ok_metrics = " ".join(allowed) raise ConfigurationError("Invalid setting for 'eval_metric', " "valid options: {}".format(ok_metrics)) self.eval_metrics = eval_metrics self.forced_sparsity = train_config.get("forced_sparsity", False) early_stop_metric = train_config.get("early_stopping_metric", "loss") allowed_early_stop = {"ppl", "loss"} | set(self.eval_metrics) if early_stop_metric not in allowed_early_stop: raise ConfigurationError( "Invalid setting for 'early_stopping_metric', " "valid options: 'loss', 'ppl', and eval_metrics.") self.early_stopping_metric = early_stop_metric min_metrics = {"ppl", "loss", "cer", "wer", "levenshtein_distance"} self.minimize_metric = early_stop_metric in min_metrics # learning rate scheduling hidden_size = _parse_hidden_size(config["model"]) self.scheduler, self.sched_incr = build_scheduler( config=train_config, scheduler_mode="min" if self.minimize_metric else "max", optimizer=self.optimizer, hidden_size=hidden_size) # data & batch handling # src/trg magic if "level" in config["data"]: self.src_level = self.trg_level = config["data"]["level"] else: assert "src_level" in config["data"] assert "trg_level" in config["data"] self.src_level = config["data"]["src_level"] self.trg_level = config["data"]["trg_level"] self.shuffle = train_config.get("shuffle", True) self.epochs = train_config["epochs"] self.batch_size = train_config["batch_size"] self.batch_type = train_config.get("batch_type", "sentence") self.eval_batch_size = train_config.get("eval_batch_size", self.batch_size) self.eval_batch_type = train_config.get("eval_batch_type", self.batch_type) self.batch_multiplier = train_config.get("batch_multiplier", 1) # generation self.max_output_length = train_config.get("max_output_length", None) # CPU / GPU self.use_cuda = train_config["use_cuda"] if self.use_cuda: self.model.cuda() self.loss.cuda() # initialize training statistics self.steps = 0 # stop training if this flag is True by reaching learning rate minimum self.stop = False self.total_tokens = 0 self.best_ckpt_iteration = 0 # initial values for best scores self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf mrt_schedule = train_config.get("mrt_schedule", None) assert mrt_schedule is None or mrt_schedule in ["warmup", "mix", "mtl"] self.mrt_schedule = mrt_schedule self.mrt_p = train_config.get("mrt_p", 0.0) self.mrt_lambda = train_config.get("mrt_lambda", 1.0) assert 0 <= self.mrt_p <= 1 assert 0 <= self.mrt_lambda <= 1 self.mrt_start_steps = train_config.get("mrt_start_steps", 0) self.mrt_samples = train_config.get("mrt_samples", 1) self.mrt_alpha = train_config.get("mrt_alpha", 1.0) self.mrt_strategy = train_config.get("mrt_strategy", "sample") self.mrt_cost = train_config.get("mrt_cost", "levenshtein") self.mrt_max_len = train_config.get("mrt_max_len", 31) # hmm self.step_counter = count() assert self.mrt_alpha > 0 assert self.mrt_strategy in ["sample", "topk"] assert self.mrt_cost in ["levenshtein", "bleu"] # model parameters if "load_model" in train_config.keys(): model_load_path = train_config["load_model"] reset_training = train_config.get("reset_training", False) self.logger.info("Loading model from %s", model_load_path) self.init_from_checkpoint(model_load_path, reset=reset_training)