Ejemplo n.º 1
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        self.normalize_loss = self.config['normalize_loss']
        self.patience = self.config['patience']
        self.lr = self.config['lr']
        self.lr_decay = self.config['lr_decay']
        self.max_epochs = self.config['max_epochs']
        self.warmup_steps = self.config['warmup_steps']

        self.train_smooth_perps = []
        self.train_true_perps = []

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)

        self.val_per_epoch = self.config['val_per_epoch']
        self.validate_freq = int(self.config['validate_freq'])
        self.logger.info('Evaluate every {} {}'.format(
            self.validate_freq, 'epochs' if self.val_per_epoch else 'batches'))

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_nll_loss = 0.
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        self.model = Model(self.config).to(self.device)

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info('Model has {:,} parameters'.format(param_count))

        # get optimizer
        beta1 = self.config['beta1']
        beta2 = self.config['beta2']
        epsilon = self.config['epsilon']
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          betas=(beta1, beta2),
                                          eps=epsilon)
Ejemplo n.º 2
0
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file

        if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        self.translate()
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.reverse = self.config['reverse']
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file
        self.plot_align = args.plot_align
        self.unk_repl   = args.unk_repl

        if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file + '.meta'):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        _, self.src_ivocab = self.data_manager.init_vocab(self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(self.data_manager.trg_lang)
        self.translate()
Ejemplo n.º 4
0
    def __init__(self, config, load_from=None):
        super(Model, self).__init__()
        self.config = config
        self.struct = self.config['struct']
        self.decoder_mask = None
        self.data_manager = DataManager(config, init_vocab=(not load_from))

        if load_from:
            self.load_state_dict(torch.load(load_from,
                                            map_location=ut.get_device()),
                                 do_init=True)
        else:
            self.init_embeddings()
            self.init_model()
            self.add_struct_params()

        # dict where keys are data_ptrs to dicts of parameter options
        # see https://pytorch.org/docs/stable/optim.html#per-parameter-options
        self.parameter_attrs = {}
Ejemplo n.º 5
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.config['fixed_var_list'] = args.fixed_var_list
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.lr = self.config['lr']
        self.max_epochs = self.config['max_epochs']
        self.save_freq = self.config['save_freq']
        self.cpkt_path = None
        self.validate_freq = None
        self.train_perps = []

        self.saver = None
        self.train_m = None
        self.dev_m = None

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)
        self.validate_freq = ut.get_validation_frequency(
            self.data_manager.length_files[ac.TRAINING],
            self.config['validate_freq'], self.config['batch_size'])
        self.logger.info('Evaluate every {} batches'.format(
            self.validate_freq))

        _, self.src_ivocab = self.data_manager.init_vocab(
            self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(
            self.data_manager.trg_lang)

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid
class Translator(object):
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.reverse = self.config['reverse']
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file
        self.plot_align = args.plot_align
        self.unk_repl   = args.unk_repl

        if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file + '.meta'):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        _, self.src_ivocab = self.data_manager.init_vocab(self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(self.data_manager.trg_lang)
        self.translate()

    def ids_to_trans(self, trans_ids, trans_alignments, no_unk_src_toks):
        words = []
        word_ids = []
        # Could have done better but this is clearer to me

        if not self.unk_repl:
            for idx, word_idx in enumerate(trans_ids):
                if word_idx == ac.EOS_ID:
                    break

                words.append(self.trg_ivocab[word_idx])
                word_ids.append(word_idx)
        else:
            for idx, word_idx in enumerate(trans_ids):
                if word_idx == ac.EOS_ID:
                    break
                    
                if word_idx == ac.UNK_ID:
                    # Replace UNK with higest attention source words
                    alignment = trans_alignments[idx]
                    highest_att_src_tok_pos = numpy.argmax(alignment)
                    words.append(no_unk_src_toks[highest_att_src_tok_pos])
                else:
                    words.append(self.trg_ivocab[word_idx])
                word_ids.append(word_idx)

        return u' '.join(words), word_ids

    def get_model(self, mode):
        reuse = mode != ac.TRAINING
        d = self.config['init_range']
        initializer = tf.random_uniform_initializer(-d, d)
        with tf.variable_scope(self.config['model_name'], reuse=reuse, initializer=initializer):
            return Model(self.config, mode)

    def get_trans(self, probs, scores, symbols, parents, alignments, no_unk_src_toks):
        sorted_rows = numpy.argsort(scores[:, -1])[::-1]
        best_trans_alignments = []
        best_trans = None
        best_tran_ids = None
        beam_trans = []
        for i, r in enumerate(sorted_rows):
            row_idx = r
            col_idx = scores.shape[1] - 1

            trans_ids = []
            trans_alignments = []
            while True:
                if col_idx < 0:
                    break

                trans_ids.append(symbols[row_idx, col_idx])
                align = alignments[row_idx, col_idx, :]
                trans_alignments.append(align)

                if i == 0:
                    best_trans_alignments.append(align if not self.reverse else align[::-1])

                row_idx = parents[row_idx, col_idx]
                col_idx -= 1

            trans_ids = trans_ids[::-1]
            trans_alignments = trans_alignments[::-1]
            trans_out, trans_out_ids = self.ids_to_trans(trans_ids, trans_alignments, no_unk_src_toks)
            beam_trans.append(u'{} {:.2f} {:.2f}'.format(trans_out, scores[r, -1], probs[r, -1]))
            if i == 0: # highest prob trans
                best_trans = trans_out
                best_tran_ids = trans_out_ids

        return best_trans, best_tran_ids, u'\n'.join(beam_trans), best_trans_alignments[::-1]

    def plot_head_map(self, mma, target_labels, target_ids, source_labels, source_ids, filename):
        """https://github.com/EdinburghNLP/nematus/blob/master/utils/plot_heatmap.py
        Change the font in family param below. If the system font is not used, delete matplotlib 
        font cache https://github.com/matplotlib/matplotlib/issues/3590
        """
        fig, ax = plt.subplots()
        heatmap = ax.pcolor(mma, cmap=plt.cm.Blues)

        # put the major ticks at the middle of each cell
        ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False)
        ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False)

        # without this I get some extra columns rows
        # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column
        ax.set_xlim(0, int(mma.shape[1]))
        ax.set_ylim(0, int(mma.shape[0]))

        # want a more natural, table-like display
        ax.invert_yaxis()
        ax.xaxis.tick_top()

        # source words -> column labels
        ax.set_xticklabels(source_labels, minor=False, family='Source Code Pro')
        for xtick, idx in zip(ax.get_xticklabels(), source_ids):
            if idx == ac.UNK_ID:
                xtick.set_color('b')
        # target words -> row labels
        ax.set_yticklabels(target_labels, minor=False, family='Source Code Pro')
        for ytick, idx in zip(ax.get_yticklabels(), target_ids):
            if idx == ac.UNK_ID:
                ytick.set_color('b')

        plt.xticks(rotation=45)

        plt.tight_layout()
        plt.savefig(filename)
        plt.close('all')

    def translate(self):
        with tf.Graph().as_default():
            train_model = self.get_model(ac.TRAINING)
            model = self.get_model(ac.VALIDATING)

            with tf.Session() as sess:
                self.logger.info('Restore model from {}'.format(self.model_file))
                saver = tf.train.Saver(var_list=tf.trainable_variables())
                saver.restore(sess, self.model_file)

                best_trans_file = self.input_file + '.best_trans'
                beam_trans_file = self.input_file + '.beam_trans'
                open(best_trans_file, 'w').close()
                open(beam_trans_file, 'w').close()
                ftrans = open(best_trans_file, 'w', 'utf-8')
                btrans = open(beam_trans_file, 'w', 'utf-8')

                self.logger.info('Start translating {}'.format(self.input_file))
                start = time.time()
                count = 0
                for (src_input, src_seq_len, no_unk_src_toks) in self.data_manager.get_trans_input(self.input_file):
                    feed = {
                        model.src_inputs: src_input,
                        model.src_seq_lengths: src_seq_len
                    }
                    probs, scores, symbols, parents, alignments = sess.run([model.probs, model.scores, model.symbols, model.parents, model.alignments], feed_dict=feed)
                    alignments = numpy.transpose(alignments, axes=(1, 0, 2))

                    probs = numpy.transpose(numpy.array(probs))
                    scores = numpy.transpose(numpy.array(scores))
                    symbols = numpy.transpose(numpy.array(symbols))
                    parents = numpy.transpose(numpy.array(parents))

                    best_trans, best_trans_ids, beam_trans, best_trans_alignments = self.get_trans(probs, scores, symbols, parents, alignments, no_unk_src_toks)
                    ftrans.write(best_trans + '\n')
                    btrans.write(beam_trans + '\n\n')

                    if self.plot_align:
                        src_input = numpy.reshape(src_input, [-1])
                        if self.reverse:
                            src_input = src_input[::-1]
                            no_unk_src_toks = no_unk_src_toks[::-1]
                        trans_toks = best_trans.split()
                        best_trans_alignments = numpy.array(best_trans_alignments)[:len(trans_toks)]
                        filename = '{}_{}.png'.format(self.input_file, count)

                        self.plot_head_map(best_trans_alignments, trans_toks, best_trans_ids, no_unk_src_toks, src_input, filename)

                    count += 1
                    if count % 100 == 0:
                        self.logger.info('  Translating line {}, average {} seconds/sent'.format(count, (time.time() - start) / count))

                ftrans.close()
                btrans.close()

                self.logger.info('Done translating {}, it takes {} minutes'.format(self.input_file, float(time.time() - start) / 60.0))
