コード例 #1
0
    def __init__(self, config, init_vocab=True):
        super(DataManager, self).__init__()
        self.logger = ut.get_logger(config['log_file'])
        self.src_lang = config['src_lang']
        self.trg_lang = config['trg_lang']
        self.data_dir = config['data_dir']
        self.save_to = config['save_to']
        self.batch_size = config['batch_size']
        self.one_embedding = config['tie_mode'] == ac.ALL_TIED
        self.share_vocab = config['share_vocab']
        self.word_dropout = config['word_dropout']
        self.batch_sort_src = config['batch_sort_src']
        self.max_src_length = config['max_src_length']
        self.max_trg_length = config['max_trg_length']
        self.parse_struct = config['struct'].parse
        self.training_tok_counts = (-1, -1)
        self.vocab_masks = {}

        self.vocab_sizes = {
            self.src_lang: config['src_vocab_size'],
            self.trg_lang: config['trg_vocab_size'],
            'joint': config['joint_vocab_size']
        }

        if init_vocab:
            self.setup()
        else:
            self.data_files = None
            self.ids_files = None
コード例 #2
0
    def __init__(self, config, data_manager):
        super(Validator, self).__init__()
        self.logger = ut.get_logger(config['log_file'])
        self.logger.info('Initializing validator')

        self.data_manager = data_manager
        
        def get_cpkt_path(score):
            return join(config['save_to'], '{}-{}.cpkt'.format(config['model_name'], score))

        self.get_cpkt_path = get_cpkt_path
        self.n_best = config['n_best']

        self.bleu_script = './scripts/multi-bleu.perl'
        assert exists(self.bleu_script)

        self.save_to = config['save_to']
        if not exists(self.save_to):
            os.makedirs(self.save_to)

        self.val_trans_out = config['val_trans_out']
        self.val_beam_out = config['val_beam_out']

        self.dev_ref = self.data_manager.data_files[ac.VALIDATING][self.data_manager.trg_lang]
        self.test_ref = self.data_manager.data_files[ac.TESTING][self.data_manager.trg_lang]

        self.bleu_curve_path = join(self.save_to, 'bleu_scores.npy')
        self.best_bleus_path = join(self.save_to, 'best_bleu_scores.npy')
        self.bleu_curve = numpy.array([], dtype=numpy.float32)
        self.best_bleus = numpy.array([], dtype=numpy.float32)

        if exists(self.bleu_curve_path):
            self.bleu_curve = numpy.load(self.bleu_curve_path)
        if exists(self.best_bleus_path):
            self.best_bleus = numpy.load(self.best_bleus_path)
コード例 #3
0
ファイル: extractor.py プロジェクト: tnq177/witwicky
    def __init__(self, args):
        super(Extractor, self).__init__()
        config = getattr(configurations, args.proto)()
        self.logger = ut.get_logger(config['log_file'])
        self.model_file = args.model_file

        var_list = args.var_list
        save_to = args.save_to

        if var_list is None:
            raise ValueError('Empty var list')

        if self.model_file is None or not os.path.exists(self.model_file):
            raise ValueError('Input file or model file does not exist')

        if not os.path.exists(save_to):
            os.makedirs(save_to)

        self.logger.info('Extracting these vars: {}'.format(
            ', '.join(var_list)))

        model = Model(config)
        model.load_state_dict(torch.load(self.model_file))
        var_values = operator.attrgetter(*var_list)(model)

        if len(var_list) == 1:
            var_values = [var_values]

        for var, var_value in zip(var_list, var_values):
            var_path = os.path.join(save_to, var + '.npy')
            numpy.save(var_path, var_value.numpy())
コード例 #4
0
    def get_input(self, toks, is_src=True):
        embeds = self.src_embedding if is_src else self.trg_embedding
        word_embeds = embeds(toks)  # [bsz, max_len, embed_dim]
        if self.config['fix_norm']:
            word_embeds = ut.normalize(word_embeds, scale=False)
        else:
            word_embeds = word_embeds * self.embed_scale

        if toks.size()[-1] > self.pos_embedding.size()[-2]:
            ut.get_logger().error(
                "Sentence length ({}) is longer than max_pos_length ({}); please increase max_pos_length"
                .format(toks.size()[-1],
                        self.pos_embedding.size()[0]))

        pos_embeds = self.pos_embedding[:toks.size()[-1], :].unsqueeze(
            0)  # [1, max_len, embed_dim]
        return word_embeds + pos_embeds
