class Trainer(): def __init__(self, config_path): self.image_config, self.model_config, self.run_config = LoadConfig( config_path=config_path).train_config() self.device = torch.device('cuda:%d' % self.run_config['device_ids'][0] if torch. cuda.is_available else 'cpu') self.model = getModel(self.model_config) os.makedirs(self.run_config['model_save_path'], exist_ok=True) self.run_config['num_workers'] = self.run_config['num_workers'] * len( self.run_config['device_ids']) self.train_set = Data(root=self.image_config['image_path'], phase='train', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.valid_set = Data(root=self.image_config['image_path'], phase='valid', data_name=self.image_config['data_name'], img_mode=self.image_config['image_mode'], n_classes=self.model_config['num_classes'], size=self.image_config['image_size'], scale=self.image_config['image_scale']) self.className = self.valid_set.className self.train_loader = DataLoader( self.train_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) self.valid_loader = DataLoader( self.valid_set, batch_size=self.run_config['batch_size'], shuffle=True, num_workers=self.run_config['num_workers'], pin_memory=True, drop_last=False) train_params = self.model.parameters() self.optimizer = RAdam(train_params, lr=eval(self.run_config['lr']), weight_decay=eval( self.run_config['weight_decay'])) if self.run_config['swa']: self.optimizer = SWA(self.optimizer, swa_start=10, swa_freq=5, swa_lr=0.005) # 设置学习率调节策略 self.lr_scheduler = utils.adjustLR.AdjustLr(self.optimizer) if self.run_config['use_weight_balance']: weight = utils.weight_balance.getWeight( self.run_config['weights_file']) else: weight = None self.Criterion = SegmentationLosses(weight=weight, cuda=True, device=self.device, batch_average=False) self.metric = utils.metrics.MetricMeter( self.model_config['num_classes']) @logger.catch # 在日志中记录错误 def __call__(self): # 设置记录日志 self.global_name = self.model_config['model_name'] logger.add(os.path.join( self.image_config['image_path'], 'log', 'log_' + self.global_name + '/train_{time}.log'), format="{time} {level} {message}", level="INFO", encoding='utf-8') self.writer = SummaryWriter(logdir=os.path.join( self.image_config['image_path'], 'run', 'runs_' + self.global_name)) logger.info("image_config: {} \n model_config: {} \n run_config: {}", self.image_config, self.model_config, self.run_config) # 如果多余一张卡,就采用数据并行 if len(self.run_config['device_ids']) > 1: self.model = nn.DataParallel( self.model, device_ids=self.run_config['device_ids']) self.model.to(device=self.device) cnt = 0 # 如果有预训练模型就加载 if self.run_config['pretrain'] != '': logger.info("loading pretrain %s" % self.run_config['pretrain']) try: self.load_checkpoint(use_optimizer=True, use_epoch=True, use_miou=True) except: print('load model with channed!!!!!') self.load_checkpoint_with_changed(use_optimizer=False, use_epoch=False, use_miou=False) logger.info("start training") for epoch in range(self.run_config['start_epoch'], self.run_config['epoch']): lr = self.optimizer.param_groups[0]['lr'] print('epoch=%d, lr=%.8f' % (epoch, lr)) self.train_epoch(epoch, lr) valid_miou = self.valid_epoch(epoch) # 确定采用哪一种学习率调节策略 self.lr_scheduler.LambdaLR_(milestone=5, gamma=0.92).step(epoch=epoch) self.save_checkpoint(epoch, valid_miou, 'last_' + self.global_name) if valid_miou > self.run_config['best_miou']: cnt = 0 self.save_checkpoint(epoch, valid_miou, 'best_' + self.global_name) logger.info("############# %d saved ##############" % epoch) self.run_config['best_miou'] = valid_miou else: cnt += 1 if cnt == self.run_config['early_stop']: logger.info("early stop") break self.writer.close() def train_epoch(self, epoch, lr): self.metric.reset() train_loss = 0.0 train_miou = 0.0 tbar = tqdm(self.train_loader) self.model.train() for i, (image, mask, edge) in enumerate(tbar): tbar.set_description('train_miou:%.6f' % train_miou) tbar.set_postfix({"train_loss": train_loss}) image = image.to(self.device) mask = mask.to(self.device) edge = edge.to(self.device) self.optimizer.zero_grad() out = self.model(image) if isinstance(out, tuple): aux_out, final_out = out[0], out[1] else: aux_out, final_out = None, out if self.model_config['model_name'] == 'ocrnet': aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask) cls_loss = self.Criterion.build_loss(mode='ce')(final_out, mask) loss = 0.4 * aux_loss + cls_loss loss = loss.mean() elif self.model_config['model_name'] == 'hrnet_duc': loss_body = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss_edge = self.Criterion.build_loss(mode='dice')( aux_out.squeeze(), edge) loss = loss_body + loss_edge loss = loss.mean() else: loss = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss.backward() self.optimizer.step() if self.run_config['swa']: self.optimizer.swap_swa_sgd() with torch.no_grad(): train_loss = ((train_loss * i) + loss.item()) / (i + 1) _, pred = torch.max(final_out, dim=1) self.metric.add(pred.cpu().numpy(), mask.cpu().numpy()) train_miou, train_ious = self.metric.miou() train_fwiou = self.metric.fw_iou() train_accu = self.metric.pixel_accuracy() train_fwaccu = self.metric.pixel_accuracy_class() logger.info( "Epoch:%2d\t lr:%.8f\t Train loss:%.4f\t Train FWiou:%.4f\t Train Miou:%.4f\t Train accu:%.4f\t " "Train fwaccu:%.4f" % (epoch, lr, train_loss, train_fwiou, train_miou, train_accu, train_fwaccu)) cls = "" ious = list() ious_dict = OrderedDict() for i, c in enumerate(self.className): ious_dict[c] = train_ious[i] ious.append(ious_dict[c]) cls += "%s:" % c + "%.4f " ious = tuple(ious) logger.info(cls % ious) # tensorboard self.writer.add_scalar("lr", lr, epoch) self.writer.add_scalar("loss/train_loss", train_loss, epoch) self.writer.add_scalar("miou/train_miou", train_miou, epoch) self.writer.add_scalar("fwiou/train_fwiou", train_fwiou, epoch) self.writer.add_scalar("accuracy/train_accu", train_accu, epoch) self.writer.add_scalar("fwaccuracy/train_fwaccu", train_fwaccu, epoch) self.writer.add_scalars("ious/train_ious", ious_dict, epoch) def valid_epoch(self, epoch): self.metric.reset() valid_loss = 0.0 valid_miou = 0.0 tbar = tqdm(self.valid_loader) self.model.eval() with torch.no_grad(): for i, (image, mask, edge) in enumerate(tbar): tbar.set_description('valid_miou:%.6f' % valid_miou) tbar.set_postfix({"valid_loss": valid_loss}) image = image.to(self.device) mask = mask.to(self.device) edge = edge.to(self.device) out = self.model(image) if isinstance(out, tuple): aux_out, final_out = out[0], out[1] else: aux_out, final_out = None, out if self.model_config['model_name'] == 'ocrnet': aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask) cls_loss = self.Criterion.build_loss(mode='ce')(final_out, mask) loss = 0.4 * aux_loss + cls_loss loss = loss.mean() elif self.model_config['model_name'] == 'hrnet_duc': loss_body = self.Criterion.build_loss( mode=self.run_config['loss_type'])(final_out, mask) loss_edge = self.Criterion.build_loss(mode='dice')( aux_out.squeeze(), edge) loss = loss_body + loss_edge # loss = loss.mean() else: loss = self.Criterion.build_loss(mode='ce')(final_out, mask) valid_loss = ((valid_loss * i) + float(loss)) / (i + 1) _, pred = torch.max(final_out, dim=1) self.metric.add(pred.cpu().numpy(), mask.cpu().numpy()) valid_miou, valid_ious = self.metric.miou() valid_fwiou = self.metric.fw_iou() valid_accu = self.metric.pixel_accuracy() valid_fwaccu = self.metric.pixel_accuracy_class() logger.info( "epoch:%d\t valid loss:%.4f\t valid fwiou:%.4f\t valid miou:%.4f valid accu:%.4f\t " "valid fwaccu:%.4f\t" % (epoch, valid_loss, valid_fwiou, valid_miou, valid_accu, valid_fwaccu)) ious = list() cls = "" ious_dict = OrderedDict() for i, c in enumerate(self.className): ious_dict[c] = valid_ious[i] ious.append(ious_dict[c]) cls += "%s:" % c + "%.4f " ious = tuple(ious) logger.info(cls % ious) self.writer.add_scalar("loss/valid_loss", valid_loss, epoch) self.writer.add_scalar("miou/valid_miou", valid_miou, epoch) self.writer.add_scalar("fwiou/valid_fwiou", valid_fwiou, epoch) self.writer.add_scalar("accuracy/valid_accu", valid_accu, epoch) self.writer.add_scalar("fwaccuracy/valid_fwaccu", valid_fwaccu, epoch) self.writer.add_scalars("ious/valid_ious", ious_dict, epoch) return valid_miou def save_checkpoint(self, epoch, best_miou, flag): meta = { 'epoch': epoch, 'model': self.model.state_dict(), 'optim': self.optimizer.state_dict(), 'bmiou': best_miou } try: torch.save(meta, os.path.join(self.run_config['model_save_path'], '%s.pth' % flag), _use_new_zipfile_serialization=False) except: torch.save( meta, os.path.join(self.run_config['model_save_path'], '%s.pth' % flag)) def load_checkpoint(self, use_optimizer, use_epoch, use_miou): state_dict = torch.load(self.run_config['pretrain'], map_location=self.device) self.model.load_state_dict(state_dict['model']) if use_optimizer: self.optimizer.load_state_dict(state_dict['optim']) if use_epoch: self.run_config['start_epoch'] = state_dict['epoch'] + 1 if use_miou: self.run_config['best_miou'] = state_dict['bmiou'] def load_checkpoint_with_changed(self, use_optimizer, use_epoch, use_miou): state_dict = torch.load(self.run_config['pretrain'], map_location=self.device) pretrain_dict = state_dict['model'] model_dict = self.model.state_dict() pretrain_dict = { k: v for k, v in pretrain_dict.items() if k in model_dict and 'edge' not in k } model_dict.update(pretrain_dict) self.model.load_state_dict(model_dict) if use_optimizer: self.optimizer.load_state_dict(state_dict['optim']) if use_epoch: self.run_config['start_epoch'] = state_dict['epoch'] + 1 if use_miou: self.run_config['best_miou'] = state_dict['bmiou']
checkpoint = torch.load(model_path) net.load_state_dict(checkpoint['model_state_dict']) net = net.to(device) # optimization # TODO: Choose an optimizer # optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=args.learning_rate) scheduler = None if args.use_swa: steps_per_epoch = len(train_dataloader) // args.batch_size optimizer = SWA(optimizer, swa_start=20 * steps_per_epoch, swa_freq=steps_per_epoch) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer.optimizer, mode="max", patience=5, factor=0.5) if args.resume_dir: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 # best_pesq = checkpoint['pesq'] best_loss = checkpoint['loss'] else: start_epoch = 0 best_loss = 1e8 # best_pesq = 0.0 # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # add graph to tensorboard if args.add_graph: # TODO: Create a dummy input for your model # dummy = torch.randn(16, 1, args.hop_length * 16).to(device) writer.add_graph(net, dummy)
optimizer = optimizer_dict[optimizer_name](model.dmg_model.parameters(), lr=lr) else: optimizer = optimizer_dict[optimizer_name](model.parameters(), lr=lr) # Call print("Starting model training....") n_epochs = setting_dict['epochs'] lr_patience = setting_dict['optimizer']['sheduler']['patience'] lr_factor = setting_dict['optimizer']['sheduler']['factor'] if weight_path is None: best_epoch = train(model,dataloaders,objective,optimizer,n_epochs,Path_list[1],Path_list[2], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"])) else: optimizer.load_state_dict(torch.load(weight_path)["optimizer"]) best_epoch = train(model,dataloaders,objective,optimizer,n_epochs-torch.load(weight_path)["epoch"],Path_list[1],Path_list[2],start_epoch = torch.load(weight_path)["epoch"]+1, loss_dict=torch.load(weight_path)["loss_dict"], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"])) print("model training finished! yey!") if optimizer_name == "SWA": print ("Updating batch norm pars for SWA") train_dataset.dataset.SWA = True SWA_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=cpu_count) optimizer.swap_swa_sgd() optimizer.bn_update(SWA_loader, model, device='cuda') state = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss_dict': {}
class Optimizer: optimizer_cls = None optimizer = None parameters = None def __init__(self, gradient_clipping, swa_start=None, swa_freq=None, swa_lr=None, **kwargs): self.gradient_clipping = gradient_clipping self.optimizer_kwargs = kwargs self.swa_start = swa_start self.swa_freq = swa_freq self.swa_lr = swa_lr def set_parameters(self, parameters): self.parameters = tuple(parameters) self.optimizer = self.optimizer_cls(self.parameters, **self.optimizer_kwargs) if self.swa_start is not None: from torchcontrib.optim import SWA assert self.swa_freq is not None, self.swa_freq assert self.swa_lr is not None, self.swa_lr self.optimizer = SWA(self.optimizer, swa_start=self.swa_start, swa_freq=self.swa_freq, swa_lr=self.swa_lr) def check_if_set(self): assert self.optimizer is not None, \ 'The optimizer is not initialized, call set_parameter before' \ ' using any of the optimizer functions' def zero_grad(self): self.check_if_set() return self.optimizer.zero_grad() def step(self): self.check_if_set() return self.optimizer.step() def swap_swa_sgd(self): self.check_if_set() from torchcontrib.optim import SWA assert isinstance(self.optimizer, SWA), self.optimizer return self.optimizer.swap_swa_sgd() def clip_grad(self): self.check_if_set() # Todo: report clipped and unclipped # Todo: allow clip=None but still report grad_norm grad_clips = self.gradient_clipping return torch.nn.utils.clip_grad_norm_(self.parameters, grad_clips) def to(self, device): if device is None: return self.check_if_set() for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) def cpu(self): return self.to('cpu') def cuda(self, device=None): assert device is None or isinstance(device, int), device if device is None: device = torch.device('cuda') return self.to(device) def load_state_dict(self, state_dict): self.check_if_set() return self.optimizer.load_state_dict(state_dict) def state_dict(self): self.check_if_set() return self.optimizer.state_dict()
class Trainer(object): def __init__(self, args, train_dataloader=None, validate_dataloader=None, test_dataloader=None): self.args = args self.train_dataloader = train_dataloader self.validate_dataloader = validate_dataloader self.test_dataloader = test_dataloader self.label_lst = [i for i in range(self.args.num_classes)] self.num_labels = self.args.num_classes self.config_class = AutoConfig self.model_class = BertForSequenceClassification self.config = self.config_class.from_pretrained( self.args.bert_model_name, num_labels=self.num_labels, finetuning_task='nsmc', id2label={str(i): label for i, label in enumerate(self.label_lst)}, label2id={label: i for i, label in enumerate(self.label_lst)}) self.model = self.model_class.from_pretrained( self.args.bert_model_name, config=self.config) self.optimizer = None self.scheduler = None # GPU or CPU self.device = "cuda" if torch.cuda.is_available( ) and args.cuda else "cpu" self.model.to(self.device) def train(self, alpha, gamma): train_dataloader = self.train_dataloader t_total = len(train_dataloader) * self.args.num_epochs # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }, { 'params': [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] if self.args.use_swa: base_opt = AdamW(optimizer_grouped_parameters, lr=self.args.lr, eps=1e-8) self.optimizer = SWA(base_opt, swa_start=4 * len(train_dataloader), swa_freq=100, swa_lr=5e-5) self.optimizer.param_groups = self.optimizer.optimizer.param_groups self.optimizer.state = self.optimizer.optimizer.state self.optimizer.defaults = self.optimizer.optimizer.defaults else: self.optimizer = optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr, eps=1e-8) self.scheduler = scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=100, num_training_steps=self.args.num_epochs * len(train_dataloader)) self.criterion = FocalLoss(alpha=alpha, gamma=gamma) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(self.train_dataloader) * self.args.batch_size) logger.info(" Num Epochs = %d", self.args.num_epochs) logger.info(" Total train batch size = %d", self.args.batch_size) logger.info(" Total optimization steps = %d", t_total) global_step = 0 tr_loss = 0.0 self.model.zero_grad() self.optimizer.zero_grad() train_iterator = trange(int(self.args.num_epochs), desc="Epoch") fin_result = None f1_max = 0.0 self.model.train() for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration") for step, batch in enumerate(epoch_iterator): batch = tuple(t.to(self.device) for t in batch) # GPU or CPU inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3], 'token_type_ids': batch[2] } # outputs = self.model(**inputs) # loss = outputs[0] # # Custom Loss loss, logits = self.model(**inputs) logits = torch.sigmoid(logits) labels = torch.zeros( (len(batch[3]), self.num_labels)).to(self.device) labels[range(len(batch[3])), batch[3]] = 1 loss = self.criterion(logits, labels) loss.backward() self.optimizer.step() self.scheduler.step() # Update learning rate schedule self.model.zero_grad() self.optimizer.zero_grad() tr_loss += loss.item() global_step += 1 logger.info('train loss %f', loss.item()) logger.info('total train loss %f', tr_loss / global_step) if epoch >= 4 and self.args.use_swa: self.optimizer.swap_swa_sgd() fin_result = self.evaluate("validate") self.save_model(epoch) self.model.train() if epoch >= 4 and self.args.use_swa: self.optimizer.swap_swa_sgd() f1_max = max(fin_result['f1_macro'], f1_max) if epoch >= 4 and self.args.use_swa: self.optimizer.swap_swa_sgd() with open(os.path.join(self.args.base_dir, self.args.result_dir, self.args.train_id, 'param_seach.txt'), "a", encoding="utf-8") as f: f.write('alpha: {}, gamma: {}, f1_macro: {}\n'.format( alpha, gamma, f1_max)) return f1_max def evaluate(self, mode='test'): if mode == 'test': dataloader = self.test_dataloader elif mode == 'validate': dataloader = self.validate_dataloader else: raise Exception("Only dev and test dataset available") # Eval! logger.info("***** Running evaluation on %s dataset *****", mode) logger.info(" Num examples = %d", len(dataloader) * self.args.batch_size) logger.info(" Batch size = %d", self.args.batch_size) eval_loss = 0.0 nb_eval_steps = 0 preds = None out_label_ids = None self.model.eval() for batch in tqdm(dataloader, desc="Evaluating"): batch = tuple(t.to(self.device) for t in batch) with torch.no_grad(): inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3], 'token_type_ids': batch[2] } outputs = self.model(**inputs) tmp_eval_loss, logits = outputs[:2] eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 if preds is None: preds = logits.detach().cpu().numpy() out_label_ids = inputs['labels'].detach().cpu().numpy() else: preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) out_label_ids = np.append( out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) eval_loss = eval_loss / nb_eval_steps results = {"loss": eval_loss} preds = np.argmax(preds, axis=1) result = compute_metrics(preds, out_label_ids) results.update(result) p_macro, r_macro, f_macro, support_macro \ = precision_recall_fscore_support(y_true=out_label_ids, y_pred=preds, labels=[i for i in range(self.num_labels)], average='macro') results.update({ 'precision': p_macro, 'recall': r_macro, 'f1_macro': f_macro }) with open(self.args.prediction_file, "w", encoding="utf-8") as f: for pred in preds: f.write("{}\n".format(pred)) if mode == 'validate': logger.info("***** Eval results *****") for key in sorted(results.keys()): logger.info(" %s = %s", key, str(results[key])) return results def save_model(self, num=0): state = { 'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() } torch.save( state, os.path.join(self.args.base_dir, self.args.result_dir, self.args.train_id, 'epoch_' + str(num) + '.pth')) logger.info('model saved') def load_model(self, model_name): state = torch.load( os.path.join(self.args.base_dir, self.args.result_dir, self.args.train_id, model_name)) self.model.load_state_dict(state['model']) if self.optimizer is not None: self.optimizer.load_state_dict(state['optimizer']) if self.scheduler is not None: self.scheduler.load_state_dict(state['scheduler']) logger.info('model loaded')
def main(): maxIOU = 0.0 assert torch.cuda.is_available() torch.backends.cudnn.benchmark = True model_fname = '../data/model_swa_8/deeplabv3_{0}_epoch%d.pth'.format( 'crops') focal_loss = FocalLoss2d() train_dataset = CropSegmentation(train=True, crop_size=args.crop_size) # test_dataset = CropSegmentation(train=False, crop_size=args.crop_size) model = torchvision.models.segmentation.deeplabv3_resnet50( pretrained=False, progress=True, num_classes=5, aux_loss=True) if args.train: weight = np.ones(4) weight[2] = 5 weight[3] = 5 w = torch.FloatTensor(weight).cuda() criterion = nn.CrossEntropyLoss() #ignore_index=255 weight=w model = nn.DataParallel(model).cuda() for param in model.parameters(): param.requires_grad = True optimizer1 = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=1e-4) scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=(args.epochs // 9) + 1) optimizer = SWA(optimizer1) dataset_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=args.train, pin_memory=True, num_workers=args.workers) max_iter = args.epochs * len(dataset_loader) losses = AverageMeter() start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {0}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print('=> loaded checkpoint {0} (epoch {1})'.format( args.resume, checkpoint['epoch'])) else: print('=> no checkpoint found at {0}'.format(args.resume)) for epoch in range(start_epoch, args.epochs): scheduler.step(epoch) model.train() for i, (inputs, target) in enumerate(dataset_loader): inputs = Variable(inputs.cuda()) target = Variable(target.cuda()) outputs = model(inputs) loss1 = focal_loss(outputs['out'], target) loss2 = focal_loss(outputs['aux'], target) loss01 = loss1 + 0.1 * loss2 loss3 = lovasz_softmax(outputs['out'], target) loss4 = lovasz_softmax(outputs['aux'], target) loss02 = loss3 + 0.1 * loss4 loss = loss01 + loss02 if np.isnan(loss.item()) or np.isinf(loss.item()): pdb.set_trace() losses.update(loss.item(), args.batch_size) loss.backward() optimizer.step() optimizer.zero_grad() if i > 10 and i % 5 == 0: optimizer.update_swa() print('epoch: {0}\t' 'iter: {1}/{2}\t' 'loss: {loss.val:.4f} ({loss.ema:.4f})'.format( epoch + 1, i + 1, len(dataset_loader), loss=losses)) if epoch > 5: torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, model_fname % (epoch + 1)) optimizer.swap_swa_sgd() torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, model_fname % (665 + 1))