コード例 #1
0
ファイル: extractor.py プロジェクト: tnq177/witwicky
    def __init__(self, args):
        super(Extractor, self).__init__()
        config = getattr(configurations, args.proto)()
        self.logger = ut.get_logger(config['log_file'])
        self.model_file = args.model_file

        var_list = args.var_list
        save_to = args.save_to

        if var_list is None:
            raise ValueError('Empty var list')

        if self.model_file is None or not os.path.exists(self.model_file):
            raise ValueError('Input file or model file does not exist')

        if not os.path.exists(save_to):
            os.makedirs(save_to)

        self.logger.info('Extracting these vars: {}'.format(
            ', '.join(var_list)))

        model = Model(config)
        model.load_state_dict(torch.load(self.model_file))
        var_values = operator.attrgetter(*var_list)(model)

        if len(var_list) == 1:
            var_values = [var_values]

        for var, var_value in zip(var_list, var_values):
            var_path = os.path.join(save_to, var + '.npy')
            numpy.save(var_path, var_value.numpy())
コード例 #2
0
def create_train_model():
    train_graph = tf.Graph()
    mode = tf.contrib.learn.ModeKeys.TRAIN

    train_model = Model(mode, hyper_parameters)
    dataset_iterator = DatasetIterator(hyper_parameters)
    train = Train(mode, hyper_parameters)

    with train_graph.as_default():

        source_vocab_table, target_vocab_table = dataset_iterator.get_tables(
            share_vocab=False)

        source_dataset, target_dataset = dataset_iterator.get_datasets()

        train_iterator = dataset_iterator.get_iterator(
            source_vocab_table=source_vocab_table,
            target_vocab_table=target_vocab_table,
            source_dataset=source_dataset,
            target_dataset=target_dataset,
            source_max_len=hyper_parameters["source_max_len_train"],
            target_max_len=hyper_parameters["target_max_len_train"],
            #skip_count=skip_count_place_holder #todo probably we need this
        )

        logits, loss, final_context_state, sample_id = train_model.build_model(
            train_iterator, target_vocab_table)

        train.configure_train_eval_infer(iterator=train_iterator,
                                         logits=logits,
                                         loss=loss,
                                         sample_id=sample_id,
                                         final_state=final_context_state)

    return train_graph, train_iterator, train
コード例 #3
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        self.normalize_loss = self.config['normalize_loss']
        self.patience = self.config['patience']
        self.lr = self.config['lr']
        self.lr_decay = self.config['lr_decay']
        self.max_epochs = self.config['max_epochs']
        self.warmup_steps = self.config['warmup_steps']

        self.train_smooth_perps = []
        self.train_true_perps = []

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)

        self.val_per_epoch = self.config['val_per_epoch']
        self.validate_freq = int(self.config['validate_freq'])
        self.logger.info('Evaluate every {} {}'.format(
            self.validate_freq, 'epochs' if self.val_per_epoch else 'batches'))

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_nll_loss = 0.
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        self.model = Model(self.config).to(self.device)

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info('Model has {:,} parameters'.format(param_count))

        # get optimizer
        beta1 = self.config['beta1']
        beta2 = self.config['beta2']
        epsilon = self.config['epsilon']
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          betas=(beta1, beta2),
                                          eps=epsilon)
コード例 #4
0
    def translate(self):
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = Model(self.config).to(device)
        self.logger.info('Restore model from {}'.format(self.model_file))
        model.load_state_dict(torch.load(self.model_file))
        model.eval()

        best_trans_file = self.input_file + '.best_trans'
        beam_trans_file = self.input_file + '.beam_trans'
        open(best_trans_file, 'w').close()
        open(beam_trans_file, 'w').close()

        num_sents = 0
        with open(self.input_file, 'r') as f:
            for line in f:
                if line.strip():
                    num_sents += 1
        all_best_trans = [''] * num_sents
        all_beam_trans = [''] * num_sents

        with torch.no_grad():
            self.logger.info('Start translating {}'.format(self.input_file))
            start = time.time()
            count = 0
            for (src_toks, original_idxs) in self.data_manager.get_trans_input(
                    self.input_file):
                src_toks_cuda = src_toks.to(device)
                rets = model.beam_decode(src_toks_cuda)

                for i, ret in enumerate(rets):
                    probs = ret['probs'].cpu().detach().numpy().reshape([-1])
                    scores = ret['scores'].cpu().detach().numpy().reshape([-1])
                    symbols = ret['symbols'].cpu().detach().numpy()

                    best_trans, best_trans_ids, beam_trans = self.get_trans(
                        probs, scores, symbols)
                    all_best_trans[original_idxs[i]] = best_trans + '\n'
                    all_beam_trans[original_idxs[i]] = beam_trans + '\n\n'

                    count += 1
                    if count % 100 == 0:
                        self.logger.info(
                            '  Translating line {}, average {} seconds/sent'.
                            format(count, (time.time() - start) / count))

        model.train()

        with open(best_trans_file, 'w') as ftrans, open(beam_trans_file,
                                                        'w') as btrans:
            ftrans.write(''.join(all_best_trans))
            btrans.write(''.join(all_beam_trans))

        self.logger.info('Done translating {}, it takes {} minutes'.format(
            self.input_file,
            float(time.time() - start) / 60.0))