コード例 #5
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        self.normalize_loss = self.config['normalize_loss']
        self.patience = self.config['patience']
        self.lr = self.config['lr']
        self.lr_decay = self.config['lr_decay']
        self.max_epochs = self.config['max_epochs']
        self.warmup_steps = self.config['warmup_steps']

        self.train_smooth_perps = []
        self.train_true_perps = []

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)

        self.val_per_epoch = self.config['val_per_epoch']
        self.validate_freq = int(self.config['validate_freq'])
        self.logger.info('Evaluate every {} {}'.format(
            self.validate_freq, 'epochs' if self.val_per_epoch else 'batches'))

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_nll_loss = 0.
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        self.model = Model(self.config).to(self.device)

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info('Model has {:,} parameters'.format(param_count))

        # get optimizer
        beta1 = self.config['beta1']
        beta2 = self.config['beta2']
        epsilon = self.config['epsilon']
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          betas=(beta1, beta2),
                                          eps=epsilon)
コード例 #6
0
ファイル: translate.py プロジェクト: ankitshah009/witwicky
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file

        if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        self.translate()
コード例 #7
0
    def __init__(self, config, data_manager):
        super(Validator, self).__init__()
        self.logger = ut.get_logger(config['log_file'])
        self.logger.info('Initializing validator')

        self.data_manager = data_manager
        self.restore_segments = config['restore_segments']
        self.val_by_bleu = config['val_by_bleu']

        self.get_cpkt_path = lambda score: join(
            config['save_to'], '{}-{}.pth'.format(config['model_name'], score))
        self.n_best = config['n_best']

        scriptdir = os.path.dirname(os.path.abspath(__file__))
        self.bleu_script = '{}/../scripts/multi-bleu.perl'.format(scriptdir)
        assert exists(self.bleu_script)

        self.save_to = config['save_to']
        if not exists(self.save_to):
            os.makedirs(self.save_to)

        self.val_trans_out = config['val_trans_out']
        self.val_beam_out = config['val_beam_out']

        # I'll leave test alone for now since this version of the code doesn't automatically
        # report BLEU on test anw. The reason is it's up to the dataset to use multi-bleu
        # or NIST bleu. I'll include it in the future
        self.dev_ref = self.data_manager.data_files[ac.VALIDATING][
            self.data_manager.trg_lang]
        if self.restore_segments:
            self.dev_ref = self.remove_bpe(self.dev_ref)

        self.perp_curve_path = join(self.save_to, 'dev_perps.npy')
        self.best_perps_path = join(self.save_to, 'best_perp_scores.npy')
        self.perp_curve = numpy.array([], dtype=numpy.float32)
        self.best_perps = numpy.array([], dtype=numpy.float32)
        if exists(self.perp_curve_path):
            self.perp_curve = numpy.load(self.perp_curve_path)
        if exists(self.best_perps_path):
            self.best_perps = numpy.load(self.best_perps_path)

        if self.val_by_bleu:
            self.bleu_curve_path = join(self.save_to, 'bleu_scores.npy')
            self.best_bleus_path = join(self.save_to, 'best_bleu_scores.npy')
            self.bleu_curve = numpy.array([], dtype=numpy.float32)
            self.best_bleus = numpy.array([], dtype=numpy.float32)
            if exists(self.bleu_curve_path):
                self.bleu_curve = numpy.load(self.bleu_curve_path)
            if exists(self.best_bleus_path):
                self.best_bleus = numpy.load(self.best_bleus_path)
