Example #1
0
 def merge_data(self, pos, neg, device):
     # FIXME: maybe just Field?
     label_field = RawField(postprocessing=lambda x: torch.cuda.LongTensor(
         x, device=device))
     label_field.is_target = True
     examples = [self._attach_label(ex, POS_LABEL) for ex in pos] +\
         [self._attach_label(ex, NEG_LABEL) for ex in neg]
     return Dataset(examples, [('sent', self.sent_field),
                               ('label', label_field)])
    def _get_datafields(self,
                        lower=True,
                        max_input_sent_len=450,
                        preprocessed=True):
        tokenize = lambda x: x.split()[:max_input_sent_len] \
            if preprocessed else 'spacy'

        self.ENTITY = Field(sequential=True, batch_first=True, lower=lower)
        self.NER = Field(sequential=True, batch_first=True, lower=lower)
        self.REL = Field(sequential=True, batch_first=True)
        self.ORDERED_REL = Field(sequential=True, batch_first=True)
        self.NUM = Field(sequential=True, batch_first=True, use_vocab=False)
        self.STR = RawField()
        self.SHOW_INP = RawField()
        self.TGT = Field(sequential=True,
                         batch_first=True,
                         init_token="<bos>",
                         eos_token="<eos>",
                         include_lengths=True,
                         lower=lower,
                         tokenize=tokenize)
        self.TGT_NON_TEMPL = Field(sequential=True,
                                   batch_first=True,
                                   init_token=ENT1_END,
                                   eos_token=ENT0_END,
                                   lower=lower,
                                   tokenize=tokenize)
        self.TGT_TXT = Field(sequential=True,
                             batch_first=True,
                             init_token="<bos>",
                             eos_token="<eos>",
                             include_lengths=True,
                             lower=lower)
        self.ADJ_SEQ = Field(sequential=True,
                             batch_first=True,
                             init_token="<bos>",
                             eos_token="<eos>",
                             include_lengths=True,
                             lower=lower)

        self.fields = [
            ("triples", None),
            ("tgt", self.TGT),
            ("tgt_non_templ", self.TGT_NON_TEMPL),
            ("tgt_txt", self.TGT_TXT),
            ("ents", self.ENTITY),
            ("ners", self.NER),
            ("rels", self.REL),
            ("ordered_rels", self.ORDERED_REL),
            ("ordered_ents", self.STR),
            ("ent_lens", self.STR),
            ("ner2ent", self.STR),
            ("adj", self.STR),
            ("adj_seq", self.ADJ_SEQ),
            ("show_inp", self.SHOW_INP),
        ]
Example #3
0
    def _add_output_classes(corpus: Corpus) -> None:
        """ Set the the pronouns for each sentence. """
        corpus.fields["token"] = RawField()
        corpus.fields["counter_token"] = RawField()

        corpus.fields["token"].is_target = False
        corpus.fields["counter_token"].is_target = False

        for ex in corpus:
            setattr(ex, "token", "he")
            setattr(ex, "counter_token", "she")
Example #4
0
    def _add_output_classes(corpus: Corpus) -> None:
        """ Set the correct and incorrect verb for each sentence. """
        corpus.fields["token"] = RawField()
        corpus.fields["wrong_token"] = RawField()

        corpus.fields["token"].is_target = False
        corpus.fields["wrong_token"].is_target = False

        for ex in corpus:
            setattr(ex, "token", ["he"])
            setattr(ex, "wrong_token", ["she"])
Example #5
0
    def __init__(self, config):
        self.config = config
        self.tokenizer = Tokenizer(self.config.vocab_path)

        self.num_labels, self.id2label, self.label2id = get_data_info(
            self.config.event_schema_path)
        from torchtext.data import RawField
        from src.dataloader.utils import sequence_padding
        self.fields = [
            ("input_ids", RawField(postprocessing=sequence_padding)),
            ("token_type_ids", RawField(postprocessing=sequence_padding)),
            ("attention_mask", RawField(postprocessing=sequence_padding)),
            ("label_ids", RawField(postprocessing=sequence_padding)),
            ("seq_len", RawField())
        ]
Example #6
0
 def init_fields(self) -> None:
     self.WORDS = Field(pad_token=None, lower=self.lower)
     self.POS_TAGS = Field(pad_token=None)
     self.NONTERMS = Field(pad_token=None)
     self.ACTIONS = ActionRuleField(self.NONTERMS, self.productions)
     self.RAWS = RawField()
     self.SEQ = RawField()
     # self.ACTIONS = ActionField(self.NONTERMS)
     # self.RAWS = Field(lower=self.lower, pad_token=None)
     self.fields = [
         ('raw_seq', self.SEQ),
         ('actions', self.ACTIONS), ('nonterms', self.NONTERMS),
         ('pos_tags', self.POS_TAGS), ('words', self.WORDS),
         ('raws', self.RAWS),
     ]
Example #7
0
    def full_split(cls,
                   root_dir,
                   val_size=1000,
                   load_processed=True,
                   save_processed=True):
        '''Generates the full train/val/test split'''
        spd = os.path.join(root_dir, 'imdb', 'processed/')
        train_path = os.path.join(spd, 'train.pkl')
        val_path = os.path.join(spd, 'val.pkl')
        test_path = os.path.join(spd, 'test.pkl')
        if (load_processed and os.path.exists(train_path)
                and os.path.exists(val_path) and os.path.exists(test_path)):
            print(" [*] Loading pre-processed IMDB objects.")
            with open(train_path,
                      'rb') as train_f, open(val_path, 'rb') as val_f, open(
                          test_path, 'rb') as test_f:
                return pickle.load(train_f), pickle.load(val_f), pickle.load(
                    test_f)

        # This means we're not loading from pickle
        itrain, itest = IMDB.splits(RawField(), RawField(), root=root_dir)

        vocab = Vocabulary([x.text for x in itrain] + [x.text for x in itest],
                           f_min=100)

        # For val we take middle val_size values as this is where pos/neg switch occurs
        mid = len(itrain) // 2
        grab = val_size // 2
        train = cls([[x.text, x.label] for x in itrain[:mid - grab]] +
                    [[x.text, x.label] for x in itrain[mid + grab:]], vocab)
        val = cls([[x.text, x.label] for x in itrain[mid - grab:mid + grab]],
                  vocab)
        test = cls([[x.text, x.label] for x in itest], vocab)

        if save_processed:
            if not os.path.exists(spd):
                os.makedirs(spd)

            with open(train_path, 'wb') as f:
                pickle.dump(train, f)

            with open(val_path, 'wb') as f:
                pickle.dump(val, f)

            with open(test_path, 'wb') as f:
                pickle.dump(test, f)

        return train, val, test