Ejemplo n.º 7
0
class Trainer(object):
    """Trainer"""
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        self.normalize_loss = self.config['normalize_loss']
        self.patience = self.config['patience']
        self.lr = self.config['lr']
        self.lr_decay = self.config['lr_decay']
        self.max_epochs = self.config['max_epochs']
        self.warmup_steps = self.config['warmup_steps']

        self.train_smooth_perps = []
        self.train_true_perps = []

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)

        self.val_per_epoch = self.config['val_per_epoch']
        self.validate_freq = int(self.config['validate_freq'])
        self.logger.info('Evaluate every {} {}'.format(
            self.validate_freq, 'epochs' if self.val_per_epoch else 'batches'))

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_nll_loss = 0.
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        self.model = Model(self.config).to(self.device)

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info('Model has {:,} parameters'.format(param_count))

        # get optimizer
        beta1 = self.config['beta1']
        beta2 = self.config['beta2']
        epsilon = self.config['epsilon']
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          betas=(beta1, beta2),
                                          eps=epsilon)

    def report_epoch(self, e):
        self.logger.info('Finish epoch {}'.format(e))
        self.logger.info('    It takes {}'.format(
            ut.format_seconds(self.epoch_time)))
        self.logger.info('    Avergage # words/second    {}'.format(
            self.epoch_weights / self.epoch_time))
        self.logger.info('    Average seconds/batch    {}'.format(
            self.epoch_time / self.epoch_batches_done))

        train_smooth_perp = self.epoch_loss / self.epoch_weights
        train_true_perp = self.epoch_nll_loss / self.epoch_weights

        self.epoch_batches_done = 0
        self.epoch_time = 0.
        self.epoch_nll_loss = 0.
        self.epoch_loss = 0.
        self.epoch_weights = 0.

        train_smooth_perp = numpy.exp(
            train_smooth_perp) if train_smooth_perp < 300 else float('inf')
        self.train_smooth_perps.append(train_smooth_perp)
        train_true_perp = numpy.exp(
            train_true_perp) if train_true_perp < 300 else float('inf')
        self.train_true_perps.append(train_true_perp)

        self.logger.info(
            '    smoothed train perplexity: {}'.format(train_smooth_perp))
        self.logger.info(
            '    true train perplexity: {}'.format(train_true_perp))

    def run_log(self, b, e, batch_data):
        start = time.time()
        src_toks, trg_toks, targets = batch_data
        src_toks_cuda = src_toks.to(self.device)
        trg_toks_cuda = trg_toks.to(self.device)
        targets_cuda = targets.to(self.device)

        # zero grad
        self.optimizer.zero_grad()

        # get loss
        ret = self.model(src_toks_cuda, trg_toks_cuda, targets_cuda)
        loss = ret['loss']
        nll_loss = ret['nll_loss']
        if self.normalize_loss == ac.LOSS_TOK:
            opt_loss = loss / (targets_cuda != ac.PAD_ID).type(
                loss.type()).sum()
        elif self.normalize_loss == ac.LOSS_BATCH:
            opt_loss = loss / targets_cuda.size()[0].type(loss.type())
        else:
            opt_loss = loss

        opt_loss.backward()
        # clip gradient
        global_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                     self.config['grad_clip'])

        # update
        self.adjust_lr()
        self.optimizer.step()

        # update training stats
        num_words = (targets != ac.PAD_ID).detach().numpy().sum()

        loss = loss.cpu().detach().numpy()
        nll_loss = nll_loss.cpu().detach().numpy()
        self.num_batches_done += 1
        self.log_train_loss += loss
        self.log_nll_loss += nll_loss
        self.log_train_weights += num_words

        self.epoch_batches_done += 1
        self.epoch_loss += loss
        self.epoch_nll_loss += nll_loss
        self.epoch_weights += num_words
        self.epoch_time += time.time() - start

        if self.num_batches_done % self.log_freq == 0:
            acc_speed_word = self.epoch_weights / self.epoch_time
            acc_speed_time = self.epoch_time / self.epoch_batches_done

            avg_smooth_perp = self.log_train_loss / self.log_train_weights
            avg_smooth_perp = numpy.exp(
                avg_smooth_perp) if avg_smooth_perp < 300 else float('inf')
            avg_true_perp = self.log_nll_loss / self.log_train_weights
            avg_true_perp = numpy.exp(
                avg_true_perp) if avg_true_perp < 300 else float('inf')

            self.log_train_loss = 0.
            self.log_nll_loss = 0.
            self.log_train_weights = 0.

            self.logger.info('Batch {}, epoch {}/{}:'.format(
                b, e + 1, self.max_epochs))
            self.logger.info(
                '   avg smooth perp:   {0:.2f}'.format(avg_smooth_perp))
            self.logger.info(
                '   avg true perp:   {0:.2f}'.format(avg_true_perp))
            self.logger.info('   acc trg words/s: {}'.format(
                int(acc_speed_word)))
            self.logger.info(
                '   acc sec/batch:   {0:.2f}'.format(acc_speed_time))
            self.logger.info('   global norm:     {0:.2f}'.format(global_norm))

    def adjust_lr(self):
        if self.config['warmup_style'] == ac.ORG_WARMUP:
            step = self.num_batches_done + 1.0
            if step < self.config['warmup_steps']:
                lr = self.config['embed_dim']**(
                    -0.5) * step * self.config['warmup_steps']**(-1.5)
            else:
                lr = max(self.config['embed_dim']**(-0.5) * step**(-0.5),
                         self.config['min_lr'])

            for p in self.optimizer.param_groups:
                p['lr'] = lr

    def train(self):
        self.model.train()
        train_ids_file = self.data_manager.data_files['ids']
        for e in range(self.max_epochs):
            b = 0
            for batch_data in self.data_manager.get_batch(
                    ids_file=train_ids_file,
                    shuffle=True,
                    num_preload=self.num_preload):
                b += 1
                self.run_log(b, e, batch_data)
                if not self.val_per_epoch:
                    self.maybe_validate()

            self.report_epoch(e + 1)
            if self.val_per_epoch and (e + 1) % self.validate_freq == 0:
                self.maybe_validate(just_validate=True)

        # validate 1 last time
        if not self.config['val_per_epoch']:
            self.maybe_validate(just_validate=True)

        self.logger.info('It is finally done, mate!')
        self.logger.info('Train smoothed perps:')
        self.logger.info(', '.join(map(str, self.train_smooth_perps)))
        self.logger.info('Train true perps:')
        self.logger.info(', '.join(map(str, self.train_true_perps)))
        numpy.save(join(self.config['save_to'], 'train_smooth_perps.npy'),
                   self.train_smooth_perps)
        numpy.save(join(self.config['save_to'], 'train_true_perps.npy'),
                   self.train_true_perps)

        self.logger.info('Save final checkpoint')
        self.save_checkpoint()

        # Evaluate on test
        for checkpoint in self.data_manager.checkpoints:
            self.logger.info('Translate for {}'.format(checkpoint))
            dev_file = self.data_manager.dev_files[checkpoint][
                self.data_manager.src_lang]
            test_file = self.data_manager.test_files[checkpoint][
                self.data_manager.src_lang]
            if exists(test_file):
                self.logger.info('  Evaluate on test')
                self.restart_to_best_checkpoint(checkpoint)
                self.validator.translate(self.model, test_file)
                self.logger.info('  Also translate dev')
                self.validator.translate(self.model, dev_file)

    def save_checkpoint(self):
        cpkt_path = join(self.config['save_to'],
                         '{}.pth'.format(self.config['model_name']))
        torch.save(self.model.state_dict(), cpkt_path)

    def restart_to_best_checkpoint(self, checkpoint):
        best_perp = numpy.min(self.validator.best_perps[checkpoint])
        best_cpkt_path = self.validator.get_cpkt_path(checkpoint, best_perp)

        self.logger.info('Restore best cpkt from {}'.format(best_cpkt_path))
        self.model.load_state_dict(torch.load(best_cpkt_path))

    def maybe_validate(self, just_validate=False):
        if self.num_batches_done % self.validate_freq == 0 or just_validate:
            self.save_checkpoint()
            self.validator.validate_and_save(self.model)

            # if doing annealing
            if self.config[
                    'warmup_style'] == ac.NO_WARMUP and self.lr_decay > 0:
                cond = len(
                    self.validator.perp_curve
                ) > self.patience and self.validator.perp_curve[-1] > max(
                    self.validator.perp_curve[-1 - self.patience:-1])
                if cond:
                    metric = 'perp'
                    scores = self.validator.perp_curve[-1 - self.patience:]
                    scores = map(str, list(scores))
                    scores = ', '.join(scores)

                    self.logger.info('Past {} are {}'.format(metric, scores))
                    # when don't use warmup, decay lr if dev not improve
                    if self.lr * self.lr_decay >= self.config['min_lr']:
                        self.logger.info(
                            'Anneal the learning rate from {} to {}'.format(
                                self.lr, self.lr * self.lr_decay))
                        self.lr = self.lr * self.lr_decay
                        for p in self.optimizer.param_groups:
                            p['lr'] = self.lr