コード例 #8
0
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.reverse = self.config['reverse']
        self.logger = ut.get_logger(self.config['log_file'])

        self.input_file = args.input_file
        self.model_file = args.model_file
        self.plot_align = args.plot_align
        self.unk_repl   = args.unk_repl

        if self.input_file is None or self.model_file is None or not os.path.exists(self.input_file) or not os.path.exists(self.model_file + '.meta'):
            raise ValueError('Input file or model file does not exist')

        self.data_manager = DataManager(self.config)
        _, self.src_ivocab = self.data_manager.init_vocab(self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(self.data_manager.trg_lang)
        self.translate()
コード例 #9
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = getattr(configurations, args.proto)()
        self.config['fixed_var_list'] = args.fixed_var_list
        self.num_preload = args.num_preload
        self.logger = ut.get_logger(self.config['log_file'])

        self.lr = self.config['lr']
        self.max_epochs = self.config['max_epochs']
        self.save_freq = self.config['save_freq']
        self.cpkt_path = None
        self.validate_freq = None
        self.train_perps = []

        self.saver = None
        self.train_m = None
        self.dev_m = None

        self.data_manager = DataManager(self.config)
        self.validator = Validator(self.config, self.data_manager)
        self.validate_freq = ut.get_validation_frequency(
            self.data_manager.length_files[ac.TRAINING],
            self.config['validate_freq'], self.config['batch_size'])
        self.logger.info('Evaluate every {} batches'.format(
            self.validate_freq))

        _, self.src_ivocab = self.data_manager.init_vocab(
            self.data_manager.src_lang)
        _, self.trg_ivocab = self.data_manager.init_vocab(
            self.data_manager.trg_lang)

        # For logging
        self.log_freq = 100  # log train stat every this-many batches
        self.log_train_loss = 0.  # total train loss every log_freq batches
        self.log_train_weights = 0.
        self.num_batches_done = 0  # number of batches done for the whole training
        self.epoch_batches_done = 0  # number of batches done for this epoch
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid
コード例 #10
0
    def __init__(self, args):
        super(Extractor, self).__init__()
        config = getattr(configurations, args.proto)()
        self.logger = ut.get_logger(config['log_file'])
        self.model_file = args.model_file

        var_list = args.var_list
        save_to = args.save_to

        if var_list is None:
            raise ValueError('Empty var list')

        if self.model_file is None or not os.path.exists(self.model_file + '.meta'):
            raise ValueError('Input file or model file does not exist')

        if not os.path.exists(save_to):
            os.makedirs(save_to)

        self.logger.info('Extracting these vars: {}'.format(', '.join(var_list)))

        with tf.Graph().as_default(), tf.Session() as sess:
            d = config['init_range']
            initializer = tf.random_uniform_initializer(-d, d)
            with tf.variable_scope(config['model_name'], reuse=False, initializer=initializer):
                model = Model(config, ac.TRAINING)

            saver = tf.train.Saver(var_list=tf.trainable_variables())
            saver.restore(sess, self.model_file)

            var_values = operator.attrgetter(*var_list)(model)
            var_values = sess.run(var_values)

            if len(var_list) == 1:
                var_values = [var_values]
                
            for var, var_value in izip(var_list, var_values):
                var_path = os.path.join(save_to, var + '.npy')
                numpy.save(var_path, var_value)
コード例 #11
0
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = configurations.get_config(
            args.proto, getattr(configurations, args.proto),
            args.config_overrides)
        self.logger = ut.get_logger(self.config['log_file'])
        self.num_preload = args.num_preload

        self.model_file = args.model_file
        if self.model_file is None:
            self.model_file = os.path.join(self.config['save_to'],
                                           self.config['model_name'] + '.pth')

        self.input_file = args.input_file
        if self.input_file is not None and not os.path.exists(self.input_file):
            raise FileNotFoundError(
                f'Input file does not exist: {self.input_file}')
        if not os.path.exists(self.model_file):
            raise FileNotFoundError(
                f'Model file does not exist: {self.model_file}')

        self.logger.info(f'Restore model from {self.model_file}')
        self.model = Model(self.config,
                           load_from=self.model_file).to(ut.get_device())

        if self.input_file:
            save_fp = os.path.join(self.config['save_to'],
                                   os.path.basename(self.input_file))
            save_fp = save_fp.rstrip(self.model.data_manager.src_lang)
            save_fp = save_fp + self.model.data_manager.trg_lang
            self.best_output_fp = save_fp + '.best_trans'
            self.beam_output_fp = save_fp + '.beam_trans'
            open(self.best_output_fp, 'w').close()
            open(self.beam_output_fp, 'w').close()
        else:
            self.best_output_fp = self.beam_output_fp = None

        self.translate()
コード例 #12
0
    def __init__(self, config, data_manager):
        super(Validator, self).__init__()
        self.logger = ut.get_logger(config['log_file'])
        self.logger.info('Initializing validator')

        self.data_manager = data_manager
        self.beam_alpha = config['beam_alpha']
        self.checkpoints = config['checkpoints']

        self.get_cpkt_path = lambda checkpoint, score: join(
            config['save_to'], '{}-{}-{}.path'.format(config['model_name'],
                                                      checkpoint, score))

        self.bleu_script = './scripts/multi-bleu.perl'
        assert exists(self.bleu_script)

        self.save_to = config['save_to']
        if not exists(self.save_to):
            os.makedirs(self.save_to)

        self.perp_curve = {}
        self.perp_curve_path = {}
        self.best_perps = {}
        self.best_perps_path = {}
        for checkpoint in self.checkpoints:
            self.perp_curve_path[checkpoint] = join(
                self.save_to, '{}_dev_perps.npy'.format(checkpoint))
            self.best_perps_path[checkpoint] = join(
                self.save_to, '{}_best_perp_scores.npy'.format(checkpoint))
            self.perp_curve[checkpoint] = numpy.array([], dtype=numpy.float32)
            self.best_perps[checkpoint] = numpy.array([], dtype=numpy.float32)
            if exists(self.perp_curve_path[checkpoint]):
                self.perp_curve[checkpoint] = numpy.load(
                    self.perp_curve_path[checkpoint])
            if exists(self.best_perps_path[checkpoint]):
                self.best_perps[checkpoint] = numpy.load(
                    self.best_perps_path[checkpoint])
コード例 #13
0
    def __init__(self, config, model):
        super(Validator, self).__init__()
        self.logger = ut.get_logger(config['log_file'])
        self.logger.info('Initializing validator')

        self.model = model
        self.model_name = config['model_name']
        self.restore_segments = config['restore_segments']
        self.val_by_bleu = config['val_by_bleu']
        self.save_to = config['save_to']
        self.grad_clamp = bool(config['grad_clamp'])

        self.get_cpkt_path = lambda score: os.path.join(
            self.save_to, f'{self.model_name}-{score:.2f}.pth')
        self.n_best = config['n_best']

        scriptdir = os.path.dirname(os.path.abspath(__file__))
        self.bleu_script = config['bleu_script']
        if not os.path.exists(self.bleu_script):
            raise FileNotFoundError(self.bleu_script)

        if not os.path.exists(self.save_to):
            os.makedirs(self.save_to)

        self.val_trans_out = os.path.join(self.save_to, 'val_trans.txt')
        self.val_beam_out = os.path.join(self.save_to, 'val_beam_trans.txt')

        self.write_val_trans = config['write_val_trans']

        # I'll leave test alone for now since this version of the code doesn't automatically
        # report BLEU on test anw. The reason is it's up to the dataset to use multi-bleu
        # or NIST bleu. I'll include it in the future
        self.dev_ref = self.model.data_manager.data_files[ac.VALIDATING][
            self.model.data_manager.trg_lang]
        if self.restore_segments:
            self.dev_ref = self.remove_bpe(
                self.dev_ref,
                outfile=os.path.join(
                    self.save_to,
                    f'dev.{self.model.data_manager.trg_lang}.nobpe'))

        self.perp_curve_path = os.path.join(self.save_to, 'dev_perps.npy')
        self.best_perps_path = os.path.join(self.save_to,
                                            'best_perp_scores.npy')
        self.perp_curve = numpy.array([], dtype=numpy.float32)
        self.best_perps = numpy.array([], dtype=numpy.float32)
        if os.path.exists(self.perp_curve_path):
            self.perp_curve = numpy.load(self.perp_curve_path)
        if os.path.exists(self.best_perps_path):
            self.best_perps = numpy.load(self.best_perps_path)

        if self.val_by_bleu:
            self.bleu_curve_path = os.path.join(self.save_to,
                                                'bleu_scores.npy')
            self.best_bleus_path = os.path.join(self.save_to,
                                                'best_bleu_scores.npy')
            self.bleu_curve = numpy.array([], dtype=numpy.float32)
            self.best_bleus = numpy.array([], dtype=numpy.float32)
            if os.path.exists(self.bleu_curve_path):
                self.bleu_curve = numpy.load(self.bleu_curve_path)
            if os.path.exists(self.best_bleus_path):
                self.best_bleus = numpy.load(self.best_bleus_path)
コード例 #14
0
    def __init__(self, config, mode):
        super(Model, self).__init__()
        self.logger = ut.get_logger(config['log_file'])

        ENC_SCOPE = 'encoder'
        DEC_SCOPE = 'decoder'
        ATT_SCOPE = 'attention'
        OUT_SCOPE = 'outputer'
        SFM_SCOPE = 'softmax'

        batch_size = config['batch_size']
        feed_input = config['feed_input']
        grad_clip = config['grad_clip']
        beam_size = config['beam_size']
        beam_alpha = config['beam_alpha']
        num_layers = config['num_layers']
        rnn_type = config['rnn_type']
        score_func_type = config['score_func_type']

        src_vocab_size = config['src_vocab_size']
        trg_vocab_size = config['trg_vocab_size']

        src_embed_size = config['src_embed_size']
        trg_embed_size = config['trg_embed_size']

        enc_rnn_size = config['enc_rnn_size']
        dec_rnn_size = config['dec_rnn_size']

        input_keep_prob = config['input_keep_prob']
        output_keep_prob = config['output_keep_prob']

        attention_maps = {
            ac.SCORE_FUNC_DOT: Attention.DOT,
            ac.SCORE_FUNC_GEN: Attention.GEN,
            ac.SCORE_FUNC_BAH: Attention.BAH
        }
        score_func_type = attention_maps[score_func_type]

        if mode != ac.TRAINING:
            batch_size = 1
            input_keep_prob = 1.0
            output_keep_prob = 1.0

        # Placeholder
        self.src_inputs = tf.placeholder(tf.int32, [batch_size, None])
        self.src_seq_lengths = tf.placeholder(tf.int32, [batch_size])
        self.trg_inputs = tf.placeholder(tf.int32, [batch_size, None])
        self.trg_targets = tf.placeholder(tf.int32, [batch_size, None])
        self.target_weights = tf.placeholder(tf.float32, [batch_size, None])

        # First, define the src/trg embeddings
        with tf.variable_scope(ENC_SCOPE):
            self.src_embedding = tf.get_variable(
                'embedding',
                shape=[src_vocab_size, src_embed_size],
                dtype=tf.float32)
        with tf.variable_scope(DEC_SCOPE):
            self.trg_embedding = tf.get_variable(
                'embedding',
                shape=[trg_vocab_size, trg_embed_size],
                dtype=tf.float32)

        # Then select the RNN cell, reuse if not in TRAINING mode
        if rnn_type != ac.LSTM:
            raise NotImplementedError

        reuse = mode != ac.TRAINING  # if dev/test, reuse cell
        encoder_cell = ut.get_lstm_cell(ENC_SCOPE,
                                        num_layers,
                                        enc_rnn_size,
                                        output_keep_prob=output_keep_prob,
                                        seed=ac.SEED,
                                        reuse=reuse)

        att_state_size = dec_rnn_size
        decoder_cell = ut.get_lstm_cell(DEC_SCOPE,
                                        num_layers,
                                        dec_rnn_size,
                                        output_keep_prob=output_keep_prob,
                                        seed=ac.SEED,
                                        reuse=reuse)

        # The model
        encoder = Encoder(encoder_cell, ENC_SCOPE)
        decoder = Encoder(decoder_cell, DEC_SCOPE)
        outputer = FeedForward(enc_rnn_size + dec_rnn_size,
                               att_state_size,
                               OUT_SCOPE,
                               activate_func=tf.tanh)
        self.softmax = softmax = Softmax(att_state_size, trg_vocab_size,
                                         SFM_SCOPE)

        # Encode source sentence
        encoder_inputs = tf.nn.embedding_lookup(self.src_embedding,
                                                self.src_inputs)
        encoder_inputs = tf.nn.dropout(encoder_inputs,
                                       input_keep_prob,
                                       seed=ac.SEED)
        encoder_outputs, last_state = encoder.encode(
            encoder_inputs,
            sequence_length=self.src_seq_lengths,
            initial_state=None)
        # Define an attention layer over encoder outputs
        attention = Attention(ATT_SCOPE,
                              score_func_type,
                              encoder_outputs,
                              enc_rnn_size,
                              dec_rnn_size,
                              common_dim=enc_rnn_size
                              if score_func_type == Attention.BAH else None)

        # This function takes an decoder's output, make it attend to encoder's outputs and
        # spit out the attentional state which is used for predicting next target word
        def decoder_output_func(h_t):
            alignments, c_t = attention.calc_context(self.src_seq_lengths, h_t)
            c_t_h_t = tf.concat([c_t, h_t], 1)
            output = outputer.transform(c_t_h_t)
            return output, alignments

        # Fit everything in the decoder & start decoding
        decoder_inputs = tf.nn.embedding_lookup(self.trg_embedding,
                                                self.trg_inputs)
        decoder_inputs = tf.nn.dropout(decoder_inputs,
                                       input_keep_prob,
                                       seed=ac.SEED)
        attentional_outputs = decoder.decode(decoder_inputs,
                                             decoder_output_func,
                                             att_state_size,
                                             feed_input=feed_input,
                                             initial_state=last_state,
                                             reuse=False)
        attentional_outputs = tf.reshape(attentional_outputs,
                                         [-1, att_state_size])

        # Loss
        logits = softmax.calc_logits(attentional_outputs)
        logits = tf.reshape(logits, [batch_size, -1, trg_vocab_size])
        loss = sequence_loss(logits,
                             self.trg_targets,
                             self.target_weights,
                             average_across_timesteps=False,
                             average_across_batch=False)

        if mode != ac.TRAINING:
            self.loss = tf.stop_gradient(tf.reduce_sum(loss))

            max_output_length = 3 * self.src_seq_lengths[0]
            tensor_to_state = partial(ut.tensor_to_lstm_state,
                                      num_layers=config['num_layers'])
            beam_outputs = decoder.beam_decode(self.trg_embedding,
                                               ac.BOS_ID,
                                               ac.EOS_ID,
                                               decoder_output_func,
                                               att_state_size,
                                               softmax.calc_logprobs,
                                               trg_vocab_size,
                                               max_output_length,
                                               tensor_to_state,
                                               alpha=beam_alpha,
                                               beam_size=beam_size,
                                               feed_input=feed_input,
                                               initial_state=last_state,
                                               reuse=True)
            self.probs, self.scores, self.symbols, self.parents, self.alignments = beam_outputs

        # If in training, do the grad backpropagate
        if mode == ac.TRAINING:
            self.loss = tf.reduce_sum(loss)

            # Option to fix some variables
            fixed_vars = config['fixed_var_list'] if config[
                'fixed_var_list'] else []

            if fixed_vars:
                fixed_vars = operator.attrgetter(*fixed_vars)(self)
                if isinstance(fixed_vars, list):
                    fixed_var_names = [
                        _fixed_var.name for _fixed_var in fixed_vars
                    ]
                else:
                    fixed_var_names = [fixed_vars.name]
            else:
                fixed_var_names = []

            tvars = tf.trainable_variables()
            tvars = [
                _var for _var in tvars if _var.name not in fixed_var_names
            ]

            grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars),
                                              grad_clip)
            self.lr = tf.Variable(1.0, trainable=False, name='lr')
            if config['optimizer'] == ac.ADADELTA:
                optimizer = tf.train.AdadeltaOptimizer(learning_rate=self.lr,
                                                       rho=0.95,
                                                       epsilon=1e-6)
            else:
                optimizer = tf.train.GradientDescentOptimizer(self.lr)

            self.train_op = optimizer.apply_gradients(zip(grads, tvars))

        # Finally, log out some model's stats
        if mode == ac.TRAINING:

            def num_params(var):
                shape = var.get_shape().as_list()
                var_count = 1
                for dim in shape:
                    var_count = var_count * dim

                return var_count

            self.logger.info('{} model:'.format('train' if mode ==
                                                ac.TRAINING else 'dev/test'))
            self.logger.info('Num trainable variables {}'.format(len(tvars)))
            self.logger.info('Num params: {:,}'.format(
                sum([num_params(v) for v in tvars])))
            self.logger.info('List of all trainable parameters:')
            for v in tvars:
                self.logger.info('   {}'.format(v.name))
            self.logger.info('List of all fixed parameters')
            for v in fixed_var_names:
                self.logger.info('   {}'.format(v))
