def test(self):
        data_loader = GigaWorldDataLoader(
            FileUtil.get_file_path(conf('train:article-file')),
            FileUtil.get_file_path(conf('train:summary-file')), 2)

        vocab = SimpleVocab(FileUtil.get_file_path(conf('vocab-file')),
                            conf('vocab-size'))

        seq2seq = cuda(Seq2Seq(vocab))

        checkpoint = t.load(FileUtil.get_file_path(conf('model-file')))

        seq2seq.load_state_dict(checkpoint['model_state_dict'])

        seq2seq.eval()

        samples = data_loader.read_all()

        article, reference = samples[3]

        summary, attention = seq2seq.evaluate(article)

        score = self.get_score(summary, reference)

        print('>>> article: ', article)
        print('>>> reference: ', reference)
        print('>>> prediction: ', summary)
        print('>>> score: ', score)
    def save_model(self, param):
        model_file = conf('train:save-model-file')
        if not model_file:
            return

        model_file = FileUtil.get_file_path(model_file)

        file_dir, _ = os.path.split(model_file)
        if not os.path.exists(file_dir):
            os.makedirs(file_dir)

        dot = model_file.rfind('.')
        if dot != -1:
            model_file = model_file[:dot] + '-' + str(param['epoch'] + 1) + model_file[dot:]
        else:
            model_file = model_file + '-' + str(param['epoch'] + 1)

        self.logger.debug('>>> save model into: ' + model_file)

        self.logger.debug('epoch: %s', str(param['epoch'] + 1))
        self.logger.debug('loss: %s', str(param['loss']))

        t.save({
            'epoch': param['epoch'] + 1,
            'loss': param['loss'],
            'model_state_dict': self.seq2seq.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, FileUtil.get_file_path(model_file))
Example #3
0
    def __init__(self):
        self.logger = logger(self)

        self.max_enc_steps = conf('max-enc-steps')
        self.max_dec_steps = conf('max-dec-steps')

        self.batch_size = conf('eval:batch-size')
        self.log_batch = conf('eval:log-batch')
        self.log_batch_interval = conf('eval:log-batch-interval', -1)

        self.pointer_generator = conf('pointer-generator')

        self.vocab = SimpleVocab(FileUtil.get_file_path(conf('vocab-file')),
                                 conf('vocab-size'))

        self.seq2seq = cuda(Seq2Seq(self.vocab))

        self.batch_initializer = BatchInitializer(self.vocab,
                                                  self.max_enc_steps,
                                                  self.max_dec_steps,
                                                  self.pointer_generator)

        self.data_loader = GigaWorldDataLoader(
            FileUtil.get_file_path(conf('eval:article-file')),
            FileUtil.get_file_path(conf('eval:summary-file')), self.batch_size)
Example #4
0
    def test(self):
        dataloader = GigaWorldMemoryDataLoader(
            FileUtil.get_file_path(conf('train:article-file')),
            FileUtil.get_file_path(conf('train:summary-file')), 15)

        for i in range(dataloader.get_num_batch()):

            batch = dataloader.get_batch(i)

            print(batch)
    def test(self):
        vocab = SimpleVocab(FileUtil.get_file_path(conf('vocab-file')),
                            conf('vocab-size'))

        embedding = GloveEmbedding(FileUtil.get_file_path(conf('emb-file')),
                                   vocab)

        id = t.tensor([0])

        emb = embedding(id)

        print(emb)
    def __init__(self):
        self.logger                     = logger(self)

        self.enc_hidden_size            = conf('enc-hidden-size')
        self.dec_hidden_size            = conf('dec-hidden-size')

        self.max_enc_steps              = conf('max-enc-steps')
        self.max_dec_steps              = conf('max-dec-steps')

        self.epoch                      = conf('train:epoch')
        self.batch_size                 = conf('train:batch-size')
        self.clip_gradient_max_norm     = conf('train:clip-gradient-max-norm')
        self.log_batch                  = conf('train:log-batch')
        self.log_batch_interval         = conf('train:log-batch-interval', -1)
        self.lr                         = conf('train:lr')
        self.lr_decay_epoch             = conf('train:lr-decay-epoch')
        self.lr_decay                   = conf('train:lr-decay')

        self.ml_enable                  = conf('train:ml:enable', True)
        self.ml_forcing_ratio           = conf('train:ml:forcing-ratio', 1)
        self.ml_forcing_decay           = conf('train:ml:forcing-decay', 0)

        self.rl_enable                  = conf('train:rl:enable')
        self.rl_weight                  = conf('train:rl:weight')
        self.rl_transit_epoch           = conf('train:rl:transit-epoch', -1)
        self.rl_transit_decay           = conf('train:rl:transit-decay', 0)

        self.save_model_per_epoch       = conf('train:save-model-per-epoch')
        self.pointer_generator          = conf('pointer-generator')

        # tensorboard
        self.tb_writer = None
        if conf('train:tb:enable') is True:
            tb_log_dir = conf('train:tb:log-dir')
            if tb_log_dir is not None:
                self.tb_writer = SummaryWriter(FileUtil.get_file_path(tb_log_dir))

        self.vocab = SimpleVocab(FileUtil.get_file_path(conf('vocab-file')), conf('vocab-size'))

        self.seq2seq = cuda(Seq2Seq(self.vocab))

        self.batch_initializer = BatchInitializer(self.vocab, self.max_enc_steps, self.max_dec_steps, self.pointer_generator)

        self.data_loader = GigaWorldDataLoader(FileUtil.get_file_path(conf('train:article-file')), FileUtil.get_file_path(conf('train:summary-file')), self.batch_size)

        self.optimizer = t.optim.Adam(self.seq2seq.parameters(), lr=self.lr)

        self.criterion = nn.NLLLoss(reduction='none', ignore_index=TK_PADDING['id'])
Example #7
0
    def __init__(self, log_file):
        with open(FileUtil.get_file_path(log_file), 'r') as f:
            param_matcher = re.compile(r'.*\$\{([^}^{]+)\}.*')

            def param_constructor(loader, node):
                value = node.value

                params = param_matcher.findall(value)
                for param in params:
                    try:
                        param_value = os.environ[param]
                        return value.replace('${' + param + '}', param_value)
                    except Exception:
                        pass

                return value

            class VariableLoader(yaml.SafeLoader):
                pass

            VariableLoader.add_implicit_resolver('!param', param_matcher, None)
            VariableLoader.add_constructor('!param', param_constructor)

            config = yaml.load(f.read(), Loader=VariableLoader)

            self.init_log_dir(config)

            logging.config.dictConfig(config)
    def test(self):
        dataloader = GigaWorldDataLoader(
            FileUtil.get_file_path(conf('train:article-file')),
            FileUtil.get_file_path(conf('train:summary-file')), 15)

        counter = 0
        while True:
            batch = dataloader.next_batch()
            if batch is None:
                break

            counter += len(batch)

            print(counter)

            print(batch)
    def load_model(self):
        epoch = 0

        model_file = conf('train:load-model-file')
        if model_file is None:
            return epoch

        model_file = FileUtil.get_file_path(model_file)

        if os.path.isfile(model_file):
            self.logger.debug('>>> load pre-trained model from: %s', model_file)

            checkpoint = t.load(model_file)

            epoch = checkpoint['epoch']
            loss = checkpoint['loss']

            self.seq2seq.load_state_dict(checkpoint['model_state_dict'])

            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            self.logger.debug('epoch: %s', str(epoch))
            self.logger.debug('loss: %s', str(loss))
        else:
            raise Exception('>>> error loading model - file not exist: %s' % model_file)

        return epoch
Example #10
0
    def test(self):
        vocab = SimpleVocab(FileUtil.get_file_path(conf('vocab-file')), conf('vocab-size'))

        print(vocab.word2id('australia'))

        words = 'doctors madrid tone thyda'.split(' ')

        ids = vocab.words2ids(words)

        print(ids)

        ids, oov = vocab.extend_words2ids(words)

        print(ids, oov)

        n_words = vocab.ids2words(ids)

        print(n_words)

        n_words = vocab.ids2words(ids, oov)

        print(n_words)
    def merge(self, conf_file):
        with open(FileUtil.get_file_path(conf_file), 'r') as file:
            cfg = yaml.load(file, Loader=Loader)

            DictUtil.dict_merge(self.cfg, cfg)
 def test(self):
     print(FileUtil.get_proj_dir())