コード例 #5
0
ファイル: translate.py プロジェクト: sshearing/mt-project
 def get_model(self, mode):
     reuse = mode != ac.TRAINING
     d = self.config['init_range']
     initializer = tf.random_uniform_initializer(-d, d)
     with tf.variable_scope(self.config['model_name'],
                            reuse=reuse,
                            initializer=initializer):
         return Model(self.config, mode)
コード例 #6
0
def create_eval_model():
    eval_graph = tf.Graph()

    mode = tf.contrib.learn.ModeKeys.EVAL

    eval_model = Model(mode, hyper_parameters)
    dataset_iterator = DatasetIterator(hyper_parameters)
    eval = Train(mode, hyper_parameters)

    with eval_graph.as_default():
        source_vocab_file = hyper_parameters["source_vocab_file"]
        target_vocab_file = hyper_parameters["target_vocab_file"]

        source_vocab_table, target_vocab_table = dataset_iterator.get_tables(
            share_vocab=False)

        reverse_target_vocab_table = tf.contrib.lookup.index_to_string_table_from_file(
            target_vocab_file, default_value="UNK")

        source_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        target_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        source_dataset = tf.data.TextLineDataset(source_file_placeholder)
        target_dataset = tf.data.TextLineDataset(target_file_placeholder)

        eval_iterator = dataset_iterator.get_iterator(
            source_vocab_table=source_vocab_table,
            target_vocab_table=target_vocab_table,
            source_dataset=source_dataset,
            target_dataset=target_dataset,
            source_max_len=hyper_parameters["source_max_len_infer"],
            target_max_len=hyper_parameters["target_max_len_infer"])

        logits, loss, final_context_state, sample_id = eval_model.build_model(
            eval_iterator, target_vocab_table)

        eval.configure_train_eval_infer(
            iterator=eval_iterator,
            logits=logits,
            loss=loss,
            sample_id=sample_id,
            final_state=final_context_state,
            reverse_target_vocab_table=reverse_target_vocab_table)

    return eval_graph, eval, eval_iterator, source_file_placeholder, target_file_placeholder
コード例 #7
0
def create_infer_model():
    infer_graph = tf.Graph()
    mode = tf.contrib.learn.ModeKeys.INFER

    infer_model = Model(mode, hyper_parameters)
    dataset_iterator = DatasetIterator(hyper_parameters)
    infer = Train(mode, hyper_parameters)

    with infer_graph.as_default():
        source_vocab_file = hyper_parameters["source_vocab_file"]
        target_vocab_file = hyper_parameters["target_vocab_file"]

        reverse_target_vocab_table = tf.contrib.lookup.index_to_string_table_from_file(
            target_vocab_file, default_value="UNK")

        source_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        source_dataset = tf.data.Dataset.from_tensor_slices(source_placeholder)

        source_vocab_table, target_vocab_table = dataset_iterator.get_tables(
            share_vocab=False)

        infer_iterator = dataset_iterator.get_infer_iterator(
            source_dataset=source_dataset,
            source_vocab_table=source_vocab_table,
            source_max_len=hyper_parameters["source_max_len_infer"])

        logits, loss, final_context_state, sample_id = infer_model.build_model(
            infer_iterator, target_vocab_table)

        infer.configure_train_eval_infer(
            iterator=infer_iterator,
            logits=logits,
            loss=loss,
            sample_id=sample_id,
            final_state=final_context_state,
            reverse_target_vocab_table=reverse_target_vocab_table)

        return infer_graph, infer, infer_iterator, source_placeholder, batch_size_placeholder
コード例 #8
0
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = configurations.get_config(
            args.proto, getattr(configurations, args.proto),
            args.config_overrides)
        self.logger = ut.get_logger(self.config['log_file'])
        self.num_preload = args.num_preload

        self.model_file = args.model_file
        if self.model_file is None:
            self.model_file = os.path.join(self.config['save_to'],
                                           self.config['model_name'] + '.pth')

        self.input_file = args.input_file
        if self.input_file is not None and not os.path.exists(self.input_file):
            raise FileNotFoundError(
                f'Input file does not exist: {self.input_file}')
        if not os.path.exists(self.model_file):
            raise FileNotFoundError(
                f'Model file does not exist: {self.model_file}')

        self.logger.info(f'Restore model from {self.model_file}')
        self.model = Model(self.config,
                           load_from=self.model_file).to(ut.get_device())

        if self.input_file:
            save_fp = os.path.join(self.config['save_to'],
                                   os.path.basename(self.input_file))
            save_fp = save_fp.rstrip(self.model.data_manager.src_lang)
            save_fp = save_fp + self.model.data_manager.trg_lang
            self.best_output_fp = save_fp + '.best_trans'
            self.beam_output_fp = save_fp + '.beam_trans'
            open(self.best_output_fp, 'w').close()
            open(self.beam_output_fp, 'w').close()
        else:
            self.best_output_fp = self.beam_output_fp = None

        self.translate()