コード例 #15
0
ファイル: data_manager.py プロジェクト: tnq177/witwicky
    def __init__(self, config):
        super(DataManager, self).__init__()
        self.logger = ut.get_logger(config['log_file'])

        self.src_lang = config['src_lang']
        self.trg_lang = config['trg_lang']
        self.data_dir = config['data_dir']
        self.batch_size = config['batch_size']
        self.beam_size = config['beam_size']
        self.one_embedding = config['tie_mode'] == ac.ALL_TIED
        self.share_vocab = config['share_vocab']
        self.word_dropout = config['word_dropout']
        self.batch_sort_src = config['batch_sort_src']
        self.max_train_length = config['max_train_length']

        self.vocab_sizes = {
            self.src_lang: config['src_vocab_size'],
            self.trg_lang: config['trg_vocab_size'],
            'joint': config['joint_vocab_size']
        }

        self.data_files = {
            ac.TRAINING: {
                self.src_lang:
                join(self.data_dir, 'train.{}'.format(self.src_lang)),
                self.trg_lang:
                join(self.data_dir, 'train.{}'.format(self.trg_lang))
            },
            ac.VALIDATING: {
                self.src_lang: join(self.data_dir,
                                    'dev.{}'.format(self.src_lang)),
                self.trg_lang: join(self.data_dir,
                                    'dev.{}'.format(self.trg_lang))
            },
            ac.TESTING: {
                self.src_lang: join(self.data_dir,
                                    'test.{}'.format(self.src_lang)),
                self.trg_lang: join(self.data_dir,
                                    'test.{}'.format(self.trg_lang))
            }
        }
        self.tok_count_files = {
            ac.TRAINING: join(self.data_dir, 'train.count'),
            ac.VALIDATING: join(self.data_dir, 'dev.count'),
            ac.TESTING: join(self.data_dir, 'test.count')
        }
        self.ids_files = {
            ac.TRAINING: join(self.data_dir, 'train.ids'),
            ac.VALIDATING: join(self.data_dir, 'dev.ids'),
            ac.TESTING: join(self.data_dir, 'test.ids')
        }
        self.vocab_files = {
            self.src_lang:
            join(
                self.data_dir,
                'vocab-{}.{}'.format(self.vocab_sizes[self.src_lang],
                                     self.src_lang)),
            self.trg_lang:
            join(
                self.data_dir,
                'vocab-{}.{}'.format(self.vocab_sizes[self.trg_lang],
                                     self.trg_lang))
        }

        self.setup()
