コード例 #1
0
    def load_data(args):
        with open(args.data_bin_path, 'rb') as fin:
            data_dict = pickle.load(fin)
        trn_features = data_dict['trn_features']
        val_features = data_dict['val_features']
        tst_features = data_dict['tst_features']
        NUM_LABEL = data_dict['C'].shape[0]
        NUM_CLUSTER = data_dict['C'].shape[1]
        logger.info('TRN {} VAL {} TST {}'.format(len(trn_features),
                                                  len(val_features),
                                                  len(tst_features)))
        logger.info('NUM_LABEL {}'.format(NUM_LABEL))
        logger.info('NUM_CLUSTER {}'.format(NUM_CLUSTER))

        # load Y csr matrix
        C_val = data_utils.Ylist_to_Ysparse(data_dict['val']['cseq'],
                                            L=NUM_CLUSTER)
        C_tst = data_utils.Ylist_to_Ysparse(data_dict['tst']['cseq'],
                                            L=NUM_CLUSTER)
        return {
            'trn_features': trn_features,
            'val_features': val_features,
            'tst_features': tst_features,
            'num_clusters': NUM_CLUSTER,
            'vocab_size': 30522,
            'NUM_LABEL': NUM_LABEL,
            'NUM_CLUSTER': NUM_CLUSTER,
            'C_val': C_val,
            'C_tst': C_tst
        }
コード例 #2
0
ファイル: attention.py プロジェクト: wuyi0614/X-BERT
    def load_data(args):
        global device
        with open(args.data_bin_path, 'rb') as fin:
            data_dict = pickle.load(fin)
        trn_features = data_dict['trn']
        val_features = data_dict['val']
        tst_features = data_dict['tst']
        NUM_TOKEN = len(data_dict['stoi'])
        NUM_LABEL = data_dict['C'].shape[0]
        NUM_CLUSTER = data_dict['C'].shape[1]
        logger.info('TRN {} VAL {} TST {}'.format(len(trn_features['xseq']), len(val_features['xseq']), len(tst_features['xseq'])))
        logger.info('NUM_LABEL {}'.format(NUM_LABEL))
        logger.info('NUM_CLUSTER {}'.format(NUM_CLUSTER))

        # load Y csr matrix
        C_val = data_utils.Ylist_to_Ysparse(data_dict['val']['cseq'], L=NUM_CLUSTER)
        C_tst = data_utils.Ylist_to_Ysparse(data_dict['tst']['cseq'], L=NUM_CLUSTER)

        # data iterator
        trn_iter = DataLoader(data_dict, set_option='trn',
                              batch_size=args.train_batch_size, device=device)
        val_iter = DataLoader(data_dict, set_option='val',
                              batch_size=args.eval_batch_size, device=device)
        tst_iter = DataLoader(data_dict, set_option='tst',
                              batch_size=args.eval_batch_size, device=device)

        return {'trn_features': trn_features, 'val_features': val_features, 'tst_features': tst_features,
                'trn_iter': trn_iter, 'val_iter': val_iter, 'tst_iter': tst_iter,
                'num_clusters': NUM_CLUSTER, 'vocab_size': NUM_TOKEN,
                'NUM_LABEL': NUM_LABEL, 'NUM_CLUSTER': NUM_CLUSTER, 'C_val': C_val, 'C_tst': C_tst}
コード例 #3
0
ファイル: bert.py プロジェクト: jicksonp/Transformer-XMC
    def load_data(args):
        with open(args.data_bin_path, "rb") as fin:
            data_dict = pickle.load(fin)
        trn_features = data_dict["trn_features"]
        val_features = data_dict["val_features"]
        tst_features = data_dict["tst_features"]
        NUM_LABEL = data_dict["C"].shape[0]
        NUM_CLUSTER = data_dict["C"].shape[1]
        logger.info("TRN {} VAL {} TST {}".format(len(trn_features), len(val_features), len(tst_features)))
        logger.info("NUM_LABEL {}".format(NUM_LABEL))
        logger.info("NUM_CLUSTER {}".format(NUM_CLUSTER))

        # load Y csr matrix
        C_val = data_utils.Ylist_to_Ysparse(data_dict["val"]["cseq"], L=NUM_CLUSTER)
        C_tst = data_utils.Ylist_to_Ysparse(data_dict["tst"]["cseq"], L=NUM_CLUSTER)
        return {
            "trn_features": trn_features,
            "val_features": val_features,
            "tst_features": tst_features,
            "num_clusters": NUM_CLUSTER,
            "vocab_size": 30522,
            "NUM_LABEL": NUM_LABEL,
            "NUM_CLUSTER": NUM_CLUSTER,
            "C_val": C_val,
            "C_tst": C_tst,
        }
