示例#1
0
def get_classifier(ebd_dim, args):
    tprint("Building classifier")

    model = MLP(ebd_dim, args)

    if args.cuda != -1:
        return model.cuda(args.cuda)
    else:
        return model
示例#2
0
def _load_json(path):
    '''
        load data file
        @param path: str, path to the data file
        @return data: list of examples
    '''
    label = {}
    text_len = []
    with open(path, 'r', errors='ignore') as f:
        data = []
        for line in f:
            row = json.loads(line)

            # count the number of examples per label
            if int(row['label']) not in label:
                label[int(row['label'])] = 1
            else:
                label[int(row['label'])] += 1

            item = {
                'label': int(row['label']),
                'text': row['text'][:500]  # truncate the text to 500 tokens
            }

            text_len.append(len(row['text']))

            keys = ['head', 'tail', 'ebd_id']
            for k in keys:
                if k in row:
                    item[k] = row[k]

            data.append(item)

        tprint('Class balance:')

        print(label)

        tprint('Avg len: {}'.format(sum(text_len) / (len(text_len))))

        return data
示例#3
0
def get_classifier(ebd_dim, args):
    tprint("Building classifier")

    if args.classifier == 'nn':
        model = NN(ebd_dim, args)
    elif args.classifier == 'proto':
        model = PROTO(ebd_dim, args)
    elif args.classifier == 'r2d2':
        model = R2D2(ebd_dim, args)
    elif args.classifier == 'lrd2':
        model = LRD2(ebd_dim, args)
    elif args.classifier == 'routing':
        model = ROUTING(ebd_dim, args)
    elif args.classifier == 'mlp':
        # detach top layer from rest of MLP
        if args.mode == 'finetune':
            top_layer = MLP.get_top_layer(args, args.n_train_class)
            model = MLP(ebd_dim, args, top_layer=top_layer)
        # if not finetune, train MLP as a whole
        else:
            model = MLP(ebd_dim, args)
    else:
        raise ValueError('Invalid classifier. '
                         'classifier can only be: nn, proto, r2d2, mlp.')

    if args.snapshot != '':
        # load pretrained models
        tprint("Loading pretrained classifier from {}".format(
            args.snapshot + '.clf'
            ))
        model.load_state_dict(torch.load(args.snapshot + '.clf'))

    if args.cuda != -1:
        return model.cuda(args.cuda)
    else:
        return model
示例#4
0
def load_dataset(args):
    if args.dataset == '20newsgroup':
        train_classes, val_classes, test_classes = _get_20newsgroup_classes(
            args)
    elif args.dataset == 'amazon':
        train_classes, val_classes, test_classes = _get_amazon_classes(args)
    elif args.dataset == 'fewrel':
        train_classes, val_classes, test_classes = _get_fewrel_classes(args)
    elif args.dataset == 'huffpost':
        train_classes, val_classes, test_classes = _get_huffpost_classes(args)
    elif args.dataset == 'reuters':
        train_classes, val_classes, test_classes = _get_reuters_classes(args)
    elif args.dataset == 'rcv1':
        train_classes, val_classes, test_classes = _get_rcv1_classes(args)
    else:
        raise ValueError(
            'args.dataset should be one of'
            '[20newsgroup, amazon, fewrel, huffpost, reuters, rcv1]')

    assert (len(train_classes) == args.n_train_class)
    assert (len(val_classes) == args.n_val_class)
    assert (len(test_classes) == args.n_test_class)

    if args.mode == 'finetune':
        # in finetune, we combine train and val for training the base classifier
        train_classes = train_classes + val_classes
        args.n_train_class = args.n_train_class + args.n_val_class
        args.n_val_class = args.n_train_class

    tprint('Loading data from {}'.format(args.data_path))
    all_data = _load_json(args.data_path)

    tprint('Loading word vectors')
    path = os.path.join(args.wv_path, args.word_vector)
    if not os.path.exists(path):
        # Download the word vector and save it locally:
        tprint('Downloading word vectors')
        import urllib.request
        urllib.request.urlretrieve(
            'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.en.vec',
            path)

    vectors = Vectors(args.word_vector, cache=args.wv_path)
    vocab = Vocab(collections.Counter(_read_words(all_data)),
                  vectors=vectors,
                  specials=['<pad>', '<unk>'],
                  min_freq=5)

    # print word embedding statistics
    wv_size = vocab.vectors.size()
    tprint('Total num. of words: {}, word vector dimension: {}'.format(
        wv_size[0], wv_size[1]))

    num_oov = wv_size[0] - torch.nonzero(
        torch.sum(torch.abs(vocab.vectors), dim=1)).size()[0]
    tprint(('Num. of out-of-vocabulary words'
            '(they are initialized to zeros): {}').format(num_oov))

    # Split into meta-train, meta-val, meta-test data
    train_data, val_data, test_data = _meta_split(all_data, train_classes,
                                                  val_classes, test_classes)
    tprint('#train {}, #val {}, #test {}'.format(len(train_data),
                                                 len(val_data),
                                                 len(test_data)))

    # Convert everything into np array for fast data loading
    train_data = _data_to_nparray(train_data, vocab, args)
    val_data = _data_to_nparray(val_data, vocab, args)
    test_data = _data_to_nparray(test_data, vocab, args)

    train_data['is_train'] = True
    # this tag is used for distinguishing train/val/test when creating source pool

    stats.precompute_stats(train_data, val_data, test_data, args)

    if args.meta_w_target:
        # augment meta model by the support features
        if args.bert:
            ebd = CXTEBD(args.pretrained_bert,
                         cache_dir=args.bert_cache_dir,
                         finetune_ebd=False,
                         return_seq=True)
        else:
            ebd = WORDEBD(vocab, finetune_ebd=False)

        train_data['avg_ebd'] = AVG(ebd, args)
        if args.cuda != -1:
            train_data['avg_ebd'] = train_data['avg_ebd'].cuda(args.cuda)

        val_data['avg_ebd'] = train_data['avg_ebd']
        test_data['avg_ebd'] = train_data['avg_ebd']

    # if finetune, train_classes = val_classes and we sample train and val data
    # from train_data
    if args.mode == 'finetune':
        train_data, val_data = _split_dataset(train_data, args.finetune_split)

    return train_data, val_data, test_data, vocab
