def loadData(args): ''' ''' __SequenceDataset = data.CharSequence if args.chars else data.TokenSequence print(__SequenceDataset.__name__) index = Index(initwords = ['<unk>'], unkindex = 0) train_ = __SequenceDataset(args.data, subset='train.txt', index = index, seqlen = args.bptt, skip = args.bptt).to(args.device) index.freeze(silent = True).tofile(os.path.join(args.data, 'vocab_chars.txt' if args.chars else 'vocab_tokens.txt')) test_ = __SequenceDataset(args.data, subset='test.txt', index = index, seqlen = args.bptt, skip = args.bptt).to(args.device) valid_ = __SequenceDataset(args.data, subset='valid.txt', index = index, seqlen = args.bptt, skip = args.bptt).to(args.device) # load pre embedding if args.init_weights: # determine type of embedding by checking it's suffix if args.init_weights.endswith('bin'): preemb = FastTextEmbedding(args.init_weights, normalize = True).load() if args.emsize != preemb.dim(): raise ValueError('emsize must match embedding size. Expected %d but got %d)' % (args.emsize, preemb.dim())) elif args.init_weights.endswith('txt'): preemb = TextEmbedding(args.init_weights, vectordim = args.emsize).load(normalize = True) elif args.init_weights.endswith('rand'): preemb = RandomEmbedding(vectordim = args.emsize) else: raise ValueError('Type of embedding cannot be inferred.') preemb = Embedding.filteredEmbedding(index.vocabulary(), preemb, fillmissing = True) preemb_weights = torch.Tensor(preemb.weights) else: preemb_weights = None eval_batch_size = 10 __ItemSampler = RandomSampler if args.shuffle_samples else SequentialSampler __BatchSampler = BatchSampler if args.sequential_sampling else EvenlyDistributingSampler train_loader = torch.utils.data.DataLoader(train_, batch_sampler = ShufflingBatchSampler(__BatchSampler(__ItemSampler(train_), batch_size=args.batch_size, drop_last = True), shuffle = args.shuffle_batches, seed = args.seed), num_workers = 0) test_loader = torch.utils.data.DataLoader(test_, batch_sampler = __BatchSampler(__ItemSampler(test_), batch_size=eval_batch_size, drop_last = True), num_workers = 0) valid_loader = torch.utils.data.DataLoader(valid_, batch_sampler = __BatchSampler(__ItemSampler(valid_), batch_size=eval_batch_size, drop_last = True), num_workers = 0) print(__ItemSampler.__name__) print(__BatchSampler.__name__) print('Shuffle training batches: ', args.shuffle_batches) setattr(args, 'index', index) setattr(args, 'ntokens', len(index)) setattr(args, 'trainloader', train_loader) setattr(args, 'testloader', test_loader) setattr(args, 'validloader', valid_loader) setattr(args, 'preembweights', preemb_weights) setattr(args, 'eval_batch_size', eval_batch_size) return args
class TextEmbedding(Embedding): def __init__(self, txtfile, sep = ' ', vectordim = 300): self.file = txtfile self.vdim = vectordim self.separator = sep def load(self, skipheader = True, nlines = sys.maxsize, normalize = False): self.index = Index() print('Loading embedding from %s' % self.file) data_ = [] with open(self.file, 'r', encoding='utf-8', errors='ignore') as f: if skipheader: f.readline() for i, line in enumerate(f): if i >= nlines: break try: line = line.strip() splits = line.split(self.separator) word = splits[0] if self.index.hasWord(word): continue coefs = np.array(splits[1:self.vdim+1], dtype=np.float32) if normalize: length = np.linalg.norm(coefs) if length == 0: length += 1e-6 coefs = coefs / length if coefs.shape != (self.vdim,): continue idx = self.index.add(word) data_.append(coefs) assert idx == len(data_) except Exception as err: print('Error in line %d' % i, sys.exc_info()[0], file = sys.stderr) print(' ', err, file = sys.stderr) continue self.data = np.array(data_, dtype = np.float32) del data_ return self def getVector(self, word): if not self.containsWord(word): print("'%s' is unknown." % word, file = sys.stderr) v = np.zeros(self.vdim) v[0] = 1 return v idx = self.index.getId(word) return self.data[idx] def search(self, q, topk = 4): if len(q.shape) == 1: q = np.matrix(q) if q.shape[1] != self.vdim: print('Wrong shape, expected %d dimensions but got %d.' % (self.vdim, q.shape[1]), file = sys.stderr ) return D, I = self.invindex.search(q, topk) # D = distances, I = indices return ( I, D ) def wordForVec(self, v): idx, dist = self.search(v, topk=1) idx = idx[0,0] dist = dist[0,0] sim = 1. - dist word = self.index.getWord(idx) return word, sim def containsWord(self, word): return self.index.hasWord(word) def vocabulary(self): return self.index.vocabulary() def dim(self): return self.vdim