def get_embedding(vocab, args): print("{}, Building embedding".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'))) # check if loading pre-trained embeddings if args.bert: ebd = CXTEBD() else: ebd = WORDEBD(vocab) if args.embedding == 'avg': model = AVG(ebd, args) elif args.embedding in ['idf', 'iwf']: model = IDF(ebd, args) elif args.embedding in ['meta', 'meta_mlp']: model = META(ebd, args) elif args.embedding == 'cnn': model = CNN(ebd, args) if args.snapshot != '': # load pretrained models print("{}, Loading pretrained embedding from {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), args.snapshot + '.ebd')) model.load_state_dict(torch.load(args.snapshot + '.ebd')) if args.cuda != -1: return model.cuda(args.cuda) else: return model
def get_embedding(vocab, args): print("{}, Building embedding".format(datetime.datetime.now()), flush=True) ebd = WORDEBD(vocab, args.finetune_ebd) modelG = ModelG(ebd, args) # modelD = ModelD(ebd, args) print("{}, Building embedding".format(datetime.datetime.now()), flush=True) if args.cuda != -1: modelG = modelG.cuda(args.cuda) # modelD = modelD.cuda(args.cuda) return modelG # , modelD else: return modelG # , modelD
def get_embedding(vocab, args): print("{}, Building embedding".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True) # check if loading pre-trained embeddings if args.bert: print('Embedding type: BERT') ebd = CXTEBD(args.pretrained_bert, cache_dir=args.bert_cache_dir, finetune_ebd=args.finetune_ebd, return_seq=(args.embedding != 'ebd')) else: print('Embedding type: WORDEBD') # WORDEBD returns a neural network layer that maps word tokens to vectors ebd = WORDEBD(vocab, args.finetune_ebd) print('Using: ', args.embedding) if args.embedding == 'avg': model = AVG(ebd, args) elif args.embedding in ['idf', 'iwf']: model = IDF(ebd, args) elif args.embedding in ['meta', 'meta_mlp']: model = META(ebd, args) elif args.embedding == 'cnn': model = CNN(ebd, args) elif args.embedding == 'lstmatt': model = LSTMAtt(ebd, args) elif args.embedding == 'ebd' and args.bert: model = ebd # using bert representation directly print("{}, Building embedding".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S')), flush=True) if args.snapshot != '': # load pretrained models print("{}, Loading pretrained embedding from {}".format( datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), args.snapshot + '.ebd')) model.load_state_dict(torch.load(args.snapshot + '.ebd')) if args.cuda != -1: return model.cuda(args.cuda) else: return model
def load_dataset(args): if args.dataset == '20newsgroup': train_classes, val_classes, test_classes = _get_20newsgroup_classes( args) elif args.dataset == 'amazon': train_classes, val_classes, test_classes = _get_amazon_classes(args) elif args.dataset == 'fewrel': train_classes, val_classes, test_classes = _get_fewrel_classes(args) elif args.dataset == 'huffpost': train_classes, val_classes, test_classes = _get_huffpost_classes(args) elif args.dataset == 'reuters': train_classes, val_classes, test_classes = _get_reuters_classes(args) elif args.dataset == 'rcv1': train_classes, val_classes, test_classes = _get_rcv1_classes(args) else: raise ValueError( 'args.dataset should be one of' '[20newsgroup, amazon, fewrel, huffpost, reuters, rcv1]') assert (len(train_classes) == args.n_train_class) assert (len(val_classes) == args.n_val_class) assert (len(test_classes) == args.n_test_class) if args.mode == 'finetune': # in finetune, we combine train and val for training the base classifier train_classes = train_classes + val_classes args.n_train_class = args.n_train_class + args.n_val_class args.n_val_class = args.n_train_class tprint('Loading data from {}'.format(args.data_path)) all_data = _load_json(args.data_path) tprint('Loading word vectors') path = os.path.join(args.wv_path, args.word_vector) if not os.path.exists(path): # Download the word vector and save it locally: tprint('Downloading word vectors') import urllib.request urllib.request.urlretrieve( 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.en.vec', path) vectors = Vectors(args.word_vector, cache=args.wv_path) vocab = Vocab(collections.Counter(_read_words(all_data)), vectors=vectors, specials=['<pad>', '<unk>'], min_freq=5) # print word embedding statistics wv_size = vocab.vectors.size() tprint('Total num. of words: {}, word vector dimension: {}'.format( wv_size[0], wv_size[1])) num_oov = wv_size[0] - torch.nonzero( torch.sum(torch.abs(vocab.vectors), dim=1)).size()[0] tprint(('Num. of out-of-vocabulary words' '(they are initialized to zeros): {}').format(num_oov)) # Split into meta-train, meta-val, meta-test data train_data, val_data, test_data = _meta_split(all_data, train_classes, val_classes, test_classes) tprint('#train {}, #val {}, #test {}'.format(len(train_data), len(val_data), len(test_data))) # Convert everything into np array for fast data loading train_data = _data_to_nparray(train_data, vocab, args) val_data = _data_to_nparray(val_data, vocab, args) test_data = _data_to_nparray(test_data, vocab, args) train_data['is_train'] = True # this tag is used for distinguishing train/val/test when creating source pool stats.precompute_stats(train_data, val_data, test_data, args) if args.meta_w_target: # augment meta model by the support features if args.bert: ebd = CXTEBD(args.pretrained_bert, cache_dir=args.bert_cache_dir, finetune_ebd=False, return_seq=True) else: ebd = WORDEBD(vocab, finetune_ebd=False) train_data['avg_ebd'] = AVG(ebd, args) if args.cuda != -1: train_data['avg_ebd'] = train_data['avg_ebd'].cuda(args.cuda) val_data['avg_ebd'] = train_data['avg_ebd'] test_data['avg_ebd'] = train_data['avg_ebd'] # if finetune, train_classes = val_classes and we sample train and val data # from train_data if args.mode == 'finetune': train_data, val_data = _split_dataset(train_data, args.finetune_split) return train_data, val_data, test_data, vocab