class Word2Vec: def __init__( self, input_path, output_dir, wordsim_path, dimension=100, batch_size=batch_size, window_size=5, epoch_count=1, initial_lr=1e-6, min_count=5, ): self.data = InputData(input_path, min_count) self.output_dir = output_dir self.vocabulary_size = len(self.data.id_from_word) self.dimension = dimension self.batch_size = batch_size self.window_size = window_size self.epoch_count = epoch_count self.initial_lr = initial_lr self.model = SkipGramModel(self.vocabulary_size, self.dimension) if torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') self.model = nn.DataParallel(self.model.to(self.device)) self.optimizer = optim.SGD(self.model.parameters(), lr=self.initial_lr) if wordsim_path: self.wordsim_verification_tuples = [] with open(wordsim_path, 'r') as f: f.readline() # Abandon header for line in f: word1, word2, actual_similarity = line.split(',') self.wordsim_verification_tuples.append( (word1, word2, float(actual_similarity)) ) else: self.wordsim_verification_tuples = None def train(self): pair_count = self.data.get_pair_count(self.window_size) batch_count = self.epoch_count * pair_count / self.batch_size best_rho = float('-inf') for i in tqdm(range(int(batch_count)), total=batch_count): self.model.train() pos_pairs = self.data.get_batch_pairs( self.batch_size, self.window_size ) neg_v = self.data.get_neg_v_neg_sampling(pos_pairs, 5) pos_u = [pair[0] for pair in pos_pairs] pos_v = [pair[1] for pair in pos_pairs] pos_u = torch.tensor(pos_u, device=self.device) pos_v = torch.tensor(pos_v, device=self.device) neg_v = torch.tensor(neg_v, device=self.device) self.optimizer.zero_grad() loss = self.model(pos_u, pos_v, neg_v) loss.backward() self.optimizer.step() if i % 250 == 0: self.model.eval() rho = self.model.module.get_wordsim_rho( self.wordsim_verification_tuples, self.data.id_from_word, self.data.word_from_id ) print( f'Loss: {loss.item()},' f' lr: {self.optimizer.param_groups[0]["lr"]},' f' rho: {rho}' ) dump_embedding( self.model.module.get_embedding( self.data.id_from_word, self.data.word_from_id ), self.model.module.dimension, self.data.word_from_id, os.path.join(self.output_dir, f'latest.txt'), ) if rho > best_rho: dump_embedding( self.model.module.get_embedding( self.data.id_from_word, self.data.word_from_id ), self.model.module.dimension, self.data.word_from_id, os.path.join(self.output_dir, f'{i}_{rho}.txt') ) best_rho = rho # warm up if i < 10000: lr = self.initial_lr * i / 10000 for param_group in self.optimizer.param_groups: param_group['lr'] = lr elif i * self.batch_size % 100000 == 0: lr = self.initial_lr * (1.0 - 1.0 * i / batch_count) for param_group in self.optimizer.param_groups: param_group['lr'] = lr
class Word2VecTrainer: def __init__(self, args):# input_file, output_file, emb_dimension=100, batch_size=32, window_size=5, iterations=3,initial_lr=0.01, min_count=25,weight_decay = 0, time_scale =1 # self.data = DataReader(args.text, args.min_count) # if not args.use_time: # dataset = Word2vecDataset(self.data, args.window_size) # else: # dataset = TimestampledWord2vecDataset(self.data, args.window_size,args.time_scale) # # self.dataloader = DataLoader(dataset, batch_size=args.batch_size, # shuffle=True, num_workers=0, collate_fn=dataset.collate) self.data,self.dataloader = self.load_train(args) # self.data if "train" in args.text: test_filename = args.text.replace("train","test") if os.path.exists(test_filename): print("load test dataset: ".format(test_filename)) self.test = self.load_train(args, data = self.data, filename=test_filename, is_train=False ) else: self.test = None dev_filename = args.text.replace("train", "dev") if os.path.exists(dev_filename): print("load dev dataset: ".format(dev_filename)) self.dev = self.load_train(args, data = self.data, filename=dev_filename, is_train=False) else: self.dev = None else: self.dev, self.test = None, None if args.use_time: self.output_file_name = "{}/{}".format(args.output, args.time_type) if args.add_phase_shift: self.output_file_name += "_shift" else: self.output_file_name = "{}/{}".format(args.output, "word2vec") if not os.path.exists(args.output): os.mkdir(args.output) if not os.path.exists(self.output_file_name): os.mkdir(self.output_file_name) self.emb_size = len(self.data.word2id) self.emb_dimension = args.emb_dimension self.batch_size = args.batch_size self.iterations = args.iterations self.lr = args.lr self.time_type = args.time_type self.weight_decay = args.weight_decay print(args) if args.use_time: self.skip_gram_model = TimestampedSkipGramModel(self.emb_size, self.emb_dimension,time_type = args.time_type,add_phase_shift=args.add_phase_shift) else: self.skip_gram_model = SkipGramModel(self.emb_size, self.emb_dimension) self.use_cuda = torch.cuda.is_available() self.device = torch.device("cuda" if self.use_cuda else "cpu") if self.use_cuda: print("using cuda and GPU ....") self.skip_gram_model.cuda() # load_path = "{}/{}".format(self.output_file_name) # torch.save(self.skip_gram_model,"pytorch.bin") # self.skip_gram_model = torch.load("pytorch.bin") # self.skip_gram_model = load_model(self.skip_gram_model,"pytorch.bin") # exit() if not args.from_scatch and os.path.exists(self.output_file_name): print("loading parameters ....") self.skip_gram_model.load_embeddings(self.data.id2word,self.output_file_name) def load_train(self,args,data= None, filename = None, is_train = True): if data is None: assert is_train==True, "wrong to load data 1" data = DataReader(args.text, args.min_count) filename = args.text else: assert is_train == False, "wrong to load test data 2" assert filename is not None, "wrong to load test data 3" assert data is not None, "wrong to load test data 4" if not args.use_time: dataset = Word2vecDataset(data, input_text = filename, window_size= args.window_size) else: dataset = TimestampledWord2vecDataset(data,input_text = filename, window_size= args.window_size, time_scale=args.time_scale) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=is_train, num_workers=0, collate_fn=dataset.collate) # shuffle if it is train if is_train: return data,dataloader else: return dataloader def evaluation_loss(self,logger =None): results = [] self.skip_gram_model.eval() print("evaluating ...") for index,dataloader in enumerate([self.dev,self.test]): if dataloader is None: continue losses = [] for i, sample_batched in enumerate(tqdm(dataloader)): if len(sample_batched[0]) > 1: pos_u = sample_batched[0].to(self.device) pos_v = sample_batched[1].to(self.device) neg_v = sample_batched[2].to(self.device) if args.use_time: time = sample_batched[3].to(self.device) # print(time) loss, pos, neg = self.skip_gram_model.forward(pos_u, pos_v, neg_v, time) else: loss, pos, neg = self.skip_gram_model.forward(pos_u, pos_v, neg_v) # print(loss) losses.append(loss.item()) mean_result = np.array(losses).mean() results.append(mean_result) print("test{} loss is {}".format(index, mean_result)) logger.write("Loss in test{}: {} \n".format( index, str(mean_result))) logger.flush() self.skip_gram_model.train() return results def train(self): print(os.path.join(self.output_file_name,"log.txt")) if not os.path.exists(self.output_file_name): os.mkdir(self.output_file_name) optimizer = optim.Adam(self.skip_gram_model.parameters(), lr=self.lr, weight_decay=self.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(self.dataloader)*self.iterations) with open("{}/log.txt".format(self.output_file_name,"log.txt"),"w") as f: for iteration in range(self.iterations): print("\nIteration: " + str(iteration + 1)) f.write(str(args) +"\n") # optimizer = optim.SparseAdam(self.skip_gram_model.parameters(), lr=self.initial_lr) running_loss = 0.0 for i, sample_batched in enumerate(tqdm(self.dataloader)): if len(sample_batched[0]) > 1: pos_u = sample_batched[0].to(self.device) pos_v = sample_batched[1].to(self.device) neg_v = sample_batched[2].to(self.device) optimizer.zero_grad() if args.use_time: time = sample_batched[3].to(self.device) # print(time) loss,pos,neg = self.skip_gram_model.forward(pos_u, pos_v, neg_v,time) else: loss,pos,neg = self.skip_gram_model.forward(pos_u, pos_v, neg_v) # print(loss) loss.backward() optimizer.step() scheduler.step() loss,pos,neg = loss.item(),pos.item(),neg.item() if i % args.log_step == 0: # i > 0 and f.write("Loss in {} steps: {} {}, {}\n".format(i,str(loss),str(pos),str(neg))) if not torch.cuda.is_available() or i % (args.log_step*10) == 0 : print("Loss in {} steps: {} {}, {}\n".format(i,str(loss),str(pos),str(neg))) self.evaluation_loss(logger=f) epoch_path = os.path.join(self.output_file_name,str(iteration)) if not os.path.exists(epoch_path): os.mkdir(epoch_path) torch.save(self.skip_gram_model, os.path.join( epoch_path,"pytorch.bin") ) self.skip_gram_model.save_embedding(self.data.id2word, os.path.join(self.output_file_name,str(iteration))) self.skip_gram_model.save_in_text_format(self.data.id2word, os.path.join(self.output_file_name, str(iteration))) self.skip_gram_model.save_in_text_format(self.data.id2word,self.output_file_name) torch.save(self.skip_gram_model, os.path.join(self.output_file_name,"pytorch.bin") ) with open(os.path.join(self.output_file_name,"config.json"), "wt") as f: json.dump(vars(args), f, indent=4) self.skip_gram_model.save_dict(self.data.id2word,self.output_file_name)