def train_one_epoch(epoch, model, train_loader, optimizer, tokenizer, params): device = params.device avg_loss = AverageMeter() avg_acc = Accuracy(ignore_index=-1) model.train() for i, batch in enumerate(train_loader): optimizer.zero_grad() batch = batch.to(device) # segment = create_dummy_segment(batch) inputs, labels = mask_tokens(batch, tokenizer, params) inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs, masked_lm_labels=labels) loss, prediction_scores = outputs[:2] # model outputs are always tuple in transformers (see doc) loss.backward() optimizer.step() avg_acc.update(prediction_scores.view(-1, params.vocab_size), labels.view(-1)) avg_loss.update(loss.item()) logging.info('Train-E-{}: loss: {:.4f}'.format(epoch, avg_loss()))
def run_epoch(self, epoch, training): self.model.train(training) if training: description = '[Train]' dataset = self.trainData shuffle = True else: description = '[Valid]' dataset = self.validData shuffle = False # dataloader for train and valid dataloader = DataLoader( dataset, batch_size=self.batch_size, shuffle=shuffle, num_workers=8, collate_fn=dataset.collate_fn, ) trange = tqdm(enumerate(dataloader), total=len(dataloader), desc=description) loss = 0 accuracy = Accuracy() for i, ( x, y, ) in trange: # (x,y) = 128*128 o_labels, batch_loss = self.run_iter(x, y) if training: self.opt.zero_grad() # reset gradient to 0 batch_loss.backward() # calculate gradient self.opt.step() # update parameter by gradient loss += batch_loss.item() # .item() to get python number in Tensor accuracy.update(o_labels.cpu(), y) trange.set_postfix(accuracy=accuracy.print_score(), loss=loss / (i + 1)) if training: self.history['train'].append({ 'accuracy': accuracy.get_score(), 'loss': loss / len(trange) }) self.scheduler.step() else: self.history['valid'].append({ 'accuracy': accuracy.get_score(), 'loss': loss / len(trange) }) if loss < self.min_loss: self.save_best_model(epoch) self.min_loss = loss self.save_hist()
def run_epoch(self, epoch, source_dataloader, target_dataloader, lamb): trange = tqdm(zip(source_dataloader, target_dataloader), total=len(source_dataloader), desc=f'[epoch {epoch}]') total_D_loss, total_F_loss = 0.0, 0.0 acc = Accuracy() for i, ((source_data, source_label), (target_data, _)) in enumerate(trange): source_data = source_data.to(self.device) source_label = source_label.to(self.device) target_data = target_data.to(self.device) # =========== Preprocess ================= # mean/var of source and target datas are different, so we put them together for properly batch_norm mixed_data = torch.cat([source_data, target_data], dim=0) domain_label = torch.zeros( [source_data.shape[0] + target_data.shape[0], 1]).to(self.device) domain_label[:source_data.shape[ 0]] = 1 # source data label=1, target data lebel=0 feature = self.feature_extractor(mixed_data) # =========== Step 1 : Train Domain Classifier (fix feature extractor by feature.detach()) ================= domain_logits = self.domain_classifier(feature.detach()) loss = self.domain_criterion(domain_logits, domain_label) total_D_loss += loss.item() loss.backward() self.optimizer_D.step() # =========== Step 2: Train Feature Extractor and Label Predictor ================= class_logits = self.label_predictor(feature[:source_data.shape[0]]) domain_logits = self.domain_classifier(feature) loss = self.class_criterion( class_logits, source_label) - lamb * self.domain_criterion( domain_logits, domain_label) total_F_loss += loss.item() loss.backward() self.optimizer_F.step() self.optimizer_C.step() self.optimizer_D.zero_grad() self.optimizer_F.zero_grad() self.optimizer_C.zero_grad() acc.update(class_logits, source_label) trange.set_postfix(D_loss=total_D_loss / (i + 1), F_loss=total_F_loss / (i + 1), acc=acc.print_score()) self.history['d_loss'].append(total_D_loss / len(trange)) self.history['f_loss'].append(total_F_loss / len(trange)) self.history['acc'].append(acc.get_score()) self.save_hist() self.save_model()
def run_epoch(self, epoch, training): self.model.train(training) if training: description = 'Train' dataset = self.trainData shuffle = True else: description = 'Valid' dataset = self.validData shuffle = False dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=shuffle, collate_fn=dataset.collate_fn, num_workers=4) trange = tqdm(enumerate(dataloader), total=len(dataloader), desc=description, ascii=True) f_loss = 0 l_loss = 0 accuracy = Accuracy() for i, (x, missing, y) in trange: o_labels, batch_f_loss, batch_l_loss = self.run_iter(x, missing, y) batch_loss = batch_f_loss + batch_l_loss if training: self.opt.zero_grad() batch_loss.backward() self.opt.step() f_loss += batch_f_loss.item() l_loss += batch_l_loss.item() accuracy.update(o_labels.cpu(), y) trange.set_postfix(accuracy=accuracy.print_score(), f_loss=f_loss / (i + 1), l_loss=l_loss / (i + 1)) if training: self.history['train'].append({ 'accuracy': accuracy.get_score(), 'loss': f_loss / len(trange) }) self.scheduler.step() else: self.history['valid'].append({ 'accuracy': accuracy.get_score(), 'loss': f_loss / len(trange) })
def training(args, train_loader, valid_loader, model, optimizer, device): train_metrics = Accuracy() best_valid_acc = 0 total_iter = 0 criterion = torch.nn.CrossEntropyLoss() for epoch in range(args.epochs): train_trange = tqdm(enumerate(train_loader), total=len(train_loader), desc='training') train_loss = 0 train_metrics.reset() for i, batch in train_trange: model.train() prob = run_iter(batch, model, device, training=True) answer = batch['label'].to(device) loss = criterion(prob, answer) optimizer.zero_grad() loss.backward() optimizer.step() total_iter += 1 train_loss += loss.item() train_metrics.update(prob, answer) train_trange.set_postfix( loss=train_loss / (i + 1), **{train_metrics.name: train_metrics.print_score()}) if total_iter % args.eval_steps == 0: valid_acc = testing(valid_loader, model, device, valid=True) if valid_acc > best_valid_acc: best_valid_acc = valid_acc torch.save( model, os.path.join( args.model_dir, 'fine-tuned_bert_{}.pkl'.format(args.seed))) # Final validation valid_acc = testing(valid_loader, model, device, valid=True) if valid_acc > best_valid_acc: best_valid_acc = valid_acc torch.save( model, os.path.join(args.model_dir, 'fine-tuned_bert_{}.pkl'.format(args.seed))) print('Best Valid Accuracy:{}'.format(best_valid_acc))
def testing(dataloader, model, device, valid): metrics = Accuracy() criterion = torch.nn.CrossEntropyLoss() trange = tqdm(enumerate(dataloader), total=len(dataloader), desc='validation' if valid else 'testing') model.eval() total_loss = 0 metrics.reset() for k, batch in trange: model.eval() prob = run_iter(batch, model, device, training=False) answer = batch['label'].to(device) loss = criterion(prob, batch['label'].to(device)) total_loss += loss.item() metrics.update(prob, answer) trange.set_postfix(loss=total_loss / (k + 1), **{metrics.name: metrics.print_score()}) acc = metrics.match / metrics.n return acc
def run_epoch(self, epoch, dataset, training, desc=''): self.model.train(training) shuffle = training dataloader = DataLoader(dataset, self.batch_size, shuffle=shuffle) trange = tqdm(enumerate(dataloader), total=len(dataloader), desc=desc) loss = 0 acc = Accuracy() for i, (imgs, labels) in trange: # (b, 3, 128, 128), (b, 1) labels = labels.view(-1) # (b,) o_labels, batch_loss = self.run_iters(imgs, labels) if training: batch_loss /= self.accum_steps batch_loss.backward() if (i + 1) % self.accum_steps == 0: self.opt.step() self.opt.zero_grad() batch_loss *= self.accum_steps loss += batch_loss.item() acc.update(o_labels.cpu(), labels) trange.set_postfix(loss=loss / (i + 1), acc=acc.print_score()) if training: self.history['train'].append({ 'loss': loss / len(trange), 'acc': acc.get_score() }) self.scheduler.step() else: self.history['valid'].append({ 'loss': loss / len(trange), 'acc': acc.get_score() }) if loss < self.best_score: self.save_best() self.save_hist()
def run_epoch(self, epoch, dataloader): self.feature_extractor.train(True) self.label_predictor.train(True) trange = tqdm(dataloader, total=len(dataloader), desc=f'[epoch {epoch}]') total_loss = 0 acc = Accuracy() for i, (target_data, target_label) in enumerate(trange): # (b,1,32,32) target_data = target_data.to(self.device) target_label = target_label.view(-1).to(self.device) # (b) feature = self.feature_extractor(target_data) # (b, 512) class_logits = self.label_predictor(feature) # (b, 10) loss = self.class_criterion(class_logits, target_label) total_loss += loss.item() loss.backward() self.optimizer_F.step() self.optimizer_C.step() self.optimizer_F.zero_grad() self.optimizer_C.zero_grad() acc.update(class_logits, target_label) trange.set_postfix(loss=total_loss / (i + 1), acc=acc.print_score()) self.history['loss'].append(total_loss / len(trange)) self.history['acc'].append(acc.get_score()) self.save_hist() self.save_model()
def run_epoch(self, epoch, training, stage1): if stage1: self.model1.train(training) else: self.model1.train(False) self.model2.train(training) if training: description = '[Stage1 Train]' if stage1 else '[Stage2 Train]' dataset = self.trainData shuffle = True else: description = '[Stage1 Valid]' if stage1 else '[Stage2 Valid]' dataset = self.validData shuffle = False # dataloader for train and valid dataloader = DataLoader( dataset, batch_size=self.batch_size, shuffle=shuffle, num_workers=8, collate_fn=dataset.collate_fn, ) trange = tqdm(enumerate(dataloader), total=len(dataloader), desc=description) loss = 0 loss2 = 0 accuracy = Accuracy() if stage1: for i, (x, y, miss) in trange: # (x,y) = b*b pdb() o_f1, batch_loss = self.run_iter_stage1(x, miss) if training: self.opt.zero_grad() # reset gradient to 0 batch_loss.backward() # calculate gradient self.opt.step() # update parameter by gradient loss += batch_loss.item( ) # .item() to get python number in Tensor trange.set_postfix(loss=loss / (i + 1)) else: for i, (x, y, miss) in trange: # (x,y) = b*b o_labels, batch_loss, missing_loss = self.run_iter_stage2( x, miss, y) # x=(256, 8), y=(256) loss2 += missing_loss.item() if training: self.opt.zero_grad() # reset gradient to 0 batch_loss.backward() # calculate gradient self.opt.step() # update parameter by gradient loss += batch_loss.item( ) #.item() to get python number in Tensor accuracy.update(o_labels.cpu(), y) trange.set_postfix(accuracy=accuracy.print_score(), loss=loss / (i + 1), missing_loss=loss2 / (i + 1))
def run_epoch(self, epoch, training): self.model.train(training) self.generator.train(training) self.discriminator.train(training) if training: description = 'Train' dataset = self.trainData shuffle = True else: description = 'Valid' dataset = self.validData shuffle = False dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=shuffle, collate_fn=dataset.collate_fn, num_workers=4) trange = tqdm(enumerate(dataloader), total=len(dataloader), desc=description, ascii=True) g_loss = 0 d_loss = 0 loss = 0 accuracy = Accuracy() for i, (features, real_missing, labels) in trange: features = features.to(self.device) # (batch, 11) real_missing = real_missing.to(self.device) # (batch, 3) labels = labels.to(self.device) # (batch, 1) batch_size = features.shape[0] if training: rand = torch.rand((batch_size, 11)).to(self.device) - 0.5 std = features.std(dim=1) noise = rand * std.unsqueeze(1) features += noise # Adversarial ground truths valid = torch.FloatTensor(batch_size, 1).fill_(1.0).to( self.device) # (batch, 1) fake = torch.FloatTensor(batch_size, 1).fill_(0.0).to( self.device) # (batch, 1) # --------------------- # Train Discriminator # --------------------- if i % 10 < 5 or not training: real_pred = self.discriminator(real_missing) d_real_loss = self.criterion(real_pred, valid) fake_missing = self.generator(features.detach()) fake_pred = self.discriminator(fake_missing) d_fake_loss = self.criterion(fake_pred, fake) batch_d_loss = (d_real_loss + d_fake_loss) if training: self.opt_D.zero_grad() batch_d_loss.backward() self.opt_D.step() d_loss += batch_d_loss.item() # ----------------- # Train Generator # ----------------- if i % 10 >= 5 or not training: gen_missing = self.generator(features.detach()) validity = self.discriminator(gen_missing) batch_g_loss = self.criterion(validity, valid) if training: self.opt_G.zero_grad() batch_g_loss.backward() self.opt_G.step() g_loss += batch_g_loss.item() # ------------------ # Train Classifier # ------------------ gen_missing = self.generator(features.detach()) all_features = torch.cat((features, gen_missing), dim=1) o_labels = self.model(all_features) batch_loss = self.criterion(o_labels, labels) if training: self.opt.zero_grad() batch_loss.backward() self.opt.step() loss += batch_loss.item() accuracy.update(o_labels, labels) trange.set_postfix(accuracy=accuracy.print_score(), g_loss=g_loss / (i + 1), d_loss=d_loss / (i + 1), loss=loss / (i + 1)) if training: self.history['train'].append({ 'accuracy': accuracy.get_score(), 'g_loss': g_loss / len(trange), 'd_loss': d_loss / len(trange), 'loss': loss / len(trange) }) self.scheduler.step() self.scheduler_G.step() self.scheduler_D.step() else: self.history['valid'].append({ 'accuracy': accuracy.get_score(), 'g_loss': g_loss / len(trange), 'd_loss': d_loss / len(trange), 'loss': loss / len(trange) })