예제 #1
0
def transformer_eval(
    data_cnf,
    model_cnf,
    data_name,
    model_name,
    model_path,
    tree_id,
    output_suffix,
    dry_run,
):
    logger.info("Loading Test Set")
    mlb = get_mlb(data_cnf["labels_binarizer"])
    num_labels = len(mlb.classes_)
    test_x, _ = get_data(data_cnf["test"]["texts"], None)
    test_atten_mask = test_x["attention_mask"]
    test_x = test_x["input_ids"]

    logger.info(f"Size of Test Set: {len(test_x):,}")

    logger.info("Predicting")
    test_loader = DataLoader(
        MultiLabelDataset(test_x, attention_mask=test_atten_mask),
        model_cnf["predict"]["batch_size"],
        num_workers=4,
    )

    model_cls = MODEL_TYPE[model_cnf["model"]["base"]]

    network = model_cls.from_pretrained(model_cnf["model"]["pretrained"],
                                        num_labels=num_labels)

    model_cnf['model'].pop('load_model', None)
    model = TransformerXML(network,
                           model_path,
                           load_model=True,
                           **data_cnf["model"],
                           **model_cnf["model"])

    scores, labels = model.predict(test_loader,
                                   k=model_cnf["predict"].get("k", 100))
    labels = mlb.classes_[labels]

    logger.info("Finish Predicting")
    score_path, label_path = output_res(
        data_cnf["output"]["res"],
        f"{model_name}-{data_name}{tree_id}",
        scores,
        labels,
        output_suffix,
    )

    log_results(score_path, label_path, dry_run)