コード例 #4
0
    def load_data(args):
        with open(args.data_bin_path, "rb") as fin:
            data_dict = pickle.load(fin)
        trn_features = data_dict["trn_features"]
        tst_features = data_dict["tst_features"]
        num_labels = data_dict["C"].shape[0]
        num_clusters = data_dict["C"].shape[1]
        logger.info("TRN {} TST {}".format(len(trn_features),
                                           len(tst_features)))
        logger.info("NUM_LABEL {}".format(num_labels))
        logger.info("NUM_CLUSTER {}".format(num_clusters))

        # load Y csr matrix
        C_trn = data_utils.Ylist_to_Ysparse(data_dict["trn"]["cseq"],
                                            L=num_clusters)
        C_tst = data_utils.Ylist_to_Ysparse(data_dict["tst"]["cseq"],
                                            L=num_clusters)
        return {
            "trn_features": trn_features,
            "tst_features": tst_features,
            "num_labels": num_labels,
            "num_clusters": num_clusters,
            "C_trn": C_trn,
            "C_tst": C_tst,
        }
コード例 #5
0
ファイル: attention.py プロジェクト: jicksonp/Transformer-XMC
    def load_data(args):
        global device
        with open(args.data_bin_path, "rb") as fin:
            data_dict = pickle.load(fin)
        trn_features = data_dict["trn"]
        val_features = data_dict["val"]
        tst_features = data_dict["tst"]
        NUM_TOKEN = len(data_dict["stoi"])
        NUM_LABEL = data_dict["C"].shape[0]
        NUM_CLUSTER = data_dict["C"].shape[1]
        logger.info("TRN {} VAL {} TST {}".format(
            len(trn_features["xseq"]),
            len(val_features["xseq"]),
            len(tst_features["xseq"]),
        ))
        logger.info("NUM_LABEL {}".format(NUM_LABEL))
        logger.info("NUM_CLUSTER {}".format(NUM_CLUSTER))

        # load Y csr matrix
        C_val = data_utils.Ylist_to_Ysparse(data_dict["val"]["cseq"],
                                            L=NUM_CLUSTER)
        C_tst = data_utils.Ylist_to_Ysparse(data_dict["tst"]["cseq"],
                                            L=NUM_CLUSTER)

        # data iterator
        trn_iter = DataLoader(data_dict,
                              set_option="trn",
                              batch_size=args.train_batch_size,
                              device=device)
        val_iter = DataLoader(data_dict,
                              set_option="val",
                              batch_size=args.eval_batch_size,
                              device=device)
        tst_iter = DataLoader(data_dict,
                              set_option="tst",
                              batch_size=args.eval_batch_size,
                              device=device)

        return {
            "trn_features": trn_features,
            "val_features": val_features,
            "tst_features": tst_features,
            "trn_iter": trn_iter,
            "val_iter": val_iter,
            "tst_iter": tst_iter,
            "num_clusters": NUM_CLUSTER,
            "vocab_size": NUM_TOKEN,
            "NUM_LABEL": NUM_LABEL,
            "NUM_CLUSTER": NUM_CLUSTER,
            "C_val": C_val,
            "C_tst": C_tst,
        }
コード例 #6
0
  def load_data(args):
    with open(args.data_bin_path, 'rb') as fin:
      data_dict = pickle.load(fin)
    trn_features = data_dict['trn_features']
    tst_features = data_dict['tst_features']
    num_labels = data_dict['C'].shape[0]
    num_clusters = data_dict['C'].shape[1]
    logger.info('TRN {} TST {}'.format(len(trn_features), len(tst_features)))
    logger.info('NUM_LABEL {}'.format(num_labels))
    logger.info('NUM_CLUSTER {}'.format(num_clusters))

    # load Y csr matrix
    C_trn = data_utils.Ylist_to_Ysparse(data_dict['trn']['cseq'], L=num_clusters)
    C_tst = data_utils.Ylist_to_Ysparse(data_dict['tst']['cseq'], L=num_clusters)
    return {'trn_features': trn_features, 'tst_features': tst_features,
            'num_labels': num_labels, 'num_clusters': num_clusters,
            'C_trn': C_trn, 'C_tst': C_tst}