예제 #1
0
def main(args):
    # --------------------------------------------------------------------------------------------
    # Data section
    # --------------------------------------------------------------------------------------------

    trainProcessor = RawDataProcessor.from_params(args.pipeline.data)
    devProcessor = RawDataProcessor.from_params(args.pipeline.data)
    trainProcessor.load_data(args.files.train_file)
    devProcessor.load_data(args.files.dev_file)

    featureDict = FeatureDict(
        trainProcessor.dataset,
        use_qemb=args.pipeline.data.params.use_qemb,
        use_in_question=args.pipeline.data.params.use_in_question,
        use_pos=args.pipeline.data.params.use_pos,
        use_ner=args.pipeline.data.params.use_ner,
        use_lemma=args.pipeline.data.params.use_lemma,
        use_tf=args.pipeline.data.params.use_tf,
    )
    wordDict = WordDict(
        trainProcessor.dataset + devProcessor.dataset,
        embedding_file=args.files.embedding_file,
        restrict_vocab=args.pipeline.data.dataProcessor.restrict_vocab)
    trainProcessor.set_utils(word_dict=wordDict, feature_dict=featureDict)
    devProcessor.set_utils(word_dict=wordDict, feature_dict=featureDict)

    # --------------------------------------------------------------------------------------------
    # Data section Above finished & tested
    # --------------------------------------------------------------------------------------------

    reader = Reader(wordDict,
                    featureDict,
                    args.pipeline.reader.optimizer,
                    args.pipeline.reader.model,
                    fix_embeddings=args.pipeline.reader.fix_embeddings)
    reader.set_model()
    if args.files.embedding_file:
        reader.load_embeddings(wordDict.tokens(), args.files.embedding_file)
    reader.init_optimizer()

    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    start_epoch = 0
    for epoch in range(start_epoch, args.runtime.num_epochs):
        stats['epoch'] = epoch
        train(args, trainProcessor, reader)
        utils.validate_unofficial(args,
                                  trainProcessor,
                                  reader,
                                  stats,
                                  mode='train')
        result = utils.validate_unofficial(args,
                                           devProcessor,
                                           reader,
                                           stats,
                                           mode='dev')
        if result[args.runtime.valid_metric] > stats['best_valid']:
            reader.save(args.files.model_file, epoch)
            stats['best_valid'] = result[args.runtime.valid_metric]
예제 #2
0
    def setUp(self):
        self.param_path = "/media/zzhuang/00091EA2000FB1D0/iGit/git_projects/libnlp/libNlp/config/newdefault.json"
        paramController = Params(self.param_path)
        self.args = paramController.args
        self.processor = RawDataProcessor.from_params(self.args.pipeline.data)
        self.processor.load_data(self.args.files.train_file)

        self.featureDict = FeatureDict(
            self.processor.dataset,
            use_qemb=self.args.pipeline.data.params.use_qemb,
            use_in_question=self.args.pipeline.data.params.use_in_question,
            use_pos=self.args.pipeline.data.params.use_pos,
            use_ner=self.args.pipeline.data.params.use_ner,
            use_lemma=self.args.pipeline.data.params.use_lemma,
            use_tf=self.args.pipeline.data.params.use_tf,
        )
        self.wordDict = WordDict(
            self.processor.dataset,
            embedding_file=self.args.files.embedding_file,
            restrict_vocab=self.args.pipeline.data.dataProcessor.restrict_vocab
        )

        self.reader = Reader(
            self.wordDict,
            self.featureDict,
            self.args.pipeline.reader.optimizer,
            self.args.pipeline.reader.model,
            fix_embeddings=self.args.pipeline.reader.fix_embeddings
        )
예제 #3
0
def main(args):
    trainProcessor = RawDataProcessor.from_params(args.pipeline.dataLoader)
    devProcessor = RawDataProcessor.from_params(args.pipeline.dataLoader)

    trainProcessor.load_data(args.files.train_file)
    devProcessor.load_data(args.files.dev_file)

    start_epoch = 0
    reader = Reader(args)
    reader.build_feature_dict(trainProcessor.dataset)
    reader.build_word_dict(args, trainProcessor.dataset + devProcessor.dataset)
    reader.set_model()

    if args.files.embedding_file:
        reader.load_embeddings(reader.word_dict.tokens(),
                               args.files.embedding_file)
    if args.pipeline.reader.tune_partial > 0:
        top_words = utils.top_question_words(args, trainProcessor.dataset,
                                             reader.word_dict)
        reader.tune_embeddings([w[0] for w in top_words])
    reader.init_optimizer()

    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    for epoch in range(start_epoch, args.runtime.num_epochs):
        stats['epoch'] = epoch
        train(args, trainProcessor, reader, stats)
        utils.validate_unofficial(args,
                                  trainProcessor,
                                  reader,
                                  stats,
                                  mode='train')
        result = utils.validate_unofficial(args,
                                           devProcessor,
                                           reader,
                                           stats,
                                           mode='dev')
        if result[args.runtime.valid_metric] > stats['best_valid']:
            reader.save(args.files.model_file)
            stats['best_valid'] = result[args.runtime.valid_metric]