コード例 #9
0
    def __init__(self, args):
        super(Extractor, self).__init__()
        config = getattr(configurations, args.proto)()
        self.logger = ut.get_logger(config['log_file'])
        self.model_file = args.model_file

        var_list = args.var_list
        save_to = args.save_to

        if var_list is None:
            raise ValueError('Empty var list')

        if self.model_file is None or not os.path.exists(self.model_file + '.meta'):
            raise ValueError('Input file or model file does not exist')

        if not os.path.exists(save_to):
            os.makedirs(save_to)

        self.logger.info('Extracting these vars: {}'.format(', '.join(var_list)))

        with tf.Graph().as_default(), tf.Session() as sess:
            d = config['init_range']
            initializer = tf.random_uniform_initializer(-d, d)
            with tf.variable_scope(config['model_name'], reuse=False, initializer=initializer):
                model = Model(config, ac.TRAINING)

            saver = tf.train.Saver(var_list=tf.trainable_variables())
            saver.restore(sess, self.model_file)

            var_values = operator.attrgetter(*var_list)(model)
            var_values = sess.run(var_values)

            if len(var_list) == 1:
                var_values = [var_values]
                
            for var, var_value in izip(var_list, var_values):
                var_path = os.path.join(save_to, var + '.npy')
                numpy.save(var_path, var_value)