Example #8
0
def prepare_dataset(dataset):
    context, query, label, start, end = list(zip(*dataset))

    dataset = list(
        zip(context, deepcopy(context), query, deepcopy(query), start, end,
            label))
    TEXT = Field(lower=True, include_lengths=False, batch_first=True)
    CHAR = RawField()
    LABEL = Field(sequential=False, tensor_type=torch.LongTensor)

    examples = []
    for i, d in enumerate(dataset):
        if i % 100 == 0: print('[%d/%d]' % (i, len(dataset)))
        examples.append(
            Example.fromlist(d, [('context', TEXT), ('context_c', CHAR),
                                 ('query', TEXT), ('query_c', CHAR),
                                 ('start', LABEL), ('end', LABEL),
                                 ('label', TEXT)]))

    dataset = Dataset(examples, [('context', TEXT), ('context_c', CHAR),
                                 ('query', TEXT), ('query_c', CHAR),
                                 ('start', LABEL), ('end', LABEL),
                                 ('label', TEXT)])
    TEXT.build_vocab(dataset, min_freq=2)
    #CHAR.build_vocab(dataset)

    return dataset, TEXT, CHAR
Example #9
0
    def _attach_sen_ids(self):
        """ Adds a sentence index field to the Corpus. """
        self.fields["sen_idx"] = RawField()
        self.fields["sen_idx"].is_target = False

        for sen_idx, item in enumerate(self.examples):
            setattr(item, "sen_idx", sen_idx)
Example #10
0
    def iters(cls,
              batch_size=64,
              device=-1,
              shuffle=True,
              vectors='glove.840B.300d'):
        cls.TEXT = Field(sequential=True,
                         tokenize='spacy',
                         lower=True,
                         batch_first=True)
        cls.LABEL = Field(sequential=False,
                          use_vocab=False,
                          batch_first=True,
                          tensor_type=torch.FloatTensor,
                          postprocessing=Pipeline(get_class_probs))
        cls.ID = RawField()

        train, val, test = cls.splits(cls.TEXT, cls.LABEL, cls.ID)

        cls.TEXT.build_vocab(train, vectors=vectors)

        return BucketIterator.splits((train, val, test),
                                     batch_size=batch_size,
                                     shuffle=shuffle,
                                     repeat=False,
                                     device=device)
Example #11
0
    def __call__(self, docs, progress=True, parallel=True):
        texts = [
            ' '.join([tok.lemma_ for tok in doc if not tok.is_stop])
            for doc in docs
        ]
        fields = [('index', RawField()),
                  ('context', SpacyBertField(self.tokenizer))]

        if parallel:
            with mp.Pool() as pool:
                examples = pool.map(Examplifier(fields),
                                    enumerate(tqdm(texts)))
        else:
            f = Examplifier(fields)
            examples = [f((i, t)) for (i, t) in enumerate(tqdm(texts))]

        ds = Dataset(examples, fields)
        buckets = BucketIterator(dataset=ds,
                                 batch_size=24,
                                 device=self.device,
                                 shuffle=False,
                                 sort=True,
                                 sort_key=lambda ex: -len(ex.context))

        embeds = np.zeros((len(texts), REDUCTION_DIMS), dtype=np.float32)
        for b in tqdm(buckets):
            with torch.no_grad():
                output = self.model.bert.embeddings(b.context)
                embeds[b.index] = reduce_embeds(b.context, output).cpu()

        return embeds
Example #12
0
    def create_fields(
        header: List[str],
        to_lower: bool = False,
        sen_column: str = "sen",
        tokenize_columns: Optional[List[str]] = None,
        convert_numerical: bool = False,
        tokenizer: Optional[PreTrainedTokenizer] = None,
    ) -> List[Tuple[str, Field]]:
        tokenize_columns = tokenize_columns or [sen_column]

        pipeline = None
        if convert_numerical:

            def preprocess_field(s: Union[str, int]) -> Union[str, int]:
                return int(s) if (isinstance(s, str) and s.isdigit()) else s

            pipeline = Pipeline(convert_token=preprocess_field)

        fields = []

        for column in header:
            if column in tokenize_columns:
                field = Field(batch_first=True, include_lengths=True, lower=to_lower)
                if tokenizer is not None:
                    attach_tokenizer(field, tokenizer)
            else:
                field = RawField(preprocessing=pipeline)
                field.is_target = False

            fields.append((column, field))

        return fields
Example #13
0
def predict(model, texts, vocabulary, device):
    src_field = TranslationField()
    index_field = RawField()
    examples = [
        Example.fromlist([x, i], [('src', src_field), ('index', index_field)])
        for i, x in enumerate(texts)
    ]
    dataset = Dataset(examples=examples,
                      fields=[('src', src_field), ('index', index_field)])
    src_field.vocab = vocabulary
    iterator = Iterator(dataset=dataset,
                        batch_size=2048,
                        sort=False,
                        sort_within_batch=True,
                        sort_key=lambda x: len(x.src),
                        device=device,
                        repeat=False,
                        shuffle=False)

    texts = []
    indices = []
    for data in tqdm(iterator):
        texts.extend(
            translate(model=model,
                      vocabulary=vocabulary,
                      data=data,
                      max_seq_len=100,
                      device=device))
        indices.extend(data.index)
    prediction = pd.DataFrame([texts, indices]).T.rename(columns={
        0: 'fullname_prediction',
        1: 'index'
    })
    prediction = prediction.sort_values('index')
    return prediction
Example #14
0
    def __init__(self, path, format, fields, skip_header=True, **kwargs):
        super(IEDB, self).__init__(path, format, fields, skip_header, **kwargs)

        # keep a raw copy of the sentence for debugging
        RAW_TEXT_FIELD = RawField()
        for ex in self.examples:
            raw_peptide, raw_mhc_amino_acid = ex.peptide[:], ex.mhc_amino_acid[:]
            setattr(ex, "raw_peptide", raw_peptide) 
            setattr(ex, "raw_mhc_amino_acid", raw_mhc_amino_acid)
        self.fields["raw_peptide"] = RAW_TEXT_FIELD
        self.fields["raw_mhc_amino_acid"] = RAW_TEXT_FIELD