Ejemplo n.º 8
0
class Trainer(object):
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.config['fixed_var_list'] = args.fixed_var_list
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.lr = self.config['lr']
        self.max_epochs = self.config['max_epochs']
        self.save_freq = self.config['save_freq']
        self.cpkt_path = None
        self.validate_freq = None
        self.train_perps = []

        self.saver = None
        self.train_m = None
        self.dev_m = None

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)
        self.validate_freq = ut.get_validation_frequency(
            self.data_manager.length_files[ac.TRAINING],
            self.config['validate_freq'], self.config['batch_size'])
        self.logger.info('Evaluate every {} batches'.format(
            self.validate_freq))

        _, self.src_ivocab = self.data_manager.init_vocab(
            self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(
            self.data_manager.trg_lang)

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

    def get_model(self, mode):
        reuse = mode != ac.TRAINING
        d = self.config['init_range']
        initializer = tf.random_uniform_initializer(-d, d, seed=ac.SEED)
        with tf.variable_scope(self.config['model_name'],
                               reuse=reuse,
                               initializer=initializer):
            return Model(self.config, mode)

    def reload_and_get_cpkt_saver(self, config, sess):
        cpkt_path = join(config['save_to'],
                         '{}.cpkt'.format(config['model_name']))
        saver = tf.train.Saver(var_list=tf.trainable_variables(),
                               max_to_keep=config['n_best'] + 1)

        # TF some version >= 0.11, no longer .cpkt so check for meta file instead
        if exists(cpkt_path + '.meta') and config['reload']:
            self.logger.info('Reload model from {}'.format(cpkt_path))
            saver.restore(sess, cpkt_path)

        self.cpkt_path = cpkt_path
        self.saver = saver

    def train(self):
        tf.reset_default_graph()
        with tf.Graph().as_default():
            tf.set_random_seed(ac.SEED)
            self.train_m = self.get_model(ac.TRAINING)
            self.dev_m = self.get_model(ac.VALIDATING)

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                self.reload_and_get_cpkt_saver(self.config, sess)
                self.logger.info('Set learning rate to {}'.format(self.lr))
                sess.run(tf.assign(self.train_m.lr, self.lr))

                for e in xrange(self.max_epochs):
                    for b, batch_data in self.data_manager.get_batch(
                            mode=ac.TRAINING, num_batches=self.num_preload):
                        self.run_log_save(sess, b, e, batch_data)
                        self.maybe_validate(sess)

                    self.report_epoch(e)

                self.logger.info('It is finally done, mate!')
                self.logger.info('Train perplexities:')
                self.logger.info(', '.join(map(str, self.train_perps)))
                numpy.save(join(self.config['save_to'], 'train_perps.npy'),
                           self.train_perps)

                self.logger.info('Save final checkpoint')
                self.saver.save(sess, self.cpkt_path)

                # Evaluate on test
                if exists(self.data_manager.ids_files[ac.TESTING]):
                    self.logger.info('Evaluate on test')
                    best_bleu = numpy.max(self.validator.best_bleus)
                    best_cpkt_path = self.validator.get_cpkt_path(best_bleu)
                    self.logger.info(
                        'Restore best cpkt from {}'.format(best_cpkt_path))
                    self.saver.restore(sess, best_cpkt_path)
                    self.validator.evaluate(sess, self.dev_m, mode=ac.TESTING)

    def report_epoch(self, e):
        self.logger.info('Finish epoch {}'.format(e + 1))
        self.logger.info('    It takes {}'.format(
            ut.format_seconds(self.epoch_time)))
        self.logger.info('    Avergage # words/second    {}'.format(
            self.epoch_weights / self.epoch_time))
        self.logger.info('    Average seconds/batch    {}'.format(
            self.epoch_time / self.epoch_batches_done))

        train_perp = self.epoch_loss / self.epoch_weights

        self.epoch_batches_done = 0
        self.epoch_time = 0.
        self.epoch_loss = 0.
        self.epoch_weights = 0.

        train_perp = numpy.exp(train_perp) if train_perp < 300 else float(
            'inf')
        self.train_perps.append(train_perp)
        self.logger.info('    train perplexity: {}'.format(train_perp))

    def sample_input(self, batch_data):
        # TODO: more on sampling
        src_sent = batch_data[0][0]
        src_length = batch_data[1][0]
        trg_sent = batch_data[2][0]
        target = batch_data[3][0]
        weight = batch_data[4][0]

        src_sent = map(self.src_ivocab.get, src_sent)
        src_sent = u' '.join(src_sent)
        trg_sent = map(self.trg_ivocab.get, trg_sent)
        trg_sent = u' '.join(trg_sent)
        target = map(self.trg_ivocab.get, target)
        target = u' '.join(target)
        weight = ' '.join(map(str, weight))

        self.logger.info('Sample input data:')
        self.logger.info(u'Src: {}'.format(src_sent))
        self.logger.info(u'Src len: {}'.format(src_length))
        self.logger.info(u'Trg: {}'.format(trg_sent))
        self.logger.info(u'Tar: {}'.format(target))
        self.logger.info(u'W: {}'.format(weight))

    def run_log_save(self, sess, b, e, batch_data):
        start = time.time()
        src_inputs, src_seq_lengths, trg_inputs, trg_targets, target_weights = batch_data
        feed = {
            self.train_m.src_inputs: src_inputs,
            self.train_m.src_seq_lengths: src_seq_lengths,
            self.train_m.trg_inputs: trg_inputs,
            self.train_m.trg_targets: trg_targets,
            self.train_m.target_weights: target_weights
        }
        loss, _ = sess.run([self.train_m.loss, self.train_m.train_op], feed)

        num_trg_words = numpy.sum(target_weights)
        self.num_batches_done += 1
        self.epoch_batches_done += 1
        self.epoch_loss += loss
        self.epoch_weights += num_trg_words
        self.log_train_loss += loss
        self.log_train_weights += num_trg_words
        self.epoch_time += time.time() - start

        if self.num_batches_done % (10 * self.log_freq) == 0:
            self.sample_input(batch_data)

        if self.num_batches_done % self.log_freq == 0:
            acc_speed_word = self.epoch_weights / self.epoch_time
            acc_speed_time = self.epoch_time / self.epoch_batches_done

            avg_word_perp = self.log_train_loss / self.log_train_weights
            avg_word_perp = numpy.exp(
                avg_word_perp) if avg_word_perp < 300 else float('inf')
            self.log_train_loss = 0.
            self.log_train_weights = 0.

            self.logger.info('Batch {}, epoch {}/{}:'.format(
                b, e + 1, self.max_epochs))
            self.logger.info(
                '   avg word perp:   {0:.2f}'.format(avg_word_perp))
            self.logger.info('   acc trg words/s: {}'.format(
                int(acc_speed_word)))
            self.logger.info(
                '   acc sec/batch:   {0:.2f}'.format(acc_speed_time))

        if self.num_batches_done % self.save_freq == 0:
            start = time.time()
            self.saver.save(sess, self.cpkt_path)
            self.logger.info('Save model to {}, takes {}'.format(
                self.cpkt_path, ut.format_seconds(time.time() - start)))

    def maybe_validate(self, sess):
        if self.num_batches_done % self.validate_freq == 0:
            self.validator.validate_and_save(sess, self.dev_m, self.saver)

    def _call_me_maybe(self):
        pass  # NO
Ejemplo n.º 9
0
class Model(nn.Module):
    """Model"""
    def __init__(self, config, load_from=None):
        super(Model, self).__init__()
        self.config = config
        self.struct = self.config['struct']
        self.decoder_mask = None
        self.data_manager = DataManager(config, init_vocab=(not load_from))

        if load_from:
            self.load_state_dict(torch.load(load_from,
                                            map_location=ut.get_device()),
                                 do_init=True)
        else:
            self.init_embeddings()
            self.init_model()
            self.add_struct_params()

        # dict where keys are data_ptrs to dicts of parameter options
        # see https://pytorch.org/docs/stable/optim.html#per-parameter-options
        self.parameter_attrs = {}

    def init_embeddings(self):
        embed_dim = self.config['embed_dim']
        tie_mode = self.config['tie_mode']
        fix_norm = self.config['fix_norm']
        max_src_len = self.config['max_src_length']
        max_trg_len = self.config['max_trg_length']

        device = ut.get_device()

        # get trg positonal embedding
        if not self.config['learned_pos_trg']:
            self.pos_embedding_trg = ut.get_position_embedding(
                embed_dim, max_trg_len)
        else:
            self.pos_embedding_trg = Parameter(
                torch.empty(max_trg_len,
                            embed_dim,
                            dtype=torch.float,
                            device=device))
            nn.init.normal_(self.pos_embedding_trg,
                            mean=0,
                            std=embed_dim**-0.5)

        # get word embeddings
        # TODO: src_vocab_mask is assigned but never used (?)
        #src_vocab_size, trg_vocab_size = ut.get_vocab_sizes(self.config)
        #self.src_vocab_mask, self.trg_vocab_mask = ut.get_vocab_masks(self.config, src_vocab_size, trg_vocab_size)
        #if tie_mode == ac.ALL_TIED:
        #    src_vocab_size = trg_vocab_size = self.trg_vocab_mask.shape[0]
        self.src_vocab_mask = self.data_manager.vocab_masks[
            self.data_manager.src_lang]
        self.trg_vocab_mask = self.data_manager.vocab_masks[
            self.data_manager.trg_lang]
        src_vocab_size = self.src_vocab_mask.shape[0]
        trg_vocab_size = self.trg_vocab_mask.shape[0]

        self.out_bias = Parameter(
            torch.empty(trg_vocab_size, dtype=torch.float, device=device))
        nn.init.constant_(self.out_bias, 0.)

        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim)
        self.out_embedding = self.trg_embedding.weight
        if self.config['separate_embed_scales']:
            self.src_embed_scale = Parameter(
                torch.tensor([embed_dim**0.5], device=device))
            self.trg_embed_scale = Parameter(
                torch.tensor([embed_dim**0.5], device=device))
        else:
            self.src_embed_scale = self.trg_embed_scale = torch.tensor(
                [embed_dim**0.5], device=device)

        self.src_pos_embed_scale = torch.tensor([(embed_dim / 2)**0.5],
                                                device=device)
        self.trg_pos_embed_scale = torch.tensor(
            [1.], device=device
        )  # trg pos embedding already returns vector of norm sqrt(embed_dim/2)
        if self.config['learn_pos_scale']:
            self.src_pos_embed_scale = Parameter(self.src_pos_embed_scale)
            self.trg_pos_embed_scale = Parameter(self.trg_pos_embed_scale)

        if tie_mode == ac.ALL_TIED:
            self.src_embedding.weight = self.trg_embedding.weight

        if not fix_norm:
            nn.init.normal_(self.src_embedding.weight,
                            mean=0,
                            std=embed_dim**-0.5)
            nn.init.normal_(self.trg_embedding.weight,
                            mean=0,
                            std=embed_dim**-0.5)
        else:
            d = 0.01  # pure magic
            nn.init.uniform_(self.src_embedding.weight, a=-d, b=d)
            nn.init.uniform_(self.trg_embedding.weight, a=-d, b=d)

    def init_model(self):
        num_enc_layers = self.config['num_enc_layers']
        num_enc_heads = self.config['num_enc_heads']
        num_dec_layers = self.config['num_dec_layers']
        num_dec_heads = self.config['num_dec_heads']

        embed_dim = self.config['embed_dim']
        ff_dim = self.config['ff_dim']
        dropout = self.config['dropout']
        norm_in = self.config['norm_in']

        # get encoder, decoder
        self.encoder = Encoder(num_enc_layers,
                               num_enc_heads,
                               embed_dim,
                               ff_dim,
                               dropout=dropout,
                               norm_in=norm_in)
        self.decoder = Decoder(num_dec_layers,
                               num_dec_heads,
                               embed_dim,
                               ff_dim,
                               dropout=dropout,
                               norm_in=norm_in)

        # leave layer norm alone
        init_func = nn.init.xavier_normal_ if self.config[
            'weight_init_type'] == ac.XAVIER_NORMAL else nn.init.xavier_uniform_
        for m in [
                self.encoder.self_atts, self.encoder.pos_ffs,
                self.decoder.self_atts, self.decoder.pos_ffs,
                self.decoder.enc_dec_atts
        ]:
            for p in m.parameters():
                if p.dim() > 1:
                    init_func(p)
                else:
                    nn.init.constant_(p, 0.)

    def add_struct_params(self):
        self.struct_params = self.struct.get_params(self.config) if hasattr(
            self.struct, "get_params") else {}
        if self.config['learned_pos_src']:
            self.struct_params = {
                name: Parameter(x)
                for name, x in self.struct_params.items()
            }
            for name, x in self.struct_params.items():
                if not name.endswith('__const__'):
                    self.register_parameter(name, x)

    def get_decoder_mask(self, size):
        if self.decoder_mask is None or self.decoder_mask.size()[-1] < size:
            self.decoder_mask = torch.triu(torch.ones((1, 1, size, size),
                                                      dtype=torch.bool,
                                                      device=ut.get_device()),
                                           diagonal=1)
            return self.decoder_mask
        else:
            return self.decoder_mask[:, :, :size, :size]

    def get_pos_embedding_h(self, x):
        embed_dim = self.config['embed_dim']
        pe = x.get_pos_embedding(embed_dim, **self.struct_params)
        if isinstance(pe, type(x)): pe = pe.flatten()
        return pe if torch.is_tensor(pe) else torch.stack(
            pe)  # [bsz, embed_dim]

    def get_pos_embedding(self, max_len, structs=None):
        if structs is not None:
            pe = [self.get_pos_embedding_h(x) for x in structs]
            return torch.nn.utils.rnn.pad_sequence(
                pe, batch_first=True)  # [bsz, max_len, embed_dim]
        else:
            return self.pos_embedding_trg[:max_len, :].unsqueeze(
                0)  # [1, max_len, embed_dim]

    def get_input(self, toks, structs=None, calc_reg=False):
        max_len = toks.size()[-1]
        embed_dim = self.config['embed_dim']
        embeds = self.src_embedding if structs is not None else self.trg_embedding
        word_embeds = embeds(toks)  # [bsz, max_len, embed_dim]
        embed_scale = self.trg_embed_scale if structs is None else self.src_embed_scale

        if self.config['fix_norm']:
            word_embeds = ut.normalize(word_embeds, scale=False)
        else:
            word_embeds = word_embeds * embed_scale

        pos_embeds = self.get_pos_embedding(max_len, structs)
        pe_scale = self.src_pos_embed_scale if structs is not None else self.trg_pos_embed_scale
        reg_penalty = 0.0
        if calc_reg:
            reg_penalty = self.struct.get_reg_penalty(
                pos_embeds,
                toks != ac.PAD_ID) * self.config['pos_norm_penalty']
        sinusoidal_pe = self.get_pos_embedding(
            max_len) if structs is not None and self.config[
                'add_sinusoidal_pe_src'] else 0
        return word_embeds + sinusoidal_pe + pos_embeds * pe_scale, reg_penalty

    def get_encoder_masks(self, src_toks, src_structs):
        encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze(
            2)  # [bsz, 1, 1, max_src_len]
        if hasattr(self.struct, "get_enc_mask"):
            encoder_mask_down = self.struct.get_enc_mask(
                src_toks, src_structs, self.config['num_enc_heads'],
                **self.struct_params)
        else:
            encoder_mask_down = encoder_mask
        return encoder_mask, encoder_mask_down

    def forward(self,
                src_toks,
                src_structs,
                trg_toks,
                targets,
                b=None,
                e=None):
        encoder_mask, encoder_mask_down = self.get_encoder_masks(
            src_toks, src_structs)

        decoder_mask = self.get_decoder_mask(trg_toks.size()[-1])

        encoder_inputs, reg_penalty = self.get_input(src_toks,
                                                     src_structs,
                                                     calc_reg=hasattr(
                                                         self.struct,
                                                         "get_reg_penalty"))

        encoder_outputs = self.encoder(encoder_inputs, encoder_mask_down)

        decoder_inputs, _ = self.get_input(trg_toks)
        decoder_outputs = self.decoder(decoder_inputs, decoder_mask,
                                       encoder_outputs, encoder_mask)

        logits = self.logit_fn(decoder_outputs)
        neglprobs = F.log_softmax(logits, -1)
        neglprobs = neglprobs * self.trg_vocab_mask.reshape(1, -1)
        targets = targets.reshape(-1, 1)
        non_pad_mask = targets != ac.PAD_ID
        nll_loss = -neglprobs.gather(dim=-1, index=targets)
        #nll_loss = nll_loss[non_pad_mask] # speed
        nll_loss = nll_loss * non_pad_mask
        #smooth_loss = -neglprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
        smooth_loss = -neglprobs.sum(dim=-1, keepdim=True) * non_pad_mask

        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
        label_smoothing = self.config['label_smoothing']

        if label_smoothing > 0:
            loss = (
                1.0 - label_smoothing
            ) * nll_loss + label_smoothing * smooth_loss / self.trg_vocab_mask.sum(
            )
        else:
            loss = nll_loss

        loss += reg_penalty

        return {'loss': loss, 'nll_loss': nll_loss}

    def logit_fn(self, decoder_output):
        softmax_weight = self.out_embedding if not self.config[
            'fix_norm'] else ut.normalize(self.out_embedding, scale=True)
        logits = F.linear(decoder_output, softmax_weight, bias=self.out_bias)
        logits = logits.reshape(-1, logits.size()[-1])
        #logits[:, ~self.trg_vocab_mask] = -1e9 # speed
        logits.masked_fill_(~self.trg_vocab_mask.unsqueeze(0), -3e38)  #-1e9)
        return logits

    def beam_decode(self, src_toks, src_structs):
        """Translate a minibatch of sentences. 

        Arguments: src_toks[i,j] is the jth word of sentence i.

        Return: See encoders.Decoder.beam_decode
        """
        #encoder_mask = (src_toks == ac.PAD_ID).unsqueeze(1).unsqueeze(2) # [bsz, 1, 1, max_src_len]
        encoder_mask, encoder_mask_down = self.get_encoder_masks(
            src_toks, src_structs)
        encoder_inputs, _ = self.get_input(src_toks, src_structs)
        encoder_outputs = self.encoder(encoder_inputs, encoder_mask_down)
        max_lengths1 = torch.sum(src_toks != ac.PAD_ID, dim=-1).type(
            src_toks.type()) + 50
        max_lengths2 = torch.tensor(self.config['max_trg_length']).type(
            src_toks.type())
        max_lengths = torch.min(max_lengths1, max_lengths2)

        def get_trg_inp(ids, time_step):
            ids = ids.type(src_toks.type())
            word_embeds = self.trg_embedding(ids)
            if self.config['fix_norm']:
                word_embeds = ut.normalize(word_embeds, scale=False)
            else:
                word_embeds = word_embeds * self.trg_embed_scale

            pos_embeds = self.pos_embedding_trg[time_step, :].reshape(1, 1, -1)
            return word_embeds + pos_embeds * self.trg_pos_embed_scale

        def logprob(decoder_output):
            return F.log_softmax(self.logit_fn(decoder_output), dim=-1)

        if self.config['length_model'] == ac.GNMT_LENGTH_MODEL:
            length_model = ut.gnmt_length_model(self.config['length_alpha'])
        elif self.config['length_model'] == ac.LINEAR_LENGTH_MODEL:
            length_model = lambda t, p: p + self.config['length_alpha'] * t
        elif self.config['length_model'] == ac.NO_LENGTH_MODEL:
            length_model = lambda t, p: p
        else:
            raise ValueError('invalid length_model ' +
                             str(self.config[length_model]))

        return self.decoder.beam_decode(encoder_outputs,
                                        encoder_mask,
                                        get_trg_inp,
                                        logprob,
                                        length_model,
                                        ac.BOS_ID,
                                        ac.EOS_ID,
                                        max_lengths,
                                        beam_size=self.config['beam_size'])

    def load_state_dict(self, loaded_dict, do_init=False):
        state_dict = loaded_dict['model']
        vocabs = loaded_dict['data_manager']
        self.data_manager.load_state_dict(vocabs)
        if do_init:
            self.init_embeddings()
            self.init_model()
            self.add_struct_params()
        super().load_state_dict(state_dict)

    def save(self, fp=None):
        fp = fp or os.path.join(self.config['save_to'],
                                self.config['model_name'] + '.pth')
        cpkt = {
            'model': self.state_dict(),
            'data_manager': self.data_manager.state_dict(),
        }
        torch.save(cpkt, fp)

    def translate(self,
                  input_file_or_stream,
                  best_output_stream,
                  beam_output_stream,
                  num_preload=ac.DEFAULT_NUM_PRELOAD,
                  to_ids=False):
        return self.data_manager.translate(self,
                                           input_file_or_stream,
                                           best_output_stream,
                                           beam_output_stream,
                                           num_preload=num_preload,
                                           to_ids=to_ids)