コード例 #10
0
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger(model_dir=self.model_dir)
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.loss = XentLoss(pad_index=self.pad_index,
                             smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens"]:
            raise ConfigurationError("Invalid normalization. "
                                     "Valid options: 'batch', 'tokens'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in ['bleu', 'chrf']:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # if we schedule after BLEU/chrf, we want to maximize it, else minimize
        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in ["bleu", "chrf"]:
                self.minimize_metric = False
            else:  # eval metric that has to get minimized (not yet implemented)
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf
        # comparison function for scores
        self.is_best = lambda score: score < self.best_ckpt_score \
            if self.minimize_metric else score > self.best_ckpt_score

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            self.init_from_checkpoint(model_load_path)
コード例 #11
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = configurations.get_config(
            args.proto, getattr(configurations, args.proto),
            args.config_overrides)
        self.num_preload = args.num_preload
        self.lr = self.config['lr']

        ut.remove_files_in_dir(self.config['save_to'])

        self.logger = ut.get_logger(self.config['log_file'])

        self.train_smooth_perps = []
        self.train_true_perps = []

        # For logging
        self.log_freq = self.config[
            'log_freq']  # log train stat every this-many batches
        self.log_train_loss = []
        self.log_nll_loss = []
        self.log_train_weights = []
        self.log_grad_norms = []
        self.total_batches = 0  # number of batches done for the whole training
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        device = ut.get_device()
        self.model = Model(self.config).to(device)
        self.validator = Validator(self.config, self.model)

        self.validate_freq = self.config['validate_freq']
        if self.validate_freq == 1:
            self.logger.info('Evaluate every ' + (
                'epoch' if self.config['val_per_epoch'] else 'batch'))
        else:
            self.logger.info(f'Evaluate every {self.validate_freq:,} ' + (
                'epochs' if self.config['val_per_epoch'] else 'batches'))

        # Estimated number of batches per epoch
        self.est_batches = max(self.model.data_manager.training_tok_counts
                               ) // self.config['batch_size']
        self.logger.info(
            f'Guessing around {self.est_batches:,} batches per epoch')

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info(f'Model has {int(param_count):,} parameters')

        # Set up parameter-specific options
        params = []
        for p in self.model.parameters():
            ptr = p.data_ptr()
            d = {'params': [p]}
            if ptr in self.model.parameter_attrs:
                attrs = self.model.parameter_attrs[ptr]
                for k in attrs:
                    d[k] = attrs[k]
            params.append(d)

        self.optimizer = torch.optim.Adam(params,
                                          lr=self.lr,
                                          betas=(self.config['beta1'],
                                                 self.config['beta2']),
                                          eps=self.config['epsilon'])
コード例 #12
0
class Trainer(object):
    """Trainer"""
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = configurations.get_config(
            args.proto, getattr(configurations, args.proto),
            args.config_overrides)
        self.num_preload = args.num_preload
        self.lr = self.config['lr']

        ut.remove_files_in_dir(self.config['save_to'])

        self.logger = ut.get_logger(self.config['log_file'])

        self.train_smooth_perps = []
        self.train_true_perps = []

        # For logging
        self.log_freq = self.config[
            'log_freq']  # log train stat every this-many batches
        self.log_train_loss = []
        self.log_nll_loss = []
        self.log_train_weights = []
        self.log_grad_norms = []
        self.total_batches = 0  # number of batches done for the whole training
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        device = ut.get_device()
        self.model = Model(self.config).to(device)
        self.validator = Validator(self.config, self.model)

        self.validate_freq = self.config['validate_freq']
        if self.validate_freq == 1:
            self.logger.info('Evaluate every ' + (
                'epoch' if self.config['val_per_epoch'] else 'batch'))
        else:
            self.logger.info(f'Evaluate every {self.validate_freq:,} ' + (
                'epochs' if self.config['val_per_epoch'] else 'batches'))

        # Estimated number of batches per epoch
        self.est_batches = max(self.model.data_manager.training_tok_counts
                               ) // self.config['batch_size']
        self.logger.info(
            f'Guessing around {self.est_batches:,} batches per epoch')

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info(f'Model has {int(param_count):,} parameters')

        # Set up parameter-specific options
        params = []
        for p in self.model.parameters():
            ptr = p.data_ptr()
            d = {'params': [p]}
            if ptr in self.model.parameter_attrs:
                attrs = self.model.parameter_attrs[ptr]
                for k in attrs:
                    d[k] = attrs[k]
            params.append(d)

        self.optimizer = torch.optim.Adam(params,
                                          lr=self.lr,
                                          betas=(self.config['beta1'],
                                                 self.config['beta2']),
                                          eps=self.config['epsilon'])

    def report_epoch(self, epoch, batches):

        self.logger.info(f'Finished epoch {epoch}')
        self.logger.info(f'    Took {ut.format_time(self.epoch_time)}')
        self.logger.info(
            f'    avg words/sec {self.epoch_weights / self.epoch_time:.2f}')
        self.logger.info(f'    avg sec/batch {self.epoch_time / batches:.2f}')
        self.logger.info(f'    {batches} batches')

        if self.epoch_weights:
            train_smooth_perp = self.epoch_loss / self.epoch_weights
            train_true_perp = self.epoch_nll_loss / self.epoch_weights
        else:
            train_smooth_perp = float('inf')
            train_true_perp = float('inf')

        self.est_batches = batches
        self.epoch_time = 0.
        self.epoch_nll_loss = 0.
        self.epoch_loss = 0.
        self.epoch_weights = 0.
        self.log_train_loss = []
        self.log_nll_loss = []
        self.log_train_weights = []
        self.log_grad_norms = []

        train_smooth_perp = numpy.exp(
            train_smooth_perp) if train_smooth_perp < 300 else float('inf')
        self.train_smooth_perps.append(train_smooth_perp)
        train_true_perp = numpy.exp(
            train_true_perp) if train_true_perp < 300 else float('inf')
        self.train_true_perps.append(train_true_perp)

        self.logger.info(
            f'    smooth, true perp: {float(train_smooth_perp):.2f}, {float(train_true_perp):.2f}'
        )

    def clip_grad_values(self):
        """
        Adapted from https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_value_
        This is the same as torch.nn.utils.clip_grad_value_, except is also sets nan gradients to 0.0
        """
        parameters = self.model.parameters()
        clip_value = float(self.config['grad_clamp'])
        if isinstance(parameters, torch.Tensor):
            parameters = [parameters]
        for p in filter(lambda p: p.grad is not None, parameters):
            p.grad.data.clamp_(min=-clip_value, max=clip_value)
            p.grad.data[torch.isnan(p.grad.data)] = 0.0

    def get_params(self, pe=False):
        for n, p in self.model.named_parameters():
            if (n in self.model.struct_params) == pe:
                yield p

    def run_log(self, batch, epoch, batch_data):
        #with torch.autograd.detect_anomaly(): # throws exception when any forward computation produces nan
        start = time.time()
        _, src_toks, src_structs, trg_toks, targets = batch_data

        # zero grad
        self.optimizer.zero_grad()

        # get loss
        ret = self.model(src_toks, src_structs, trg_toks, targets, batch,
                         epoch)
        loss = ret['loss']
        nll_loss = ret['nll_loss']

        if self.config['normalize_loss'] == ac.LOSS_TOK:
            opt_loss = loss / (targets != ac.PAD_ID).sum()
        elif self.config['normalize_loss'] == ac.LOSS_BATCH:
            opt_loss = loss / targets.size()[0]
        else:
            opt_loss = loss

        opt_loss.backward()
        # clip gradient
        if self.config['grad_clamp']: self.clip_grad_values()
        if self.config['grad_clip_pe']:
            pms = list(self.get_params(True))
            if pms:
                torch.nn.utils.clip_grad_norm_(pms,
                                               self.config['grad_clip_pe'])
            pms = self.get_params()
        else:
            pms = self.model.parameters()
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(), self.config['grad_clip']).detach()

        # update
        self.adjust_lr()
        self.optimizer.step()

        # update training stats
        num_words = (targets != ac.PAD_ID).detach().sum()

        loss = loss.detach()
        nll_loss = nll_loss.detach()
        self.total_batches += 1
        self.log_train_loss.append(loss)
        self.log_nll_loss.append(nll_loss)
        self.log_train_weights.append(num_words)
        self.log_grad_norms.append(grad_norm)
        self.epoch_time += time.time() - start

        if self.total_batches % self.log_freq == 0:

            log_train_loss = torch.tensor(0.0)
            log_nll_loss = torch.tensor(0.0)
            log_train_weights = torch.tensor(0.0)
            log_all_weights = torch.tensor(0.0)
            for smooth, nll, weight in zip(self.log_train_loss,
                                           self.log_nll_loss,
                                           self.log_train_weights):
                if not self.config['grad_clamp'] or (torch.isfinite(smooth)
                                                     and torch.isfinite(nll)):
                    log_train_loss += smooth
                    log_nll_loss += nll
                    log_train_weights += weight
                log_all_weights += weight
            #log_train_loss = sum(x for x in self.log_train_loss).item()
            #log_nll_loss = sum(x for x in self.log_nll_loss).item()
            #log_train_weights = sum(x for x in self.log_train_weights).item()
            avg_smooth_perp = log_train_loss / log_train_weights
            avg_smooth_perp = numpy.exp(
                avg_smooth_perp) if avg_smooth_perp < 300 else float('inf')
            avg_true_perp = log_nll_loss / log_train_weights
            avg_true_perp = numpy.exp(
                avg_true_perp) if avg_true_perp < 300 else float('inf')

            self.epoch_loss += log_train_loss
            self.epoch_nll_loss += log_nll_loss
            self.epoch_weights += log_all_weights

            acc_speed_word = self.epoch_weights / self.epoch_time
            acc_speed_time = self.epoch_time / batch

            avg_grad_norm = sum(self.log_grad_norms) / len(self.log_grad_norms)
            #median_grad_norm = sorted(self.log_grad_norms)[len(self.log_grad_norms)//2]

            est_percent = int(100 * batch / self.est_batches)
            epoch_len = max(5, ut.get_num_digits(self.config['max_epochs']))
            batch_len = max(5, ut.get_num_digits(self.est_batches))
            if batch > self.est_batches: remaining = '?'
            else:
                remaining = ut.format_time(acc_speed_time *
                                           (self.est_batches - batch))

            self.log_train_loss = []
            self.log_nll_loss = []
            self.log_train_weights = []
            self.log_grad_norms = []
            cells = [
                f'{epoch:{epoch_len}}', f'{batch:{batch_len}}',
                f'{est_percent:3}%', f'{remaining:>9}',
                f'{acc_speed_word:#10.4g}', f'{acc_speed_time:#6.4g}s',
                f'{avg_smooth_perp:#11.4g}', f'{avg_true_perp:#9.4g}',
                f'{avg_grad_norm:#9.4g}'
            ]
            self.logger.info('  '.join(cells))

    def adjust_lr(self):
        if self.config['warmup_style'] == ac.ORG_WARMUP:
            step = self.total_batches + 1.0
            if step < self.config['warmup_steps']:
                lr = self.config['embed_dim']**(
                    -0.5) * step * self.config['warmup_steps']**(-1.5)
            else:
                lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5),
                         self.config['min_lr'])
            for p in self.optimizer.param_groups:
                p['lr'] = lr
        elif self.config['warmup_style'] == ac.FIXED_WARMUP:
            warmup_steps = self.config['warmup_steps']
            step = self.total_batches + 1.0
            start_lr = self.config['start_lr']
            peak_lr = self.config['lr']
            min_lr = self.config['min_lr']
            if step < warmup_steps:
                lr = start_lr + (peak_lr - start_lr) * step / warmup_steps
            else:
                lr = max(min_lr, peak_lr * warmup_steps**(0.5) * step**(-0.5))
            for p in self.optimizer.param_groups:
                p['lr'] = lr
        elif self.config['warmup_style'] == ac.UPFLAT_WARMUP:
            warmup_steps = self.config['warmup_steps']
            step = self.total_batches + 1.0
            start_lr = self.config['start_lr']
            peak_lr = self.config['lr']
            min_lr = self.config['min_lr']
            if step < warmup_steps:
                lr = start_lr + (peak_lr - start_lr) * step / warmup_steps
                for p in self.optimizer.param_groups:
                    p['lr'] = lr
        else:
            pass

    def train(self):
        self.model.train()
        stop_early = False

        early_stop_msg_num = self.config[
            'early_stop_patience'] * self.validate_freq
        early_stop_msg_metric = 'epochs' if self.config[
            'val_by_bleu'] else 'batches'
        early_stop_msg = f'No improvement for last {early_stop_msg_num} {early_stop_msg_metric}; stopping early!'
        for epoch in range(1, self.config['max_epochs'] + 1):
            batch = 0
            for batch_data in self.model.data_manager.get_batches(
                    mode=ac.TRAINING, num_preload=self.num_preload):
                if batch == 0:
                    self.logger.info(f'Begin epoch {epoch}')
                    epoch_str = ' ' * max(
                        0,
                        ut.get_num_digits(self.config['max_epochs']) -
                        5) + 'epoch'
                    batch_str = ' ' * max(
                        0,
                        ut.get_num_digits(self.est_batches) - 5) + 'batch'
                    self.logger.info('  '.join([
                        epoch_str, batch_str, 'est%', 'remaining',
                        'trg word/s', 's/batch', 'smooth perp', 'true perp',
                        'grad norm'
                    ]))
                batch += 1
                self.run_log(batch, epoch, batch_data)
                if not self.config['val_per_epoch']:
                    stop_early = self.maybe_validate()
                    if stop_early:
                        self.logger.info(early_stop_msg)
                        break
            if stop_early:
                break
            self.report_epoch(epoch, batch)
            if self.config['val_per_epoch'] and epoch % self.validate_freq == 0:
                stop_early = self.maybe_validate(just_validate=True)
                if stop_early:
                    self.logger.info(early_stop_msg)
                    break

        if not self.config['val_by_bleu'] and not stop_early:
            # validate 1 last time
            self.maybe_validate(just_validate=True)

        self.logger.info('Training finished')
        self.logger.info('Train smooth perps:')
        self.logger.info(', '.join(
            [f'{x:.2f}' for x in self.train_smooth_perps]))
        self.logger.info('Train true perps:')
        self.logger.info(', '.join([f'{x:.2f}'
                                    for x in self.train_true_perps]))
        numpy.save(
            os.path.join(self.config['save_to'], 'train_smooth_perps.npy'),
            self.train_smooth_perps)
        numpy.save(
            os.path.join(self.config['save_to'], 'train_true_perps.npy'),
            self.train_true_perps)

        self.model.save()

        # Evaluate test
        test_file = self.model.data_manager.data_files[ac.TESTING][
            self.model.data_manager.src_lang]
        dev_file = self.model.data_manager.data_files[ac.VALIDATING][
            self.model.data_manager.src_lang]
        if os.path.exists(test_file):
            self.logger.info('Evaluate test')
            self.restart_to_best_checkpoint()
            self.model.save()
            self.validator.translate(test_file, to_ids=True)
            self.logger.info('Translate dev set')
            self.validator.translate(dev_file, to_ids=True)

    def restart_to_best_checkpoint(self):
        if self.config['val_by_bleu']:
            best_bleu = numpy.max(self.validator.best_bleus)
            best_cpkt_path = self.validator.get_cpkt_path(best_bleu)
        else:
            best_perp = numpy.min(self.validator.best_perps)
            best_cpkt_path = self.validator.get_cpkt_path(best_perp)

        self.logger.info(f'Restore best cpkt from {best_cpkt_path}')
        self.model.load_state_dict(torch.load(best_cpkt_path))

    def is_patience_exhausted(self, patience, if_worst=False):
        '''
        if_worst=False (default) -> check if last patience epochs have failed to improve dev score
        if_worst=True            -> check if last epoch was WORSE than the patience epochs before it
        '''
        curve = self.validator.bleu_curve if self.config[
            'val_by_bleu'] else self.validator.perp_curve
        best_worse = max if self.config['val_by_bleu'] is not if_worst else min
        return patience and len(
            curve) > patience and curve[-1 if if_worst else -1 -
                                        patience] == best_worse(
                                            curve[-1 - patience:])

    def maybe_validate(self, just_validate=False):
        if self.total_batches % self.validate_freq == 0 or just_validate:
            self.model.save()
            self.validator.validate_and_save()

            # if doing annealing
            step = self.total_batches + 1.0
            warmup_steps = self.config['warmup_steps']

            if self.config['warmup_style'] == ac.NO_WARMUP \
               or (self.config['warmup_style'] == ac.UPFLAT_WARMUP and step >= warmup_steps) \
               and self.config['lr_decay'] > 0:

                if self.is_patience_exhausted(self.config['lr_decay_patience'],
                                              if_worst=True):
                    if self.config['val_by_bleu']:
                        metric = 'bleu'
                        scores = self.validator.bleu_curve
                    else:
                        metric = 'perp'
                        scores = self.validator.perp_curve
                    scores = ', '.join([
                        str(x)
                        for x in scores[-1 - self.config['lr_decay_patience']:]
                    ])

                    self.logger.info(f'Past {metric} scores are {scores}')
                    # when don't use warmup, decay lr if dev not improve
                    if self.lr * self.config['lr_decay'] >= self.config[
                            'min_lr']:
                        new_lr = self.lr * self.config['lr_decay']
                        self.logger.info(
                            f'Anneal the learning rate from {self.lr} to {new_lr}'
                        )
                        self.lr = new_lr
                        for p in self.optimizer.param_groups:
                            p['lr'] = self.lr
        return self.is_patience_exhausted(self.config['early_stop_patience'])