Example #15
0
    def __init__(self, path, format, fields, skip_header=True, **kwargs):
        super(WikiQA, self).__init__(path, format, fields, skip_header, **kwargs)

        # We want to keep a raw copy of the sentence for some models and for debugging
        RAW_TEXT_FIELD = RawField()
        for ex in self.examples:
            raw_sentence_a, raw_sentence_b = ex.sentence_a[:], ex.sentence_b[:]
            setattr(ex, 'raw_sentence_a', raw_sentence_a)
            setattr(ex, 'raw_sentence_b', raw_sentence_b)

        self.fields['raw_sentence_a'] = RAW_TEXT_FIELD
        self.fields['raw_sentence_b'] = RAW_TEXT_FIELD
Example #16
0
    def _create_pos_tags(self):
        """ Attaches nltk POS tags to the corpus for each sentence. """
        import nltk

        nltk.download("averaged_perceptron_tagger")

        self.fields["pos_tags"] = RawField()
        self.fields["pos_tags"].is_target = False

        print("Tagging corpus...")
        for item in self.examples:
            sen = getattr(item, self.sen_column)
            setattr(item, "pos_tags", [t[1] for t in nltk.pos_tag(sen)])
Example #17
0
 def build_field(self, tokenizer, preprocessing):
     """Use custom defined TransformerField which is an extension of torchtext.Field"""
     ID = RawField()
     TWEET = TransformersField(tokenizer,
                               include_lengths=True,
                               use_vocab=False,
                               batch_first=True,
                               preprocessing=preprocessing,
                               tokenize=tokenizer.tokenize,
                               pad_token=tokenizer.pad_token_id)  # id
     LABEL = Field(sequential=False, unk_token=None, pad_token=None)
     fields = [('id', ID), ('tweet', TWEET), ('label', LABEL)]
     return fields
Example #18
0
    def __init__(self, df, question_field, answer_field, right_answer_col):
        # print(right_answer_col)
        # df.to_excel("Test.xlsx")
        fields = [('context_id', RawField()), ('left', question_field), ('right', question_field), ('right_item', answer_field)]
        examples = []
        for i, row in df.iterrows():
            left = row.left
            right = row.right
            right_answer = row[right_answer_col]
            # if right_answer not in answer_field.vocab.stoi:
            #   print(right_answer)
            examples.append(Example.fromlist([i, left, right, right_answer], fields))

        super().__init__(examples, fields)
Example #19
0
def create_dataset(config: Config,
                   device: torch.device,
                   vocab: Vocab,
                   rics: List[str],
                   seqtypes: List[SeqType]) -> Iterator:

    fields = dict()
    fields[SeqType.ArticleID.value] = (SeqType.ArticleID.value, RawField())

    time_field = Field(use_vocab=False, batch_first=True, sequential=False)
    fields['jst_hour'] = (SeqType.Time.value, time_field)

    token_field = \
        Field(use_vocab=True,
              init_token=SpecialToken.BOS.value,
              eos_token=SpecialToken.EOS.value,
              pad_token=SpecialToken.Padding.value,
              unk_token=SpecialToken.Unknown.value)

    fields['processed_tokens'] = (SeqType.Token.value, token_field)

    tensor_type = torch.FloatTensor if device.type == 'cpu' else torch.cuda.FloatTensor

    for (ric, seqtype) in itertools.product(rics, seqtypes):
        n = N_LONG_TERM if seqtype.value.endswith('long') else N_SHORT_TERM
        price_field = Field(use_vocab=False,
                            fix_length=n,
                            batch_first=True,
                            pad_token=0.0,
                            preprocessing=lambda xs: [float(x) for x in xs],
                            tensor_type=tensor_type)
        key = stringify_ric_seqtype(ric, seqtype)
        fields[key] = (key, price_field)

    # load an alignment for predicttion
    predict = TabularDataset(path='output/alignment-predict.json',
                             format='json',
                             fields=fields)

    token_field.vocab = vocab

    # Make an iteroter for prediction
    return Iterator(predict,
                    batch_size=1,
                    device=-1 if device.type == 'cpu' else device,
                    repeat=False,
                    sort=False)
Example #20
0
    def __init__(self, data_dir='./data', train_fname='train.csv', valid_fname='valid.csv', test_fname='test.csv',
                 vocab_fname='vocab.json'):

        stop_words = get_stop_words()

        tokenize = lambda x: x.split()
        INPUT = Field(sequential=True, batch_first=True, tokenize=tokenize, lower=True)
        ENT = Field(sequential=False, batch_first=True, lower=True)
        TGT = Field(sequential=True, batch_first=True)
        SHOW_INP = RawField()
        fields = [('tgt', TGT), ('input', INPUT), ('show_inp', SHOW_INP), ('ent1', ENT), ('ent2', ENT)]

        datasets = TabularDataset.splits(
            fields=fields,
            path=data_dir,
            format=train_fname.rsplit('.')[-1],
            train=train_fname,
            validation=valid_fname,
            test=test_fname,
            skip_header=True,
        )

        INPUT.build_vocab(*datasets, max_size=100000,
                          vectors=GloVe(name='6B', dim=100),
                          unk_init=torch.Tensor.normal_, )
        TGT.build_vocab(*datasets)
        ENT.build_vocab(*datasets)

        self.INPUT = INPUT
        self.ENT = ENT
        self.TGT = TGT
        self.train_ds, self.valid_ds, self.test_ds = datasets

        if vocab_fname:
            writeout = {
                'tgt_vocab': {
                    'itos': TGT.vocab.itos, 'stoi': TGT.vocab.stoi,
                },
                'input_vocab': {
                    'itos': INPUT.vocab.itos, 'stoi': INPUT.vocab.stoi,
                },
                'ent_vocab': {
                    'itos': ENT.vocab.itos, 'stoi': ENT.vocab.stoi,
                },
            }
            fwrite(json.dumps(writeout, indent=4), vocab_fname)
