Beispiel #1
0
class Demo(object):
    def __init__(self, config_filename):
        # configuration
        config = Config()
        config_file = "{}/{}".format(config.config_dir, config_filename)
        config.update_config(config_file)

        # word embedding
        print("setting word embedding...")
        word_embedding = Embedding()

        word_embedding_file = "{}/word_embedding_{}.pkl".format(
            config.cache_dir, config.config_name)
        print("loading word embedding from {}...".format(word_embedding_file))
        word_embedding.load_word_embedding(word_embedding_file)

        # demo dataset
        print("setting demo dataset...")
        self.demo_dataset = Dataset(config.data_config)

        self.demo_dataset.set_word_to_index(word_embedding.word2index)

        label_mapping_file = "{}/label_mapping_{}.pkl".format(
            config.cache_dir, config.config_name)
        print("loading label mapping from {}...".format(label_mapping_file))
        self.demo_dataset.load_label_mapping(label_mapping_file)

        # model
        new_model_config = {
            "vocab_size": word_embedding.vocab_size,
            "word_dim": word_embedding.word_dim,
            "document_length": self.demo_dataset.document_length,
            "sentence_length": self.demo_dataset.sentence_length,
            "num_labels": self.demo_dataset.num_labels
        }
        config.update_model_config(new_model_config)

        model = Model(config.model_config)

        # model factory
        self.network = Factory(model)

        self.network.set_test_module()
        print("number of GPUs: {}".format(self.network.num_gpus))
        print("device: {}".format(self.network.device))

        # load model
        model_file = "{}/model_{}.pkl".format(config.cache_dir,
                                              config.config_name)
        print("loading model from {}...".format(model_file))
        self.network.load_model(model_file)

        self.network.model_to_device()
        self.network.eval_mode()

    def predict(self, data_list):
        """
        data_list  : [{"title": str, "content": str}]
        result_list: [{"strategy_ids": [str]}]
        """
        self.demo_dataset.load_data_from_list(data_list)
        self.demo_dataset.process_data_from_list()

        demo_data_loader = DataLoader(self.demo_dataset,
                                      batch_size=50,
                                      shuffle=False)

        demo_preds = np.zeros([0, self.demo_dataset.num_labels], dtype=np.int)
        for batch, data in enumerate(demo_data_loader):
            sequences_ttl, sequences_cnt, labels = data
            preds = self.network.test(sequences_ttl, sequences_cnt)
            demo_preds = np.concatenate((demo_preds, preds), axis=0)

        result_list = []
        for index in range(self.demo_dataset.num_samples):
            strategy_ids = []
            for label in range(self.demo_dataset.num_labels):
                if demo_preds[index, label] == 1:
                    strategy_ids.append(
                        self.demo_dataset.label2strategy[label])
            result = {"strategy_ids": strategy_ids}
            result_list.append(result)
        return result_list