コード例 #16
0
    def __init__(self, config):
        super(DataManager, self).__init__()
        self.logger = ut.get_logger(config['log_file'])

        self.src_lang = config['src_lang']
        self.trg_lang = config['trg_lang']
        self.data_dir = config['data_dir']
        self.batch_size = config['batch_size']
        self.reverse = config['reverse']

        self.vocab_sizes = {
            self.src_lang: config['src_vocab_size'],
            self.trg_lang: config['trg_vocab_size']
        }

        self.max_src_length = config['max_src_length']
        self.max_trg_length = config['max_trg_length']

        self.data_files = {
            ac.TRAINING: {
                self.src_lang:
                join(self.data_dir, 'train.{}'.format(self.src_lang)),
                self.trg_lang:
                join(self.data_dir, 'train.{}'.format(self.trg_lang))
            },
            ac.VALIDATING: {
                self.src_lang: join(self.data_dir,
                                    'dev.{}'.format(self.src_lang)),
                self.trg_lang: join(self.data_dir,
                                    'dev.{}'.format(self.trg_lang))
            },
            ac.TESTING: {
                self.src_lang: join(self.data_dir,
                                    'test.{}'.format(self.src_lang)),
                self.trg_lang: join(self.data_dir,
                                    'test.{}'.format(self.trg_lang))
            }
        }
        self.length_files = {
            ac.TRAINING: join(self.data_dir, 'train.length'),
            ac.VALIDATING: join(self.data_dir, 'dev.length'),
            ac.TESTING: join(self.data_dir, 'test.length')
        }
        self.clean_files = {
            self.src_lang:
            join(
                self.data_dir,
                'train.{}.clean-{}'.format(self.src_lang,
                                           self.max_src_length)),
            self.trg_lang:
            join(
                self.data_dir,
                'train.{}.clean-{}'.format(self.trg_lang, self.max_trg_length))
        }
        self.ids_files = {
            ac.TRAINING: join(self.data_dir, 'train.ids'),
            ac.VALIDATING: join(self.data_dir, 'dev.ids'),
            ac.TESTING: join(self.data_dir, 'test.ids')
        }
        self.vocab_files = {
            self.src_lang:
            join(
                self.data_dir,
                'vocab-{}.{}'.format(self.vocab_sizes[self.src_lang],
                                     self.src_lang)),
            self.trg_lang:
            join(
                self.data_dir,
                'vocab-{}.{}'.format(self.vocab_sizes[self.trg_lang],
                                     self.trg_lang))
        }

        self.setup()