Example #21
0
def get_fields(bos='<s>', eos='</s>', inflection=False,
               language=False, share_vocab=False, use_bpe=False):
    """
    language: whether to also create a field for the language of the sample
        (as for multilingual training)
    share_vocab: determines whether to create separate fields for src and
        tgt sequences or use one in common
    returns: A dictionary. The keys are minimally 'src', 'tgt', 'inflection'.
        If language is True, a 'language' key will also be present.
    """
    fields = {'src': [], 'tgt': []}

    assert not share_vocab or inflection

    if not share_vocab:
        if inflection:
            src = Field(tokenize=list, include_lengths=True)
        else:
            # total kludge here in order to get the baseline to work
            src = Field(tokenize=str.split, include_lengths=True)
        tgt = Field(init_token=bos, eos_token=eos, tokenize=list)
    else:
        src = Field(init_token=bos, eos_token=eos,
                    tokenize=list, include_lengths=True)
        tgt = src

    fields['src'].append(('src', src))
    fields['tgt'].append(('tgt', tgt))

    if inflection:
        tokenize_infl = partial(str.split, sep=';')
        infl = Field(tokenize=tokenize_infl, include_lengths=True)

        fields['inflection'] = [('inflection', infl)]

    if language:
        lang = Field(sequential=False)
        fields['language'] = [('language', lang)]

    if use_bpe:
        bpe = RawField()
        #bpe = Field(sequential=False, use_vocab=False)
        fields['word_split'] = [('word_split', bpe)]

    return fields
Example #22
0
def get_fields(src_data_type,
               n_src_feats,
               n_tgt_feats,
               pad='<blank>',
               bos='<s>',
               eos='</s>',
               dynamic_dict=False,
               with_align=False,
               src_truncate=None,
               tgt_truncate=None):
    """
    Args:
        src_data_type: type of the source input. Options are [text|img|audio].
        n_src_feats (int): the number of source features (not counting tokens)
            to create a :class:`torchtext.data.Field` for. (If
            ``src_data_type=="text"``, these fields are stored together
            as a ``TextMultiField``).
        n_tgt_feats (int): See above.
        pad (str): Special pad symbol. Used on src and tgt side.
        bos (str): Special beginning of sequence symbol. Only relevant
            for tgt.
        eos (str): Special end of sequence symbol. Only relevant
            for tgt.
        dynamic_dict (bool): Whether or not to include source map and
            alignment fields.
        with_align (bool): Whether or not to include word align.
        src_truncate: Cut off src sequences beyond this (passed to
            ``src_data_type``'s data reader - see there for more details).
        tgt_truncate: Cut off tgt sequences beyond this (passed to
            :class:`TextDataReader` - see there for more details).
    Returns:
        A dict mapping names to fields. These names need to match
        the dataset example attributes.
    """

    assert src_data_type in ['text', 'img', 'audio', 'vec'], \
        "Data type not implemented"
    assert not dynamic_dict or src_data_type == 'text', \
        'it is not possible to use dynamic_dict with non-text input'
    fields = {}

    fields_getters = {
        "text": text_fields,
        "img": image_fields,
        "audio": audio_fields,
        "vec": vec_fields
    }

    src_field_kwargs = {
        "n_feats": n_src_feats,
        "include_lengths": True,
        "pad": pad,
        "bos": None,
        "eos": None,
        "truncate": src_truncate,
        "base_name": "src"
    }
    fields["src"] = fields_getters[src_data_type](**src_field_kwargs)

    tgt_field_kwargs = {
        "n_feats": n_tgt_feats,
        "include_lengths": False,
        "pad": pad,
        "bos": bos,
        "eos": eos,
        "truncate": tgt_truncate,
        "base_name": "tgt"
    }
    fields["tgt"] = fields_getters["text"](**tgt_field_kwargs)

    indices = Field(use_vocab=False, dtype=torch.long, sequential=False)
    fields["indices"] = indices

    corpus_ids = Field(use_vocab=True, sequential=False)
    fields["corpus_id"] = corpus_ids

    if dynamic_dict:
        src_map = Field(use_vocab=False,
                        dtype=torch.float,
                        postprocessing=make_src,
                        sequential=False)
        fields["src_map"] = src_map

        src_ex_vocab = RawField()
        fields["src_ex_vocab"] = src_ex_vocab

        align = Field(use_vocab=False,
                      dtype=torch.long,
                      postprocessing=make_tgt,
                      sequential=False)
        fields["alignment"] = align

    if with_align:
        word_align = AlignField()
        fields["align"] = word_align

    return fields