예제 #2
0
def random_forest_eval(
    data_cnf, model_cnf, data_name, model_name, model_path, emb_init,
    tree_id, output_suffix, dry_run, num_tree,
):
    mlb_list = []
    logger.info('Loading Test Set')
    mlb = get_mlb(data_cnf['labels_binarizer'])
    labels_num = len(mlb.classes_)
    test_x, _ = get_data(data_cnf['test']['texts'], None)
    logger.info(F'Size of Test Set: {len(test_x):,}')

    logger.info('Predicting')
    if 'cluster' not in model_cnf:
        raise Exception("AttentionXML is not currently supported random forest mode")
    else:
        labels_binarizer_path = data_cnf['labels_binarizer']
        for i in range(num_tree):
            filename = f"{labels_binarizer_path}_RF_{i}"
            mlb_tree = get_mlb(filename)
            mlb_list.append(mlb_tree)

        scores_list = []
        labels_list = []

        for i, mlb in enumerate(mlb_list):
            logger.info(f"Predicting RF {i}")
            model = FastAttentionXML(
                        len(mlb.classes_), data_cnf, model_cnf, tree_id,
                        f"{output_suffix}-{i}")
            scores, labels = model.predict(test_x, model_cnf['predict'].get('rf_k', 100 // num_tree))
            scores_list.append(scores)
            labels_list.append(mlb.classes_[labels])
            logger.info(f"Finish Prediting RF {i}")

        scores = np.hstack(scores_list)
        labels = np.hstack(labels_list)

        i = np.arange(len(scores))[:, None]
        j = np.argsort(scores)[:, ::-1]

        scores = scores[i, j]
        labels = labels[i, j]

    logger.info('Finish Predicting')
    score_path, label_path = output_res(data_cnf['output']['res'],
                                        f'{model_name}-{data_name}{tree_id}',
                                        scores, labels, output_suffix)

    log_results(score_path, label_path, dry_run)
예제 #3
0
def default_eval(
    data_cnf, model_cnf, data_name, model_name, model_path, emb_init,
    tree_id, output_suffix, dry_run,
):
    logger.info('Loading Test Set')
    mlb = get_mlb(data_cnf['labels_binarizer'])
    labels_num = len(mlb.classes_)
    test_x, _ = get_data(data_cnf['test']['texts'], None)
    logger.info(F'Size of Test Set: {len(test_x):,}')

    logger.info('Predicting')
    model_cnf['model'].pop('load_model', None)
    if 'cluster' not in model_cnf:
        test_loader = DataLoader(
            MultiLabelDataset(test_x),
            model_cnf['predict']['batch_size'],
            num_workers=4)

        if 'loss' in model_cnf:
            gamma = model_cnf['loss'].get('gamma', 1.0)
            loss_name = model_cnf['loss']['name']
        else:
            gamma = None
            loss_name = 'bce'

        model = Model(
            network=AttentionRNN, labels_num=labels_num,
            model_path=model_path, emb_init=emb_init,
            load_model=True, loss_name=loss_name, gamma=gamma,
            **data_cnf['model'], **model_cnf['model'])

        scores, labels = model.predict(test_loader, k=model_cnf['predict'].get('k', 100))
        labels = mlb.classes_[labels]
    else:
        model = FastAttentionXML(labels_num, data_cnf, model_cnf,
                                 tree_id, output_suffix)

        scores, labels = model.predict(test_x, model_cnf['predict'].get('k', 100))
        labels = mlb.classes_[labels]

    logger.info('Finish Predicting')
    score_path, label_path = output_res(data_cnf['output']['res'],
                                        f'{model_name}-{data_name}{tree_id}',
                                        scores, labels, output_suffix)

    log_results(score_path, label_path, dry_run)
예제 #4
0
def main(data_cnf, model_cnf, mode, reg):
    yaml = YAML(typ='safe')
    data_cnf, model_cnf = yaml.load(Path(data_cnf)), yaml.load(Path(model_cnf))
    model, model_name, data_name = None, model_cnf['name'], data_cnf['name']
    model_path = os.path.join(model_cnf['path'], F'{model_name}-{data_name}')
    emb_init = get_word_emb(data_cnf['embedding']['emb_init'])
    logger.info(F'Model Name: {model_name}')

    if mode is None or mode == 'train':
        logger.info('Loading Training and Validation Set')
        train_x, train_labels = get_data(data_cnf['train']['texts'],
                                         data_cnf['train']['labels'])
        if 'size' in data_cnf['valid']:
            random_state = data_cnf['valid'].get('random_state', 1240)
            train_x, valid_x, train_labels, valid_labels = train_test_split(
                train_x,
                train_labels,
                test_size=data_cnf['valid']['size'],
                random_state=random_state)
        else:
            valid_x, valid_labels = get_data(data_cnf['valid']['texts'],
                                             data_cnf['valid']['labels'])
        mlb = get_mlb(data_cnf['labels_binarizer'],
                      np.hstack((train_labels, valid_labels)))
        train_y, valid_y = mlb.transform(train_labels), mlb.transform(
            valid_labels)
        labels_num = len(mlb.classes_)
        logger.info(F'Number of Labels: {labels_num}')
        logger.info(F'Size of Training Set: {len(train_x)}')
        logger.info(F'Size of Validation Set: {len(valid_x)}')

        edges = set()
        if reg:
            classes = mlb.classes_.tolist()
            with open(data_cnf['hierarchy']) as fin:
                for line in fin:
                    data = line.strip().split()
                    p = data[0]
                    if p not in classes:
                        continue
                    p_id = classes.index(p)
                    for c in data[1:]:
                        if c not in classes:
                            continue
                        c_id = classes.index(c)
                        edges.add((p_id, c_id))
            logger.info(F'Number of Edges: {len(edges)}')

        logger.info('Training')
        train_loader = DataLoader(MultiLabelDataset(train_x, train_y),
                                  model_cnf['train']['batch_size'],
                                  shuffle=True,
                                  num_workers=4)
        valid_loader = DataLoader(MultiLabelDataset(valid_x,
                                                    valid_y,
                                                    training=True),
                                  model_cnf['valid']['batch_size'],
                                  num_workers=4)
        model = Model(network=MATCH,
                      labels_num=labels_num,
                      model_path=model_path,
                      emb_init=emb_init,
                      mode='train',
                      reg=reg,
                      hierarchy=edges,
                      **data_cnf['model'],
                      **model_cnf['model'])
        opt_params = {
            'lr': model_cnf['train']['learning_rate'],
            'betas':
            (model_cnf['train']['beta1'], model_cnf['train']['beta2']),
            'weight_decay': model_cnf['train']['weight_decay']
        }
        model.train(train_loader,
                    valid_loader,
                    opt_params=opt_params,
                    **model_cnf['train'])  # CHANGE: inserted opt_params
        logger.info('Finish Training')

    if mode is None or mode == 'eval':
        logger.info('Loading Test Set')
        mlb = get_mlb(data_cnf['labels_binarizer'])
        labels_num = len(mlb.classes_)
        test_x, _ = get_data(data_cnf['test']['texts'], None)
        logger.info(F'Size of Test Set: {len(test_x)}')

        logger.info('Predicting')
        test_loader = DataLoader(MultiLabelDataset(test_x),
                                 model_cnf['predict']['batch_size'],
                                 num_workers=4)
        if model is None:
            model = Model(network=MATCH,
                          labels_num=labels_num,
                          model_path=model_path,
                          emb_init=emb_init,
                          mode='eval',
                          **data_cnf['model'],
                          **model_cnf['model'])
        scores, labels = model.predict(test_loader,
                                       k=model_cnf['predict'].get('k', 100))
        logger.info('Finish Predicting')
        labels = mlb.classes_[labels]
        output_res(data_cnf['output']['res'], F'{model_name}-{data_name}',
                   scores, labels)
예제 #5
0
파일: main.py 프로젝트: yourh/AttentionXML
def main(data_cnf, model_cnf, mode, tree_id):
    tree_id = F'-Tree-{tree_id}' if tree_id is not None else ''
    yaml = YAML(typ='safe')
    data_cnf, model_cnf = yaml.load(Path(data_cnf)), yaml.load(Path(model_cnf))
    model, model_name, data_name = None, model_cnf['name'], data_cnf['name']
    model_path = os.path.join(model_cnf['path'], F'{model_name}-{data_name}{tree_id}')
    emb_init = get_word_emb(data_cnf['embedding']['emb_init'])
    logger.info(F'Model Name: {model_name}')

    if mode is None or mode == 'train':
        logger.info('Loading Training and Validation Set')
        train_x, train_labels = get_data(data_cnf['train']['texts'], data_cnf['train']['labels'])
        if 'size' in data_cnf['valid']:
            random_state = data_cnf['valid'].get('random_state', 1240)
            train_x, valid_x, train_labels, valid_labels = train_test_split(train_x, train_labels,
                                                                            test_size=data_cnf['valid']['size'],
                                                                            random_state=random_state)
        else:
            valid_x, valid_labels = get_data(data_cnf['valid']['texts'], data_cnf['valid']['labels'])
        mlb = get_mlb(data_cnf['labels_binarizer'], np.hstack((train_labels, valid_labels)))
        train_y, valid_y = mlb.transform(train_labels), mlb.transform(valid_labels)
        labels_num = len(mlb.classes_)
        logger.info(F'Number of Labels: {labels_num}')
        logger.info(F'Size of Training Set: {len(train_x)}')
        logger.info(F'Size of Validation Set: {len(valid_x)}')

        logger.info('Training')
        if 'cluster' not in model_cnf:
            train_loader = DataLoader(MultiLabelDataset(train_x, train_y),
                                      model_cnf['train']['batch_size'], shuffle=True, num_workers=4)
            valid_loader = DataLoader(MultiLabelDataset(valid_x, valid_y, training=False),
                                      model_cnf['valid']['batch_size'], num_workers=4)
            model = Model(network=AttentionRNN, labels_num=labels_num, model_path=model_path, emb_init=emb_init,
                          **data_cnf['model'], **model_cnf['model'])
            model.train(train_loader, valid_loader, **model_cnf['train'])
        else:
            model = FastAttentionXML(labels_num, data_cnf, model_cnf, tree_id)
            model.train(train_x, train_y, valid_x, valid_y, mlb)
        logger.info('Finish Training')

    if mode is None or mode == 'eval':
        logger.info('Loading Test Set')
        mlb = get_mlb(data_cnf['labels_binarizer'])
        labels_num = len(mlb.classes_)
        test_x, _ = get_data(data_cnf['test']['texts'], None)
        logger.info(F'Size of Test Set: {len(test_x)}')

        logger.info('Predicting')
        if 'cluster' not in model_cnf:
            test_loader = DataLoader(MultiLabelDataset(test_x), model_cnf['predict']['batch_size'],
                                     num_workers=4)
            if model is None:
                model = Model(network=AttentionRNN, labels_num=labels_num, model_path=model_path, emb_init=emb_init,
                              **data_cnf['model'], **model_cnf['model'])
            scores, labels = model.predict(test_loader, k=model_cnf['predict'].get('k', 100))
        else:
            if model is None:
                model = FastAttentionXML(labels_num, data_cnf, model_cnf, tree_id)
            scores, labels = model.predict(test_x)
        logger.info('Finish Predicting')
        labels = mlb.classes_[labels]
        output_res(data_cnf['output']['res'], F'{model_name}-{data_name}{tree_id}', scores, labels)
예제 #6
0
def load_dataset(data_cnf):
    logger.info('Loading Training and Validation Set')
    train_x, train_labels = get_data(data_cnf['train']['texts'],
                                     data_cnf['train']['labels'])
    return train_x, train_labels
예제 #7
0
def spectral_clustering_train(
    data_cnf, data_cnf_path, model_cnf, model_cnf_path,
    emb_init, model_path, tree_id, output_suffix, dry_run,
):
    train_xs = []
    valid_xs = []
    train_labels_list = []
    valid_labels_list = []
    train_ys = []
    valid_ys = []
    mlb_list = []
    indices_list = []

    n_clusters = model_cnf['spectral_clustering']['num_clusters']
    n_components = model_cnf['spectral_clustering']['n_components']
    alg = model_cnf['spectral_clustering']['alg']
    size_min = model_cnf['spectral_clustering']['size_min']
    size_max = model_cnf['spectral_clustering']['size_max']

    train_x, train_labels = load_dataset(data_cnf)

    if 'cluster' not in model_cnf:
        mlb = get_mlb(data_cnf['labels_binarizer'], train_labels)
        train_y = mlb.transform(train_labels)

        logger.info('Build label adjacency matrix')
        adj = train_y.T @ train_y
        adj.setdiag(0)
        adj.eliminate_zeros()
        logger.info(f"Sparsity: {adj.count_nonzero() / adj.shape[0] ** 2}")
        clustering = MySpectralClustering(n_clusters=n_clusters, affinity='precomputed',
                                          n_components=n_components, n_init=1,
                                          size_min=size_min,
                                          size_max=size_max,
                                          assign_labels=alg, n_jobs=-1)
        logger.info('Start Spectral Clustering')
        clustering.fit(adj)
        logger.info('Finish Spectral Clustering')

        groups = [[] for _ in range(n_clusters)]
        for i, group in enumerate(clustering.labels_):
            groups[group].append(i)

        splitted_labels = []
        for indices in groups:
            splitted_labels.append(mlb.classes_[indices])

        for labels in splitted_labels:
            indices = get_splitted_samples(labels, train_labels)
            indices_list.append(indices)
            train_xs.append(train_x[indices])
            train_labels_list.append(train_labels[indices])

        if 'size' in data_cnf['valid']:
            for i, (train_x, train_labels) in enumerate(zip(train_xs, train_labels_list)):
                valid_size = data_cnf['valid']['size']
                if len(train_x) * 0.8 > len(train_x) - valid_size:
                    valid_size = 0.2
                train_x, valid_x, train_labels, valid_labels = train_test_split(
                    train_x, train_labels, test_size=valid_size,
                )
                train_xs[i] = train_x
                train_labels_list[i] = train_labels
                valid_xs.append(valid_x)
                valid_labels_list.append(valid_labels)

        else:
            raise Exception("Setting valid set explicitly is not "
                            "supported spectral clustering mode.")

        labels_binarizer_path = data_cnf['labels_binarizer']
        suffix = output_suffix.upper().replace('-', '_')
        for i, labels in enumerate(splitted_labels):
            filename = f"{labels_binarizer_path}_{suffix}_{i}"
            mlb_tree = get_mlb(filename, labels[None, ...], force=True)
            mlb_list.append(mlb_tree)
            logger.info(f"Number of labels of cluster {i}: {len(labels):,}")
            logger.info(f"Number of Training Set of cluster {i}: {len(train_xs[i]):,}")
            logger.info(f"Number of Validation Set of cluster {i}: {len(valid_xs[i]):,}")

            with redirect_stderr(None):
                train_y = mlb_tree.transform(train_labels_list[i])
                valid_y = mlb_tree.transform(valid_labels_list[i])

            train_ys.append(train_y)
            valid_ys.append(valid_y)

    else:
        if 'size' in data_cnf['valid']:
            train_x, valid_x, train_labels, valid_labels = train_test_split(
                train_x, train_labels, test_size=data_cnf['valid']['size'],
            )

        else:
            valid_x, valid_labels = get_data(data_cnf['valid']['texts'], data_cnf['valid']['labels'])

        mlb = get_mlb(data_cnf['labels_binarizer'], np.hstack((
            train_labels, valid_labels,
        )))

        train_y, valid_y = mlb.transform(train_labels), mlb.transform(valid_labels)


    logger.info('Training')
    if 'cluster' not in model_cnf:
        for i, (train_x, train_y, valid_x, valid_y) in enumerate(zip(
            train_xs, train_ys, valid_xs, valid_ys,
        )):
            train_loader = DataLoader(
                MultiLabelDataset(train_x, train_y),
                model_cnf['train']['batch_size'], shuffle=True, num_workers=4)
            valid_loader = DataLoader(
                MultiLabelDataset(valid_x, valid_y, training=False),
                model_cnf['valid']['batch_size'], num_workers=4)
            model = Model(
                network=AttentionRNN, labels_num=len(mlb_list[i].classes_),
                model_path=f'{model_path}-{i}', emb_init=emb_init,
                **data_cnf['model'], **model_cnf['model'])

            if not dry_run:
                logger.info(f"Start Training Cluster {i}")
                model.train(train_loader, valid_loader, **model_cnf['train'])
                logger.info(f"Finish Training Cluster {i}")
            else:
                model.save_model()

    else:
        model = FastAttentionXML(
            len(mlb.classes_), data_cnf, model_cnf, tree_id, output_suffix,
        )

        if not dry_run:
            model.train(train_x, train_y, valid_x, valid_y, mlb)

    log_config(data_cnf_path, model_cnf_path, dry_run)
예제 #8
0
def spectral_clustering_eval(
    data_cnf, model_cnf, data_name, model_name, model_path, emb_init,
    tree_id, output_suffix, dry_run,
):
    mlb_list = []
    n_clusters = model_cnf['spectral_clustering']['num_clusters']
    labels_binarizer_path = data_cnf['labels_binarizer']
    scores_list = []
    labels_list = []

    logger.info('Loading Test Set')
    test_x, _ = get_data(data_cnf['test']['texts'], None)
    logger.info(F'Size of Test Set: {len(test_x):,}')

    logger.info('Predicting')
    if 'cluster' not in model_cnf:
        suffix = output_suffix.upper().replace('-', '_')
        for i in range(n_clusters):
            filename = f"{labels_binarizer_path}_{suffix}_{i}"
            mlb_tree = get_mlb(filename)
            mlb_list.append(mlb_tree)

        test_loader = DataLoader(
            MultiLabelDataset(test_x),
            model_cnf['predict']['batch_size'],
            num_workers=4)

        for i, mlb in enumerate(mlb_list):
            logger.info(f"Predicting Cluster {i}")
            labels_num = len(mlb.classes_)
            k = model_cnf['predict'].get('k', 100) // n_clusters

            model = Model(
                network=AttentionRNN, labels_num=labels_num,
                model_path=f'{model_path}-{i}', emb_init=emb_init,
                load_model=True,
                **data_cnf['model'], **model_cnf['model'])

            scores, labels = model.predict(test_loader, k=k)
            scores_list.append(scores)
            labels_list.append(mlb.classes_[labels])
            logger.info(f"Finish Prediting Cluster {i}")

        scores = np.hstack(scores_list)
        labels = np.hstack(labels_list)

        i = np.arange(len(scores))[:, None]
        j = np.argsort(scores)[:, ::-1]

        scores = scores[i, j]
        labels = labels[i, j]

    else:
        mlb = get_mlb(data_cnf['labels_binarizer'])
        model = FastAttentionXML(len(mlb.classes_), data_cnf, model_cnf,
                                 tree_id, output_suffix)

        scores, labels = model.predict(test_x, model_cnf['predict'].get('k', 100))
        labels = mlb.classes_[labels]

    logger.info('Finish Predicting')
    score_path, label_path = output_res(data_cnf['output']['res'],
                                        f'{model_name}-{data_name}{tree_id}',
                                        scores, labels, output_suffix)

    log_results(score_path, label_path, dry_run)
예제 #9
0
def build_tree_by_level(
    sparse_data_x,
    sparse_data_y,
    train_x: str,
    emb_init: str,
    mlb,
    indices: np.ndarray,
    eps: float,
    max_leaf: int,
    levels: list,
    label_emb: str,
    alg: str,
    groups_path: str,
    n_components: int = None,
    overlap_ratio: float = 0.0,
    head_split_ratio: float = 0.0,
    adj_th: int = None,
    random_state: int = None,
):
    os.makedirs(os.path.split(groups_path)[0], exist_ok=True)
    logger.info('Clustering')
    logger.info('Getting Labels Feature')

    if label_emb == 'tf-idf':
        sparse_x, sparse_labels = get_sparse_feature(sparse_data_x,
                                                     sparse_data_y)

        with redirect_stderr(None):
            sparse_y = mlb.transform(sparse_labels)

        if indices is not None:
            sparse_x = sparse_x[indices]
            sparse_y = sparse_y[indices]

        labels_f = normalize(csr_matrix(sparse_y.T) @ csc_matrix(sparse_x))

    elif label_emb == 'glove':
        emb_init = get_word_emb(emb_init)
        train_x, train_y = get_data(train_x, sparse_data_y)

        with redirect_stderr(None):
            train_y = mlb.transform(train_y)

        if indices is not None:
            train_x = train_x[indices]
            train_y = train_y[indices]

        labels_f = normalize(_get_labels_f(emb_init, train_x, train_y))

    elif label_emb == 'spectral':
        _, sparse_labels = get_sparse_feature(sparse_data_x, sparse_data_y)
        sparse_y = mlb.transform(sparse_labels)

        logger.info('Build label adjacency matrix')

        adj = sparse_y.T @ sparse_y
        adj.setdiag(0)
        adj.eliminate_zeros()

        if adj_th is not None:
            logger.info(f"adj th: {adj_th}")
            ind1 = np.where(adj.data < adj_th)
            ind2 = np.where(adj.data >= adj_th)
            adj.data[ind1] = 0
            adj.data[ind2] = 1
            adj.eliminate_zeros()

        logger.info(
            f"Sparsity: {1 - (adj.count_nonzero() / adj.shape[0] ** 2)}")

        logger.info('Getting spectral embedding')
        labels_f = spectral_embedding(adj,
                                      n_components=n_components,
                                      norm_laplacian=adj_th is None,
                                      eigen_solver='amg',
                                      drop_first=False)
        labels_f = normalize(labels_f)

    else:
        raise ValueError(f"label_emb: {label_emb} is invalid")

    head_labels = None

    if head_split_ratio > 0:
        logger.info(f"head ratio: {head_split_ratio}")
        train_labels = np.load(sparse_data_y, allow_pickle=True)
        train_y = mlb.transform(train_labels)
        counts = np.sum(train_y, axis=0).A1
        cnt_indices = np.argsort(counts)[::-1]
        head_labels = cnt_indices[:int(len(counts) * head_split_ratio)]
        logger.info(f"# of head labels: {len(head_labels)}")
        logger.info(f"# of tail labels: {len(counts) - len(head_labels)}")

    logger.info(F'Start Clustering {levels}')

    levels, q = [2**x for x in levels], None

    for i in range(len(levels) - 1, -1, -1):
        if os.path.exists(F'{groups_path}-Level-{i}.npy'):
            labels_list = np.load(F'{groups_path}-Level-{i}.npy',
                                  allow_pickle=True)
            q = [(labels_i, labels_f[labels_i]) for labels_i in labels_list]
            break
    if q is None:
        q = [(np.arange(labels_f.shape[0]), labels_f)]
    while q:
        labels_list = np.asarray([x[0] for x in q])
        assert len(reduce(lambda a, b: a | set(b), labels_list,
                          set())) == labels_f.shape[0]
        if len(labels_list) in levels:
            level = levels.index(len(labels_list))
            groups = np.asarray(labels_list)
            a = set(groups[0])
            b = set(groups[1])
            n_nodes = [len(set(group)) for group in groups]
            logger.info(F'Finish Clustering Level-{level}')
            logger.info(f'# of node: {len(a)}, # of overlapped: {len(a & b)}')
            logger.info(f'max # of node: {max(n_nodes)}')
            logger.info(f'average # of node: {np.mean(n_nodes)}')

            if head_labels is not None:
                logger.info(f"Getting Cluster Centers")
                if sp.issparse(labels_f):
                    centers = sp.vstack([
                        normalize(csr_matrix(labels_f[idx].mean(axis=0)))
                        for idx in groups
                    ])
                else:
                    centers = np.vstack([
                        normalize(labels_f[idx].mean(axis=0, keepdims=True))
                        for idx in groups
                    ])

                # Find tail groups
                # If all labels in a group are not in head labels,
                # this group is tail group
                tail_groups = []
                for i, group in enumerate(groups):
                    is_tail_group = True
                    for label in group:
                        if label in head_labels:
                            is_tail_group = False
                            break
                    if is_tail_group:
                        tail_groups.append(i)
                tail_groups = np.array(tail_groups)

                nearest_head_labels = np.argmax(
                    centers[tail_groups] @ labels_f[head_labels].T, axis=1)

                if hasattr(nearest_head_labels, 'A1'):
                    nearest_head_labels = nearest_head_labels.A1

                for i, tail_group in enumerate(tail_groups):
                    head_label = head_labels[nearest_head_labels[i]]
                    group = groups[tail_group]
                    groups[tail_group] = np.append(groups[tail_group],
                                                   head_label)

            np.save(F'{groups_path}-Level-{level}.npy', groups)

            if level == len(levels) - 1:
                break
        else:
            logger.info(F'Finish Clustering {len(labels_list)}')
        next_q = []
        for node_i, node_f in q:
            if len(node_i) > max_leaf:
                next_q += list(
                    split_node(node_i, node_f, eps, alg, overlap_ratio,
                               random_state))
        q = next_q
    logger.info('Finish Clustering')
예제 #10
0
def transformer_train(
    data_cnf,
    data_cnf_path,
    model_cnf,
    model_cnf_path,
    model_path,
    dry_run,
):
    train_x, train_labels = load_dataset(data_cnf)
    train_input_ids = train_x["input_ids"]
    train_atten_mask = train_x["attention_mask"]

    if "size" in data_cnf["valid"]:
        (
            train_x,
            valid_x,
            train_atten_mask,
            valid_atten_mask,
            train_labels,
            valid_labels,
        ) = train_test_split(
            train_input_ids,
            train_atten_mask,
            train_labels,
            test_size=data_cnf["valid"]["size"],
        )
    else:
        valid_x, valid_labels = get_data(data_cnf["valid"]["texts"],
                                         data_cnf["valid"]["labels"])
        valid_atten_mask = None

    mlb = get_mlb(data_cnf["labels_binarizer"],
                  np.hstack((
                      train_labels,
                      valid_labels,
                  )))

    train_y, valid_y = mlb.transform(train_labels), mlb.transform(valid_labels)
    num_labels = len(mlb.classes_)
    logger.info(f"Number of Labels: {num_labels}")
    logger.info(f"Size of Training Set: {len(train_x):,}")
    logger.info(f"Size of Validation Set: {len(valid_x):,}")

    logger.info("Training")

    train_loader = DataLoader(
        MultiLabelDataset(train_x, train_y, train_atten_mask),
        model_cnf["train"]["batch_size"],
        shuffle=True,
        num_workers=4,
    )
    valid_loader = DataLoader(
        MultiLabelDataset(valid_x, valid_y, valid_atten_mask, training=False),
        model_cnf["valid"]["batch_size"],
        num_workers=4,
    )

    model_cls = MODEL_TYPE[model_cnf["model"]["base"]]

    network = model_cls.from_pretrained(model_cnf["model"]["pretrained"],
                                        num_labels=num_labels)

    if model_cnf['model'].get('freeze_encoder', False):
        for param in network.base_model.parameters():
            param.requires_grad = False

    model = TransformerXML(network, model_path, **data_cnf["model"],
                           **model_cnf["model"])

    if not dry_run:
        model.train(train_loader, valid_loader, mlb=mlb, **model_cnf["train"])

    log_config(data_cnf_path, model_cnf_path, dry_run)
예제 #11
0
def default_train(
    data_cnf, data_cnf_path, model_cnf, model_cnf_path,
    emb_init, model_path, tree_id, output_suffix, dry_run,
):
    train_x, train_labels = load_dataset(data_cnf)

    if 'size' in data_cnf['valid']:
        train_x, valid_x, train_labels, valid_labels = train_test_split(
            train_x, train_labels, test_size=data_cnf['valid']['size'],
        )

    else:
        valid_x, valid_labels = get_data(data_cnf['valid']['texts'], data_cnf['valid']['labels'])

    mlb = get_mlb(data_cnf['labels_binarizer'], np.hstack((
        train_labels, valid_labels,
    )))
    freq = mlb.transform(np.hstack([train_labels, valid_labels])).sum(axis=0).A1
    train_y, valid_y = mlb.transform(train_labels), mlb.transform(valid_labels)
    labels_num = len(mlb.classes_)
    logger.info(F'Number of Labels: {labels_num}')
    logger.info(F'Size of Training Set: {len(train_x):,}')
    logger.info(F'Size of Validation Set: {len(valid_x):,}')

    logger.info('Training')
    if 'cluster' not in model_cnf:
        if 'propensity' in data_cnf:
            a = data_cnf['propensity']['a']
            b = data_cnf['propensity']['b']
            pos_weight = get_inv_propensity(train_y, a, b)
        else:
            pos_weight = None

        train_loader = DataLoader(
            MultiLabelDataset(train_x, train_y),
            model_cnf['train']['batch_size'], shuffle=True, num_workers=4)
        valid_loader = DataLoader(
            MultiLabelDataset(valid_x, valid_y, training=False),
            model_cnf['valid']['batch_size'], num_workers=4)

        if 'loss' in model_cnf:
            gamma = model_cnf['loss'].get('gamma', 2.0)
            loss_name = model_cnf['loss']['name']
        else:
            gamma = None
            loss_name = 'bce'

        model = Model(
            network=AttentionRNN, labels_num=labels_num, model_path=model_path,
            emb_init=emb_init, pos_weight=pos_weight, loss_name=loss_name, gamma=gamma,
            freq=freq, **data_cnf['model'], **model_cnf['model'])

        if not dry_run:
            model.train(train_loader, valid_loader, mlb=mlb, **model_cnf['train'])
        else:
            model.save_model()

    else:
        model = FastAttentionXML(labels_num, data_cnf, model_cnf, tree_id, output_suffix)

        if not dry_run:
            model.train(train_x, train_y, valid_x, valid_y, mlb)

    log_config(data_cnf_path, model_cnf_path, dry_run)
예제 #12
0
def main(data_cnf, model_cnf, mode):
    model_name = os.path.split(model_cnf)[1].split(".")[0]
    # 設定log檔案位置
    logfile("./logs/logfile_" + model_name + ".log")
    yaml = YAML(typ='safe')
    data_cnf, model_cnf = yaml.load(Path(data_cnf)), yaml.load(Path(model_cnf))
    model, model_name, data_name = None, model_cnf['name'], data_cnf['name']
    # model_path = model_cnf['path'] + "/" + model_cnf['name'] + '.h'
    model_path = r'E:\\PycharmProject\\CorNet\\' + model_name + '.h5'
    emb_init = get_word_emb(data_cnf['embedding']['emb_init'])
    logger.info(F'Model Name: {model_name}')

    # keras log file
    csv_logger = CSVLogger('./logs/' + model_name + '_log.csv', append=True, separator=',')

    if mode is None or mode == 'train':
        logger.info('Loading Training and Validation Set')
        train_x, train_labels = get_data(data_cnf['train']['texts'], data_cnf['train']['labels'])
        if 'size' in data_cnf['valid']:
            random_state = data_cnf['valid'].get('random_state', 1240)
            train_x, valid_x, train_labels, valid_labels = train_test_split(train_x, train_labels,
                                                                            test_size=data_cnf['valid']['size'],
                                                                            random_state=random_state)
        else:
            valid_x, valid_labels = get_data(data_cnf['valid']['texts'], data_cnf['valid']['labels'])
        mlb = get_mlb(data_cnf['labels_binarizer'], np.hstack((train_labels, valid_labels)))
        train_y, valid_y = mlb.transform(train_labels), mlb.transform(valid_labels)
        labels_num = len(mlb.classes_)
        logger.info(F'Number of Labels: {labels_num}')
        logger.info(F'Size of Training Set: {len(train_x)}')
        logger.info(F'Size of Validation Set: {len(valid_x)}')

        vocab_size = emb_init.shape[0]
        emb_size = emb_init.shape[1]

        # 可調參數
        data_num = len(train_x)
        ks = 3
        output_channel = model_cnf['model']['num_filters']
        dynamic_pool_length = model_cnf['model']['dynamic_pool_length']
        num_bottleneck_hidden = model_cnf['model']['bottleneck_dim']
        drop_out = model_cnf['model']['dropout']
        cornet_dim = model_cnf['model']['cornet_dim']
        nb_cornet_block = model_cnf['model'].get('nb_cornet_block', 0)
        nb_epochs = model_cnf['train']['nb_epoch']
        batch_size = model_cnf['train']['batch_size']

        max_length = 500

        input_tensor = Input(batch_shape=(batch_size, max_length), name='input')
        emb_data = Embedding(input_dim=vocab_size,
                             output_dim=emb_size,
                             input_length=max_length,
                             weights=[emb_init],
                             trainable=False,
                             name='embedding1')(input_tensor)
        emb_data.trainable = False
        # emd_out_4d = keras.layers.core.RepeatVector(1)(emb_data)
        # unsqueeze_emb_data = tf.keras.layers.Reshape((1, 500, 300), input_shape=(500, 300))(emb_data)
        # emb_data = tf.expand_dims(emb_data, axis=1)
        # emb_data = Lambda(reshape_tensor, arguments={'shape': (1, max_length, 300)}, name='lambda1')(
        #     emb_data)

        conv1_output = Convolution1D(output_channel, 2, padding='same',
                                     kernel_initializer=keras.initializers.glorot_uniform(seed=None),
                                     activation='relu', name='conv1')(emb_data)
        # conv1_output = Lambda(reshape_tensor, arguments={'shape': (batch_size, max_length, output_channel)},
        #                       name='conv1_lambda')(
        #     conv1_output)

        conv2_output = Convolution1D(output_channel, 4, padding='same',
                                     kernel_initializer=keras.initializers.glorot_uniform(seed=None),
                                     activation='relu', name='conv2')(emb_data)
        # conv2_output = Lambda(reshape_tensor, arguments={'shape': (batch_size, max_length, output_channel)},
        #                       name='conv2_lambda')(
        #     conv2_output)

        conv3_output = Convolution1D(output_channel, 8, padding='same',
                                     kernel_initializer=keras.initializers.glorot_uniform(seed=None),
                                     activation='relu', name='conv3')(emb_data)
        # conv3_output = Lambda(reshape_tensor, arguments={'shape': (batch_size, max_length, output_channel)},
        #                       name='conv3_lambda')(
        #     conv3_output)
        # pool1 = adapmaxpooling(conv1_output, dynamic_pool_length)
        pool1 = GlobalMaxPooling1D(name='globalmaxpooling1')(conv1_output)
        pool2 = GlobalMaxPooling1D(name='globalmaxpooling2')(conv2_output)
        pool3 = GlobalMaxPooling1D(name='globalmaxpooling3')(conv3_output)
        output = concatenate([pool1, pool2, pool3], axis=-1)
        # output = Dense(num_bottleneck_hidden, activation='relu',name='bottleneck')(output)
        output = Dropout(drop_out, name='dropout1')(output)
        output = Dense(labels_num, activation='softmax', name='dense_final',
                       kernel_initializer=keras.initializers.glorot_uniform(seed=None))(output)

        if nb_cornet_block > 0:
            for i in range(nb_cornet_block):
                x_shortcut = output
                x = keras.layers.Activation('sigmoid', name='cornet_sigmoid_{0}'.format(i + 1))(output)
                x = Dense(cornet_dim, kernel_initializer='glorot_uniform', name='cornet_1st_dense_{0}'.format(i + 1))(x)

                # x = Dense(cornet_dim, kernel_initializer=keras.initializers.glorot_uniform(seed=None),
                #           activation='sigmoid', name='cornet_1st_dense_{0}'.format(i + 1))(output)

                x = keras.layers.Activation('elu', name='cornet_elu_{0}'.format(i + 1))(x)
                x = Dense(labels_num, kernel_initializer='glorot_uniform', name='cornet_2nd_dense_{0}'.format(i + 1))(x)

                # x = Dense(labels_num, kernel_initializer=keras.initializers.glorot_uniform(seed=None), activation='elu',
                #           name='cornet_2nd_dense_{0}'.format(i + 1))(x)

                output = Add()([x, x_shortcut])

        model = Model(input_tensor, output)
        model.summary()
        model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[tf.keras.metrics.Precision(top_k=5)])
        # model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[tf.keras.metrics.top_k_categorical_accuracy(k=5)])
        model.fit_generator(steps_per_epoch=data_num / batch_size,
                            generator=batch_generator(train_x, train_y, batch_size),
                            validation_data=batch_generator(valid_x, valid_y, batch_size),
                            validation_steps=valid_x.shape[0] / batch_size,
                            nb_epoch=nb_epochs, callbacks=[csv_logger])
        model.save(model_path)
    elif mode is None or mode == 'eval':
        logger.info('Loading Training and Validation Set')
        train_x, train_labels = get_data(data_cnf['train']['texts'], data_cnf['train']['labels'])
        if 'size' in data_cnf['valid']:  # 如果有設定valid的size 則直接使用train的一部分作為valid
            random_state = data_cnf['valid'].get('random_state', 1240)
            train_x, valid_x, train_labels, valid_labels = train_test_split(train_x, train_labels,
                                                                            test_size=data_cnf['valid']['size'],
                                                                            random_state=random_state)
        else:
            valid_x, valid_labels = get_data(data_cnf['valid']['texts'], data_cnf['valid']['labels'])
        mlb = get_mlb(data_cnf['labels_binarizer'], np.hstack((train_labels, valid_labels)))
        train_y, valid_y = mlb.transform(train_labels), mlb.transform(valid_labels)
        labels_num = len(mlb.classes_)
        ##################################################################################################
        logger.info('Loading Test Set')
        logger.info('model path: ', model_path)
        mlb = get_mlb(data_cnf['labels_binarizer'])
        labels_num = len(mlb.classes_)
        test_x, test_label = get_data(data_cnf['test']['texts'], data_cnf['test']['labels'])
        logger.info(F'Size of Test Set: {len(test_x)}')
        test_y = mlb.transform(test_label).toarray()
        model = tf.keras.models.load_model(model_path)
        score = model.predict(test_x)
        print("p5: ", p5(test_y, score))
예제 #13
0
파일: main.py 프로젝트: SYLin117/CorNet
def main(data_cnf, model_cnf, mode):
    model_name = os.path.split(model_cnf)[1].split(".")[0]
    yaml = YAML(typ='safe')
    data_cnf, model_cnf = yaml.load(Path(data_cnf)), yaml.load(Path(model_cnf))

    # 設定log檔案位置
    logfile("./logs/logfile_{0}_cornet_{1}_cornet_dim_{2}.log".format(
        model_name, model_cnf['model']['n_cornet_blocks'],
        model_cnf['model']['cornet_dim']))

    model, model_name, data_name = None, model_cnf['name'], data_cnf['name']
    model_path = os.path.join(
        model_cnf['path'],
        F'{model_name}-{data_name}-{model_cnf["model"]["n_cornet_blocks"]}-{model_cnf["model"]["cornet_dim"]}'
    )
    emb_init = get_word_emb(data_cnf['embedding']['emb_init'])
    logger.info(F'Model Name: {model_name}')
    # summary(model_dict[model_name])
    if mode is None or mode == 'train':
        logger.info('Loading Training and Validation Set')
        train_x, train_labels = get_data(data_cnf['train']['texts'],
                                         data_cnf['train']['labels'])
        if 'size' in data_cnf['valid']:
            random_state = data_cnf['valid'].get('random_state', 1240)
            train_x, valid_x, train_labels, valid_labels = train_test_split(
                train_x,
                train_labels,
                test_size=data_cnf['valid']['size'],
                random_state=random_state)
        else:
            valid_x, valid_labels = get_data(data_cnf['valid']['texts'],
                                             data_cnf['valid']['labels'])
        mlb = get_mlb(data_cnf['labels_binarizer'],
                      np.hstack((train_labels, valid_labels)))
        train_y, valid_y = mlb.transform(train_labels), mlb.transform(
            valid_labels)
        labels_num = len(mlb.classes_)
        logger.info(F'Number of Labels: {labels_num}')
        logger.info(F'Size of Training Set: {len(train_x)}')
        logger.info(F'Size of Validation Set: {len(valid_x)}')

        logger.info('Training')
        train_loader = DataLoader(MultiLabelDataset(train_x, train_y),
                                  model_cnf['train']['batch_size'],
                                  shuffle=True,
                                  num_workers=4)
        valid_loader = DataLoader(MultiLabelDataset(valid_x,
                                                    valid_y,
                                                    training=True),
                                  model_cnf['valid']['batch_size'],
                                  num_workers=4)

        if 'gpipe' not in model_cnf:
            model = Model(network=model_dict[model_name],
                          labels_num=labels_num,
                          model_path=model_path,
                          emb_init=emb_init,
                          **data_cnf['model'],
                          **model_cnf['model'])
        else:
            model = GPipeModel(model_name,
                               labels_num=labels_num,
                               model_path=model_path,
                               emb_init=emb_init,
                               **data_cnf['model'],
                               **model_cnf['model'])
        loss, p1, p5 = model.train(train_loader, valid_loader,
                                   **model_cnf['train'])
        np.save(
            model_cnf['np_loss'] + "{0}_cornet_{1}_cornet_dim_{2}.npy".format(
                model_name, model_cnf['model']['n_cornet_blocks'],
                model_cnf['model']['cornet_dim']), loss)
        np.save(
            model_cnf['np_p1'] + "{0}_cornet_{1}_cornet_dim_{2}.npy".format(
                model_name, model_cnf['model']['n_cornet_blocks'],
                model_cnf['model']['cornet_dim']), p1)
        np.save(
            model_cnf['np_p5'] + "{0}_cornet_{1}_cornet_dim_{2}.npy".format(
                model_name, model_cnf['model']['n_cornet_blocks'],
                model_cnf['model']['cornet_dim']), p5)
        logger.info('Finish Training')

    if mode is None or mode == 'eval':
        logger.info('Loading Test Set')
        logger.info('model path: ', model_path)
        mlb = get_mlb(data_cnf['labels_binarizer'])
        labels_num = len(mlb.classes_)
        test_x, _ = get_data(data_cnf['test']['texts'], None)
        logger.info(F'Size of Test Set: {len(test_x)}')

        logger.info('Predicting')
        test_loader = DataLoader(MultiLabelDataset(test_x),
                                 model_cnf['predict']['batch_size'],
                                 num_workers=4)
        if 'gpipe' not in model_cnf:
            if model is None:
                model = Model(network=model_dict[model_name],
                              labels_num=labels_num,
                              model_path=model_path,
                              emb_init=emb_init,
                              **data_cnf['model'],
                              **model_cnf['model'])
        else:
            if model is None:
                model = GPipeModel(model_name,
                                   labels_num=labels_num,
                                   model_path=model_path,
                                   emb_init=emb_init,
                                   **data_cnf['model'],
                                   **model_cnf['model'])
        scores, labels = model.predict(test_loader,
                                       k=model_cnf['predict'].get('k', 3801))
        logger.info('Finish Predicting')
        labels = mlb.classes_[labels]
        output_res(data_cnf['output']['res'], F'{model_name}-{data_name}',
                   scores, labels)
예제 #14
0
def splitting_head_tail_train(
    data_cnf,
    data_cnf_path,
    model_cnf,
    model_cnf_path,
    emb_init,
    model_path,
    tree_id,
    output_suffix,
    dry_run,
    split_ratio,
):
    train_x, train_labels = load_dataset(data_cnf)

    logger.info(f'Split head and tail labels: {split_ratio}')
    head_labels, head_labels_i, tail_labels, tail_labels_i = get_head_tail_labels(
        train_labels,
        split_ratio,
    )

    train_h_x = train_x[head_labels_i]
    train_h_labels = train_labels[head_labels_i]

    train_t_x = train_x[tail_labels_i]
    train_t_labels = train_labels[tail_labels_i]

    if 'size' in data_cnf['valid']:
        valid_size = data_cnf['valid']['size']
        train_h_x, valid_h_x, train_h_labels, valid_h_labels = train_test_split(
            train_h_x,
            train_h_labels,
            test_size=valid_size if len(train_h_x) > 2 * valid_size else 0.1,
        )

        train_t_x, valid_t_x, train_t_labels, valid_t_labels = train_test_split(
            train_t_x,
            train_t_labels,
            test_size=valid_size if len(train_t_x) > 2 * valid_size else 0.1,
        )

    else:
        valid_x, valid_labels = get_data(data_cnf['valid']['texts'],
                                         data_cnf['valid']['labels'])
        valid_h_labels_i, valid_t_labels_i = get_head_tail_samples(
            head_labels,
            tail_labels,
            valid_labels,
        )
        valid_t_x = valid_x[valid_h_labels_i]
        valid_h_x = valid_x[valid_t_labels_i]
        valid_h_labels = valid_x[valid_h_labels_i]
        valid_t_labels = valid_x[valid_t_labels_i]

    labels_binarizer_path = data_cnf['labels_binarizer']
    mlb_h = get_mlb(f"{labels_binarizer_path}_h_{split_ratio}",
                    head_labels[None, ...])
    mlb_t = get_mlb(f"{labels_binarizer_path}_t_{split_ratio}",
                    tail_labels[None, ...])

    with redirect_stderr(None):
        train_h_y = mlb_h.transform(train_h_labels)
        valid_h_y = mlb_h.transform(valid_h_labels)
        train_t_y = mlb_t.transform(train_t_labels)
        valid_t_y = mlb_t.transform(valid_t_labels)

    logger.info(f'Number of Head Labels: {len(head_labels):,}')
    logger.info(f'Number of Tail Labels: {len(tail_labels):,}')
    logger.info(f'Size of Head Training Set: {len(train_h_x):,}')
    logger.info(f'Size of Head Validation Set: {len(valid_h_x):,}')
    logger.info(f'Size of Tail Training Set: {len(train_t_x):,}')
    logger.info(f'Size of Tail Validation Set: {len(valid_t_x):,}')

    logger.info('Training')
    if 'cluster' not in model_cnf:
        train_h_loader = DataLoader(MultiLabelDataset(train_h_x, train_h_y),
                                    model_cnf['train']['batch_size'],
                                    shuffle=True,
                                    num_workers=4)
        valid_h_loader = DataLoader(MultiLabelDataset(valid_h_x,
                                                      valid_h_y,
                                                      training=False),
                                    model_cnf['valid']['batch_size'],
                                    num_workers=4)
        head_model = Model(network=AttentionRNN,
                           labels_num=len(head_labels),
                           model_path=f'{model_path}-head',
                           emb_init=emb_init,
                           **data_cnf['model'],
                           **model_cnf['model'])

        if not dry_run:
            logger.info('Training Head Model')
            head_model.train(train_h_loader, valid_h_loader,
                             **model_cnf['train'])
            logger.info('Finish Traning Head Model')
        else:
            head_model.save_model()

        train_t_loader = DataLoader(MultiLabelDataset(train_t_x, train_t_y),
                                    model_cnf['train']['batch_size'],
                                    shuffle=True,
                                    num_workers=4)
        valid_t_loader = DataLoader(MultiLabelDataset(valid_t_x,
                                                      valid_t_y,
                                                      training=False),
                                    model_cnf['valid']['batch_size'],
                                    num_workers=4)
        tail_model = Model(network=AttentionRNN,
                           labels_num=len(tail_labels),
                           model_path=f'{model_path}-tail',
                           emb_init=emb_init,
                           **data_cnf['model'],
                           **model_cnf['model'])

        if not dry_run:
            logger.info('Training Tail Model')
            tail_model.train(train_t_loader, valid_t_loader,
                             **model_cnf['train'])
            logger.info('Finish Traning Tail Model')
        else:
            tail_model.save_model()

    else:
        raise Exception("FastAttention is not currently supported for "
                        "splited head and tail dataset")

    log_config(data_cnf_path, model_cnf_path, dry_run)

    return head_model, tail_model, head_labels, tail_labels
예제 #15
0
def splitting_head_tail_eval(
    data_cnf,
    model_cnf,
    data_name,
    model_name,
    model_path,
    emb_init,
    tree_id,
    output_suffix,
    dry_run,
    split_ratio,
    head_labels,
    tail_labels,
    head_model,
    tail_model,
):
    logger.info('Loading Test Set')
    mlb = get_mlb(data_cnf['labels_binarizer'])
    labels_num = len(mlb.classes_)
    test_x, _ = get_data(data_cnf['test']['texts'], None)
    logger.info(F'Size of Test Set: {len(test_x):,}')

    labels_binarizer_path = data_cnf['labels_binarizer']
    mlb_h = get_mlb(f"{labels_binarizer_path}_h_{split_ratio}")
    mlb_t = get_mlb(f"{labels_binarizer_path}_t_{split_ratio}")

    if head_labels is None:
        train_x, train_labels = get_data(data_cnf['train']['texts'],
                                         data_cnf['train']['labels'])
        head_labels, _, tail_labels, _ = get_head_tail_labels(
            train_labels,
            split_ratio,
        )

    h_labels_i = np.nonzero(mlb.transform(head_labels[None, ...]).toarray())[0]
    t_labels_i = np.nonzero(mlb.transform(tail_labels[None, ...]).toarray())[0]

    logger.info('Predicting')
    if 'cluster' not in model_cnf:
        test_loader = DataLoader(MultiLabelDataset(test_x),
                                 model_cnf['predict']['batch_size'],
                                 num_workers=4)

        if head_model is None:
            head_model = Model(network=AttentionRNN,
                               labels_num=len(head_labels),
                               model_path=f'{model_path}-head',
                               emb_init=emb_init,
                               load_model=True,
                               **data_cnf['model'],
                               **model_cnf['model'])

        logger.info('Predicting Head Model')
        h_k = model_cnf['predict'].get('top_head_k', 30)
        scores_h, labels_h = head_model.predict(test_loader, k=h_k)
        labels_h = mlb_h.classes_[labels_h]
        logger.info('Finish Predicting Head Model')

        if tail_model is None:
            tail_model = Model(network=AttentionRNN,
                               labels_num=len(tail_labels),
                               model_path=f'{model_path}-tail',
                               emb_init=emb_init,
                               load_model=True,
                               **data_cnf['model'],
                               **model_cnf['model'])

        logger.info('Predicting Tail Model')
        t_k = model_cnf['predict'].get('top_tail_k', 70)
        scores_t, labels_t = tail_model.predict(test_loader, k=t_k)
        labels_t = mlb_t.classes_[labels_t]
        logger.info('Finish Predicting Tail Model')

        scores = np.c_[scores_h, scores_t]
        labels = np.c_[labels_h, labels_t]

        i = np.arange(len(scores))[:, None]
        j = np.argsort(scores)[:, ::-1]

        scores = scores[i, j]
        labels = labels[i, j]
    else:
        raise Exception("FastAttention is not currently supported for "
                        "splited head and tail dataset")

    logger.info('Finish Predicting')
    score_path, label_path = output_res(data_cnf['output']['res'],
                                        f'{model_name}-{data_name}{tree_id}',
                                        scores, labels, output_suffix)

    log_results(score_path, label_path, dry_run)