コード例 #13
0
class Trainer(object):
    """Trainer"""
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        self.normalize_loss = self.config['normalize_loss']
        self.patience = self.config['patience']
        self.lr = self.config['lr']
        self.lr_decay = self.config['lr_decay']
        self.max_epochs = self.config['max_epochs']
        self.warmup_steps = self.config['warmup_steps']

        self.train_smooth_perps = []
        self.train_true_perps = []

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)

        self.val_per_epoch = self.config['val_per_epoch']
        self.validate_freq = int(self.config['validate_freq'])
        self.logger.info('Evaluate every {} {}'.format(
            self.validate_freq, 'epochs' if self.val_per_epoch else 'batches'))

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_nll_loss = 0.
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        self.model = Model(self.config).to(self.device)

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info('Model has {:,} parameters'.format(param_count))

        # get optimizer
        beta1 = self.config['beta1']
        beta2 = self.config['beta2']
        epsilon = self.config['epsilon']
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          betas=(beta1, beta2),
                                          eps=epsilon)

    def report_epoch(self, e):
        self.logger.info('Finish epoch {}'.format(e))
        self.logger.info('    It takes {}'.format(
            ut.format_seconds(self.epoch_time)))
        self.logger.info('    Avergage # words/second    {}'.format(
            self.epoch_weights / self.epoch_time))
        self.logger.info('    Average seconds/batch    {}'.format(
            self.epoch_time / self.epoch_batches_done))

        train_smooth_perp = self.epoch_loss / self.epoch_weights
        train_true_perp = self.epoch_nll_loss / self.epoch_weights

        self.epoch_batches_done = 0
        self.epoch_time = 0.
        self.epoch_nll_loss = 0.
        self.epoch_loss = 0.
        self.epoch_weights = 0.

        train_smooth_perp = numpy.exp(
            train_smooth_perp) if train_smooth_perp < 300 else float('inf')
        self.train_smooth_perps.append(train_smooth_perp)
        train_true_perp = numpy.exp(
            train_true_perp) if train_true_perp < 300 else float('inf')
        self.train_true_perps.append(train_true_perp)

        self.logger.info(
            '    smoothed train perplexity: {}'.format(train_smooth_perp))
        self.logger.info(
            '    true train perplexity: {}'.format(train_true_perp))

    def run_log(self, b, e, batch_data):
        start = time.time()
        src_toks, trg_toks, targets = batch_data
        src_toks_cuda = src_toks.to(self.device)
        trg_toks_cuda = trg_toks.to(self.device)
        targets_cuda = targets.to(self.device)

        # zero grad
        self.optimizer.zero_grad()

        # get loss
        ret = self.model(src_toks_cuda, trg_toks_cuda, targets_cuda)
        loss = ret['loss']
        nll_loss = ret['nll_loss']
        if self.normalize_loss == ac.LOSS_TOK:
            opt_loss = loss / (targets_cuda != ac.PAD_ID).type(
                loss.type()).sum()
        elif self.normalize_loss == ac.LOSS_BATCH:
            opt_loss = loss / targets_cuda.size()[0].type(loss.type())
        else:
            opt_loss = loss

        opt_loss.backward()
        # clip gradient
        global_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                     self.config['grad_clip'])

        # update
        self.adjust_lr()
        self.optimizer.step()

        # update training stats
        num_words = (targets != ac.PAD_ID).detach().numpy().sum()

        loss = loss.cpu().detach().numpy()
        nll_loss = nll_loss.cpu().detach().numpy()
        self.num_batches_done += 1
        self.log_train_loss += loss
        self.log_nll_loss += nll_loss
        self.log_train_weights += num_words

        self.epoch_batches_done += 1
        self.epoch_loss += loss
        self.epoch_nll_loss += nll_loss
        self.epoch_weights += num_words
        self.epoch_time += time.time() - start

        if self.num_batches_done % self.log_freq == 0:
            acc_speed_word = self.epoch_weights / self.epoch_time
            acc_speed_time = self.epoch_time / self.epoch_batches_done

            avg_smooth_perp = self.log_train_loss / self.log_train_weights
            avg_smooth_perp = numpy.exp(
                avg_smooth_perp) if avg_smooth_perp < 300 else float('inf')
            avg_true_perp = self.log_nll_loss / self.log_train_weights
            avg_true_perp = numpy.exp(
                avg_true_perp) if avg_true_perp < 300 else float('inf')

            self.log_train_loss = 0.
            self.log_nll_loss = 0.
            self.log_train_weights = 0.

            self.logger.info('Batch {}, epoch {}/{}:'.format(
                b, e + 1, self.max_epochs))
            self.logger.info(
                '   avg smooth perp:   {0:.2f}'.format(avg_smooth_perp))
            self.logger.info(
                '   avg true perp:   {0:.2f}'.format(avg_true_perp))
            self.logger.info('   acc trg words/s: {}'.format(
                int(acc_speed_word)))
            self.logger.info(
                '   acc sec/batch:   {0:.2f}'.format(acc_speed_time))
            self.logger.info('   global norm:     {0:.2f}'.format(global_norm))

    def adjust_lr(self):
        if self.config['warmup_style'] == ac.ORG_WARMUP:
            step = self.num_batches_done + 1.0
            if step < self.config['warmup_steps']:
                lr = self.config['embed_dim']**(
                    -0.5) * step * self.config['warmup_steps']**(-1.5)
            else:
                lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5),
                         self.config['min_lr'])

            for p in self.optimizer.param_groups:
                p['lr'] = lr

    def train(self):
        self.model.train()
        train_ids_file = self.data_manager.data_files['ids']
        for e in range(self.max_epochs):
            b = 0
            for batch_data in self.data_manager.get_batch(
                    ids_file=train_ids_file,
                    shuffle=True,
                    num_preload=self.num_preload):
                b += 1
                self.run_log(b, e, batch_data)
                if not self.val_per_epoch:
                    self.maybe_validate()

            self.report_epoch(e + 1)
            if self.val_per_epoch and (e + 1) % self.validate_freq == 0:
                self.maybe_validate(just_validate=True)

        # validate 1 last time
        if not self.config['val_per_epoch']:
            self.maybe_validate(just_validate=True)

        self.logger.info('It is finally done, mate!')
        self.logger.info('Train smoothed perps:')
        self.logger.info(', '.join(map(str, self.train_smooth_perps)))
        self.logger.info('Train true perps:')
        self.logger.info(', '.join(map(str, self.train_true_perps)))
        numpy.save(join(self.config['save_to'], 'train_smooth_perps.npy'),
                   self.train_smooth_perps)
        numpy.save(join(self.config['save_to'], 'train_true_perps.npy'),
                   self.train_true_perps)

        self.logger.info('Save final checkpoint')
        self.save_checkpoint()

        # Evaluate on test
        for checkpoint in self.data_manager.checkpoints:
            self.logger.info('Translate for {}'.format(checkpoint))
            dev_file = self.data_manager.dev_files[checkpoint][
                self.data_manager.src_lang]
            test_file = self.data_manager.test_files[checkpoint][
                self.data_manager.src_lang]
            if exists(test_file):
                self.logger.info('  Evaluate on test')
                self.restart_to_best_checkpoint(checkpoint)
                self.validator.translate(self.model, test_file)
                self.logger.info('  Also translate dev')
                self.validator.translate(self.model, dev_file)

    def save_checkpoint(self):
        cpkt_path = join(self.config['save_to'],
                         '{}.pth'.format(self.config['model_name']))
        torch.save(self.model.state_dict(), cpkt_path)

    def restart_to_best_checkpoint(self, checkpoint):
        best_perp = numpy.min(self.validator.best_perps[checkpoint])
        best_cpkt_path = self.validator.get_cpkt_path(checkpoint, best_perp)

        self.logger.info('Restore best cpkt from {}'.format(best_cpkt_path))
        self.model.load_state_dict(torch.load(best_cpkt_path))

    def maybe_validate(self, just_validate=False):
        if self.num_batches_done % self.validate_freq == 0 or just_validate:
            self.save_checkpoint()
            self.validator.validate_and_save(self.model)

            # if doing annealing
            if self.config[
                    'warmup_style'] == ac.NO_WARMUP and self.lr_decay > 0:
                cond = len(
                    self.validator.perp_curve
                ) > self.patience and self.validator.perp_curve[-1] > max(
                    self.validator.perp_curve[-1 - self.patience:-1])
                if cond:
                    metric = 'perp'
                    scores = self.validator.perp_curve[-1 - self.patience:]
                    scores = map(str, list(scores))
                    scores = ', '.join(scores)

                    self.logger.info('Past {} are {}'.format(metric, scores))
                    # when don't use warmup, decay lr if dev not improve
                    if self.lr * self.lr_decay >= self.config['min_lr']:
                        self.logger.info(
                            'Anneal the learning rate from {} to {}'.format(
                                self.lr, self.lr * self.lr_decay))
                        self.lr = self.lr * self.lr_decay
                        for p in self.optimizer.param_groups:
                            p['lr'] = self.lr