Ejemplo n.º 10
0
class Translator(object):
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file

        if self.input_file is None or self.model_file is None or not os.path.exists(
                self.input_file) or not os.path.exists(self.model_file):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        self.translate()

    def ids_to_trans(self, trans_ids):
        words = []
        word_ids = []
        for idx, word_idx in enumerate(trans_ids):
            if word_idx == ac.EOS_ID:
                break
            words.append(self.data_manager.trg_ivocab[word_idx])
            word_ids.append(word_idx)

        return u' '.join(words), word_ids

    def get_trans(self, probs, scores, symbols):
        sorted_rows = numpy.argsort(scores)[::-1]
        best_trans = None
        best_tran_ids = None
        beam_trans = []
        for i, r in enumerate(sorted_rows):
            trans_ids = symbols[r]
            trans_out, trans_out_ids = self.ids_to_trans(trans_ids)
            beam_trans.append(u'{} {:.2f} {:.2f}'.format(
                trans_out, scores[r], probs[r]))
            if i == 0:  # highest prob trans
                best_trans = trans_out
                best_tran_ids = trans_out_ids

        return best_trans, best_tran_ids, u'\n'.join(beam_trans)

    def plot_head_map(self, mma, target_labels, target_ids, source_labels,
                      source_ids, filename):
        """https://github.com/EdinburghNLP/nematus/blob/master/utils/plot_heatmap.py
        Change the font in family param below. If the system font is not used, delete matplotlib
        font cache https://github.com/matplotlib/matplotlib/issues/3590
        """
        fig, ax = plt.subplots()
        heatmap = ax.pcolor(mma, cmap=plt.cm.Blues)

        # put the major ticks at the middle of each cell
        ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False)
        ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False)

        # without this I get some extra columns rows
        # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column
        ax.set_xlim(0, int(mma.shape[1]))
        ax.set_ylim(0, int(mma.shape[0]))

        # want a more natural, table-like display
        ax.invert_yaxis()
        ax.xaxis.tick_top()

        # source words -> column labels
        ax.set_xticklabels(source_labels,
                           minor=False,
                           family='Source Code Pro')
        for xtick, idx in zip(ax.get_xticklabels(), source_ids):
            if idx == ac.UNK_ID:
                xtick.set_color('b')
        # target words -> row labels
        ax.set_yticklabels(target_labels,
                           minor=False,
                           family='Source Code Pro')
        for ytick, idx in zip(ax.get_yticklabels(), target_ids):
            if idx == ac.UNK_ID:
                ytick.set_color('b')

        plt.xticks(rotation=45)

        plt.tight_layout()
        plt.savefig(filename)
        plt.close('all')

    def translate(self):
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = Model(self.config).to(device)
        self.logger.info('Restore model from {}'.format(self.model_file))
        model.load_state_dict(torch.load(self.model_file))
        model.eval()

        best_trans_file = self.input_file + '.best_trans'
        beam_trans_file = self.input_file + '.beam_trans'
        open(best_trans_file, 'w').close()
        open(beam_trans_file, 'w').close()

        num_sents = 0
        with open(self.input_file, 'r') as f:
            for line in f:
                if line.strip():
                    num_sents += 1
        all_best_trans = [''] * num_sents
        all_beam_trans = [''] * num_sents

        with torch.no_grad():
            self.logger.info('Start translating {}'.format(self.input_file))
            start = time.time()
            count = 0
            for (src_toks, original_idxs) in self.data_manager.get_trans_input(
                    self.input_file):
                src_toks_cuda = src_toks.to(device)
                rets = model.beam_decode(src_toks_cuda)

                for i, ret in enumerate(rets):
                    probs = ret['probs'].cpu().detach().numpy().reshape([-1])
                    scores = ret['scores'].cpu().detach().numpy().reshape([-1])
                    symbols = ret['symbols'].cpu().detach().numpy()

                    best_trans, best_trans_ids, beam_trans = self.get_trans(
                        probs, scores, symbols)
                    all_best_trans[original_idxs[i]] = best_trans + '\n'
                    all_beam_trans[original_idxs[i]] = beam_trans + '\n\n'

                    count += 1
                    if count % 100 == 0:
                        self.logger.info(
                            '  Translating line {}, average {} seconds/sent'.
                            format(count, (time.time() - start) / count))

        model.train()

        with open(best_trans_file, 'w') as ftrans, open(beam_trans_file,
                                                        'w') as btrans:
            ftrans.write(''.join(all_best_trans))
            btrans.write(''.join(all_beam_trans))

        self.logger.info('Done translating {}, it takes {} minutes'.format(
            self.input_file,
            float(time.time() - start) / 60.0))