Beispiel #2
0
def test(config_filename):
    # configuration
    config = Config()
    config_file = "{}/{}".format(config.config_dir, config_filename)
    config.update_config(config_file)

    # logger
    log_file = "{}/test_{}.txt".format(config.log_dir, config.config_name)
    logger = Logger(log_file)

    # word embedding
    logger.info("setting word embedding...")
    word_embedding = Embedding()

    word_embedding_file = "{}/word_embedding_{}.pkl".format(
        config.cache_dir, config.config_name)
    logger.info(
        "loading word embedding from {}...".format(word_embedding_file))
    word_embedding.load_word_embedding(word_embedding_file)

    logger.info("vocab_size: {}".format(word_embedding.vocab_size))
    logger.info("word_dim  : {}".format(word_embedding.word_dim))

    # testing dataset
    logger.info("setting testing dataset...")
    test_dataset = Dataset(config.data_config)

    test_dataset.set_word_to_index(word_embedding.word2index)

    label_mapping_file = "{}/label_mapping_{}.pkl".format(
        config.cache_dir, config.config_name)
    logger.info("loading label mapping from {}...".format(label_mapping_file))
    test_dataset.load_label_mapping(label_mapping_file)

    test_data_file = "{}/{}".format(config.data_dir, config.test_data_file)
    logger.info("loading data from {}...".format(test_data_file))
    test_dataset.load_data_from_file(test_data_file)
    logger.info("number of samples: {}".format(test_dataset.num_samples))

    logger.info("processing data...")
    test_dataset.process_data_from_file()

    # model
    new_model_config = {
        "vocab_size": word_embedding.vocab_size,
        "word_dim": word_embedding.word_dim,
        "document_length": test_dataset.document_length,
        "sentence_length": test_dataset.sentence_length,
        "num_labels": test_dataset.num_labels
    }
    config.update_model_config(new_model_config)

    model = Model(config.model_config)

    # metric
    metric = Metric()

    # test configuration
    logger.info("configuration: {}".format(config))

    # data loader
    test_data_loader = DataLoader(test_dataset,
                                  batch_size=config.batch_size,
                                  shuffle=False)

    # model factory
    network = Factory(model)
    network.set_test_module()
    logger.info("number of GPUs: {}".format(network.num_gpus))
    logger.info("device: {}".format(network.device))

    # load model
    model_file = "{}/model_{}.pkl".format(config.cache_dir, config.config_name)
    logger.info("loading model from {}...".format(model_file))
    network.load_model(model_file)

    network.model_to_device()

    # test
    network.eval_mode()
    test_preds = np.zeros([0, test_dataset.num_labels], dtype=np.int)
    test_labels = np.zeros([0, test_dataset.num_labels], dtype=np.int)
    for batch, data in enumerate(test_data_loader):
        sequences_ttl, sequences_cnt, labels = data
        preds = network.test(sequences_ttl, sequences_cnt)
        test_preds = np.concatenate((test_preds, preds), axis=0)
        test_labels = np.concatenate(
            (test_labels, labels.numpy().astype(np.int)), axis=0)

    # metrics
    ac, mp, mr, mf = metric.all_metrics(test_preds, test_labels)
    logger.info("Acc: {:.4f}".format(ac))
    logger.info("MP : {:.4f}".format(mp))
    logger.info("MR : {:.4f}".format(mr))
    logger.info("MF : {:.4f}".format(mf))