예제 #4
0
    def setUp(self):
        self.param_path = "/media/zzhuang/00091EA2000FB1D0/iGit/git_projects/libnlp/libNlp/config/newdefault.json"
        paramController = Params(self.param_path)
        self.args = paramController.args
        self.processor = RawDataProcessor.from_params(self.args.pipeline.data)
        self.processor.load_data(self.args.files.train_file)

        self.featureDict = FeatureDict(
            self.processor.dataset,
            use_qemb=self.args.pipeline.data.params.use_qemb,
            use_in_question=self.args.pipeline.data.params.use_in_question,
            use_pos=self.args.pipeline.data.params.use_pos,
            use_ner=self.args.pipeline.data.params.use_ner,
            use_lemma=self.args.pipeline.data.params.use_lemma,
            use_tf=self.args.pipeline.data.params.use_tf,
        )
예제 #5
0
 def setUp(self):
     self.param_path = "/media/zzhuang/00091EA2000FB1D0/iGit/git_projects/libnlp/libNlp/config/newdefault.json"
     paramController = Params(self.param_path)
     self.args = paramController.args
     self.processor = RawDataProcessor.from_params(self.args.pipeline.data)
예제 #6
0
def main(args):
    # --------------------------------------------------------------------------
    # DATA LOADERS
    # Two datasets: train and dev. If we sort by length it's faster.
    # args.pipeline.dataLoader, train_exs, args.pipeline.reader, word_dict, feature_dict, args.pipeline.dataLoader.batch_size, args.runtime.data_workers
    logger.info('-' * 100)
    logger.info('Make data loaders')
    trainProcessor = RawDataProcessor.from_params(args.pipeline.dataLoader)
    devProcessor = RawDataProcessor.from_params(args.pipeline.dataLoader)

    # --------------------------------------------------------------------------
    # DATA
    logger.info('-' * 100)
    logger.info('Load data files')
    trainProcessor.load_data(args.files.train_file)
    logger.info('Num train examples = %d' % len(trainProcessor.dataset))
    devProcessor.load_data(args.files.dev_file)
    logger.info('Num dev examples = %d' % len(devProcessor.dataset))

    # --------------------------------------------------------------------------
    # READER
    # Initialize reader
    logger.info('-' * 100)
    start_epoch = 0
    reader = Reader(args.pipeline.reader)

    logger.info('-' * 100)
    logger.info('Generate features')
    reader.build_feature_dict(trainProcessor.dataset)
    logger.info('Num features = %d' % len(reader.feature_dict))
    logger.info(reader.feature_dict)

    # Build a dictionary from the data questions + words (train/dev splits)
    logger.info('-' * 100)
    logger.info('Build dictionary')
    reader.build_word_dict(args, trainProcessor.dataset + devProcessor.dataset)
    logger.info('Num words = %d' % len(reader.word_dict))

    reader.set_model()

    # Load pretrained embeddings for words in dictionary
    if args.files.embedding_file:
        reader.load_embeddings(reader.word_dict.tokens(),
                               args.files.embedding_file)

    # Set up partial tuning of embeddings
    if args.pipeline.reader.tune_partial > 0:
        logger.info('-' * 100)
        logger.info('Counting %d most frequent question words' %
                    args.pipeline.reader.tune_partial)
        top_words = utils.top_question_words(args, trainProcessor.dataset,
                                             reader.word_dict)
        for word in top_words[:5]:
            logger.info(word)
        logger.info('...')
        for word in top_words[-6:-1]:
            logger.info(word)
        reader.tune_embeddings([w[0] for w in top_words])

    # Set up optimizer
    reader.init_optimizer()

    # -------------------------------------------------------------------------
    # PRINT CONFIG
    logger.info('-' * 100)
    logger.info('CONFIG:\n%s' %
                json.dumps(vars(args), indent=4, sort_keys=True))

    # --------------------------------------------------------------------------
    # TRAIN/VALID LOOP
    logger.info('-' * 100)
    logger.info('Starting training...')
    stats = {'timer': utils.Timer(), 'epoch': 0, 'best_valid': 0}
    for epoch in range(start_epoch, args.runtime.num_epochs):
        stats['epoch'] = epoch

        # Train
        train(args, trainProcessor, reader, stats)

        # Validate unofficial (train)
        utils.validate_unofficial(args,
                                  trainProcessor,
                                  reader,
                                  stats,
                                  mode='train')

        # Validate unofficial (dev)
        result = utils.validate_unofficial(args,
                                           devProcessor,
                                           reader,
                                           stats,
                                           mode='dev')

        # Save best valid
        if result[args.runtime.valid_metric] > stats['best_valid']:
            logger.info(
                'Best valid: %s = %.2f (epoch %d, %d updates)' %
                (args.runtime.valid_metric, result[args.runtime.valid_metric],
                 stats['epoch'], reader.updateCount))
            reader.save(args.files.model_file)
            stats['best_valid'] = result[args.runtime.valid_metric]