Ejemplo n.º 11
0
class Translator(object):
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.reverse = self.config['reverse']
        self.unk_repl = self.config['unk_repl']
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file
        self.plot_align = args.plot_align

        if self.input_file is None or self.model_file is None or not os.path.exists(
                self.input_file) or not os.path.exists(self.model_file +
                                                       '.meta'):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        _, self.src_ivocab = self.data_manager.init_vocab(
            self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(
            self.data_manager.trg_lang)
        self.translate()

    def get_model(self, mode):
        reuse = mode != ac.TRAINING
        d = self.config['init_range']
        initializer = tf.random_uniform_initializer(-d, d)
        with tf.variable_scope(self.config['model_name'],
                               reuse=reuse,
                               initializer=initializer):
            return Model(self.config, mode)

    def translate(self):
        with tf.Graph().as_default():
            train_model = self.get_model(ac.TRAINING)
            model = self.get_model(ac.VALIDATING)

            with tf.Session() as sess:
                self.logger.info('Restore model from {}'.format(
                    self.model_file))
                saver = tf.train.Saver(var_list=tf.trainable_variables())
                saver.restore(sess, self.model_file)

                best_trans_file = self.input_file + '.best_trans'
                beam_trans_file = self.input_file + '.beam_trans'
                open(best_trans_file, 'w').close()
                open(beam_trans_file, 'w').close()
                ftrans = open(best_trans_file, 'w', 'utf-8')
                btrans = open(beam_trans_file, 'w', 'utf-8')

                self.logger.info('Start translating {}'.format(
                    self.input_file))
                start = time.time()
                count = 0
                for (src_input, src_seq_len,
                     no_unk_src_toks) in self.data_manager.get_trans_input(
                         self.input_file):
                    feed = {
                        model.src_inputs: src_input,
                        model.src_seq_lengths: src_seq_len
                    }
                    probs, scores, symbols, parents, alignments = sess.run(
                        [
                            model.probs, model.scores, model.symbols,
                            model.parents, model.alignments
                        ],
                        feed_dict=feed)
                    alignments = numpy.transpose(alignments, axes=(1, 0, 2))

                    probs = numpy.transpose(numpy.array(probs))
                    scores = numpy.transpose(numpy.array(scores))
                    symbols = numpy.transpose(numpy.array(symbols))
                    parents = numpy.transpose(numpy.array(parents))

                    best_trans, best_trans_ids, beam_trans, best_trans_alignments = ut.get_trans(
                        probs,
                        scores,
                        symbols,
                        parents,
                        alignments,
                        no_unk_src_toks,
                        self.trg_ivocab,
                        reverse=self.reverse,
                        unk_repl=self.unk_repl)
                    best_trans_wo_eos = best_trans.split()[:-1]
                    best_trans_wo_eos = u' '.join(best_trans_wo_eos)
                    ftrans.write(best_trans_wo_eos + '\n')
                    btrans.write(beam_trans + '\n\n')

                    if self.plot_align:
                        src_input = numpy.reshape(src_input, [-1])
                        if self.reverse:
                            src_input = src_input[::-1]
                            no_unk_src_toks = no_unk_src_toks[::-1]
                        trans_toks = best_trans.split()
                        best_trans_alignments = numpy.array(
                            best_trans_alignments)[:len(trans_toks)]
                        filename = '{}_{}.png'.format(self.input_file, count)

                        ut.plot_head_map(best_trans_alignments, trans_toks,
                                         best_tran_ids, no_unk_src_toks,
                                         src_input, filename)

                    count += 1
                    if count % 100 == 0:
                        self.logger.info(
                            '  Translating line {}, average {} seconds/sent'.
                            format(count, (time.time() - start) / count))

                ftrans.close()
                btrans.close()

                self.logger.info(
                    'Done translating {}, it takes {} minutes'.format(
                        self.input_file,
                        float(time.time() - start) / 60.0))