示例#5
0
def load_dataset(args):
    if args.dataset == '20newsgroup':
        train_classes, val_classes, test_classes, label_dict = _get_20newsgroup_classes(
            args)
    elif args.dataset == 'amazon':
        train_classes, val_classes, test_classes, label_dict = _get_amazon_classes(
            args)
    elif args.dataset == 'fewrel':
        train_classes, val_classes, test_classes, label_dict = _get_fewrel_classes(
            args)
    elif args.dataset == 'huffpost':
        train_classes, val_classes, test_classes, label_dict = _get_huffpost_classes(
            args)
    elif args.dataset == 'reuters':
        train_classes, val_classes, test_classes, label_dict = _get_reuters_classes(
            args)
    elif args.dataset == 'rcv1':
        train_classes, val_classes, test_classes, label_dict = _get_rcv1_classes(
            args)
    else:
        raise ValueError(
            'args.dataset should be one of'
            '[20newsgroup, amazon, fewrel, huffpost, reuters, rcv1]')

    assert (len(train_classes) == args.n_train_class)
    assert (len(val_classes) == args.n_val_class)
    assert (len(test_classes) == args.n_test_class)

    print("train_classes", train_classes)
    print("val_classes", val_classes)
    print("test_classes", test_classes)

    tprint('Loading data')
    all_data = _load_json(args.data_path)
    class_names = []
    class_name_words = []
    for ld in label_dict:
        class_name_dic = {}
        class_name_dic['label'] = label_dict[ld]
        class_name_dic['text'] = ld.lower().split()
        class_names.append(class_name_dic)
        class_name_words.append(class_name_dic['text'])

    tprint('Loading word vectors')

    vectors = Vectors(args.word_vector, cache=args.wv_path)
    vocab = Vocab(collections.Counter(_read_words(all_data, class_name_words)),
                  vectors=vectors,
                  specials=['<pad>', '<unk>'],
                  min_freq=5)

    # print word embedding statistics
    wv_size = vocab.vectors.size()
    tprint('Total num. of words: {}, word vector dimension: {}'.format(
        wv_size[0], wv_size[1]))

    num_oov = wv_size[0] - torch.nonzero(
        torch.sum(torch.abs(vocab.vectors), dim=1)).size()[0]
    tprint(('Num. of out-of-vocabulary words'
            '(they are initialized to zeros): {}').format(num_oov))

    # Split into meta-train, meta-val, meta-test data
    train_data, val_data, test_data = _meta_split(all_data, train_classes,
                                                  val_classes, test_classes)
    tprint('#train {}, #val {}, #test {}'.format(len(train_data),
                                                 len(val_data),
                                                 len(test_data)))

    # Convert everything into np array for fast data loading
    class_names = _data_to_nparray(class_names, vocab, args)
    train_data = _data_to_nparray(train_data, vocab, args)
    val_data = _data_to_nparray(val_data, vocab, args)
    test_data = _data_to_nparray(test_data, vocab, args)

    train_data['is_train'] = True
    val_data['is_train'] = True
    test_data['is_train'] = True
    # this tag is used for distinguishing train/val/test when creating source pool

    temp_num = np.argsort(class_names['label'])
    class_names['label'] = class_names['label'][temp_num]
    class_names['text'] = class_names['text'][temp_num]
    class_names['text_len'] = class_names['text_len'][temp_num]

    return train_data, val_data, test_data, class_names, vocab