def train_iteration(model, optimizer, dataset, train_pairs, qrels): total = 0 model.train() total_loss = 0. with tqdm('training', total=BATCH_SIZE * BATCHES_PER_EPOCH, ncols=80, desc='train') as pbar: for record in data.iter_train_pairs(model, dataset, train_pairs, qrels, GRAD_ACC_SIZE): scores = model(record['query_tok'], record['query_mask'], record['doc_tok'], record['doc_mask']) count = len(record['query_id']) // 2 scores = scores.reshape(count, 2) loss = torch.mean(1. - scores.softmax(dim=1)[:, 0]) # pariwse softmax loss.backward() total_loss += loss.item() total += count if total % BATCH_SIZE == 0: optimizer.step() optimizer.zero_grad() pbar.update(count) if total >= BATCH_SIZE * BATCHES_PER_EPOCH: return total_loss
def train_iteration(model, optimizer, dataset, train_pairs, contentid2entity, embed): GRAD_ACC_SIZE = 2 total = 0 model.train() total_loss = 0. with tqdm('training', total=BATCH_SIZE * BATCHES_PER_EPOCH, ncols=80, desc='train', leave=False) as pbar: for record in data.iter_train_pairs(model, dataset, train_pairs, GRAD_ACC_SIZE, contentid2entity): query_entity = embed(record['query_entity'].cpu() + 1).cuda() doc_entity = embed(record['doc_entity'].cpu() + 1).cuda() scores = model(record['query_tok'], record['query_mask'], record['doc_tok'], record['doc_mask'], query_entity, doc_entity) count = len(record['query_id']) // 2 scores = scores.reshape(count, 2) loss = torch.mean(1. - scores.softmax(dim=1)[:, 0]) # pariwse softmax loss.backward() total_loss += loss.item() total += count if total % BATCH_SIZE == 0: optimizer.step() optimizer.zero_grad() pbar.update(count) if total >= BATCH_SIZE * BATCHES_PER_EPOCH: return total_loss
def duet_train_iteration(model, optimizer, dataset, train_pairs, qrels, warmup_epoch, epoch): BATCH_SIZE = 16 BATCHES_PER_EPOCH = 64 GRAD_ACC_SIZE = 2 total = 0 model.train() total_loss = 0. total_vloss = 0. total_closs = 0. cq_sum = 0. cd_sum = 0. with tqdm('training', total=BATCH_SIZE * BATCHES_PER_EPOCH, ncols=80, desc='train', leave=False) as pbar: for record in data.iter_train_pairs(model, dataset, train_pairs, qrels, GRAD_ACC_SIZE): scores, v_scores, c_scores = model(record['query_tok'], record['query_mask'], record['doc_tok'], record['doc_mask']) count = len(record['query_id']) // 2 v_scores = v_scores.reshape(count, 2) c_scores = c_scores.reshape(count, 2) scores = scores.reshape(count, 2) #print(v_scores) #print(c_scores) #print(scores) ## independent learning v_loss = torch.mean( 1. - v_scores.softmax(dim=1)[:, 0]) # pariwse softmax c_loss = torch.mean( 1. - c_scores.softmax(dim=1)[:, 0]) # pariwse softmax #loss = v_loss + c_loss ## joint learning loss = torch.mean(1. - scores.softmax(dim=1)[:, 0]) # pariwse softmax if (epoch < warmup_epoch): ## initial warming up v_loss.backward() c_loss.backward() else: ## jointly learning loss.backward() total_loss += loss.item() total_vloss += v_loss.item() total_closs += c_loss.item() total += count if total % BATCH_SIZE == 0: optimizer.step() optimizer.zero_grad() pbar.update(count) if total >= BATCH_SIZE * BATCHES_PER_EPOCH: return total_loss, total_vloss, total_closs
def main(model, dataset, train_pairs, qrels, valid_run, qrelf, model_out_dir): params = [(k, v) for k, v in model.named_parameters() if v.requires_grad] non_bert_params = { 'params': [v for k, v in params if not k.startswith('bert.')] } bert_params = { 'params': [v for k, v in params if k.startswith('bert.')], 'lr': BERT_LR } optimizer = torch.optim.Adam([non_bert_params, bert_params], lr=LR) # optimizer = torch.optim.SGD([non_bert_params, bert_params], lr=LR, momentum=0.9) # model.to(device) model_parallel = dp.DataParallel(model, device_ids=devices) epoch = 0 top_valid_score = None for epoch in range(MAX_EPOCH): # loss = train_iteration(model, optimizer, dataset, train_pairs, qrels) # print(f'train epoch={epoch} loss={loss}') # # return train_set = TrainDataset(it=data.iter_train_pairs( model, dataset, train_pairs, qrels, 1), length=BATCH_SIZE * BATCHES_PER_EPOCH) train_loader = torch.utils.data.DataLoader( train_set, batch_size=GRAD_ACC_SIZE, ) # for i, tr in enumerate(train_loader): # for tt in tr: # print(tt, tr[tt].size()) # break # print('finished') # return model_parallel(train_iteration_multi, train_loader) '''
def train_iteration(model, optimizer, dataset, train_pairs, qrels): model.train() total = 0 total_loss = 0. with tqdm('training', total=BATCH_SIZE * BATCHES_PER_EPOCH, ncols=80, desc='train') as pbar: for n_iter, record in enumerate( data.iter_train_pairs(model, dataset, train_pairs, qrels, GRAD_ACC_SIZE)): # if n_iter > 15: # return scores = model(record['query_tok'], record['query_mask'], record['doc_tok'], record['doc_mask']) count = len(record['query_id']) // 2 # scores = scores.reshape(count, 2) # loss = torch.mean(1. - scores.softmax(dim=1)[:, 0]) # pairwise softmax # loss.backward() # total_loss += loss.item() # total_loss += loss total += count # if n_iter > 0: # print(n_iter, [(record[x].size(), record[x].device) for x in ['query_tok', 'query_mask', 'doc_tok', 'doc_mask']]) # import torch_xla.debug.metrics as met # print(met.metrics_report()) if total % BATCH_SIZE == 0: xm.optimizer_step(optimizer, barrier=True) optimizer.zero_grad() pbar.update(count) if total >= BATCH_SIZE * BATCHES_PER_EPOCH: return total_loss