コード例 #17
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = configurations.get_config(
            args.proto, getattr(configurations, args.proto),
            args.config_overrides)
        self.num_preload = args.num_preload
        self.lr = self.config['lr']

        ut.remove_files_in_dir(self.config['save_to'])

        self.logger = ut.get_logger(self.config['log_file'])

        self.train_smooth_perps = []
        self.train_true_perps = []

        # For logging
        self.log_freq = self.config[
            'log_freq']  # log train stat every this-many batches
        self.log_train_loss = []
        self.log_nll_loss = []
        self.log_train_weights = []
        self.log_grad_norms = []
        self.total_batches = 0  # number of batches done for the whole training
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        device = ut.get_device()
        self.model = Model(self.config).to(device)
        self.validator = Validator(self.config, self.model)

        self.validate_freq = self.config['validate_freq']
        if self.validate_freq == 1:
            self.logger.info('Evaluate every ' + (
                'epoch' if self.config['val_per_epoch'] else 'batch'))
        else:
            self.logger.info(f'Evaluate every {self.validate_freq:,} ' + (
                'epochs' if self.config['val_per_epoch'] else 'batches'))

        # Estimated number of batches per epoch
        self.est_batches = max(self.model.data_manager.training_tok_counts
                               ) // self.config['batch_size']
        self.logger.info(
            f'Guessing around {self.est_batches:,} batches per epoch')

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info(f'Model has {int(param_count):,} parameters')

        # Set up parameter-specific options
        params = []
        for p in self.model.parameters():
            ptr = p.data_ptr()
            d = {'params': [p]}
            if ptr in self.model.parameter_attrs:
                attrs = self.model.parameter_attrs[ptr]
                for k in attrs:
                    d[k] = attrs[k]
            params.append(d)

        self.optimizer = torch.optim.Adam(params,
                                          lr=self.lr,
                                          betas=(self.config['beta1'],
                                                 self.config['beta2']),
                                          eps=self.config['epsilon'])
