def __init__(self, train_loader, val_loader, test_loader, config): self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.config = config self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255 self.beta = 0.3 self.device = torch.device('cpu') if self.config.cuda: cudnn.benchmark = True self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("NLDF", 1) self.build_model() if self.config.pre_trained: self.generator.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') self.test_output02 = open( "%s/pre_and_recall.txt" % config.test_fold, 'w')
def __init__(self, train_loader, val_loader, test_dataset, config): self.train_loader = train_loader self.val_loader = val_loader self.test_dataset = test_dataset self.config = config self.beta = math.sqrt(0.3) # for max F_beta metric # inference: choose the side map (see paper) self.select = [1, 2, 3, 6] self.device = torch.device('cpu') self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1) self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) if self.config.cuda: cudnn.benchmark = True self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("DSS 12-6-19", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
def __init__(self, train_loader, val_loader, test_dataset, config): self.train_loader = train_loader self.val_loader = val_loader self.test_dataset = test_dataset self.config = config self.beta = math.sqrt(0.3) self.device = torch.device('cpu') self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)#view()函数作用是将一个多行的Tensor,拼接成某种行 self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) self.visual_save_fold = config.pre_map if self.config.cuda: cudnn.benchmark = True self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("camu", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
def __init__(self, train_loader, val_loader, test_dataset, config,mode): self.train_loader = train_loader self.val_loader = val_loader self.test_dataset = test_dataset self.config = config self.beta = 0.3 self.select = [1, 2, 3, 6] self.device = torch.device('cuda:0') self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1) self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) self.mode = mode if self.config.mode == "train": self.lossfile = open("%s/logs/loss.txt" % config.save_fold, 'w') self.maefile = open("%s/logs/mae.txt" % config.save_fold, 'w') if self.config.cuda: cudnn.benchmark = True self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("DSS", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
def __init__(self, train_loader, val_loader, test_loader, config): self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.config = config self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255 self.beta = 0.3 if config.visdom: self.visual = Viz_visdom("NLFD", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w')
def __init__(self, train_loader, val_loader, test_loader, config): self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.config = config self.beta = 0.3 # for max F_beta metric self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255 # inference: choose the side map (see paper) self.select = [1, 2, 3, 6] if config.visdom: self.visual = Viz_visdom("DSS", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w')
def __init__(self, train_loader, test_dataset, config): self.train_loader = train_loader self.test_dataset = test_dataset self.config = config self.beta = 0.3 # for max F_beta metric # inference: choose the side map (see paper) self.select = [1, 2, 3, 6] # self.device = torch.device('cpu') self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1) self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) self.update = config.update self.step = config.step #modified by hanqi self.summary = TensorboardSummary("%s/logs/" % config.save_fold) self.writer = self.summary.create_summary() self.visual_save_fold = config.save_fold if self.config.cuda: cudnn.benchmark = True # self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("DSS", 1) self.build_model() if self.config.pre_trained: self.net.module.load_state_dict(torch.load(self.config.pre_trained)) self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.t_transform = transforms.Compose([ # transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Lambda(lambda x: torch.round(x)) # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.module.load_state_dict(torch.load(self.config.model)["state_dict"]) self.net.eval()
class Solver(object): def __init__(self, train_loader, val_loader, test_loader, config): self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.config = config self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255 self.beta = 0.3 self.device = torch.device('cpu') if self.config.cuda: cudnn.benchmark = True self.device = torch.device('cuda') if config.visdom: self.visual = Viz_visdom("NLDF", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel() print(name) print(model) print("The number of parameters: {}".format(num_params)) def build_model(self): self.net = build_model() if self.config.mode == 'train': self.loss = Loss(self.config.area, self.config.boundary) self.net = self.net.to(self.device) if self.config.cuda and self.config.mode == 'train': self.loss = self.loss.cuda() self.net.train() self.net.apply(weights_init) if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg)) if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load)) self.optimizer = Adam(self.net.parameters(), self.config.lr) self.print_network(self.net, 'NLDF') def update_lr(self, lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr def clip(self, y): return torch.clamp(y, 0.0, 1.0) def eval_mae(self, y_pred, y): return torch.abs(y_pred - y).mean() # TODO: write a more efficient version def eval_pr(self, y_pred, y, num): prec, recall = torch.zeros(num), torch.zeros(num) thlist = torch.linspace(0, 1 - 1e-10, num) for i in range(num): y_temp = (y_pred >= thlist[i]).float() tp = (y_temp * y).sum() prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum() return prec, recall def validation(self): avg_mae = 0.0 self.net.eval() for i, data_batch in enumerate(self.val_loader): with torch.no_grad(): images, labels = data_batch images, labels = images.to(self.device), labels.to(self.device) prob_pred = self.net(images) avg_mae += self.eval_mae(prob_pred, labels).cpu().item() self.net.train() return avg_mae / len(self.val_loader) def test(self, num): avg_mae, img_num = 0.0, len(self.test_loader) avg_prec, avg_recall = torch.zeros(num), torch.zeros(num) for i, data_batch in enumerate(self.test_loader): with torch.no_grad(): images, labels = data_batch shape = labels.size()[2:] images = images.to(self.device) prob_pred = F.interpolate(self.net(images), size=shape, mode='bilinear', align_corners=True).cpu() mae = self.eval_mae(prob_pred, labels) prec, recall = self.eval_pr(prob_pred, labels, num) print("[%d] mae: %.4f" % (i, mae)) print("[%d] mae: %.4f" % (i, mae), file=self.test_output) avg_mae += mae avg_prec, avg_recall = avg_prec + prec, avg_recall + recall avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num score = (1 + self.beta**2) * avg_prec * avg_recall / ( self.beta**2 * avg_prec + avg_recall) score[score != score] = 0 # delete the nan print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max())) print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output) def train(self): iter_num = len(self.train_loader.dataset) // self.config.batch_size best_mae = 1.0 if self.config.val else None for epoch in range(self.config.epoch): loss_epoch = 0 for i, data_batch in enumerate(self.train_loader): if (i + 1) > iter_num: break self.net.zero_grad() x, y = data_batch x, y = x.to(self.device), y.to(self.device) y_pred = self.net(x) loss = self.loss(y_pred, y) loss.backward() utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) self.optimizer.step() loss_epoch += loss.cpu().item() print( 'epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % (epoch, self.config.epoch, i, iter_num, loss.cpu().item())) if self.config.visdom: error = OrderedDict([('loss:', loss.cpu().item())]) self.visual.plot_current_errors(epoch, i / iter_num, error) if (epoch + 1) % self.config.epoch_show == 0: print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num), file=self.log_output) if self.config.visdom: avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num) ]) self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1) img = OrderedDict([('origin', self.mean + x.cpu()[0]), ('label', y.cpu()[0][0]), ('pred_label', y_pred.cpu()[0][0])]) self.visual.plot_current_img(img) if self.config.val and (epoch + 1) % self.config.epoch_val == 0: mae = self.validation() print('--- Best MAE: %.4f, Curr MAE: %.4f ---' % (best_mae, mae)) print('--- Best MAE: %.4f, Curr MAE: %.4f ---' % (best_mae, mae), file=self.log_output) if best_mae > mae: best_mae = mae torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold) if (epoch + 1) % self.config.epoch_save == 0: torch.save( self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1)) torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)
class Solver(object): def __init__(self, train_loader, val_loader, test_dataset, config): self.train_loader = train_loader self.val_loader = val_loader self.test_dataset = test_dataset self.config = config self.beta = math.sqrt(0.3) # for max F_beta metric # inference: choose the side map (see paper) self.select = [1, 2, 3, 6] self.device = torch.device('cpu') self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1) self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) if self.config.cuda: cudnn.benchmark = True self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("DSS 12-6-19", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # print the network information and parameter numbers def print_network(self, model, name): num_params = 0 for p in model.parameters(): if p.requires_grad: num_params += p.numel() print(name) print(model) print("The number of parameters: {}".format(num_params)) # build the network def build_model(self): self.net = build_model().to(self.device) if self.config.mode == 'train': self.loss = Loss().to(self.device) self.net.train() self.net.apply(weights_init) if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg)) if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load)) self.optimizer = Adam(self.net.parameters(), self.config.lr) self.print_network(self.net, 'DSS') # update the learning rate def update_lr(self, lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr # evaluate MAE (for test or validation phase) def eval_mae(self, y_pred, y): return torch.abs(y_pred - y).mean() # TODO: write a more efficient version # get precisions and recalls: threshold---divided [0, 1] to num values def eval_pr(self, y_pred, y, num): prec, recall = torch.zeros(num), torch.zeros(num) thlist = torch.linspace(0, 1 - 1e-10, num) for i in range(num): y_temp = (y_pred >= thlist[i]).float() tp = (y_temp * y).sum() prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum() return prec, recall # validation: using resize image, and only evaluate the MAE metric def validation(self): avg_mae = 0.0 self.net.eval() with torch.no_grad(): for i, data_batch in enumerate(self.val_loader): #images, labels = data_batch images, labels = data_batch['image'], data_batch['label'] images = images.type(torch.cuda.FloatTensor) labels = labels.type(torch.cuda.FloatTensor) images, labels = images.to(self.device), labels.to(self.device) prob_pred = self.net(images) prob_pred = torch.mean(torch.cat( [prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) avg_mae += self.eval_mae(prob_pred, labels).item() print("Average Mae" + str(avg_mae)) self.net.train() return avg_mae / len(self.val_loader) # test phase: using origin image size, evaluate MAE and max F_beta metrics def test(self, num, use_crf=False): if use_crf: from tools.crf_process import crf avg_mae, img_num = 0.0, len(self.test_dataset) avg_prec, avg_recall = torch.zeros(num), torch.zeros(num) with torch.no_grad(): for i, data in enumerate( self.test_dataset ): #(img, labels) in enumerate(self.test_dataset): images, labels = data['image'], data['label'] images = images.type(torch.cuda.FloatTensor) labels = labels.type(torch.cuda.FloatTensor) #images = self.transform(img).unsqueeze(0) #labels = labels.unsqueeze(0) shape = labels.size()[2:] #print(shape) images = images.to(self.device) labels = labels.to(self.device) prob_pred = self.net(images) prob_pred = torch.mean(torch.cat( [prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data print(prob_pred[0].size()) result_dir = 'C:/Users/Paul Vincent Nonat/Documents/Graduate Student Files/results/' save_image(prob_pred[0], result_dir + 'result' + str(i) + '.png') if use_crf: prob_pred = crf(img, prob_pred.numpy(), to_tensor=True) mae = self.eval_mae(prob_pred, labels) prec, recall = self.eval_pr(prob_pred, labels, num) print(num) print("[%d] mae: %.4f" % (i, mae)) print("[%d] mae: %.4f" % (i, mae), file=self.test_output) avg_mae += mae avg_prec, avg_recall = avg_prec + prec, avg_recall + recall avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num score = (1 + self.beta**2) * avg_prec * avg_recall / ( self.beta**2 * avg_prec + avg_recall) score[score != score] = 0 # delete the nan print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max())) print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output) # training phase def train(self, num): iter_num = len(self.train_loader.dataset) // self.config.batch_size best_mae = 1.0 if self.config.val else None for epoch in range(self.config.epoch): loss_epoch = 0 for i, data_batch in enumerate(self.train_loader): x, y = data_batch['image'], data_batch['label'] x = x.type(torch.cuda.FloatTensor) y = y.type(torch.cuda.FloatTensor) x, y = Variable(x.to(self.device), requires_grad=False), Variable( y.to(self.device), requires_grad=False) #x, y = x.to(self.device), y.to(self.device) if (i + 1) > iter_num: break self.net.zero_grad() y_pred = self.net(x) loss = self.loss(y_pred, y) loss.backward() utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient) self.optimizer.step() loss_epoch += loss.item() print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % (epoch, self.config.epoch, i, iter_num, loss.item())) if self.config.visdom: error = OrderedDict([('loss:', loss.item())]) self.visual.plot_current_errors('Cross Entropy Loss', epoch, i / iter_num, error) if (epoch + 1) % self.config.epoch_show == 0: print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num), file=self.log_output) if self.config.visdom: avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num) ]) self.visual.plot_current_errors('Average Loss per Epoch', epoch, i / iter_num, avg_err, 1) for i in self.select: y_show = torch.mean(torch.cat( [y_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) img = OrderedDict([('origin' + str(epoch) + str(i), x.cpu()[0] * self.std + self.mean), ('label' + str(epoch) + str(i), y.cpu()[0][0]), ('pred_label' + str(epoch) + str(i), y_pred[i].cpu().data[0][0])]) self.visual.plot_current_img(img) #this shows the mean prediction of the 5 output layers. if self.config.val and (epoch + 1) % self.config.epoch_val == 0: mae = self.validation() prec, recall = self.eval_pr(prob_pred, labels, num) score = (1 + self.beta**2) * prec * recall / ( self.beta**2 * prec + recall) score[score != score] = 0 # delete the nan print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae)) print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output) if self.config.visdom: error = OrderedDict([('MAE:', mae)]) self.visual.plot_current_errors( 'Mean Absolute Error Graph', epoch, i / iter_num, error, 2) prec_graph = OrderedDict([('Precission:', prec)]) self.visual.plot_current_errors('Precission Graph', epoch, i / iter_num, prec_graph, 3) recall_graph = OrderedDict([('Recall:', recall)]) self.visual.plot_current_errors('Recall Graph', epoch, i / iter_num, recall_graph, 4) fscore_graph = OrderedDict([('F-Measure:', score)]) self.visual.plot_current_errors('F-Measure Graph', epoch, i / iter_num, fscore_graph, 5) if best_mae > mae: best_mae = mae torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold) if (epoch + 1) % self.config.epoch_save == 0: torch.save( self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1)) torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)
class Solver(object): def __init__(self, train_loader, val_loader, test_loader, config): self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.config = config self.beta = 0.3 # for max F_beta metric self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255 # inference: choose the side map (see paper) self.select = [1, 2, 3, 6] if config.visdom: self.visual = Viz_visdom("DSS", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') # print the network information and parameter numbers def print_network(self, model, name): num_params = 0 for p in model.parameters(): num_params += p.numel() print(name) print(model) print("The number of parameters: {}".format(num_params)) # build the network def build_model(self): self.net = build_model() if self.config.mode == 'train': self.loss = Loss() if self.config.cuda: self.net = self.net.cuda() if self.config.cuda and self.config.mode == 'train': self.loss = self.loss.cuda() self.net.train() self.net.apply(weights_init) if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg)) if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load)) self.optimizer = Adam(self.net.parameters(), self.config.lr) self.print_network(self.net, 'DSS') # update the learning rate def update_lr(self, lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr # evaluate MAE (for test or validation phase) def eval_mae(self, y_pred, y): return torch.abs(y_pred - y).mean() # TODO: write a more efficient version # get precisions and recalls: threshold---divided [0, 1] to num values def eval_pr(self, y_pred, y, num): prec, recall = torch.zeros(num), torch.zeros(num) thlist = torch.linspace(0, 1 - 1e-10, num) for i in range(num): y_temp = (y_pred >= thlist[i]).float() tp = (y_temp * y).sum() prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum() return prec, recall # validation: using resize image, and only evaluate the MAE metric def validation(self): avg_mae = 0.0 self.net.eval() for i, data_batch in enumerate(self.val_loader): images, labels = data_batch images, labels = Variable(images, volatile=True), Variable(labels, volatile=True) if self.config.cuda: images, labels = images.cuda(), labels.cuda() prob_pred = self.net(images) prob_pred = torch.mean(torch.cat( [prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) avg_mae += self.eval_mae(prob_pred, labels).cpu().data[0] self.net.train() return avg_mae / len(self.val_loader) # test phase: using origin image size, evaluate MAE and max F_beta metrics def test(self, num): avg_mae, img_num = 0.0, len(self.test_loader) avg_prec, avg_recall = torch.zeros(num), torch.zeros(num) for i, data_batch in enumerate(self.test_loader): images, labels = data_batch shape = labels.size()[2:] images = Variable(images, volatile=True) if self.config.cuda: images = images.cuda() prob_pred = self.net(images) prob_pred = torch.mean(torch.cat( [prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) prob_pred = F.upsample(prob_pred, size=shape, mode='bilinear').cpu().data mae = self.eval_mae(prob_pred, labels) prec, recall = self.eval_pr(prob_pred, labels, num) print("[%d] mae: %.4f" % (i, mae)) print("[%d] mae: %.4f" % (i, mae), file=self.test_output) avg_mae += mae avg_prec, avg_recall = avg_prec + prec, avg_recall + recall avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num score = (1 + self.beta**2) * avg_prec * avg_recall / ( self.beta**2 * avg_prec + avg_recall) score[score != score] = 0 # delete the nan print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max())) print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output) # training phase def train(self): x = torch.FloatTensor(self.config.batch_size, self.config.n_color, self.config.img_size, self.config.img_size) y = torch.FloatTensor(self.config.batch_size, self.config.n_color, self.config.img_size, self.config.img_size) if self.config.cuda: cudnn.benchmark = True x, y = x.cuda(), y.cuda() x, y = Variable(x), Variable(y) iter_num = len(self.train_loader.dataset) // self.config.batch_size best_mae = 1.0 if self.config.val else None for epoch in range(self.config.epoch): loss_epoch = 0 for i, data_batch in enumerate(self.train_loader): if (i + 1) > iter_num: break self.net.zero_grad() images, labels = data_batch if self.config.cuda: images, labels = images.cuda(), labels.cuda() x.data.resize_as_(images).copy_(images) y.data.resize_as_(labels).copy_(labels) y_pred = self.net(x) loss = self.loss(y_pred, y) loss.backward() utils.clip_grad_norm(self.net.parameters(), self.config.clip_gradient) # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient) self.optimizer.step() loss_epoch += loss.cpu().data[0] print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % (epoch, self.config.epoch, i, iter_num, loss.cpu().data[0])) if self.config.visdom: error = OrderedDict([('loss:', loss.cpu().data[0])]) self.visual.plot_current_errors(epoch, i / iter_num, error) if (epoch + 1) % self.config.epoch_show == 0: print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num), file=self.log_output) if self.config.visdom: avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num) ]) self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1) y_show = torch.mean(torch.cat( [y_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) img = OrderedDict([('origin', self.mean + images.cpu()[0]), ('label', labels.cpu()[0][0]), ('pred_label', y_show.cpu().data[0][0]) ]) self.visual.plot_current_img(img) if self.config.val and (epoch + 1) % self.config.epoch_val == 0: mae = self.validation() print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae)) print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output) if best_mae > mae: best_mae = mae torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold) if (epoch + 1) % self.config.epoch_save == 0: torch.save( self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1)) torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)
class Solver(object): def __init__(self, train_loader, val_loader, test_dataset, config): self.train_loader = train_loader self.val_loader = val_loader self.test_dataset = test_dataset self.config = config self.beta = math.sqrt(0.3) self.device = torch.device('cpu') self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)#view()函数作用是将一个多行的Tensor,拼接成某种行 self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) self.visual_save_fold = config.pre_map if self.config.cuda: cudnn.benchmark = True self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("camu", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # build the network def build_model(self): self.net = build_model().to(self.device) if self.config.mode == 'train': self.loss = Loss().to(self.device) self.net.train() self.net.eval() params_dict = dict(self.net.named_parameters()) self.optimizer = Adam(self.net.parameters(), self.config.lr) #self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.config.lr) # evaluate MAE (for test or validation phase) def eval_mae(self, y_pred, y): return torch.abs(y_pred - y).mean() # TODO: write a more efficient version # get precisions and recalls: threshold---divided [0, 1] to num values def eval_pr(self, y_pred, y, num): prec, recall = torch.zeros(num), torch.zeros(num) thlist = torch.linspace(0, 1 - 1e-10, num) for i in range(num): y_temp = (y_pred >= thlist[i]).float() tp = (y_temp * y).sum() prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum() return prec, recall # validation: using resize image, and only evaluate the MAE metric def validation(self): avg_mae = 0.0 self.net.eval() with torch.no_grad(): for i, data_batch in enumerate(self.val_loader): images, labels = data_batch images, labels = images.to(self.device), labels.to(self.device) prob_pred = self.net(images) avg_mae += self.eval_mae(prob_pred[0], labels).item() self.net.train() return avg_mae / len(self.val_loader) # test phase: using origin image size, evaluate MAE and max F_beta metrics def test(self, num, use_crf=False): if use_crf: from tools.crf_process import crf avg_mae, img_num = 0.0, len(self.test_dataset) avg_prec, avg_recall = torch.zeros(num), torch.zeros(num) with torch.no_grad(): for i, (img, labels, name) in enumerate(self.test_dataset): images = self.transform(img).unsqueeze(0) labels = labels.unsqueeze(0) shape = labels.size()[2:] images = images.to(self.device) prob_pred = self.net(images) # 因为输出多个 测试的时候需要改一下 prob_pred = F.interpolate(prob_pred[0], size=shape, mode='bilinear', align_corners=True).cpu().data if not os.path.exists('{}/'.format(self.visual_save_fold)): os.mkdir('{}/'.format(self.visual_save_fold)) img_save = prob_pred.numpy() img_save = img_save.reshape(-1,img_save.shape[2], img_save.shape[3]).transpose(1,2,0) * 255 cv2.imwrite('{}/{}.png'.format(self.visual_save_fold,name), img_save.astype(np.uint8)) mae = self.eval_mae(prob_pred, labels) prec, recall = self.eval_pr(prob_pred, labels, num) print("[%d] mae: %.4f" % (i, mae)) print("[%d] mae: %.4f" % (i, mae), file=self.test_output) avg_mae += mae avg_prec, avg_recall = avg_prec + prec, avg_recall + recall avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num print('average mae: %.4f' % (avg_mae)) print('average mae: %.4f' % (avg_mae), file=self.test_output) # training phase def train(self): iter_num = len(self.train_loader.dataset) / self.config.batch_size best_mae = 1.0 if self.config.val else None for epoch in range(self.config.epoch): loss_epoch = 0 for i, data_batch in enumerate(self.train_loader): if (i + 1) > iter_num: break self.net.zero_grad() x, y = data_batch x, y = x.to(self.device), y.to(self.device) y_pred = self.net(x) loss = self.loss(y_pred, y) loss.backward() utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) self.optimizer.step() loss_epoch += loss.item() print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f], lr: [%s]' % ( epoch, self.config.epoch, i, iter_num, loss.item(), self.config.lr)) if self.config.visdom: error = OrderedDict([('loss:', loss.item())]) self.visual.plot_current_errors(epoch, i / iter_num, error) if (epoch + 1) % self.config.epoch_show == 0: print('epoch: [%d/%d], epoch_loss: [%.4f], lr: [%s]' % (epoch, self.config.epoch, loss_epoch / iter_num, self.config.lr), file=self.log_output) if self.config.val and (epoch + 1) % self.config.epoch_val == 0: mae = self.validation() print('--- Best MAE: %.5f, Curr MAE: %.5f ---' % (best_mae, mae)) print('--- Best MAE: %.5f, Curr MAE: %.5f ---' % (best_mae, mae), file=self.log_output) if best_mae > mae: best_mae = mae torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold) if (epoch + 1) % self.config.epoch_save == 0: torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1)) torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)
class Solver(object): def __init__(self, train_loader, test_dataset, config): self.train_loader = train_loader self.test_dataset = test_dataset self.config = config self.beta = 0.3 # for max F_beta metric # inference: choose the side map (see paper) self.select = [1, 2, 3, 6] # self.device = torch.device('cpu') self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1) self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) self.update = config.update self.step = config.step #modified by hanqi self.summary = TensorboardSummary("%s/logs/" % config.save_fold) self.writer = self.summary.create_summary() self.visual_save_fold = config.save_fold if self.config.cuda: cudnn.benchmark = True # self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("DSS", 1) self.build_model() if self.config.pre_trained: self.net.module.load_state_dict(torch.load(self.config.pre_trained)) self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.t_transform = transforms.Compose([ # transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Lambda(lambda x: torch.round(x)) # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.module.load_state_dict(torch.load(self.config.model)["state_dict"]) self.net.eval() # self.test_output = open("%s/test.txt" % config.test_fold, 'w') # print the network information and parameter numbers def print_network(self, model, name): num_params = 0 for p in model.parameters(): if p.requires_grad: num_params += p.numel() print(name) print(model) print("The number of parameters: {}".format(num_params)) # build the network def build_model(self): self.net = torch.nn.DataParallel(build_model()).cuda() if self.config.mode == 'train': self.loss = Loss().cuda() self.net.train() self.net.apply(weights_init) if self.config.load == '': self.net.module.base.load_state_dict(torch.load(self.config.vgg)) if self.config.load != '': self.net.module.load_state_dict(torch.load(self.config.load)) self.optimizer = Adam(self.net.parameters(), self.config.lr) self.print_network(self.net, 'DSS') # update the learning rate def update_lr(self): for param_group in self.optimizer.param_groups: param_group['lr'] = param_group['lr'] / 10.0 # evaluate MAE (for test or validation phase) def eval_mae(self, y_pred, y): return torch.abs(y_pred - y).mean() # TODO: write a more efficient version # get precisions and recalls: threshold---divided [0, 1] to num values def eval_pr(self, y_pred, y, num): prec, recall = torch.zeros(num), torch.zeros(num) thlist = torch.linspace(0, 1 - 1e-10, num) for i in range(num): y_temp = (y_pred >= thlist[i]).float() tp = (y_temp * y).sum() prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum() return prec, recall # test phase: using origin image size, evaluate MAE and max F_beta metrics def test(self, num, use_crf=False, epoch=None): if use_crf: from tools.crf_process import crf avg_mae, img_num = 0.0, 0.0 avg_prec, avg_recall = torch.zeros(num), torch.zeros(num) with torch.no_grad(): for i, (img, labels, bg, fg, name) in enumerate(self.test_dataset): images = self.transform(img).unsqueeze(0) labels = self.t_transform(labels).unsqueeze(0) shape = labels.size()[2:] images = images.cuda() prob_pred = self.net(images, mode='test') bg_pred = torch.mean(torch.cat([prob_pred[i+7] for i in self.select], dim=1), dim=1, keepdim=True) bg_pred = (bg_pred > 0.5).float() prob_pred = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data bg_pred = F.interpolate(bg_pred, size=shape, mode='nearest').cpu().data.numpy() fork_bg, fork_fg = Bwdist(bg_pred) if use_crf: prob_pred = crf(img, prob_pred.numpy(), to_tensor=True) if not os.path.exists('{}/visualize_pred{}/'.format(self.visual_save_fold, epoch)): os.mkdir('{}/visualize_pred{}/'.format(self.visual_save_fold, epoch)) img_save = prob_pred.numpy() img_save = img_save.reshape(-1, img_save.shape[2], img_save.shape[3]).transpose(1,2,0) * 255 cv2.imwrite('{}/visualize_pred{}/{}'.format(self.visual_save_fold, epoch, name), img_save.astype(np.uint8)) # print('save visualize_pred{}/{} done.'.format(name, epoch)) if not os.path.exists('{}/visualize_bg{}/'.format(self.visual_save_fold, epoch)): os.mkdir('{}/visualize_bg{}/'.format(self.visual_save_fold, epoch)) img_save = fork_bg img_save = img_save.reshape(-1, img_save.shape[2], img_save.shape[3]).transpose(1,2,0) * 255 cv2.imwrite('{}/visualize_bg{}/{}'.format(self.visual_save_fold, epoch, name), img_save.astype(np.uint8)) # print('save visualize_bg{}/{} done.'.format(name, epoch)) if not os.path.exists('{}/visualize_fg{}/'.format(self.visual_save_fold, epoch)): os.mkdir('{}/visualize_fg{}/'.format(self.visual_save_fold, epoch)) img_save = fork_fg img_save = img_save.reshape(-1, img_save.shape[2], img_save.shape[3]).transpose(1,2,0) * 255 cv2.imwrite('{}/visualize_fg{}/{}'.format(self.visual_save_fold, epoch, name), img_save.astype(np.uint8)) # print('save visualize_bg{}/{} done.'.format(name, epoch)) mae = self.eval_mae(prob_pred, labels) if mae == mae: avg_mae += mae img_num += 1.0 # prec, recall = self.eval_pr(prob_pred, labels, num) # avg_prec, avg_recall = avg_prec + prec, avg_recall + recall avg_mae = avg_mae / img_num # avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num # score = (1 + self.beta ** 2) * avg_prec * avg_recall / (self.beta ** 2 * avg_prec + avg_recall) # score[score != score] = 0 # delete the nan # print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max())) print('average mae: %.4f' % (avg_mae)) # print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output) return avg_mae, 1.0 #score.max() # training phase def train(self): start_epoch = 0 best_mae = 1.0 if self.config.val else None if self.config.resume is not None: if not os.path.isfile(self.config.resume): raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) checkpoint = torch.load(self.config.resume) start_epoch = checkpoint['epoch'] if self.config.cuda: self.net.module.load_state_dict(checkpoint['state_dict']) else: self.net.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) best_mae = checkpoint['best_mae'] print("=> loaded checkpoint '{}' (epoch {})" .format(self.config.resume, checkpoint['epoch'])) iter_num = len(self.train_loader.dataset) // self.config.batch_size for epoch in range(start_epoch, self.config.epoch): # if str(epoch + 1) in self.step: # self.update_lr() loss_epoch = 0 tbar = tqdm(self.train_loader) for i, data_batch in enumerate(tbar): if (i + 1) > iter_num: break self.net.zero_grad() x, y, bg, fg= data_batch x, y, bg, fg = x.cuda(), y.cuda(), bg.cuda(), fg.cuda() y_pred = self.net(x, bg=bg, fg=fg) loss = self.loss(y_pred, y) loss.backward() utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient) # if (i+1) % self.update == 0 or (i+1) == iter_num: self.optimizer.step() loss_epoch += loss.item() self.writer.add_scalar('train/total_loss_iter', loss.item(), epoch * iter_num + i) tbar.set_description('epoch:[%d/%d],loss:[%.4f]' % ( epoch, self.config.epoch, loss.item())) # print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % ( # epoch, self.config.epoch, i, iter_num, loss.item())) if self.config.visdom: error = OrderedDict([('loss:', loss.item())]) self.visual.plot_current_errors(epoch, i / iter_num, error) self.writer.add_scalar('train/total_loss_epoch', loss_epoch / iter_num, epoch) if (epoch + 1) % self.config.epoch_show == 0: print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num), file=self.log_output) if self.config.visdom: avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)]) self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1) y_show = torch.mean(torch.cat([y_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) img = OrderedDict([('origin', x.cpu()[0] * self.std + self.mean), ('label', y.cpu()[0][0]), ('pred_label', y_show.cpu().data[0][0])]) self.visual.plot_current_img(img) if self.config.val and (epoch + 1) % self.config.epoch_val == 0: mae, fscore = self.test(100, epoch=epoch+1) self.writer.add_scalar('test/MAE', mae, epoch) self.writer.add_scalar('test/F-Score', fscore, epoch) print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae)) print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output) if best_mae > mae: best_mae = mae torch.save({ 'epoch': epoch + 1, 'state_dict': self.net.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_mae': mae }, '%s/models/best.pth' % self.config.save_fold) # torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold) # if (epoch + 1) % self.config.epoch_save == 0: # torch.save(self.net.module.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1)) torch.save(self.net.module.state_dict(), '%s/models/final.pth' % self.config.save_fold)
class Solver(object): def __init__(self, train_loader, val_loader, test_dataset, config,mode): self.train_loader = train_loader self.val_loader = val_loader self.test_dataset = test_dataset self.config = config self.beta = 0.3 self.select = [1, 2, 3, 6] self.device = torch.device('cuda:0') self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1) self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) self.mode = mode if self.config.mode == "train": self.lossfile = open("%s/logs/loss.txt" % config.save_fold, 'w') self.maefile = open("%s/logs/mae.txt" % config.save_fold, 'w') if self.config.cuda: cudnn.benchmark = True self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("DSS", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def print_network(self, model, name): num_params = 0 for p in model.parameters(): if p.requires_grad: num_params += p.numel() print(name) print(model) print("The number of parameters: {}".format(num_params)) def build_model(self): if (self.mode == 1): self.net = build_model().to(self.device) else: self.net = build_modelv2().to(self.device) if self.config.mode == 'train': self.loss = Loss().to(self.device) self.net.train() self.net.apply(weights_init) if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg)) if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load)) self.optimizer = Adam(self.net.parameters(), self.config.lr) self.print_network(self.net, 'DSS') # update the learning rate def update_lr(self, lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr # evaluate MAE (for test or validation phase) def eval_mae(self, y_pred, y): return torch.abs(y_pred - y).mean() def eval_pr(self, y_pred, y, num): prec, recall = torch.zeros(num), torch.zeros(num) thlist = torch.linspace(0, 1 - 1e-10, num) for i in range(num): y_temp = (y_pred >= thlist[i]).float() tp = (y_temp * y).sum() prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum() return prec, recall # validation: using resize image, and only evaluate the MAE metric def validation(self): avg_mae = 0.0 self.net.eval() with torch.no_grad(): for i, data_batch in enumerate(self.val_loader): images, y1, y2, y3, y4, labels = data_batch images, labels = images.to(self.device), labels.to(self.device) y1 = y1.to(self.device) y2 = y2.to(self.device) y3 = y3.to(self.device) y4 = y4.to(self.device) prob_pred = self.net(images, y1, y2, y3, y4) avg_mae += self.eval_mae(prob_pred, labels).item() self.net.train() return avg_mae / len(self.val_loader) # test phase: using origin image size, evaluate MAE and max F_beta metrics def test(self, num, use_crf=False): if use_crf: from tools.crf_process import crf avg_mae, img_num = 0.0, len(self.test_dataset) avg_prec, avg_recall = torch.zeros(num), torch.zeros(num) with torch.no_grad(): for i, (img, y1,y2,y3,y4, labels) in enumerate(self.test_dataset): img.show() images = self.transform(img).unsqueeze(0) y1 = self.transform(y1).unsqueeze(0) y2= self.transform(y2).unsqueeze(0) y3 = self.transform(y3).unsqueeze(0) y4 = self.transform(y4).unsqueeze(0) if(images.shape != torch.Size([1,3,256,256])): continue labels = labels.unsqueeze(0) shape = labels.size()[2:] images = images.to(self.device) y1 = y1.to(self.device) y2 = y2.to(self.device) y3 = y3.to(self.device) y4 = y4.to(self.device) prob_pred = self.net(images, y1, y2, y3, y4) if (self.mode == 1): prob_pred = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data else: prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data if use_crf: prob_pred = crf(img, prob_pred.numpy(), to_tensor=True) mae = self.eval_mae(prob_pred, labels) prec, recall = self.eval_pr(prob_pred, labels, num) print("[%d] mae: %.4f" % (i, mae)) print("[%d] mae: %.4f" % (i, mae), file=self.test_output) #********************To present hard cases********************************************** """ if (mae>0.2): img.show() ss = prob_pred[0][0].cpu().numpy() ss = 256 * ss ims1 = Image.fromarray(ss) ims1.show() """ avg_mae += mae avg_prec, avg_recall = avg_prec + prec, avg_recall + recall avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num score = (1 + self.beta ** 2) * avg_prec * avg_recall / (self.beta ** 2 * avg_prec + avg_recall) score[score != score] = 0 print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max())) print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output) # training phase def train(self): iter_num = len(self.train_loader.dataset) // self.config.batch_size best_mae = 1.0 if self.config.val else None self.lossfile.write("epoch\tavg_loss\n") self.maefile.write("epoch\tavg_mae\n") for epoch in range(self.config.epoch): loss_epoch = 0 #learning rate decay. if epoch ==30: lr = self.config.lr self.update_lr(lr) mae = 0 for i, data_batch in enumerate(self.train_loader): if (i + 1) > iter_num: break self.net.zero_grad() x, y1,y2,y3,y4, y = data_batch x, y1,y2,y3,y4, y = x.to(self.device), y1.to(self.device),y2.to(self.device),\ y3.to(self.device),y4.to(self.device), y.to(self.device) y_pred = self.net(x, y1, y2 ,y3 ,y4) loss = self.loss(y_pred, y) loss.backward() utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient) self.optimizer.step() loss_epoch += float(loss.item()) tmp_mae = self.eval_mae(y_pred,y).item() mae += tmp_mae print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f], mae: [%.4f]' % ( epoch, self.config.epoch, i, iter_num, loss.item(),tmp_mae)) if self.config.visdom: error = OrderedDict([('loss:', loss.item())]) self.visual.plot_current_errors(epoch, i / iter_num, error) avg_loss = loss_epoch / iter_num self.lossfile.write("%d\t%.4f\n"%(epoch,avg_loss)) avg_mae = mae / iter_num self.maefile.write("%d\t%.4f\n"%(epoch,avg_mae)) if (epoch + 1) % self.config.epoch_show == 0: print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num), file=self.log_output) if self.config.visdom: avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)]) self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1) y_show = torch.mean(torch.cat([y_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) img = OrderedDict([('origin', x.cpu()[0] * self.std + self.mean), ('label', y.cpu()[0][0]), ('pred_label', y_show.cpu().data[0][0])]) self.visual.plot_current_img(img) if self.config.val and (epoch + 1) % self.config.epoch_val == 0: mae = self.validation() print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae)) print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output) if best_mae > mae: best_mae = mae torch.save(self.net.state_dict(), '%s/models/mybest.pth' % self.config.save_fold) if (epoch + 1) % self.config.epoch_save == 0: torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1)) torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)
class Solver(object): def __init__(self, train_loader, target_loader, val_loader, test_dataset, config): self.train_loader = train_loader self.val_loader = val_loader self.test_dataset = test_dataset self.target_loader = target_loader self.config = config self.beta = math.sqrt(0.3) # for max F_beta metric # inference: choose the side map (see paper) self.select = [1, 2, 3, 6] self.device = torch.device('cpu') self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1) self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) self.TENSORBOARD_LOGDIR = f'{config.save_fold}/tensorboards' self.TENSORBOARD_VIZRATE = 100 if self.config.cuda: cudnn.benchmark = True self.device = torch.device('cuda:0') if config.visdom: self.visual = Viz_visdom("DSS", 1) self.build_model() if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) if config.mode == 'train': self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') self.val_output = open("%s/logs/val.txt" % config.save_fold, 'w') else: self.net.load_state_dict(torch.load(self.config.model)) self.net.eval() self.test_output = open("%s/test.txt" % config.test_fold, 'w') self.test_maeid = open("%s/mae_id.txt" % config.test_fold, 'w') self.test_outmap = config.test_map_fold self.transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # print the network information and parameter numbers def print_network(self, model, name): num_params = 0 for p in model.parameters(): if p.requires_grad: num_params += p.numel() print(name) print(model) print("The number of parameters: {}".format(num_params)) # # build the network # def build_model(self): # self.net = build_model().to(self.device) # if self.config.mode == 'train': self.loss = Loss().to(self.device) # self.net.train() # self.net.apply(weights_init) # if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg)) # if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load)) # self.optimizer = Adam(self.net.parameters(), self.config.lr) # self.print_network(self.net, 'DSS') # # build the network --new def build_model(self): if self.config.mode == 'train': self.loss = Loss().to(self.device) self.l2loss = nn.MSELoss().to(self.device) self.iouloss = IoULoss().to(self.device) self.net = build_model().to(self.device) self.net.train() self.net.apply(weights_init) if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg)) if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load)) self.optimizer = Adam(self.net.parameters(), self.config.lr) self.net2 = build_model().to(self.device) self.net2.train() self.net2.apply(weights_init) if self.config.load == '': self.net2.base.load_state_dict(torch.load(self.config.vgg)) if self.config.load != '': self.net2.load_state_dict(torch.load(self.config.load)) self.optimizer2 = Adam(self.net2.parameters(), self.config.lr) # self.print_network(self.net, 'DSS') # update the learning rate def update_lr(self, lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr # evaluate MAE (for test or validation phase) def eval_mae(self, y_pred, y): return torch.abs(y_pred - y).mean() # TODO: write a more efficient version # get precisions and recalls: threshold---divided [0, 1] to num values def eval_pr(self, y_pred, y, num): prec, recall = torch.zeros(num), torch.zeros(num) thlist = torch.linspace(0, 1 - 1e-10, num) for i in range(num): y_temp = (y_pred >= thlist[i]).float() tp = (y_temp * y).sum() # prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum() prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / (y.sum() + 1e-20) return prec, recall # validation: using resize image, and only evaluate the MAE metric def validation(self): avg_mae, avg_loss = 0.0, 0.0 self.net.eval() with torch.no_grad(): for i, data_batch in enumerate(self.val_loader): images, labels = data_batch shape = labels.size()[2:] images, labels = images.to(self.device), labels.to(self.device) _, prob_pred = self.net(images) # for side_num in range(len(prob_pred)): # tmp = torch.sigmoid(prob_pred[side_num])[0] # tmp = tmp.cpu().data # img = ToPILImage()(tmp) # img.save(self.config.val_fold_sub + '/' + self.val_loader.dataset.label_path[i][36:-4] +'_side_' + str(side_num) + '.png') # prob_pred1 = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) # prob_pred1 = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True) # prob_pred2 = torch.mean(torch.cat([torch.sigmoid(prob_pred[i]) for i in self.select], dim=1), dim=1, keepdim=True) # prob_pred2 = F.interpolate(prob_pred2, size=shape, mode='bilinear', align_corners=True) prob_pred2 = F.interpolate(torch.sigmoid(prob_pred), size=shape, mode='bilinear', align_corners=True) # avg_loss += self.loss(prob_pred2, labels).item() avg_mae += self.eval_mae(prob_pred2, labels).item() self.net.train() return avg_mae / len(self.val_loader), avg_loss / len(self.val_loader) # test phase: using origin image size, evaluate MAE and max F_beta metrics def test(self, num, use_crf=False): if use_crf: from tools.crf_process import crf dic = {} avg_mae, img_num = 0.0, len(self.test_dataset) avg_prec, avg_recall = torch.zeros(num), torch.zeros(num) with torch.no_grad(): for i, (img, labels) in enumerate(self.test_dataset): images = self.transform(img).unsqueeze(0) labels = labels.unsqueeze(0) shape = labels.size()[2:] images = images.to(self.device) _, prob_pred = self.net(images) # prob_pred = torch.mean(torch.cat([torch.sigmoid(prob_pred[i]) for i in self.select], dim=1), dim=1, keepdim=True) prob_pred = F.interpolate(torch.sigmoid(prob_pred), size=shape, mode='bilinear', align_corners=True).cpu().data # prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data if use_crf: prob_pred = crf(img, prob_pred.numpy(), to_tensor=True) mae = self.eval_mae(prob_pred, labels) # dic.update({self.test_dataset.label_path[i][self.config.test_map_save_pos:-4] : mae}) prec, recall = self.eval_pr(prob_pred, labels, num) tmp = prob_pred[0] imgpred = ToPILImage()(tmp) imgpred.save(self.test_outmap + '/' + self.test_dataset.label_path[i] [self.config.test_map_save_pos:]) print("[%d] mae: %.4f" % (i, mae)) print("[%d] mae: %.4f" % (i, mae), file=self.test_output) avg_mae += mae avg_prec, avg_recall = avg_prec + prec, avg_recall + recall avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num score = (1 + self.beta**2) * avg_prec * avg_recall / ( self.beta**2 * avg_prec + avg_recall) score[score != score] = 0 # delete the nan print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max())) print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output) # dic_sorted = sorted(dic.items(), key = lambda kv:(kv[1], kv[0]),reverse=True) # # file1 = open('/data1/liumengmeng/CG4_id_mae/HKU-IS.txt','w') # for i in range(int(len(dic_sorted)*0.1)): # print(dic_sorted[i][0] ,file=self.test_maeid) def test_bg(self): dic = {} with torch.no_grad(): for i, img in enumerate(self.test_dataset): print(self.test_dataset.image_path[i] [self.config.test_map_save_pos:-4]) try: images = self.transform(img).unsqueeze(0) images = images.to(self.device) prob_pred = self.net(images) prob_pred = torch.mean(torch.cat( [torch.sigmoid(prob_pred[i]) for i in self.select], dim=1), dim=1, keepdim=True) prob_pred = prob_pred.cpu().data tmp = prob_pred[0] probarray = tmp.numpy() num_1 = len(np.argwhere(probarray > 0.5)) ratio = num_1 / (tmp.shape[1] * tmp.shape[2]) dic.update({ self.test_dataset.image_path[i][self.config.test_map_save_pos:-4]: ratio }) print(ratio) except TypeError as tycode: print(self.test_dataset.image_path[i] [self.config.test_map_save_pos:-4], file=filebad_id) dic_sorted = sorted(dic.items(), key=lambda kv: (kv[1], kv[0])) for i in dic_sorted: print(f'{i[0]} : {i[1]}', file=self.test_output) print(i[0], file=self.test_bg_id) def train(self): num_classes = 1 viz_tensorboard = os.path.exists(self.TENSORBOARD_LOGDIR) if viz_tensorboard: writer = SummaryWriter(log_dir=self.TENSORBOARD_LOGDIR) # # DISCRIMINATOR NETWORK # d_main = get_fc_discriminator(num_classes=num_classes) # d_main.train() # d_main.to(self.device) # # # OPTIMIZERS # # # discriminators' optimizers # optimizer_d_main = optim.Adam(d_main.parameters(), lr=self.config.lr_d, # betas=(0.9, 0.99)) # # LABELS for adversarial training------------------------------------------------------- # source_label = 0 # target_label = 1 trainloader_iter = enumerate(self.train_loader) targetloader_iter = enumerate(self.target_loader) best_mae = 1.0 if self.config.val else None for i_iter in tqdm(range(self.config.early_stop)): # if i_iter >= 3000: # self.update_lr(1e-5) # # reset optimizers self.optimizer.zero_grad() self.optimizer2.zero_grad() # optimizer_d_main.zero_grad() # # adapt LR if needed # adjust_learning_rate(self.optimizer, i_iter, cfg) # adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg) # adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg) # # UDA Training-------------------------------------------------------------------------- # # only train segnet. Don't accumulate grads in disciminators # for param in d_main.parameters(): # param.requires_grad = False # # train on source with seg loss # _, batch = trainloader_iter.__next__() # imgs_src, labels_src = batch # imgs_src, labels_src = imgs_src.to(self.device), labels_src.to(self.device) # pred_src_main = self.net(imgs_src) # # loss_seg_src = self.loss(pred_src_main[0], labels_src) #side output 1 # loss_seg_src = self.loss(pred_src_main, labels_src) #side output 1 - 6 with fusion # loss = loss_seg_src # loss.backward() # utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) # # train on target with seg loss # _, batch1 = targetloader_iter.__next__() # imgs_trg, labels_trg = batch1 # imgs_trg, labels_trg = imgs_trg.to(self.device), labels_trg.to(self.device) # pred_trg = self.net(imgs_trg) # loss_seg_trg = self.loss(pred_trg[5], labels_trg) # side output 6 # loss = loss_seg_trg # loss.backward() # utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) #----train on source branch------------------------------ _, batch = trainloader_iter.__next__() imgs_src, labels_src = batch imgs_src, labels_src = imgs_src.to(self.device), labels_src.to( self.device) smap, pred_src = self.net(imgs_src) stmap, _ = self.net2(imgs_src) loss_seg_src = self.loss(pred_src, labels_src) #sigmoid BCE loss loss_fc_src = self.l2loss(smap, stmap) #L2 loss -> self attention maps loss = loss_seg_src + loss_fc_src loss.backward() # utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) #----train on target branch----------------------------- _, batch1 = targetloader_iter.__next__() imgs_trg, labels_trg = batch1 imgs_trg, labels_trg = imgs_trg.to(self.device), labels_trg.to( self.device) tmap, pred_trg = self.net2(imgs_trg) tsmap, _ = self.net(imgs_trg) loss_ctr_trg = self.iouloss(pred_trg[-1], labels_trg) # IoU loss: dns6 loss_fc_trg = self.l2loss(tmap, tsmap) #L2 loss -> self attention maps loss = loss_ctr_trg + loss_fc_trg loss.backward() utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) utils.clip_grad_norm_(self.net2.parameters(), self.config.clip_gradient) current_losses = { 'loss_seg_src': loss_seg_src, 'loss_fc_src': loss_fc_src, 'loss_ctr_trg': loss_ctr_trg, 'loss_fc_trg': loss_fc_trg } #-------add ADVENT-------------------------------------------------------------------------------- if self.config.add_adv: # adversarial training ot fool the discriminator _, batch = targetloader_iter.__next__() images = batch images = images.to(self.device) pred_trg_main = self.net(images) # d_out_main = d_main(torch.sigmoid(pred_trg_main)) # loss_adv_trg_main = bce_loss(d_out_main, source_label) # loss_adv_trg = self.config.LAMBDA_ADV_MAIN * loss_adv_trg_main # loss = loss_adv_trg # loss.backward() # d_out_main = d_main(prob_2_entropy(pred_trg_main[0])) d_out_main = d_main(torch.sigmoid(pred_trg_main[0])) loss_adv_trg_main = bce_loss(d_out_main, source_label) loss_adv_trg = self.config.LAMBDA_ADV_MAIN * loss_adv_trg_main for i in range(len(pred_trg_main) - 1): # d_out_main = d_main(prob_2_entropy(pred_trg_main[i+1])) d_out_main = d_main(torch.sigmoid(pred_trg_main[i + 1])) loss_adv_trg_main = bce_loss(d_out_main, source_label) loss_adv_trg += self.config.LAMBDA_ADV_MAIN * loss_adv_trg_main loss = loss_adv_trg loss.backward() # Train discriminator networks-------------------- # enable training mode on discriminator networks for param in d_main.parameters(): param.requires_grad = True # # train with source # pred_src_main = pred_src_main.detach() # d_out_main = d_main(torch.sigmoid(pred_src_main)) # loss_d_main = bce_loss(d_out_main, source_label) # loss_d_src = loss_d_main / 2 # loss_d = loss_d_src # loss_d.backward() # # train with target # pred_trg_main = pred_trg_main.detach() # d_out_main = d_main(torch.sigmoid(pred_trg_main)) # loss_d_main = bce_loss(d_out_main, target_label) # loss_d_trg = loss_d_main / 2 # loss_d = loss_d_trg # loss_d.backward() # train with source pred_src_main[0] = pred_src_main[0].detach() # d_out_main = d_main(prob_2_entropy(pred_src_main[0])) d_out_main = d_main(torch.sigmoid(pred_src_main[0])) loss_d_main = bce_loss(d_out_main, source_label) loss_d_src = loss_d_main / 2 for i in range(len(pred_src_main) - 1): pred_src_main[i + 1] = pred_src_main[i + 1].detach() # d_out_main = d_main(prob_2_entropy(pred_src_main[i+1])) d_out_main = d_main(torch.sigmoid(pred_src_main[i + 1])) loss_d_main = bce_loss(d_out_main, source_label) loss_d_src += loss_d_main / 2 loss_d = loss_d_src loss_d.backward() # train with target pred_trg_main[0] = pred_trg_main[0].detach() # d_out_main = d_main(prob_2_entropy(pred_trg_main[0])) d_out_main = d_main(torch.sigmoid(pred_trg_main[0])) loss_d_main = bce_loss(d_out_main, target_label) loss_d_trg = loss_d_main / 2 for i in range(len(pred_trg_main) - 1): pred_trg_main[i + 1] = pred_trg_main[i + 1].detach() # d_out_main = d_main(prob_2_entropy(pred_trg_main[i+1])) d_out_main = d_main(torch.sigmoid(pred_trg_main[i + 1])) loss_d_main = bce_loss(d_out_main, target_label) loss_d_trg += loss_d_main / 2 loss_d = loss_d_trg loss_d.backward() current_losses = { 'loss_seg_src': loss_seg_src, 'loss_adv_trg': loss_adv_trg, 'loss_d_src': loss_d_src, 'loss_d_trg': loss_d_trg } # # optimizer.step()------------------------------------------------------------------------------ self.optimizer.step() self.optimizer2.step() # optimizer_d_main.step() # current_losses = { # 'loss_seg_src': loss_seg_src} # # 'loss_adv_trg': loss_adv_trg, # # 'loss_d_src': loss_d_src, # # 'loss_d_trg': loss_d_trg} print_losses(current_losses, i_iter, self.log_output) if self.config.val and (i_iter + 1) % self.config.iter_val == 0: # val = i_iter + 1 # os.mkdir("%s/val-%d" % (self.config.val_fold, val)) # self.config.val_fold_sub = "%s/val-%d" % (self.config.val_fold, val) mae, loss_val = self.validation() log_vals_tensorboard(writer, best_mae, mae, loss_val, i_iter + 1) tqdm.write('%d:--- Best MAE: %.4f, Curr MAE: %.4f ---' % ((i_iter + 1), best_mae, mae)) print(' %d:--- Best MAE: %.4f, Curr MAE: %.4f ---' % ((i_iter + 1), best_mae, mae), file=self.log_output) print(' %d:--- Best MAE: %.4f, Curr MAE: %.4f ---' % ((i_iter + 1), best_mae, mae), file=self.val_output) if best_mae > mae: best_mae = mae torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold) if (i_iter + 1) % self.config.iter_save == 0 and i_iter != 0: # tqdm.write('taking snapshot ...') # torch.save(self.net.state_dict(), '%s/models/iter_%d.pth' % (self.config.save_fold, i_iter + 1)) # torch.save(d_main.state_dict(), '%s/models/iter_Discriminator_%d.pth' % (self.config.save_fold, i_iter + 1)) if i_iter >= self.config.early_stop - 1: break sys.stdout.flush() if viz_tensorboard: log_losses_tensorboard(writer, current_losses, i_iter) # if i_iter % self.TENSORBOARD_VIZRATE == self.TENSORBOARD_VIZRATE - 1: # draw_in_tensorboard(writer, images, i_iter, pred_trg_main, num_classes, 'T') # draw_in_tensorboard(writer, images_source, i_iter, pred_src_main, num_classes, 'S') # torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold) def train_old(self): print(len(self.train_loader.dataset)) iter_num = len(self.train_loader.dataset) // self.config.batch_size best_mae = 1.0 if self.config.val else None for epoch in range(self.config.epoch): loss_epoch = 0 for i, data_batch in enumerate(self.train_loader): if (i + 1) > iter_num: break self.net.zero_grad() x, y = data_batch x, y = x.to(self.device), y.to(self.device) y_pred = self.net(x) loss = self.loss(y_pred, y) loss.backward() utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient) self.optimizer.step() loss_epoch += loss.item() print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % (epoch, self.config.epoch, i, iter_num, loss.item())) if self.config.visdom: error = OrderedDict([('loss:', loss.item())]) self.visual.plot_current_errors(epoch, i / iter_num, error) if (epoch + 1) % self.config.epoch_show == 0: print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num), file=self.log_output) if self.config.visdom: avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num) ]) self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1) y_show = torch.mean(torch.cat( [y_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) img = OrderedDict([ ('origin', x.cpu()[0] * self.std + self.mean), ('label', y.cpu()[0][0]), ('pred_label', y_show.cpu().data[0][0]) ]) self.visual.plot_current_img(img) # if self.config.val and (epoch + 1) % self.config.epoch_val == 0: # mae = self.validation() # print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae)) # print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output) # if best_mae > mae: # best_mae = mae # torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold) if (epoch + 1) % self.config.epoch_save == 0: torch.save( self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1)) torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)