Beispiel #3
0
def train(config_filename):
    # configuration
    config = Config()
    config_file = "{}/{}".format(config.config_dir, config_filename)
    config.update_config(config_file)

    # logger
    log_file = "{}/train_{}.txt".format(config.log_dir, config.config_name)
    logger = Logger(log_file)

    # word embedding
    logger.info("setting word embedding...")
    word_embedding = Embedding()

    train_data_file = "{}/{}".format(config.data_dir, config.train_data_file)
    word_vector_file = "{}/{}".format(config.src_dir, config.word_vector_file)
    vocab_list_file = "{}/vocab_list_{}.txt".format(config.cache_dir,
                                                    config.config_name)
    word_embedding_file = "{}/word_embedding_{}.pkl".format(
        config.cache_dir, config.config_name)

    if not os.path.exists(word_embedding_file):
        logger.info("building word embedding...")
        word_embedding.build_word_embedding(train_data_file, word_vector_file,
                                            vocab_list_file,
                                            word_embedding_file)

    logger.info(
        "loading word embedding from {}...".format(word_embedding_file))
    word_embedding.load_word_embedding(word_embedding_file)

    logger.info("vocab_size: {}".format(word_embedding.vocab_size))
    logger.info("word_dim  : {}".format(word_embedding.word_dim))

    # training dataset
    logger.info("setting training dataset...")
    train_dataset = Dataset(config.data_config)

    train_dataset.set_word_to_index(word_embedding.word2index)

    train_data_file = "{}/{}".format(config.data_dir, config.train_data_file)
    logger.info("loading data from {}...".format(train_data_file))
    train_dataset.load_data_from_file(train_data_file)
    logger.info("number of samples: {}".format(train_dataset.num_samples))

    label_list_file = "{}/label_list_{}.txt".format(config.cache_dir,
                                                    config.config_name)
    label_mapping_file = "{}/label_mapping_{}.pkl".format(
        config.cache_dir, config.config_name)
    logger.info("building label mapping...")
    train_dataset.build_label_mapping(label_list_file, label_mapping_file)

    logger.info("processing data...")
    train_dataset.process_data_from_file()

    # validation dataset
    logger.info("setting validation dataset...")
    valid_dataset = Dataset(config.data_config)

    valid_dataset.set_word_to_index(word_embedding.word2index)

    label_mapping_file = "{}/label_mapping_{}.pkl".format(
        config.cache_dir, config.config_name)
    logger.info("loading label mapping from {}...".format(label_mapping_file))
    valid_dataset.load_label_mapping(label_mapping_file)

    valid_data_file = "{}/{}".format(config.data_dir, config.valid_data_file)
    logger.info("loading data from {}...".format(valid_data_file))
    valid_dataset.load_data_from_file(valid_data_file)
    logger.info("number of samples: {}".format(valid_dataset.num_samples))

    logger.info("processing data...")
    valid_dataset.process_data_from_file()

    # model
    new_model_config = {
        "vocab_size": word_embedding.vocab_size,
        "word_dim": word_embedding.word_dim,
        "document_length": train_dataset.document_length,
        "sentence_length": train_dataset.sentence_length,
        "num_labels": train_dataset.num_labels
    }
    config.update_model_config(new_model_config)

    model = Model(config.model_config)

    # metric
    metric = Metric()

    # train configuration
    logger.info("configuration: {}".format(config))

    # data loader
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=config.batch_size,
                                   shuffle=True)
    valid_data_loader = DataLoader(valid_dataset,
                                   batch_size=config.batch_size,
                                   shuffle=False)

    # model factory
    network = Factory(model)
    network.set_train_module()
    logger.info("number of GPUs: {}".format(network.num_gpus))
    logger.info("device: {}".format(network.device))

    # set word embedding
    network.set_word_embedding(word_embedding.matrix)

    network.model_to_device()

    # train and validate
    max_mf = 0
    epoch_count = 0
    for epoch in range(config.num_epochs):
        logger.info("----------------------------------------")

        # train
        network.train_mode()
        for batch, data in enumerate(train_data_loader):
            sequences_ttl, sequences_cnt, labels = data
            loss = network.train(sequences_ttl, sequences_cnt, labels)
            if batch > 0 and batch % config.info_interval == 0:
                logger.info("epoch: {} | batch: {} | loss: {:.6f}".format(
                    epoch, batch, loss))

        # validate
        network.eval_mode()
        valid_preds = np.zeros([0, valid_dataset.num_labels], dtype=np.int)
        valid_labels = np.zeros([0, valid_dataset.num_labels], dtype=np.int)
        for batch, data in enumerate(valid_data_loader):
            sequences_ttl, sequences_cnt, labels = data
            preds, loss = network.validate(sequences_ttl, sequences_cnt,
                                           labels)
            valid_preds = np.concatenate((valid_preds, preds), axis=0)
            valid_labels = np.concatenate(
                (valid_labels, labels.numpy().astype(np.int)), axis=0)

        # metrics
        ac, mp, mr, mf = metric.all_metrics(valid_preds, valid_labels)
        logger.info("Acc: {:.4f}".format(ac))
        logger.info("MP : {:.4f}".format(mp))
        logger.info("MR : {:.4f}".format(mr))
        logger.info("MF : {:.4f}".format(mf))

        # early stop
        if mf >= max_mf:
            max_mf = mf
            epoch_count = 0
            model_file = "{}/model_{}.pkl".format(config.cache_dir,
                                                  config.config_name)
            logger.info("saving model to {}...".format(model_file))
            network.save_model(model_file)
        else:
            epoch_count += 1
            if epoch_count == config.early_stop:
                logger.info("stop training process.")
                logger.info("best epoch: {}".format(epoch - epoch_count))
                break