Example #23
0
    def __init__(self, yaml_path):
        config_file = yaml_path
        config = yaml.load(open(config_file), Loader=yaml.FullLoader)
        args = config["training"]
        SEED = args["seed"]
        DATASET = args["dataset"]  # Multi30k or ISWLT
        MODEL = args["model"]  # gru**2, gru_attn**2, transformer, gcn_gru, gcngru_gru, gcngruattn_gru, gcnattn_gru
        REVERSE = args["reverse"]
        BATCH_SIZE = args["batch_size"]
        ENC_EMB_DIM = args["encoder_embed_dim"]
        DEC_EMB_DIM = args["decoder_embed_dim"]
        ENC_HID_DIM = args["encoder_hidden_dim"]
        DEC_HID_DIM = args["decoder_hidden_dim"]
        ENC_DROPOUT = args["encoder_dropout"]
        DEC_DROPOUT = args["decoder_dropout"]
        NLAYERS = args["num_layers"]
        N_EPOCHS = args["num_epochs"]
        CLIP = args["grad_clip"]
        LR = args["lr"]
        LR_DECAY_RATIO = args["lr_decay_ratio"]
        ID = args["id"]
        PATIENCE = args["patience"]
        DIR = 'checkpoints/{}-{}-{}/'.format(DATASET, MODEL, ID)
        MODEL_PATH = DIR
        LOG_PATH = '{}test-log.log'.format(DIR)
        set_seed(SEED)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.config = args
        self.device = device

        if 'transformer' in MODEL:
            ENC_HEADS = args["encoder_heads"]
            DEC_HEADS = args["decoder_heads"]
            ENC_PF_DIM = args["encoder_pf_dim"]
            DEC_PF_DIM = args["decoder_pf_dim"]
            MAX_LEN = args["max_len"]
            
        SRC = Field(tokenize = lambda text: tokenize_de(text, REVERSE), 
                    init_token = '<sos>', 
                    eos_token = '<eos>', 
                    lower = True)
        TGT = Field(tokenize = tokenize_en, 
                    init_token = '<sos>', 
                    eos_token = '<eos>', 
                    lower = True)
        GRH = RawField(postprocessing=batch_graph)
        data_fields = [('src', SRC), ('trg', TGT), ('grh', GRH)]
        
        train_data = Dataset(torch.load("data/Multi30k/train_data.pt"), data_fields)
        valid_data = Dataset(torch.load("data/Multi30k/valid_data.pt"), data_fields)
        test_data = Dataset(torch.load("data/Multi30k/test_data.pt"), data_fields)
        self.train_data, self.valid_data, self.test_data = train_data, valid_data, test_data
        
        train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
            (train_data, valid_data, test_data), 
            batch_size = BATCH_SIZE, 
            sort_key = lambda x: len(x.src),
            sort_within_batch=False,
            device = device)
        self.train_iterator, self.valid_iterator, self.test_iterator = train_iterator, valid_iterator, test_iterator
        
        SRC.build_vocab(train_data, min_freq = 2)
        TGT.build_vocab(train_data, min_freq = 2)
        self.SRC, self.TGT, self.GRH = SRC, TGT, GRH

        print(f"Number of training examples: {len(train_data.examples)}")
        print(f"Number of validation examples: {len(valid_data.examples)}")
        print(f"Number of testing examples: {len(test_data.examples)}")
        print(f"Unique tokens in source (de) vocabulary: {len(SRC.vocab)}")
        print(f"Unique tokens in target (en) vocabulary: {len(TGT.vocab)}")

        src_c, tgt_c = get_sentence_lengths(train_data)
        src_lengths = counter2array(src_c)
        tgt_lengths = counter2array(tgt_c)

        print("maximum src, tgt sent lengths: ")
        np.quantile(src_lengths, 1), np.quantile(tgt_lengths, 1)

        # Get models and corresponding training scripts

        INPUT_DIM = len(SRC.vocab)
        OUTPUT_DIM = len(TGT.vocab)
        SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
        TGT_PAD_IDX = TGT.vocab.stoi[TGT.pad_token]
        self.SRC_PAD_IDX = SRC_PAD_IDX
        self.TGT_PAD_IDX = TGT_PAD_IDX

        if MODEL == "gru**2":  # gru**2, gru_attn**2, transformer, gcn_gru
            from models.gru_seq2seq import GRUEncoder, GRUDecoder, Seq2Seq
            enc = GRUEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT)
            model = Seq2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gru, evaluate_gru, epoch_time
            train_epoch = train_epoch_gru
            evaluate = evaluate_gru
            
            self.enc, self.dec, self.model, self.train_epoch, self.evaluate = enc, dec, model, train_epoch, evaluate
            
        elif MODEL == "gru_attn**2":
            from models.gru_attn import GRUEncoder, GRUDecoder, Seq2Seq, Attention
            attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
            enc = GRUEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT, attn)
            model = Seq2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gru_attn, evaluate_gru_attn, epoch_time
            train_epoch = train_epoch_gru_attn
            evaluate = evaluate_gru_attn
            
            self.enc, self.dec, self.model, self.train_epoch, self.evaluate, self.attn = enc, dec, model, train_epoch, evaluate, attn

        elif MODEL == "transformer":
            from models.transformer import Encoder, Decoder, Seq2Seq
            enc = Encoder(INPUT_DIM, ENC_HID_DIM, NLAYERS, ENC_HEADS, 
                          ENC_PF_DIM, ENC_DROPOUT, device, MAX_LEN)
            dec = Decoder(OUTPUT_DIM, DEC_HID_DIM, NLAYERS, DEC_HEADS, 
                          DEC_PF_DIM, DEC_DROPOUT, device, MAX_LEN)
            model = Seq2Seq(enc, dec, SRC_PAD_IDX, TGT_PAD_IDX, device).to(device)

            from src.train import train_epoch_tfmr, evaluate_tfmr, epoch_time
            train_epoch = train_epoch_tfmr
            evaluate = evaluate_tfmr

            self.enc, self.dec, self.model, self.train_epoch, self.evaluate = enc, dec, model, train_epoch, evaluate
            
        elif MODEL == "gcn_gru":
            from models.gru_seq2seq import GCNEncoder, GRUDecoder, GCN2Seq
            enc = GCNEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, NLAYERS, ENC_DROPOUT)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT)
            model = GCN2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gcn_gru, evaluate_gcn_gru, epoch_time
            train_epoch = train_epoch_gcn_gru
            evaluate = evaluate_gcn_gru

            self.enc, self.dec, self.model, self.train_epoch, self.evaluate = enc, dec, model, train_epoch, evaluate
            
        elif MODEL == "gcngru_gru":
            from models.gru_seq2seq import GCNGRUEncoder, GRUDecoder, GCN2Seq
            enc = GCNGRUEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT, device)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT)
            model = GCN2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gcn_gru, evaluate_gcn_gru, epoch_time
            train_epoch = train_epoch_gcn_gru
            evaluate = evaluate_gcn_gru

            self.enc, self.dec, self.model, self.train_epoch, self.evaluate = enc, dec, model, train_epoch, evaluate
            
        elif MODEL == "gcnattn_gru":
            from models.gru_attn import GCNEncoder, GRUDecoder, GCN2Seq, Attention
            attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
            enc = GCNEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT, attn)
            model = GCN2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gcnattn_gru, evaluate_gcnattn_gru, epoch_time
            train_epoch = train_epoch_gcnattn_gru
            evaluate = evaluate_gcnattn_gru
            
            self.enc, self.dec, self.model, self.train_epoch, self.evaluate, self.attn = enc, dec, model, train_epoch, evaluate, attn

        elif MODEL == "gcngruattn_gru":
            from models.gru_attn import GCNGRUEncoder, GRUDecoder, GCN2Seq, Attention
            attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
            enc = GCNGRUEncoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, ENC_DROPOUT, device)
            dec = GRUDecoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, NLAYERS, DEC_DROPOUT, attn)
            model = GCN2Seq(enc, dec, device).to(device)

            from src.train import train_epoch_gcnattn_gru, evaluate_gcnattn_gru, epoch_time
            train_epoch = train_epoch_gcnattn_gru
            evaluate = evaluate_gcnattn_gru
            
            self.enc, self.dec, self.model, self.train_epoch, self.evaluate, self.attn = enc, dec, model, train_epoch, evaluate, attn

        else:
            raise ValueError("Wrong model choice")

        if 'gcn' in MODEL:
            from src.utils import init_weights_uniform as init_weights
        else: 
            from src.utils import init_weights_xavier as init_weights

        model.apply(init_weights)
        n_params = count_parameters(model)
        print("Model initialized...{} params".format(n_params))
        
        self.criterion = nn.CrossEntropyLoss(ignore_index=TGT_PAD_IDX)
        
        print(os.path.join(MODEL_PATH, "checkpoint.pt"))
