コード例 #1
0
    def setUp(self):
        self.param_path = "/media/zzhuang/00091EA2000FB1D0/iGit/git_projects/libnlp/libNlp/config/newdefault.json"
        paramController = Params(self.param_path)
        self.args = paramController.args
        self.processor = RawDataProcessor.from_params(self.args.pipeline.data)
        self.processor.load_data(self.args.files.train_file)

        self.featureDict = FeatureDict(
            self.processor.dataset,
            use_qemb=self.args.pipeline.data.params.use_qemb,
            use_in_question=self.args.pipeline.data.params.use_in_question,
            use_pos=self.args.pipeline.data.params.use_pos,
            use_ner=self.args.pipeline.data.params.use_ner,
            use_lemma=self.args.pipeline.data.params.use_lemma,
            use_tf=self.args.pipeline.data.params.use_tf,
        )
        self.wordDict = WordDict(
            self.processor.dataset,
            embedding_file=self.args.files.embedding_file,
            restrict_vocab=self.args.pipeline.data.dataProcessor.restrict_vocab
        )

        self.reader = Reader(
            self.wordDict,
            self.featureDict,
            self.args.pipeline.reader.optimizer,
            self.args.pipeline.reader.model,
            fix_embeddings=self.args.pipeline.reader.fix_embeddings
        )
コード例 #2
0
    def setUp(self):
        self.param_path = "/media/zzhuang/00091EA2000FB1D0/iGit/git_projects/libnlp/libNlp/config/newdefault.json"
        paramController = Params(self.param_path)
        self.args = paramController.args
        self.processor = RawDataProcessor.from_params(self.args.pipeline.data)
        self.processor.load_data(self.args.files.train_file)

        self.featureDict = FeatureDict(
            self.processor.dataset,
            use_qemb=self.args.pipeline.data.params.use_qemb,
            use_in_question=self.args.pipeline.data.params.use_in_question,
            use_pos=self.args.pipeline.data.params.use_pos,
            use_ner=self.args.pipeline.data.params.use_ner,
            use_lemma=self.args.pipeline.data.params.use_lemma,
            use_tf=self.args.pipeline.data.params.use_tf,
        )
コード例 #3
0
 def from_params(cls, params: Params) -> 'RawDataProcessor':
     iterator_type = params.pop_choice("type", cls.list_available())
     return cls.by_name(iterator_type).from_params(params)
コード例 #4
0
 def setUp(self):
     self.param_path = "/media/zzhuang/00091EA2000FB1D0/iGit/git_projects/libnlp/libNlp/config/newdefault.json"
     paramController = Params(self.param_path)
     self.args = paramController.args
     self.processor = RawDataProcessor.from_params(self.args.pipeline.data)
コード例 #5
0
ファイル: train_noLog.py プロジェクト: farrellsc/LibrarianNLP
    reader.set_model()
    if args.files.embedding_file:
        reader.load_embeddings(wordDict.tokens(), args.files.embedding_file)
    reader.init_optimizer()

    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    start_epoch = 0
    for epoch in range(start_epoch, args.runtime.num_epochs):
        stats['epoch'] = epoch
        train(args, trainProcessor, reader)
        utils.validate_unofficial(args,
                                  trainProcessor,
                                  reader,
                                  stats,
                                  mode='train')
        result = utils.validate_unofficial(args,
                                           devProcessor,
                                           reader,
                                           stats,
                                           mode='dev')
        if result[args.runtime.valid_metric] > stats['best_valid']:
            reader.save(args.files.model_file, epoch)
            stats['best_valid'] = result[args.runtime.valid_metric]


if __name__ == '__main__':
    paramController = Params(sys.argv[1])
    np.random.seed(paramController.args.runtime.random_seed)
    torch.manual_seed(paramController.args.runtime.random_seed)
    main(paramController.args)
コード例 #6
0
ファイル: Param_test.py プロジェクト: farrellsc/LibrarianNLP
 def test_get(self):
     param_path = "/media/zzhuang/00091EA2000FB1D0/iGit/git_projects/libnlp/libNlp/config/newdefault.json"
     paramController = Params(param_path)
     print(paramController.args)
コード例 #7
0
 def from_params(cls, params: Params) -> 'LibDataProcessor':
     batch_size = params.pop('batch_size')
     data_workers = params.pop('data_workers')
     data_args = params
     params.assert_empty(cls.__name__)
     return cls(data_args, batch_size, data_workers)