コード例 #18
0
    def __init__(self, config):
        super(DataManager, self).__init__()
        self.logger = ut.get_logger(config['log_file'])

        self.src_lang = config['src_lang']
        self.trg_lang = config['trg_lang']
        self.data_dir = config['data_dir']
        self.batch_size = config['batch_size']
        self.beam_size = config['beam_size']
        self.one_embedding = config['tie_mode'] == ac.ALL_TIED
        self.share_vocab = config['share_vocab']
        self.max_length = config['max_length']
        self.length_ratio = config['length_ratio']
        self.checkpoints = config['checkpoints']
        self.clean_corpus_script = './scripts/clean-corpus-n.perl'

        self.vocab_sizes = {
            self.src_lang: config['src_vocab_size'],
            self.trg_lang: config['trg_vocab_size'],
            'joint': config['joint_vocab_size']
        }

        self.data_files = {
            'org': {
                self.src_lang:
                join(self.data_dir, 'train.{}'.format(self.src_lang)),
                self.trg_lang:
                join(self.data_dir, 'train.{}'.format(self.trg_lang))
            },
            'clean': {
                self.src_lang:
                join(
                    self.data_dir,
                    'train.clean-{}-{}.{}'.format(self.max_length,
                                                  self.length_ratio,
                                                  self.src_lang)),
                self.trg_lang:
                join(
                    self.data_dir,
                    'train.clean-{}-{}.{}'.format(self.max_length,
                                                  self.length_ratio,
                                                  self.trg_lang))
            },
            'ids': join(self.data_dir, 'train.ids')
        }
        self.dev_files = {}
        self.test_files = {}
        for checkpoint in self.checkpoints:
            self.dev_files[checkpoint] = {
                self.src_lang:
                join(self.data_dir, 'dev.{}.{}'.format(checkpoint,
                                                       self.src_lang)),
                self.trg_lang:
                join(self.data_dir, 'dev.{}.{}'.format(checkpoint,
                                                       self.trg_lang)),
                'ids':
                join(self.data_dir, 'dev.{}.ids'.format(checkpoint))
            }
            self.test_files[checkpoint] = {
                self.src_lang:
                join(self.data_dir,
                     'test.{}.{}'.format(checkpoint, self.src_lang)),
                self.trg_lang:
                join(self.data_dir,
                     'test.{}.{}'.format(checkpoint, self.trg_lang))
            }

        self.vocab_files = {
            self.src_lang:
            join(
                self.data_dir,
                'vocab-{}.{}'.format(self.vocab_sizes[self.src_lang],
                                     self.src_lang)),
            self.trg_lang:
            join(
                self.data_dir,
                'vocab-{}.{}'.format(self.vocab_sizes[self.trg_lang],
                                     self.trg_lang))
        }

        self.setup()