#         try:
#             state_dict = torch.load(os.path.join(MODEL_PATH, "checkpoint.pt"), map_location=device)['model_state_dict']
#         except:
#             state_dict = torch.load(os.path.join(MODEL_PATH, "checkpoint.pt"), map_location=device)
        state_dict = torch.load(os.path.join(MODEL_PATH, "checkpoint.pt"), map_location=device)
        if 'model_state_dict' in state_dict:
            state_dict = state_dict['model_state_dict']
        model.load_state_dict(state_dict)
        self.model = model
Example #24
0
def main():
    title='dump-trace'
    argParser = config.get_arg_parser(title)
    args = argParser.parse_args()
    max_len_trg = 0
    max_len_src = 0
    sys.modules['Tree'] = Tree

    with open(args.golden_c_path,'rb') as file_c:
        trg = pickle.load(file_c)


    SEED=1234
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    exp_list = []
    SRC = Field(
                init_token = '<sos>',
                eos_token = '<eos>')
    TRG = RawField()
    ID  = RawField()
    DICT_INFO = RawField()

    cache_dir = args.cache_path
    src_g = np.load(args.input_g_path, allow_pickle=True)
    src_f = np.load(args.input_f_path, allow_pickle=True)

    for i in range(0,args.gen_num):
        src_elem = src_f[i]
        dict_info = 0
        trg_elem = trg[i]['tree']
        exp = Example.fromlist([src_elem,trg_elem,i, dict_info],fields =[('src',SRC),('trg',TRG), ('id', ID), ('dict_info',DICT_INFO)] )
        exp_list.append(exp)

        len_elem_src = len(src_elem)
        len_elem_trg = trg[i]['treelen']

        if len_elem_src + 2 >= max_len_src:
            max_len_src = len_elem_src  + 2
        if len_elem_trg >= max_len_trg:
            max_len_trg = len_elem_trg + 2
    data_sets = Dataset(exp_list,fields = [('src',SRC),('trg',TRG), ('id', ID), ('dict_info', DICT_INFO)])
    trn, vld = data_sets.split([0.8,0.2,0.0])
    SRC.build_vocab(trn, min_freq = 2)

    print("Number of training examples: %d" % (len(trn.examples)))
    print("Number of validation examples: %d" % (len(vld.examples)))
    print("Unique tokens in source assembly vocabulary: %d "%(len(SRC.vocab)))
    print("Max input length : %d" % (max_len_src))
    print("Max output length : %d" % (max_len_trg))
    del trg, src_f, src_g

    BATCH_SIZE = 1

    train_iterator, valid_iterator = BucketIterator.splits(
        (trn, vld),
        batch_size = BATCH_SIZE,
        sort_key= lambda x :len(x.trg),
        sort_within_batch=False,
        sort=False)

    processing_data(cache_dir, [train_iterator, valid_iterator])
Example #25
0
    agg_test_df = pd.read_csv(agg_test_filepath)

    # We can use a different batch size for generation
    batch_size = g_conf["batch_size"]

    # Load dataset
    # ================================================================
    fields = {
        "eid": ("eid", ID),
        "rid": ("rid", ID),
        "review": ("out_text", OUT_TEXT),
        "input_text": ("in_text", IN_TEXT)
    }

    TEST_EID = RawField()
    agg_fields = {"eid": ("eid", TEST_EID), "input_text": ("in_text", IN_TEXT)}

    train = TabularDataset(path=train_filepath, format="csv", fields=fields)
    valid = TabularDataset(path=valid_filepath, format="csv", fields=fields)
    test = TabularDataset(path=test_filepath, format="csv", fields=fields)
    agg_test = TabularDataset(path=agg_test_filepath,
                              format="csv",
                              fields=agg_fields)

    train_iterator = BucketIterator(train,
                                    batch_size=batch_size,
                                    device=device,
                                    sort=False,
                                    sort_within_batch=False)
    valid_iterator = BucketIterator(valid,
Example #26
0
    DEC_HEADS = args["decoder_heads"]
    ENC_PF_DIM = args["encoder_pf_dim"]
    DEC_PF_DIM = args["decoder_pf_dim"]
    MAX_LEN = args["max_len"]

# dataset

SRC = Field(tokenize=lambda text: tokenize_de(text, REVERSE),
            init_token='<sos>',
            eos_token='<eos>',
            lower=True)
TGT = Field(tokenize=tokenize_en,
            init_token='<sos>',
            eos_token='<eos>',
            lower=True)
GRH = RawField(postprocessing=batch_graph)
data_fields = [('src', SRC), ('trg', TGT), ('grh', GRH)]

train_data = Dataset(torch.load("data/Multi30k/train_data.pt"), data_fields)
valid_data = Dataset(torch.load("data/Multi30k/valid_data.pt"), data_fields)
test_data = Dataset(torch.load("data/Multi30k/test_data.pt"), data_fields)

# dataloader
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=BATCH_SIZE,
    sort_key=lambda x: len(x.src),
    sort_within_batch=False,
    device=device)

