Пример #1
0
    def _build(self, batch_size):
        src_time_dim = 4
        vocab_size = 7

        emb = Embeddings(embedding_dim=self.emb_size,
                         vocab_size=vocab_size,
                         padding_idx=self.pad_index)

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     emb_dropout=self.dropout,
                                     vocab_size=vocab_size)

        encoder_output = torch.rand(size=(batch_size, src_time_dim,
                                          self.hidden_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1

        encoder_hidden = None  # unused

        model = Model(encoder=None,
                      decoder=decoder,
                      src_embed=emb,
                      trg_embed=emb,
                      src_vocab=self.vocab,
                      trg_vocab=self.vocab)
        return src_mask, model, encoder_output, encoder_hidden
Пример #2
0
    def _build(self, batch_size):
        src_time_dim = 4
        vocab_size = 7

        emb = Embeddings(embedding_dim=self.emb_size,
                         vocab_size=vocab_size,
                         padding_idx=self.pad_index)

        encoder = RecurrentEncoder(emb_size=self.emb_size,
                                   num_layers=self.num_layers,
                                   hidden_size=self.encoder_hidden_size,
                                   bidirectional=True)

        decoder = RecurrentDecoder(hidden_size=self.hidden_size,
                                   encoder=encoder,
                                   attention="bahdanau",
                                   emb_size=self.emb_size,
                                   vocab_size=self.vocab_size,
                                   num_layers=self.num_layers,
                                   init_hidden="bridge",
                                   input_feeding=True)

        encoder_output = torch.rand(size=(batch_size, src_time_dim,
                                          encoder.output_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1

        encoder_hidden = torch.rand(size=(batch_size, encoder.output_size))

        model = Model(encoder=encoder,
                      decoder=decoder,
                      src_embed=emb,
                      trg_embed=emb,
                      src_vocab=self.vocab,
                      trg_vocab=self.vocab)

        return src_mask, model, encoder_output, encoder_hidden
Пример #3
0
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger("{}/train.log".format(self.model_dir))
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.loss = XentLoss(pad_index=self.pad_index,
                             smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens", "none"]:
            raise ConfigurationError("Invalid normalization option."
                                     "Valid options: "
                                     "'batch', 'tokens', 'none'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in [
                'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy'
        ]:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf', "
                                     "'token_accuracy', 'sequence_accuracy'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric.
        # If we schedule after BLEU/chrf/accuracy, we want to maximize the
        # score, else we want to minimize it.
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in [
                    "bleu", "chrf", "token_accuracy", "sequence_accuracy"
            ]:
                self.minimize_metric = False
            # eval metric that has to get minimized (not yet implemented)
            else:
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)
        self.current_batch_multiplier = self.batch_multiplier

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize accumalted batch loss (needed for batch_multiplier)
        self.norm_batch_loss_accumulated = 0
        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf
        # comparison function for scores
        self.is_best = lambda score: score < self.best_ckpt_score \
            if self.minimize_metric else score > self.best_ckpt_score

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            reset_best_ckpt = train_config.get("reset_best_ckpt", False)
            reset_scheduler = train_config.get("reset_scheduler", False)
            reset_optimizer = train_config.get("reset_optimizer", False)
            self.init_from_checkpoint(model_load_path,
                                      reset_best_ckpt=reset_best_ckpt,
                                      reset_scheduler=reset_scheduler,
                                      reset_optimizer=reset_optimizer)
Пример #4
0
def validate_on_data(model: Model, data: Dataset,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     n_gpu: int,
                     batch_class: Batch = Batch,
                     compute_loss: bool = False,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True,
                     bpe_type: str = "subword-nmt",
                     sacrebleu: dict = None,
                     n_best: int = 1) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `compute_loss` is True and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param batch_class: class type of batch
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param n_gpu: number of GPUs
    :param compute_loss: whether to computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations
    :param bpe_type: bpe type, one of {"subword-nmt", "sentencepiece"}
    :param sacrebleu: sacrebleu options
    :param n_best: Amount of candidates to return

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    assert batch_size >= n_gpu, "batch_size must be bigger than n_gpu."
    if sacrebleu is None:  # assign default value
        sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
    if batch_size > 1000 and batch_type == "sentence":
        logger.warning(
            "WARNING: Are you sure you meant to work on huge batches like "
            "this? 'batch_size' is > 1000 for sentence-batching. "
            "Consider decreasing it or switching to"
            " 'eval_batch_type: token'.")
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False)
    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = batch_class(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            reverse_index = batch.sort_by_src_length()
            sort_reverse_index = expand_reverse_index(reverse_index, n_best)

            # run as during training with teacher forcing
            if compute_loss and batch.trg is not None:
                batch_loss, _, _, _ = model(return_type="loss", **vars(batch))
                if n_gpu > 1:
                    batch_loss = batch_loss.mean()  # average on multi-gpu
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = run_batch(
                model=model,
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length,
                n_best=n_best)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data) * n_best

        if compute_loss and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [
                bpe_postprocess(s, bpe_type=bpe_type) for s in valid_sources
            ]
            valid_references = [
                bpe_postprocess(v, bpe_type=bpe_type) for v in valid_references
            ]
            valid_hypotheses = [
                bpe_postprocess(v, bpe_type=bpe_type) for v in valid_hypotheses
            ]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses,
                                           valid_references,
                                           tokenize=sacrebleu["tokenize"])
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(
                    valid_hypotheses,
                    valid_references,
                    remove_whitespace=sacrebleu["remove_whitespace"])
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(  # supply List[List[str]]
                    list(decoded_valid), list(data.trg))
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores
Пример #5
0
    def __init__(self,
                 model: Model,
                 config: dict,
                 batch_class: Batch = Batch) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        :param batch_class: batch class to encapsulate the torch class
        """
        train_config = config["training"]
        self.batch_class = batch_class

        # files for logging and storing
        self.model_dir = train_config["model_dir"]
        assert os.path.exists(self.model_dir)

        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        self.save_latest_checkpoint = train_config.get("save_latest_ckpt",
                                                       True)

        # model
        self.model = model
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.model.loss_function = XentLoss(pad_index=self.model.pad_index,
                                            smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens", "none"]:
            raise ConfigurationError("Invalid normalization option."
                                     "Valid options: "
                                     "'batch', 'tokens', 'none'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = collections.deque(
            maxlen=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in [
                'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy'
        ]:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf', "
                                     "'token_accuracy', 'sequence_accuracy'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric.
        # If we schedule after BLEU/chrf/accuracy, we want to maximize the
        # score, else we want to minimize it.
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in [
                    "bleu", "chrf", "token_accuracy", "sequence_accuracy"
            ]:
                self.minimize_metric = False
            # eval metric that has to get minimized (not yet implemented)
            else:
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # eval options
        test_config = config["testing"]
        self.bpe_type = test_config.get("bpe_type", "subword-nmt")
        self.sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
        if "sacrebleu" in config["testing"].keys():
            self.sacrebleu["remove_whitespace"] = test_config["sacrebleu"] \
                .get("remove_whitespace", True)
            self.sacrebleu["tokenize"] = test_config["sacrebleu"] \
                .get("tokenize", "13a")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        # Placeholder so that we can use the train_iter in other functions.
        self.train_iter = None
        self.train_iter_state = None
        # per-device batch_size = self.batch_size // self.n_gpu
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        # per-device eval_batch_size = self.eval_batch_size // self.n_gpu
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"] and torch.cuda.is_available()
        self.n_gpu = torch.cuda.device_count() if self.use_cuda else 0
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        if self.use_cuda:
            self.model.to(self.device)

        # fp16
        self.fp16 = train_config.get("fp16", False)
        if self.fp16:
            if 'apex' not in sys.modules:
                raise ImportError("Please install apex from "
                                  "https://www.github.com/nvidia/apex "
                                  "to use fp16 training.") from no_apex
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')
            # opt level: one of {"O0", "O1", "O2", "O3"}
            # see https://nvidia.github.io/apex/amp.html#opt-levels

        # initialize training statistics
        self.stats = self.TrainStatistics(
            steps=0,
            stop=False,
            total_tokens=0,
            best_ckpt_iter=0,
            best_ckpt_score=np.inf if self.minimize_metric else -np.inf,
            minimize_metric=self.minimize_metric)

        # model parameters
        if "load_model" in train_config.keys():
            self.init_from_checkpoint(
                train_config["load_model"],
                reset_best_ckpt=train_config.get("reset_best_ckpt", False),
                reset_scheduler=train_config.get("reset_scheduler", False),
                reset_optimizer=train_config.get("reset_optimizer", False),
                reset_iter_state=train_config.get("reset_iter_state", False))

        # multi-gpu training (should be after apex fp16 initialization)
        if self.n_gpu > 1:
            self.model = _DataParallel(self.model)
Пример #6
0
def validate_on_data(model: Model, data: Dataset,
                     logger: Logger,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param logger: logger
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    if batch_size > 1000 and batch_type == "sentence":
        logger.warning(
            "WARNING: Are you sure you meant to work on huge batches like "
            "this? 'batch_size' is > 1000 for sentence-batching. "
            "Consider decreasing it or switching to"
            " 'eval_batch_type: token'.")
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False)
    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores
Пример #7
0
def validate_on_data(model: Model,
                     data: Dataset,
                     batch_size: int,
                     use_cuda: bool,
                     max_output_length: int,
                     trg_level: str,
                     eval_metrics: Optional[Sequence[str]],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0,
                     force_prune_size: int = 5,
                     beam_alpha: int = 0,
                     batch_type: str = "sentence",
                     save_attention: bool = False,
                     validate_by_label: bool = False,
                     forced_sparsity: bool = False,
                     method=None,
                     max_hyps=1,
                     break_at_p: float = 1.0,
                     break_at_argmax: bool = False,
                     short_depth: int = 0):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model:
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda:
    :param max_output_length: maximum length for generated hypotheses
    :param trg_level: target segmentation level
    :param eval_metrics:
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation (default 0 is greedy)
    :param beam_alpha: beam search alpha for length penalty (default 0)
    :param batch_type: validation batch type (sentence or token)

    :return:
        - current_valid_scores: current validation score [eval_metric],
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    if beam_size > 0:
        force_prune_size = beam_size

    if validate_by_label:
        assert isinstance(data, TSVDataset) and data.label_columns

    valid_scores = defaultdict(float)  # container for scores
    stats = defaultdict(float)

    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False,
                                use_cuda=use_cuda)

    pad_index = model.trg_vocab.stoi[PAD_TOKEN]

    model.eval()  # disable dropout

    force_objectives = loss_function is not None or forced_sparsity

    # possible tasks are: force w/ gold, force w/ empty, search
    scorer = partial(len_penalty, alpha=beam_alpha) if beam_alpha > 0 else None
    confidences = []
    corrects = []
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = defaultdict(list)
        for valid_batch in iter(valid_iter):
            batch = Batch(valid_batch, pad_index)
            rev_index = batch.sort_by_src_lengths()

            encoder_output, _ = model.encode(batch)

            empty_probs = None
            if force_objectives and not isinstance(model, EnsembleModel):
                # compute all the logits.
                logits = model.force_decode(batch, encoder_output)[0]
                bsz, gold_len, vocab_size = logits.size()
                gold, gold_lengths, _ = batch["trg"]
                prediction_steps = gold_lengths.sum().item() - bsz
                assert gold.size(0) == bsz

                if loss_function is not None:
                    gold_pred = gold[:, 1:].contiguous().view(-1)
                    batch_loss = loss_function(
                        logits.view(-1, logits.size(-1)), gold_pred)
                    valid_scores["loss"] += batch_loss

                if forced_sparsity:
                    # compute probabilities
                    out = logits.view(-1, vocab_size)
                    if isinstance(model, EnsembleModel):
                        probs = out
                    else:
                        probs = model.decoder.gen_func(out, dim=-1)

                    # Compute numbers derived from the distributions.
                    # This includes support size, entropy, and calibration
                    non_pad = (gold[:, 1:] != pad_index).view(-1)
                    real_probs = probs[non_pad]
                    n_supported = real_probs.gt(0).sum().item()
                    pred_ps, pred_ix = real_probs.max(dim=-1)
                    real_gold = gold[:, 1:].contiguous().view(-1)[non_pad]
                    real_correct = pred_ix.eq(real_gold)
                    corrects.append(real_correct)
                    confidences.append(pred_ps)

                    beam_probs, _ = real_probs.topk(force_prune_size, dim=-1)
                    pruned_mass = 1 - beam_probs.sum(dim=-1)
                    stats["force_pruned_mass"] += pruned_mass.sum().item()

                    # compute stuff with the empty sequence
                    empty_probs = probs.view(bsz, gold_len,
                                             vocab_size)[:, 0, model.eos_index]
                    assert empty_probs.size() == gold_lengths.size()
                    empty_possible = empty_probs.gt(0).sum().item()
                    empty_mass = empty_probs.sum().item()

                    stats["eos_supported"] += empty_possible
                    stats["eos_mass"] += empty_mass
                    stats["n_supp"] += n_supported
                    stats["n_pred"] += prediction_steps

                short_scores = None
                if short_depth > 0:
                    # we call run_batch again with the short depth. We don't
                    # really care what the hypotheses are, we only want the
                    # scores
                    _, _, short_scores = model.run_batch(
                        batch=batch,
                        beam_size=beam_size,  # can this be removed?
                        scorer=scorer,  # should be none
                        max_output_length=short_depth,
                        method="dfs",
                        max_hyps=max_hyps,
                        encoder_output=encoder_output,
                        return_scores=True)

            # run as during inference to produce translations
            # todo: return_scores for greedy
            output, attention_scores, beam_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                scorer=scorer,
                max_output_length=max_output_length,
                method=method,
                max_hyps=max_hyps,
                encoder_output=encoder_output,
                return_scores=True,
                break_at_argmax=break_at_argmax,
                break_at_p=break_at_p)
            stats["hyp_length"] += output.ne(model.pad_index).sum().item()
            if beam_scores is not None and empty_probs is not None:
                # I need to expand this to handle stuff up to length m.
                # note that although you can compute the probability of the
                # empty sequence without any extra computation, you *do* need
                # to do extra decoding if you want to get the most likely
                # sequence with length <= m.
                empty_better = empty_probs.log().gt(beam_scores).sum().item()
                stats["empty_better"] += empty_better

                if short_scores is not None:
                    short_better = short_scores.gt(beam_scores).sum().item()
                    stats["short_better"] += short_better

            # sort outputs back to original order
            all_outputs.extend(output[rev_index])

            if save_attention and attention_scores is not None:
                # beam search currently does not support attention logging
                for k, v in attention_scores.items():
                    valid_attention_scores[k].extend(v[rev_index])

        assert len(all_outputs) == len(data)

    ref_length = sum(len(d.trg) for d in data)
    valid_scores["length_ratio"] = stats["hyp_length"] / ref_length

    assert len(corrects) == len(confidences)
    if corrects:
        valid_scores["ece"] = expected_calibration_error(corrects, confidences)

    if stats["n_pred"] > 0:
        valid_scores["ppl"] = math.exp(valid_scores["loss"] / stats["n_pred"])

    if forced_sparsity and stats["n_pred"] > 0:
        valid_scores["support"] = stats["n_supp"] / stats["n_pred"]
        valid_scores["empty_possible"] = stats["eos_supported"] / len(
            all_outputs)
        valid_scores["empty_prob"] = stats["eos_mass"] / len(all_outputs)
        valid_scores[
            "force_pruned_mass"] = stats["force_pruned_mass"] / stats["n_pred"]
        if beam_size > 0:
            valid_scores["empty_better"] = stats["empty_better"] / len(
                all_outputs)
            if short_depth > 0:
                score_name = "depth_{}_better".format(short_depth)
                valid_scores[score_name] = stats["short_better"] / len(
                    all_outputs)

    # postprocess
    raw_hyps = model.trg_vocab.arrays_to_sentences(all_outputs)
    valid_hyps = postprocess(raw_hyps, trg_level)
    valid_refs = postprocess(data.trg, trg_level)

    # evaluate
    eval_funcs = {
        "bleu": bleu,
        "chrf": chrf,
        "token_accuracy": partial(token_accuracy, level=trg_level),
        "sequence_accuracy": sequence_accuracy,
        "wer": word_error_rate,
        "cer": partial(character_error_rate, level=trg_level),
        "levenshtein_distance": partial(levenshtein_distance, level=trg_level)
    }
    selected_eval_metrics = {name: eval_funcs[name] for name in eval_metrics}
    decoding_scores, scores_by_label = evaluate_decoding(
        data, valid_refs, valid_hyps, selected_eval_metrics, validate_by_label)
    valid_scores.update(decoding_scores)

    return valid_scores, valid_refs, valid_hyps, \
        raw_hyps, valid_attention_scores, scores_by_label
Пример #8
0
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger(model_dir=self.model_dir)
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = join(self.model_dir, "validations.txt")
        self.tb_writer = SummaryWriter(
            log_dir=join(self.model_dir, "tensorboard/")
        )
        self.log_sparsity = train_config.get("log_sparsity", False)

        self.apply_mask = train_config.get("apply_mask", False)
        self.valid_apply_mask = train_config.get("valid_apply_mask", True)

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        objective = train_config.get("loss", "cross_entropy")
        loss_alpha = train_config.get("loss_alpha", 1.5)
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        if self.label_smoothing > 0 and objective == "cross_entropy":
            xent_loss = partial(
                LabelSmoothingLoss, smoothing=self.label_smoothing)
        else:
            xent_loss = nn.CrossEntropyLoss

        assert loss_alpha >= 1
        entmax_loss = partial(
            EntmaxBisectLoss, alpha=loss_alpha, n_iter=30
        )

        loss_funcs = {"cross_entropy": xent_loss,
                      "entmax15": partial(Entmax15Loss, k=512),
                      "sparsemax": partial(SparsemaxLoss, k=512),
                      "entmax": entmax_loss}
        if objective not in loss_funcs:
            raise ConfigurationError("Unknown loss function")
        loss_func = loss_funcs[objective]
        self.loss = loss_func(ignore_index=self.pad_index, reduction='sum')

        if "language_loss" in train_config:
            assert "language_weight" in train_config
            self.language_loss = loss_func(
                ignore_index=self.pad_index, reduction='sum'
            )
            self.language_weight = train_config["language_weight"]
        else:
            self.language_loss = None
            self.language_weight = 0.0

        self.norm_type = train_config.get("normalization", "batch")
        if self.norm_type not in ["batch", "tokens"]:
            raise ConfigurationError("Invalid normalization. "
                                     "Valid options: 'batch', 'tokens'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(
            config=train_config, parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.plot_attention = train_config.get("plot_attention", False)
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))

        allowed = {'bleu', 'chrf', 'token_accuracy',
                   'sequence_accuracy', 'cer', 'wer'}
        eval_metrics = train_config.get("eval_metric", "bleu")
        if isinstance(eval_metrics, str):
            eval_metrics = [eval_metrics]
        if any(metric not in allowed for metric in eval_metrics):
            ok_metrics = " ".join(allowed)
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: {}".format(ok_metrics))
        self.eval_metrics = eval_metrics

        early_stop_metric = train_config.get("early_stopping_metric", "loss")
        allowed_early_stop = {"ppl", "loss"} | set(self.eval_metrics)
        if early_stop_metric not in allowed_early_stop:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', and eval_metrics.")
        self.early_stopping_metric = early_stop_metric
        self.minimize_metric = early_stop_metric in {"ppl", "loss",
                                                     "cer", "wer"}

        attn_metrics = train_config.get("attn_metric", [])
        if isinstance(attn_metrics, str):
            attn_metrics = [attn_metrics]
        ok_attn_metrics = {"support"}
        assert all(met in ok_attn_metrics for met in attn_metrics)
        self.attn_metrics = attn_metrics

        # learning rate scheduling
        if "encoder" in config["model"]:
            hidden_size = config["model"]["encoder"]["hidden_size"]
        else:
            hidden_size = config["model"]["encoders"]["src"]["hidden_size"]
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=hidden_size)

        # data & batch handling
        data_cfg = config["data"]
        self.src_level = data_cfg.get(
            "src_level", data_cfg.get("level", "word")
        )
        self.trg_level = data_cfg.get(
            "trg_level", data_cfg.get("level", "word")
        )
        levels = ["word", "bpe", "char"]
        if self.src_level not in levels or self.trg_level not in levels:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")

        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            restart_training = train_config.get("restart_training", False)
            self.init_from_checkpoint(model_load_path, restart_training)
Пример #9
0
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger(model_dir=self.model_dir)
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        self.loss = WeightedCrossEntropy(ignore_index=self.pad_index)
        #nn.NLLLoss(ignore_index=self.pad_index, reduction='sum')
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens"]:
            raise ConfigurationError("Invalid normalization. "
                                     "Valid options: 'batch', 'tokens'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)
        self.clip_grad_fun = build_gradient_clipper(config=train_config)

        # re-order the model parameters by name before initialisation of optimizer
        # Reference: https://github.com/pytorch/pytorch/issues/1489
        all_params = list(model.named_parameters())
        sorted_params = sorted(all_params)
        sorted_params = OrderedDict(sorted_params)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=sorted_params.values())

        # save checkpoint by epoch
        self.save_freq = train_config.get("save_freq", -1)

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in ['bleu', 'chrf']:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")
        # if we schedule after BLEU/chrf, we want to maximize it, else minimize
        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in ["bleu", "chrf"]:
                self.minimize_metric = False
            else:  # eval metric that has to get minimized (not yet implemented)
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")
        self.post_process = config["data"].get("post_process", True)

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer)

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            self.init_from_checkpoint(model_load_path)

        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf
        # comparison function for scores
        self.is_best = lambda score: score < self.best_ckpt_score \
            if self.minimize_metric else score > self.best_ckpt_score

        # for learning with logged feedback
        if config["data"].get("feedback", None) is not None:
            self.logger.info("Learning with token-level feedback.")
        self.return_logp = config["testing"].get("return_logp", False)
Пример #10
0
def validate_on_data(model: Model, data: Dataset,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     src_level: str,
                     trg_level: str,
                     eval_metrics: Optional[Sequence[str]],
                     attn_metrics: Optional[Sequence[str]],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0, beam_alpha: int = 0,
                     batch_type: str = "sentence",
                     save_attention: bool = False,
                     log_sparsity: bool = False,
                     apply_mask: bool = True  # hmm
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param src_level: source segmentation level, one of "char", "bpe", "word"
    :param trg_level: target segmentation level, one of "char", "bpe", "word"
    :param eval_metrics: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If 0 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to 0 (default).
    :param batch_type: validation batch type (sentence or token)

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    eval_funcs = {
        "bleu": bleu,
        "chrf": chrf,
        "token_accuracy": partial(token_accuracy, level=trg_level),
        "sequence_accuracy": sequence_accuracy,
        "wer": wer,
        "cer": partial(character_error_rate, level=trg_level)
    }
    selected_eval_metrics = {name: eval_funcs[name] for name in eval_metrics}

    valid_iter = make_data_iter(
        dataset=data, batch_size=batch_size, batch_type=batch_type,
        shuffle=False, train=False)
    valid_sources_raw = [s for s in data.src]
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    scorer = partial(len_penalty, alpha=beam_alpha) if beam_alpha > 0 else None
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = defaultdict(list)
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        total_attended = defaultdict(int)
        greedy_steps = 0
        greedy_supported = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores, probs = model.run_batch(
                batch=batch, beam_size=beam_size, scorer=scorer,
                max_output_length=max_output_length, log_sparsity=log_sparsity,
                apply_mask=apply_mask)
            if log_sparsity:
                lengths = torch.LongTensor((output == model.trg_vocab.stoi[EOS_TOKEN]).argmax(axis=1)).unsqueeze(1)
                batch_greedy_steps = lengths.sum().item()
                greedy_steps += lengths.sum().item()

                ix = torch.arange(output.shape[1]).unsqueeze(0).expand(output.shape[0], -1)
                mask = ix <= lengths
                supp = probs.exp().gt(0).sum(dim=-1).cpu()  # batch x len
                supp = torch.where(mask, supp, torch.tensor(0)).sum()
                greedy_supported += supp.float().item()

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])

            if attention_scores is not None:
                # is attention_scores ever None?
                if save_attention:
                    # beam search currently does not support attention logging
                    for k, v in attention_scores.items():
                        valid_attention_scores[k].extend(v[sort_reverse_index])
                if attn_metrics:
                    # add to total_attended
                    for k, v in attention_scores.items():
                        total_attended[k] += (v > 0).sum()

        assert len(all_outputs) == len(data)

        if log_sparsity:
            print(greedy_supported / greedy_steps)

        valid_scores = dict()
        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            valid_scores["loss"] = total_loss
            valid_scores["ppl"] = torch.exp(total_loss / total_ntokens)

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        src_join_char = " " if src_level in ["word", "bpe"] else ""
        trg_join_char = " " if trg_level in ["word", "bpe"] else ""
        valid_sources = [src_join_char.join(s) for s in data.src]
        valid_references = [trg_join_char.join(t) for t in data.trg]
        valid_hypotheses = [trg_join_char.join(t) for t in decoded_valid]

        if attn_metrics:
            decoded_ntokens = sum(len(t) for t in decoded_valid)
            for attn_metric in attn_metrics:
                assert attn_metric == "support"
                for attn_name, tot_attended in total_attended.items():
                    score_name = attn_name + "_" + attn_metric
                    # this is not the right denominator
                    valid_scores[score_name] = tot_attended / decoded_ntokens

        # post-process
        if src_level == "bpe":
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
        if trg_level == "bpe":
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        languages = [language for language in data.language]
        by_language = defaultdict(list)
        seqs = zip(valid_references, valid_hypotheses) if valid_references else valid_hypotheses
        if languages:
            examples = zip(languages, seqs)
            for lang, seq in examples:
                by_language[lang].append(seq)
        else:
            by_language[None].extend(seqs)

        # if references are given, evaluate against them
        # incorrect if-condition?
        # scores_by_lang = {name: dict() for name in selected_eval_metrics}
        scores_by_lang = dict()
        if valid_references and eval_metrics is not None:
            assert len(valid_hypotheses) == len(valid_references)

            for eval_metric, eval_func in selected_eval_metrics.items():
                score_by_lang = dict()
                for lang, pairs in by_language.items():
                    lang_hyps, lang_refs = zip(*pairs)
                    lang_score = eval_func(lang_hyps, lang_refs)
                    score_by_lang[lang] = lang_score

                score = sum(score_by_lang.values()) / len(score_by_lang)
                valid_scores[eval_metric] = score
                scores_by_lang[eval_metric] = score_by_lang

    if not languages:
        scores_by_lang = None
    return valid_scores, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores, scores_by_lang, by_language
Пример #11
0
def transformer_greedy(src_mask: Tensor, max_output_length: int, model: Model,
                       encoder_output: Tensor, encoder_hidden: Tensor,
                       trg_embed: Embeddings) -> (np.array, None):
    """
    Special greedy function for transformer, since it works differently.
    The transformer remembers all previous states and attends to them.

    :param src_mask: mask for source inputs, 0 for positions after </s>
    :param max_output_length: maximum length for the hypotheses
    :param model: model to use for greedy decoding
    :param encoder_output: encoder hidden states for attention
    :param encoder_hidden: encoder final state (unused in Transformer)
    :return:
        - stacked_output: output hypotheses (2d array of indices),
        - stacked_attention_scores: attention scores (3d array)
    """

    with torch.no_grad():

        bos_index = model.bos_index
        eos_index = model.eos_index
        batch_size = src_mask.size(0)

        # start with BOS-symbol for each sentence in the batch
        ys = encoder_output.new_full([batch_size, 1],
                                     bos_index,
                                     dtype=torch.long)

        # a subsequent mask is intersected with this in decoder forward pass
        trg_mask = src_mask.new_ones([1, 1, 1])
        if isinstance(model, torch.nn.DataParallel):
            trg_mask = torch.stack(
                [src_mask.new_ones([1, 1]) for _ in model.device_ids])

        finished = src_mask.new_zeros(batch_size).byte()

        for _ in range(max_output_length):
            # pylint: disable=unused-variable
            logits, _, _, _ = model(
                return_type="decode",
                trg_input=ys,  # model.trg_embed(ys) # embed the previous tokens
                encoder_output=encoder_output,
                encoder_hidden=None,
                src_mask=src_mask,
                unroll_steps=None,
                decoder_hidden=None,
                trg_mask=trg_mask)

            assert False, "reimplement along lines of final RNN version"

            # logits = logits[:, -1]
            # _, next_word = torch.max(logits, dim=1)
            pred = logits[:, -1].unsqueeze(1)

            losses = model._loss_function(pred,
                                          None,
                                          trg_embed,
                                          do_nearest_neighbor=True)
            next_word = torch.argmin(losses, dim=-1).data

            ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1)

            # check if previous symbol was <eos>
            is_eos = torch.eq(next_word, eos_index)
            finished += is_eos
            # stop predicting if <eos> reached for all elements in batch
            if (finished >= 1).sum() == batch_size:
                break

        ys = ys[:, 1:]  # remove BOS-symbol
    return ys.detach().cpu().numpy(), None
Пример #12
0
def validate_on_data(model: Model,
                     data: Dataset,
                     batch_size: int,
                     use_cuda: bool,
                     max_output_length: int,
                     level: str,
                     eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0,
                     beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     kb_task = None,
                     valid_kb: Dataset = None,
                     valid_kb_lkp: list = [],
                     valid_kb_lens:list=[],
                     valid_kb_truvals: Dataset = None,
                     valid_data_canon: Dataset = None,
                     report_on_canonicals: bool = False,
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If 0 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param kb_task: is not None if kb_task should be executed
    :param valid_kb: MonoDataset holding the loaded valid kb data
    :param valid_kb_lkp: List with valid example index to corresponding kb indices
    :param valid_kb_len: List with amount of triples per kb 
    :param valid_data_canon: TranslationDataset of valid data but with canonized target data (for loss reporting)


    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
        - valid_ent_f1: TODO FIXME
    """

    print(f"\n{'-'*10} ENTER VALIDATION {'-'*10}\n")

    print(f"\n{'-'*10}  VALIDATION DEBUG {'-'*10}\n")

    print("---data---")
    print(dir(data[0]))
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr
    ] for example in data[:3]])
    print(batch_size)
    print(use_cuda)
    print(max_output_length)
    print(level)
    print(eval_metric)
    print(loss_function)
    print(beam_size)
    print(beam_alpha)
    print(batch_type)
    print(kb_task)
    print("---valid_kb---")
    print(dir(valid_kb[0]))
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr
    ] for example in valid_kb[:3]])
    print(len(valid_kb_lkp), valid_kb_lkp[-5:])
    print(len(valid_kb_lens), valid_kb_lens[-5:])
    print("---valid_kb_truvals---")
    print(len(valid_kb_truvals), valid_kb_lens[-5:])
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr or "trv" in attr
    ] for example in valid_kb_truvals[:3]])
    print("---valid_data_canon---")
    print(len(valid_data_canon), valid_data_canon[-5:])
    print([[
        getattr(example, attr) for attr in dir(example)
        if hasattr(getattr(example, attr), "__iter__") and "kb" in attr
        or "src" in attr or "trg" in attr or "trv" or "can" in attr
    ] for example in valid_data_canon[:3]])
    print(report_on_canonicals)

    print(f"\n{'-'*10} END VALIDATION DEBUG {'-'*10}\n")

    if not kb_task:
        valid_iter = make_data_iter(dataset=data,
                                    batch_size=batch_size,
                                    batch_type=batch_type,
                                    shuffle=False,
                                    train=False)
    else:
        # knowledgebase version of make data iter and also provide canonized target data
        # data: for bleu/ent f1
        # canon_data: for loss
        valid_iter = make_data_iter_kb(data,
                                       valid_kb,
                                       valid_kb_lkp,
                                       valid_kb_lens,
                                       valid_kb_truvals,
                                       batch_size=batch_size,
                                       batch_type=batch_type,
                                       shuffle=False,
                                       train=False,
                                       canonize=model.canonize,
                                       canon_data=valid_data_canon)

    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]

    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        valid_kb_att_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda) \
                                if not kb_task else \
                Batch_with_KB(valid_batch, pad_index, use_cuda=use_cuda)

            assert hasattr(batch, "kbsrc") == bool(kb_task)

            # sort batch now by src length and keep track of order
            if not kb_task:
                sort_reverse_index = batch.sort_by_src_lengths()
            else:
                sort_reverse_index = list(range(batch.src.shape[0]))

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:

                ntokens = batch.ntokens
                if hasattr(batch, "trgcanon") and batch.trgcanon is not None:
                    ntokens = batch.ntokenscanon  # normalize loss with num canonical tokens for perplexity
                # do a loss calculation without grad updates just to report valid loss
                # we can only do this when batch.trg exists, so not during actual translation/deployment
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                # keep track of metrics for reporting
                total_loss += batch_loss
                total_ntokens += ntokens  # gold target tokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores, kb_att_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])
            valid_kb_att_scores.extend(kb_att_scores[sort_reverse_index]
                                       if kb_att_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log likelihood
            # can be seen as 2^(cross_entropy of model on valid set); normalized by num tokens;
            # see https://en.wikipedia.org/wiki/Perplexity#Perplexity_per_word
            valid_ppl = torch.exp(valid_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoding_vocab = model.trg_vocab if not kb_task else model.trv_vocab

        decoded_valid = decoding_vocab.arrays_to_sentences(arrays=all_outputs,
                                                           cut_at_eos=True)

        print(f"decoding_vocab.itos: {decoding_vocab.itos}")
        print(decoded_valid)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        # TODO replace valid_references with uncanonicalized dev.car data ... requires writing new Dataset in data.py
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe":
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            print(list(zip(valid_sources, valid_references, valid_hypotheses)))

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)

            if kb_task:
                valid_ent_f1, valid_ent_mcc = calc_ent_f1_and_ent_mcc(
                    valid_hypotheses,
                    valid_references,
                    vocab=model.trv_vocab,
                    c_fun=model.canonize,
                    report_on_canonicals=report_on_canonicals)

            else:
                valid_ent_f1, valid_ent_mcc = -1, -1
        else:
            current_valid_score = -1

    print(f"\n{'-'*10} EXIT VALIDATION {'-'*10}\n")
    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores, valid_kb_att_scores, \
        valid_ent_f1, valid_ent_mcc
Пример #13
0
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = train_config["model_dir"]
        make_model_dir(
            self.model_dir, overwrite=train_config.get("overwrite", False)
        )
        self.logger = make_logger(model_dir=self.model_dir)
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = join(self.model_dir, "validations.txt")
        self.tb_writer = SummaryWriter(
            log_dir=join(self.model_dir, "tensorboard/")
        )

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self._log_parameters_list()

        # objective
        objective = train_config.get("loss", "cross_entropy")
        loss_alpha = train_config.get("loss_alpha", 1.5)

        assert loss_alpha >= 1
        # maybe don't do the label smoothing thing here, instead have
        # nn.CrossEntropyLoss
        # then you look up the loss func, and you either use it directly or
        # wrap it in FYLabelSmoothingLoss
        if objective == "softmax":
            objective = "cross_entropy"
        loss_funcs = {
            "cross_entropy": nn.CrossEntropyLoss,
            "entmax15": partial(Entmax15Loss, k=512),
            "sparsemax": partial(SparsemaxLoss, k=512),
            "entmax": partial(EntmaxBisectLoss, alpha=loss_alpha, n_iter=30)
        }
        if objective not in loss_funcs:
            raise ConfigurationError("Unknown loss function")

        loss_module = loss_funcs[objective]
        loss_func = loss_module(ignore_index=self.pad_index, reduction='sum')

        label_smoothing = train_config.get("label_smoothing", 0.0)
        label_smoothing_type = train_config.get("label_smoothing_type", "fy")
        assert label_smoothing_type in ["fy", "szegedy"]
        smooth_dist = train_config.get("smoothing_distribution", "uniform")
        assert smooth_dist in ["uniform", "unigram"]
        if label_smoothing > 0:
            if label_smoothing_type == "fy":
                # label smoothing entmax loss
                if smooth_dist is not None:
                    smooth_p = torch.FloatTensor(model.trg_vocab.frequencies)
                    smooth_p /= smooth_p.sum()
                else:
                    smooth_p = None
                loss_func = FYLabelSmoothingLoss(
                    loss_func, smoothing=label_smoothing, smooth_p=smooth_p
                )
            else:
                assert objective == "cross_entropy"
                loss_func = LabelSmoothingLoss(
                    ignore_index=self.pad_index,
                    reduction="sum",
                    smoothing=label_smoothing
                )
        self.loss = loss_func

        self.norm_type = train_config.get("normalization", "batch")
        if self.norm_type not in ["batch", "tokens"]:
            raise ConfigurationError("Invalid normalization. "
                                     "Valid options: 'batch', 'tokens'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(
            config=train_config, parameters=model.parameters())

        # validation & early stopping
        self.validate_by_label = train_config.get("validate_by_label", False)
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.plot_attention = train_config.get("plot_attention", False)
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))

        allowed = {'bleu', 'chrf', 'token_accuracy',
                   'sequence_accuracy', 'cer', "wer", "levenshtein_distance"}
        eval_metrics = train_config.get("eval_metric", "bleu")
        if isinstance(eval_metrics, str):
            eval_metrics = [eval_metrics]
        if any(metric not in allowed for metric in eval_metrics):
            ok_metrics = " ".join(allowed)
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: {}".format(ok_metrics))
        self.eval_metrics = eval_metrics
        self.forced_sparsity = train_config.get("forced_sparsity", False)

        early_stop_metric = train_config.get("early_stopping_metric", "loss")
        allowed_early_stop = {"ppl", "loss"} | set(self.eval_metrics)
        if early_stop_metric not in allowed_early_stop:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', and eval_metrics.")
        self.early_stopping_metric = early_stop_metric
        min_metrics = {"ppl", "loss", "cer", "wer", "levenshtein_distance"}
        self.minimize_metric = early_stop_metric in min_metrics

        # learning rate scheduling
        hidden_size = _parse_hidden_size(config["model"])
        self.scheduler, self.sched_incr = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=hidden_size)

        # data & batch handling
        # src/trg magic
        if "level" in config["data"]:
            self.src_level = self.trg_level = config["data"]["level"]
        else:
            assert "src_level" in config["data"]
            assert "trg_level" in config["data"]
            self.src_level = config["data"]["src_level"]
            self.trg_level = config["data"]["trg_level"]

        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf

        mrt_schedule = train_config.get("mrt_schedule", None)
        assert mrt_schedule is None or mrt_schedule in ["warmup", "mix", "mtl"]
        self.mrt_schedule = mrt_schedule
        self.mrt_p = train_config.get("mrt_p", 0.0)
        self.mrt_lambda = train_config.get("mrt_lambda", 1.0)
        assert 0 <= self.mrt_p <= 1
        assert 0 <= self.mrt_lambda <= 1
        self.mrt_start_steps = train_config.get("mrt_start_steps", 0)
        self.mrt_samples = train_config.get("mrt_samples", 1)
        self.mrt_alpha = train_config.get("mrt_alpha", 1.0)
        self.mrt_strategy = train_config.get("mrt_strategy", "sample")
        self.mrt_cost = train_config.get("mrt_cost", "levenshtein")
        self.mrt_max_len = train_config.get("mrt_max_len", 31)  # hmm
        self.step_counter = count()

        assert self.mrt_alpha > 0
        assert self.mrt_strategy in ["sample", "topk"]
        assert self.mrt_cost in ["levenshtein", "bleu"]

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            reset_training = train_config.get("reset_training", False)
            self.logger.info("Loading model from %s", model_load_path)
            self.init_from_checkpoint(model_load_path, reset=reset_training)