コード例 #1
0
    def __init__(self,
                 path='./data/',
                 dataset_prefix='vci_1543_abs_tit_key_apr_1_2019_',
                 test_data_name='',
                 full_meta_data_name='explanations_5panels.csv',
                 label_size=5,
                 fix_length=None,
                 meta_data=None):
        """
        :param meta_data: MetaData class instance. Will be used for vocab building.
        """
        # we will add metalabel here and make iterators
        self.TEXT = ReversibleField(sequential=True,
                                    include_lengths=True,
                                    lower=False,
                                    fix_length=fix_length)
        self.LABEL = MultiLabelField(sequential=True,
                                     use_vocab=False,
                                     label_size=label_size,
                                     tensor_type=torch.FloatTensor,
                                     fix_length=fix_length)

        # it's actually this step that will take 5 minutes
        self.train, self.val, self.test = data.TabularDataset.splits(
            path=path,
            train=dataset_prefix + 'train.csv',
            validation=dataset_prefix + 'valid.csv',
            test=dataset_prefix + 'test.csv',
            format='tsv',
            fields=[('Text', self.TEXT), ('Description', self.LABEL)])

        self.full_meta_data = data.TabularDataset(
            path=pjoin(path, full_meta_data_name),
            format='tsv',
            fields=[('Text', self.TEXT), ('Description', self.LABEL)])

        self.meta_data = meta_data

        self.is_vocab_bulit = False
        self.iterators = []

        if test_data_name != '':
            self.external_test = data.TabularDataset(
                path=path + test_data_name,
                format='tsv',
                fields=[('Text', self.TEXT), ('Description', self.LABEL)])
        else:
            self.external_test = None
コード例 #2
0
    def __init__(self, path='./data/',
                 weak_train_dataset="",
                 acmg_weak_data_path="",
                 dataset_prefix='vci_1543_abs_tit_key_apr_1_2019_',
                 test_data_name='vci_358_abs_tit_key_may_7_2019_true_test.csv',
                 multi_task_train_dataset="",
                 label_size=5, fix_length=None):
        self.TEXT = ReversibleField(sequential=True, include_lengths=True, lower=False, fix_length=fix_length)
        self.LABEL = MultiLabelField(sequential=True, use_vocab=False, label_size=label_size,
                                     tensor_type=torch.FloatTensor, fix_length=fix_length)

        if weak_train_dataset != "":
            self.weak_train = data.TabularDataset(weak_train_dataset, format='tsv',
                                                  fields=[('Text', self.TEXT), ('Description', self.LABEL)])
            if acmg_weak_data_path != "":
                acmg_weak_data = data.TabularDataset(acmg_weak_data_path, format='tsv',
                                                  fields=[('Text', self.TEXT), ('Description', self.LABEL)])
                # this should be enough!
                self.weak_train.examples.extend(acmg_weak_data.examples)
        else:
            self.weak_train = None

        if multi_task_train_dataset != "":
            self.multi_task_train = data.TabularDataset(multi_task_train_dataset, format='tsv',
                                                        fields=[('Text', self.TEXT), ('Description', self.LABEL)])
        else:
            self.multi_task_train = None

        # it's actually this step that will take 5 minutes
        self.train, self.val, self.test = data.TabularDataset.splits(
            path=path, train=dataset_prefix + 'train.csv',
            validation=dataset_prefix + 'valid.csv',
            test=dataset_prefix + 'test.csv', format='tsv',
            fields=[('Text', self.TEXT), ('Description', self.LABEL)])

        if test_data_name != '':
            self.external_test = data.TabularDataset(path=path + test_data_name,
                                                     format='tsv',
                                                     fields=[('Text', self.TEXT), ('Description', self.LABEL)])
        else:
            self.external_test = None

        self.is_vocab_bulit = False
        self.iterators = []
        self.test_iterator = None
        self.weak_train_iterator = None
        self.multi_task_train_iterator = None
コード例 #3
0
    def __init__(self,
                 data_path,
                 batch_size=5,
                 num_meta_labels=5,
                 fix_length=None):
        """
        :param data_path: "./models/data/"
        :param batch_size: number of explanations to draw, let's say 5
        :param data_path: data should be in tsv format, and last label should be the grouping factor
        :param num_meta_labels:
        """
        self.num_meta_labels = num_meta_labels
        self.fix_length = fix_length
        self.batch_size = batch_size
        self.data_path = data_path

        self.TEXT_FIELD = ReversibleField(sequential=True,
                                          include_lengths=True,
                                          lower=False,
                                          fix_length=self.fix_length)
        # the vocab will be shared with the main text field in the main dataset

        self.datasets = []
        self.data_iters = []
コード例 #4
0
    labels = {}
    for true_label in range(1, 19):
        labels[str(true_label)] = true_label - 1  # actual label we see

    # # map labels to list
    label_list = [None] * len(labels)
    for k, v in labels.items():
        label_list[v] = k

    labels = label_list
    logger.info("available labels: ")
    logger.info(labels)

    TEXT = ReversibleField(sequential=True,
                           tokenize=tokenizer,
                           include_lengths=True,
                           lower=False)

    LABEL = MultiLabelField(sequential=True,
                            use_vocab=False,
                            label_size=18,
                            tensor_type=torch.FloatTensor)

    if args.dataset == 'major':
        train, val, test = data.TabularDataset.splits(
            path='../../data/csu/',
            train='maj_label_train.tsv',
            validation='maj_label_valid.tsv',
            test='maj_label_test.tsv',
            format='tsv',
            fields=[('Text', TEXT), ('Description', LABEL)])