# build vocab and print basic stats
def main():
    title = 'trf-tree'
    sys.modules['Tree'] = Tree
    argParser = config.get_arg_parser(title)
    args = argParser.parse_args()
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    args.summary = TrainingSummaryWriter(args.log_dir)
    logging = get_logger(log_path=os.path.join(
        args.log_dir, "log" + time.strftime('%Y%m%d-%H%M%S') + '.txt'),
                         print_=True,
                         log_=True)

    max_len_trg, max_len_src = 0, 0
    if not os.path.exists(args.checkpoint_path):
        os.makedirs(args.checkpoint_path)
    with open(args.golden_c_path, 'rb') as file_c:
        trg = pickle.load(file_c)

    src_g = np.load(args.input_g_path, allow_pickle=True)
    src_f = np.load(args.input_f_path, allow_pickle=True)

    graphs_asm = load_graphs(args, src_f, src_g)

    SEED = 1234
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    exp_list = []
    SRC = Field(init_token='<sos>', eos_token='<eos>')
    TRG = RawField()
    ID = RawField()
    DICT_INFO = RawField()
    GRAPHS_ASM = RawField()
    NODE_NUM = RawField()
    # args.gen_num = 500

    for i in range(0, args.gen_num):
        src_elem = src_f[i]
        broken_file_flag = 0
        # edge_elem = src_g[i]
        if args.dump_trace:
            dict_info = {}
            for path in glob.glob(os.path.join(args.cache_path,
                                               str(i) + '/*')):
                if os.path.getsize(path) > 0:
                    with open(path, 'rb') as f:
                        dict_info[path] = pickle.load(f)
                else:
                    print("broken file!" + path)
                    broken_file_flag = 1
                    break

        if broken_file_flag == 1:
            continue

        if dict_info == {}:
            print(dict_info)
            continue
        trg_elem = trg[i]['tree']
        len_elem_src = graphs_asm[i].number_of_nodes()
        exp = Example.fromlist([src_elem,trg_elem,i, dict_info, graphs_asm[i], len_elem_src], \
            fields =[('src', SRC), ('trg', TRG), ('id', ID), ('dict_info', DICT_INFO), ('graphs_asm', GRAPHS_ASM), ('src_len', NODE_NUM)] )
        exp_list.append(exp)
        len_elem_trg = trg[i]['treelen']

        if len_elem_src >= max_len_src:
            max_len_src = len_elem_src + 2
        if len_elem_trg >= max_len_trg:
            max_len_trg = len_elem_trg + 2

    data_sets = Dataset(exp_list,
                        fields=[('src', SRC), ('trg', TRG), ('id', ID),
                                ('dict_info', DICT_INFO),
                                ('graphs_asm', GRAPHS_ASM),
                                ('src_len', NODE_NUM)])
    trn, vld, tst = data_sets.split([0.8, 0.05, 0.15])
    SRC.build_vocab(trn, min_freq=2)

    logging("Number of training examples: %d" % (len(trn.examples)))
    logging("Number of validation examples: %d" % (len(vld.examples)))
    logging("Number of testing examples: %d" % (len(tst.examples)))
    logging("Unique tokens in source assembly vocabulary: %d " %
            (len(SRC.vocab)))
    logging("Max input length : %d" % (max_len_src))
    logging("Max output length : %d" % (max_len_trg))
    print(args.device)

    num_workers = 0

    collate = text_data_collator(trn)
    train_iterator = DataLoader(trn,
                                batch_size=args.bsz,
                                collate_fn=collate,
                                num_workers=num_workers,
                                shuffle=False)
    collate = text_data_collator(vld)
    valid_iterator = DataLoader(vld,
                                batch_size=args.bsz,
                                collate_fn=collate,
                                num_workers=num_workers,
                                shuffle=False)
    collate = text_data_collator(tst)
    test_iterator = DataLoader(tst,
                               batch_size=args.bsz,
                               collate_fn=collate,
                               num_workers=num_workers,
                               shuffle=False)

    best_valid_loss = float('inf')
    INPUT_DIM = len(SRC.vocab)

    gnn_asm = Graph_NN(annotation_size=len(SRC.vocab),
                       out_feats=args.hid_dim,
                       n_steps=args.n_gnn_layers,
                       device=args.device)

    gnn_ast = Graph_NN(annotation_size=None,
                       out_feats=args.hid_dim,
                       n_steps=args.n_gnn_layers,
                       device=args.device)

    enc = Encoder(INPUT_DIM,
                  args.hid_dim,
                  args.n_layers,
                  args.n_heads,
                  args.pf_dim,
                  args.dropout,
                  args.device,
                  args.mem_dim,
                  embedding_flag=args.embedding_flag,
                  max_length=max_len_src)

    dec = Decoder_AST(args.output_dim,
                      args.hid_dim,
                      args.n_layers,
                      args.n_heads,
                      args.pf_dim,
                      args.dropout,
                      args.device,
                      max_length=max_len_trg)

    SRC_PAD_IDX = 0
    TRG_PAD_IDX = 0

    if args.parallel_gpu:
        enc = torch.nn.DataParallel(enc)
        dec = torch.nn.DataParallel(dec)

    model = Transformer(enc,
                        dec,
                        SRC_PAD_IDX,
                        TRG_PAD_IDX,
                        args.device,
                        gnn=gnn_ast,
                        gnn_asm=gnn_asm).to(args.device)

    model.apply(initialize_weights)
    criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)
    optimizer = NoamOpt(args.hid_dim, args.lr_ratio, args.warmup, \
                torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    if args.training and not args.eval:
        for epoch in range(args.n_epoch):
            start_time = time.time()
            train_loss, train_acc = train_eval_tree(args, model, train_iterator, optimizer\
                                , args.device, criterion, max_len_trg, train_flag=True)

            valid_loss, valid_acc = train_eval_tree(args, model, valid_iterator, None\
                                , args.device, criterion, max_len_trg, train_flag=False)
            end_time = time.time()

            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_loss < best_valid_loss and (args.checkpoint_path
                                                 is not None):
                best_valid_loss = valid_loss
                torch.save(model.state_dict(),
                           os.path.join(args.checkpoint_path, 'model.pt'))

            logging('Epoch: %d | Time: %dm %ds | learning rate %.3f' %
                    (epoch, epoch_mins, epoch_secs, optimizer._rate * 10000))
            print_performances('Training',
                               train_loss,
                               train_acc,
                               start_time,
                               logging=logging)
            print_performances('Validation',
                               valid_loss,
                               valid_acc,
                               start_time,
                               logging=logging)
            args.summary.add_scalar('train/acc', train_acc)
            args.summary.add_scalar('valid/acc', valid_acc)

    start_time = time.time()
    model.load_state_dict(
        torch.load(os.path.join(args.checkpoint_path, 'model.pt'),
                   map_location=args.device))
    test_loss, test_acc = test_tree(args, model, test_iterator, TRG_PAD_IDX,
                                    args.device, args.label_smoothing,
                                    criterion, args.clip)
    print_performances('Test',
                       test_loss,
                       test_acc,
                       start_time,
                       logging=logging)
Example #28
0
    def __init__(
        self,
        proc_id=0,
        data_dir='tmp/',
        train_fname='train.csv',
        preprocessed=True,
        lower=True,
        vocab_max_size=100000,
        emb_dim=100,
        save_vocab_fname='vocab.json',
        verbose=True,
    ):
        self.verbose = verbose and (proc_id == 0)
        tokenize = lambda x: x.split() if preprocessed else 'spacy'

        INPUT = Field(
            sequential=True,
            batch_first=True,
            tokenize=tokenize,
            lower=lower,
            # include_lengths=True,
        )
        # TGT = Field(sequential=False, dtype=torch.long, batch_first=True,
        #             use_vocab=False)
        TGT = Field(sequential=True, batch_first=True)
        SHOW_INP = RawField()
        fields = [
            ('tgt', TGT),
            ('input', INPUT),
            ('show_inp', SHOW_INP),
        ]

        if self.verbose:
            show_time("[Info] Start building TabularDataset from: {}{}".format(
                data_dir, 'train.csv'))
        datasets = TabularDataset.splits(
            fields=fields,
            path=data_dir,
            format=train_fname.rsplit('.')[-1],
            train=train_fname,
            validation=train_fname.replace('train', 'valid'),
            test=train_fname.replace('train', 'test'),
            skip_header=True,
        )
        INPUT.build_vocab(
            *datasets,
            max_size=vocab_max_size,
            vectors=GloVe(name='6B', dim=emb_dim),
            unk_init=torch.Tensor.normal_,
        )
        # load_vocab(hard_dosk) like opennmt
        # emb_dim = {50, 100}
        # Elmo
        TGT.build_vocab(*datasets)

        self.INPUT = INPUT
        self.TGT = TGT
        self.train_ds, self.valid_ds, self.test_ds = datasets

        if save_vocab_fname and self.verbose:
            writeout = {
                'tgt_vocab': {
                    'itos': TGT.vocab.itos,
                    'stoi': TGT.vocab.stoi,
                },
                'input_vocab': {
                    'itos': INPUT.vocab.itos,
                    'stoi': INPUT.vocab.stoi,
                },
            }
            fwrite(json.dumps(writeout, indent=4), save_vocab_fname)

        if self.verbose:
            msg = "[Info] Finished building vocab: {} INPUT, {} TGT" \
                .format(len(INPUT.vocab), len(TGT.vocab))
            show_time(msg)
Example #29
0
config = {'bert': 'bert-base-cased', 'H': 768, 'dropout': 0.2}

# parse conll dependency data
model_class, tokenizer_class, pretrained_weights = BertModel, BertTokenizer, config['bert']
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)

def batch_num(nums):
    lengths = torch.tensor([len(n) for n in nums]).long()
    n = lengths.max()
    out = torch.zeros(len(nums), n).long()
    for b, n in enumerate(nums):
        out[b, :len(n)] = torch.tensor(n)
    return out, lengths

HEAD = RawField(preprocessing=lambda x: [int(i) for i in x],
        postprocessing=batch_num)
HEAD.is_target = True
WORD = SubTokenizedField(tokenizer)

def len_filt(x): return 5 < len(x.word[0]) < 40

train = ConllXDataset('wsj.train.conllx', (('word', WORD), ('head', HEAD)),
        filter_pred=len_filt)
train_iter = TokenBucket(train, 750)
val = ConllXDataset('wsj.dev.conllx', (('word', WORD), ('head', HEAD)),
        filter_pred=len_filt)
val_iter = BucketIterator(val, batch_size=20, device='cuda:0')

# make bert model to compute potentials
H = config['H']
class Model(nn.Module):
Example #30
0
def import_corpus(
    path: str,
    header: Optional[List[str]] = None,
    header_from_first_line: bool = False,
    to_lower: bool = False,
    vocab_path: Optional[str] = None,
    vocab_from_corpus: bool = False,
    sen_column: str = "sen",
) -> TabularDataset:

    """ Imports a corpus from a path.

    The corpus can either be a raw string or a pickled dictionary.
    Outputs a `Corpus` type, that is used throughout the library.

    The raw sentence is assumed to be labeled `sen` or `sent`
    Sentences can possibly be labeled, which are assumed to be labeled
    by a `labels` tag.

    Parameters
    ----------
    path : str
        Path to corpus file
    header : List[str], optional
        Optional list of attribute names of each column, if not provided
        all lines will be considered to be sentences,  with the
        attribute name "sen".
    to_lower : bool, optional
        Transform entire corpus to lower case, defaults to False.
    header_from_first_line : bool, optional
        Use the first line of the corpus as the attribute names of the
        corpus.
    vocab_path : str, optional
        Path to the model vocabulary, which should a file containing a
        vocab entry at each line.
    vocab_from_corpus : bool, optional
        Create a new vocabulary from the tokens of the corpus itself.
        If set to True `vocab_path` does not need to be provided.
        Defaults to False.
    sen_column : str, optional
        Name of the corpus column containing the raw sentences.
        Defaults to `sen`.

    Returns
    -------
    corpus : TabularDataset
        A TabularDataset containing the parsed sentences and optional labels
    """

    if header is None:
        if header_from_first_line:
            with open(path) as f:
                header = f.readline().strip().split("\t")
        else:
            header = ["sen"]

    assert sen_column in header, "`sen` should be part of corpus_header!"

    def preprocess(s: str) -> Union[str, int]:
        return int(s) if s.isdigit() else s

    pipeline = Pipeline(convert_token=preprocess)
    fields = {}
    for field in header:
        if field == sen_column:
            fields[field] = Field(
                batch_first=True, include_lengths=True, lower=to_lower
            )
        elif field == "labels":
            fields[field] = Field(
                use_vocab=False, tokenize=lambda s: list(map(int, s.split()))
            )
        else:
            fields[field] = RawField(preprocessing=pipeline)
            fields[field].is_target = False

    corpus = TabularDataset(
        fields=fields.items(),
        format="tsv",
        path=path,
        skip_header=header_from_first_line,
        csv_reader_params={"quotechar": None},
    )

    # The current torchtext Vocab does not allow a fixed vocab order
    if vocab_path is not None or vocab_from_corpus:
        attach_vocab(corpus, vocab_path or path, sen_column=sen_column)

    return corpus