コード例 #14
0
class Translator(object):
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = configurations.get_config(
            args.proto, getattr(configurations, args.proto),
            args.config_overrides)
        self.logger = ut.get_logger(self.config['log_file'])
        self.num_preload = args.num_preload

        self.model_file = args.model_file
        if self.model_file is None:
            self.model_file = os.path.join(self.config['save_to'],
                                           self.config['model_name'] + '.pth')

        self.input_file = args.input_file
        if self.input_file is not None and not os.path.exists(self.input_file):
            raise FileNotFoundError(
                f'Input file does not exist: {self.input_file}')
        if not os.path.exists(self.model_file):
            raise FileNotFoundError(
                f'Model file does not exist: {self.model_file}')

        self.logger.info(f'Restore model from {self.model_file}')
        self.model = Model(self.config,
                           load_from=self.model_file).to(ut.get_device())

        if self.input_file:
            save_fp = os.path.join(self.config['save_to'],
                                   os.path.basename(self.input_file))
            save_fp = save_fp.rstrip(self.model.data_manager.src_lang)
            save_fp = save_fp + self.model.data_manager.trg_lang
            self.best_output_fp = save_fp + '.best_trans'
            self.beam_output_fp = save_fp + '.beam_trans'
            open(self.best_output_fp, 'w').close()
            open(self.beam_output_fp, 'w').close()
        else:
            self.best_output_fp = self.beam_output_fp = None

        self.translate()

    def translate(self):
        best_stream = open(self.best_output_fp,
                           'a') if self.best_output_fp else sys.stdout
        beam_stream = open(self.beam_output_fp,
                           'a') if self.beam_output_fp else None
        self.model.translate(self.input_file or sys.stdin,
                             best_stream,
                             beam_stream,
                             to_ids=True,
                             num_preload=self.num_preload)
        if self.best_output_fp: best_stream.close()
        if self.beam_output_fp: beam_stream.close()

    def plot_head_map(self, mma, target_labels, target_ids, source_labels,
                      source_ids, filename):
        """https://github.com/EdinburghNLP/nematus/blob/master/utils/plot_heatmap.py
        Change the font in family param below. If the system font is not used, delete matplotlib
        font cache https://github.com/matplotlib/matplotlib/issues/3590
        """
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt

        fig, ax = plt.subplots()
        heatmap = ax.pcolor(mma, cmap=plt.cm.Blues)

        # put the major ticks at the middle of each cell
        ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False)
        ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False)

        # without this I get some extra columns rows
        # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column
        ax.set_xlim(0, int(mma.shape[1]))
        ax.set_ylim(0, int(mma.shape[0]))

        # want a more natural, table-like display
        ax.invert_yaxis()
        ax.xaxis.tick_top()

        # source words -> column labels
        ax.set_xticklabels(source_labels,
                           minor=False,
                           family='Source Code Pro')
        for xtick, idx in zip(ax.get_xticklabels(), source_ids):
            if idx == ac.UNK_ID:
                xtick.set_color('b')
        # target words -> row labels
        ax.set_yticklabels(target_labels,
                           minor=False,
                           family='Source Code Pro')
        for ytick, idx in zip(ax.get_yticklabels(), target_ids):
            if idx == ac.UNK_ID:
                ytick.set_color('b')

        plt.xticks(rotation=45)

        plt.tight_layout()
        plt.savefig(filename)
        plt.close('all')
コード例 #15
0
def validate_on_data(model: Model, data: Dataset,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 0, beam_alpha: int = -1,
                     batch_type: str = "sentence"
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If 0 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False)
    valid_sources_raw = [s for s in data.src]
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe":
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores