def exact_result3(): preds = [] labels = [] imp_indexes = [] preds_t = [] labels_t = [] imp_indexes_t = [] x=1 flag='' count=0 #for num in [30,90,150,300]: for num in [20,40,60,80,100,120,140,160,180,200,220,240,260,280,300]: #f1=open('/home/dihe/cudnn_file/recommender_shuqi/MIND_data/hf_'+str(num)+'.txt','r').readlines() #f1=open('../data/res_roberta_dot4_abs_cat_fp16_add2_'+str(num)+'.txt','r').readlines() #res_roberta_dot_abstract_63.txt f1=open('/home/dihe/cudnn_file/recommender_shuqi/MIND_data/res_'+str(num)+'_2.txt','r').readlines() for line in f1: line=line.strip().split(' ') logit=float(line[3]) imp_index=int(line[1]) # label=int(float(line[5])) # labels.append(label) preds.append(logit) imp_indexes.append(imp_index) if imp_index != flag: flag=imp_index count+=1 #x=imp_indexes[-1]+1 print('x: ',count,len(labels)) group_labels, group_preds = group_labels_func(labels, preds, imp_indexes) res = cal_metric(group_labels, group_preds, metrics) print(res)
def test(model, args): preds = [] labels = [] imp_indexes = [] metrics = ['group_auc'] test_file = os.path.join(args.data_dir, args.test_data_file) preds = [] labels = [] imp_indexes = [] feature_file = os.path.join(args.data_dir, args.feature_file) history_file = os.path.join(args.data_dir, args.history_file) if 'last' in args.field: abs_file = os.path.join(args.data_dir, args.abs_file) else: abs_file = '' iterator = NewsIterator(batch_size=1, npratio=-1, feature_file=feature_file, history_file=history_file, abs_file=abs_file, field=args.field, fp16=True) print('test...') cudaid = 0 #model = nn.DataParallel(model, device_ids=list(range(args.size))) step = 0 with torch.no_grad(): data_batch = iterator.load_test_data_from_file(test_file, None) batch_t = 0 for imp_index, user_index, his_id, candidate_id, label, _ in data_batch: batch_t += len(candidate_id) his_id = his_id.cuda(cudaid) candidate_id = candidate_id.cuda(cudaid) logit = model(his_id, candidate_id, None, mode='validation') # print('???',label_t,label) # assert 1==0 logit = list(np.reshape(np.array(logit.cpu()), -1)) label = list(np.reshape(np.array(label), -1)) imp_index = list(np.reshape(np.array(imp_index), -1)) assert len(imp_index) == 1 imp_index = imp_index * len(logit) assert len(logit) == len(label) assert len(logit) == len(imp_index) assert np.sum(np.array(label)) != 0 labels.extend(label) preds.extend(logit) imp_indexes.extend(imp_index) step += 1 if step % 100 == 0: print('all data: ', len(labels)) group_labels, group_preds = group_labels_func(labels, preds, imp_indexes) res = cal_metric(group_labels, group_preds, metrics) return res['group_auc']
def test(model, arges): preds = [] labels = [] imp_indexes = [] metrics = ['group_auc'] test_file = os.path.join(args.data_dir, args.test_data_file) preds = [] labels = [] imp_indexes = [] feature_file = os.path.join(args.data_dir, args.feature_file) iterator = NewsIterator(batch_size=900, npratio=-1, feature_file=feature_file, field=args.field) print('test...') with torch.no_grad(): data_batch = iterator.load_data_from_file(test_file) batch_t = 0 for imp_index, user_index, his_id, candidate_id, label in data_batch: batch_t += len(candidate_id) his_id = his_id.cuda(cudaid) candidate_id = candidate_id.cuda(cudaid) logit = model(his_id, candidate_id, None, mode='validation') # print('???',label_t,label) # assert 1==0 logit = list(np.reshape(np.array(logit.cpu()), -1)) label = list(np.reshape(np.array(label), -1)) imp_index = list(np.reshape(np.array(imp_index), -1)) labels.extend(label) preds.extend(logit) imp_indexes.extend(imp_index) print('all data: ', len(labels)) group_labels, group_preds = group_labels_func(labels, preds, imp_indexes) res = cal_metric(group_labels, group_preds, metrics) return res['group_auc']
def exact_result(flag): preds = [] labels = [] imp_indexes = [] f1 = open('../data/res_roberta_dot' + str(flag) + '0.txt', 'r').readlines() i = 0 x = 1 for line in f1: if line[:9] == 'imp_index': if '/home/shuqilu/' in line: # print('???',line,f1[i+1],f1[i+2],f1[i+3]) # assert 1==0 # logit=0.7214859 # imp_index=94113+x # label=0 logit = 0.9787183 imp_index = 94112 + x label = 0 preds.append(logit) labels.append(label) imp_indexes.append(imp_index) else: line = line.strip().split(' ') #print('???',line) logit = float(line[3]) imp_index = int(line[1]) + x label = int(float(line[5])) preds.append(logit) labels.append(label) imp_indexes.append(imp_index) i += 1 f1 = open('../data/res_roberta_dot' + str(flag) + '1.txt', 'r').readlines() i = 0 x = imp_indexes[-1] + 1 for line in f1: if line[:9] == 'imp_index': if '/home/shuqilu/' in line: # print('???',line,f1[i+1],f1[i+2],f1[i+3]) # assert 1==0 # logit=0.9143183 # imp_index=94117+x # label=0 logit = 0.9924859 imp_index = 94112 + x label = 0 preds.append(logit) labels.append(label) imp_indexes.append(imp_index) else: line = line.strip().split(' ') logit = float(line[3]) imp_index = int(line[1]) + x label = int(float(line[5])) preds.append(logit) labels.append(label) imp_indexes.append(imp_index) i += 1 f1 = open('../data/res_roberta_dot' + str(flag) + '2.txt', 'r').readlines() i = 0 x = imp_indexes[-1] + 1 for line in f1: if line[:9] == 'imp_index': if '/home/shuqilu/' in line: # print('???',line,f1[i+1],f1[i+2],f1[i+3]) # assert 1==0 # logit=0.9850529 # imp_index=94116+x # label=1 logit = 0.9541026 imp_index = 94116 + x label = 0 preds.append(logit) labels.append(label) imp_indexes.append(imp_index) else: line = line.strip().split(' ') logit = float(line[3]) imp_index = int(line[1]) + x label = int(float(line[5])) preds.append(logit) labels.append(label) imp_indexes.append(imp_index) i += 1 f1 = open('../data/res_roberta_dot' + str(flag) + '3.txt', 'r').readlines() i = 0 x = imp_indexes[-1] + 1 for line in f1: if line[:9] == 'imp_index': if '/home/shuqilu/' in line: # print('???',line,f1[i+1],f1[i+2],f1[i+3]) # assert 1==0 # logit=0.88046765 # imp_index=94111+x # label=0 logit = 0.9569805 imp_index = 94117 + x label = 0 preds.append(logit) labels.append(label) imp_indexes.append(imp_index) else: line = line.strip().split(' ') logit = float(line[3]) imp_index = int(line[1]) + x label = int(float(line[5])) preds.append(logit) labels.append(label) imp_indexes.append(imp_index) i += 1 group_labels, group_preds = group_labels_func(labels, preds, imp_indexes) res = cal_metric(group_labels, group_preds, metrics) print(res)
def train(cudaid, args, model): dist.init_process_group(backend='nccl', init_method='env://', world_size=args.size, rank=cudaid) random.seed(1) np.random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed(1) print('params: ', " T_warm: ", T_warm, " all_iteration: ", all_iteration, " lr: ", lr) #cuda_list=range(args.size) print('rank: ', cudaid) torch.cuda.set_device(cudaid) model.cuda(cudaid) accumulation_steps = int(args.batch_size / args.size / args.gpu_size) optimizer = apex.optimizers.FusedLAMB(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, max_grad_norm=1.0) model, optimizer = amp.initialize(model, optimizer, opt_level='O2') model = DDP(model) #model = nn.DataParallel(model, device_ids=cuda_list) # torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=0, world_size=1) # torch.cuda.set_device(cudaid) #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") #model=torch.nn.parallel.DistributedDataParallel(model, device_ids=cuda_list) #model = torch.nn.DataParallel(model) #model=apex.parallel.DistributedDataParallel(model) accum_batch_loss = 0 iterator = NewsIterator(batch_size=args.gpu_size, npratio=4, feature_file=os.path.join(args.data_dir, args.feature_file), field=args.field) train_file = os.path.join(args.data_dir, args.data_file) #for epoch in range(0,100): batch_t = 0 iteration = 0 print('train...', args.field) #w=open(os.path.join(args.data_dir,args.log_file),'w') if cudaid == 0: writer = SummaryWriter(os.path.join(args.save_dir, args.log_file)) epoch = 0 model.train() # batch_t=52880-1 # iteration=3305-1 batch_t = 0 iteration = 0 step = 0 best_score = -1 #w=open(os.path.join(args.data_dir,args.log_file),'w') # model.eval() # auc=test(model,args) for epoch in range(0, 200): #while True: all_loss = 0 all_batch = 0 data_batch = iterator.load_data_from_file(train_file, cudaid, args.size) for imp_index, user_index, his_id, candidate_id, label in data_batch: batch_t += 1 assert candidate_id.shape[1] == 2 his_id = his_id.cuda(cudaid) candidate_id = candidate_id.cuda(cudaid) label = label.cuda(cudaid) loss = model(his_id, candidate_id, label) sample_size = candidate_id.shape[0] loss = loss.sum() / sample_size / math.log(2) accum_batch_loss += float(loss) all_loss += float(loss) all_batch += 1 loss = loss / accumulation_steps #loss.backward() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() if (batch_t) % accumulation_steps == 0: iteration += 1 adjust_learning_rate(optimizer, iteration) optimizer.step() optimizer.zero_grad() if cudaid == 0: print(' batch_t: ', batch_t, ' iteration: ', iteration, ' epoch: ', epoch, ' accum_batch_loss: ', accum_batch_loss / accumulation_steps, ' lr: ', optimizer.param_groups[0]['lr']) writer.add_scalar('Loss/train', accum_batch_loss / accumulation_steps, iteration) writer.add_scalar('Ltr/train', optimizer.param_groups[0]['lr'], iteration) accum_batch_loss = 0 if iteration % 1000 == 0: torch.cuda.empty_cache() model.eval() labels, preds, imp_indexes = test(model, args, cudaid) pred_pkl = { 'labels': labels, 'preds': preds, 'imp_indexes': imp_indexes } all_preds = all_gather(pred_pkl) if cudaid == 0: labels = np.concatenate( [ele['labels'] for ele in all_preds], axis=0) preds = np.concatenate( [ele['preds'] for ele in all_preds], axis=0) imp_indexes = np.concatenate( [ele['imp_indexes'] for ele in all_preds], axis=0) print('valid labels: ', len(labels)) group_labels, group_preds = group_labels_func( labels, preds, imp_indexes) res = cal_metric(group_labels, group_preds, ['group_auc']) auc = res['group_auc'] #auc=test(model,args) print('valid auc: ', auc) writer.add_scalar('valid/auc', auc, step) step += 1 if auc > best_score: torch.save( model.state_dict(), os.path.join(args.save_dir, 'Plain_robert_dot_best.pkl')) best_score = auc print('best score: ', best_score) torch.save( model.state_dict(), os.path.join( args.save_dir, 'Plain_robert_dot_' + str(iteration) + '.pkl')) torch.cuda.empty_cache() model.train() if iteration >= all_iteration: break if iteration >= all_iteration: break