def __init__(self, args, config, cuda=None): self.args = args os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu self.config = config self.cuda = cuda and torch.cuda.is_available() self.device = torch.device('cuda' if self.cuda else 'cpu') self.best_MIou = 0 self.current_epoch = 0 self.epoch_num = self.config.epoch_num self.current_iter = 0 self.writer = SummaryWriter() # path definition self.val_list_filepath = os.path.join( args.data_root_path, 'VOC2012/ImageSets/Segmentation/val.txt') self.gt_filepath = os.path.join(args.data_root_path, 'VOC2012/SegmentationClass/') self.pre_filepath = os.path.join(args.data_root_path, 'VOC2012/JPEGImages/') # Metric definition self.Eval = Eval(self.config.num_classes) # loss definition if args.loss_weight: classes_weights_path = os.path.join( self.config.classes_weight, self.args.dataset + 'classes_weights_log.npy') print(classes_weights_path) if not os.path.isfile(classes_weights_path): logger.info('calculating class weights...') calculate_weigths_labels(self.config) class_weights = np.load(classes_weights_path) pprint.pprint(class_weights) weight = torch.from_numpy(class_weights.astype(np.float32)) logger.info('loading class weights successfully!') else: weight = None self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=255) self.loss.to(self.device) # model self.model = DeepLab(output_stride=self.args.output_stride, class_num=self.config.num_classes, pretrained=self.args.imagenet_pretrained, bn_momentum=self.args.bn_momentum, freeze_bn=self.args.freeze_bn) self.model = nn.DataParallel(self.model, device_ids=range(4)) patch_replication_callback(self.model) self.model.to(self.device) self.optimizer = torch.optim.SGD( params=[ { "params": self.get_params(self.model.module, key="1x"), "lr": self.args.lr, }, { "params": self.get_params(self.model.module, key="10x"), "lr": 10 * self.args.lr, }, ], momentum=self.config.momentum, # dampening=self.config.dampening, weight_decay=self.config.weight_decay, # nesterov=self.config.nesterov ) # dataloader self.dataloader = VOCDataLoader(self.args, self.config)
class Trainer(): def __init__(self, args, config, cuda=None): self.args = args os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu self.config = config self.cuda = cuda and torch.cuda.is_available() self.device = torch.device('cuda' if self.cuda else 'cpu') self.best_MIou = 0 self.current_epoch = 0 self.epoch_num = self.config.epoch_num self.current_iter = 0 self.writer = SummaryWriter() # path definition self.val_list_filepath = os.path.join( args.data_root_path, 'VOC2012/ImageSets/Segmentation/val.txt') self.gt_filepath = os.path.join(args.data_root_path, 'VOC2012/SegmentationClass/') self.pre_filepath = os.path.join(args.data_root_path, 'VOC2012/JPEGImages/') # Metric definition self.Eval = Eval(self.config.num_classes) # loss definition if args.loss_weight: classes_weights_path = os.path.join( self.config.classes_weight, self.args.dataset + 'classes_weights_log.npy') print(classes_weights_path) if not os.path.isfile(classes_weights_path): logger.info('calculating class weights...') calculate_weigths_labels(self.config) class_weights = np.load(classes_weights_path) pprint.pprint(class_weights) weight = torch.from_numpy(class_weights.astype(np.float32)) logger.info('loading class weights successfully!') else: weight = None self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=255) self.loss.to(self.device) # model self.model = DeepLab(output_stride=self.args.output_stride, class_num=self.config.num_classes, pretrained=self.args.imagenet_pretrained, bn_momentum=self.args.bn_momentum, freeze_bn=self.args.freeze_bn) self.model = nn.DataParallel(self.model, device_ids=range(4)) patch_replication_callback(self.model) self.model.to(self.device) self.optimizer = torch.optim.SGD( params=[ { "params": self.get_params(self.model.module, key="1x"), "lr": self.args.lr, }, { "params": self.get_params(self.model.module, key="10x"), "lr": 10 * self.args.lr, }, ], momentum=self.config.momentum, # dampening=self.config.dampening, weight_decay=self.config.weight_decay, # nesterov=self.config.nesterov ) # dataloader self.dataloader = VOCDataLoader(self.args, self.config) def main(self): # set TensorboardX # display config details logger.info("Global configuration as follows:") pprint.pprint(self.config) pprint.pprint(self.args) # choose cuda if self.cuda: # torch.cuda.set_device(4) current_device = torch.cuda.current_device() logger.info("This model will run on {}".format( torch.cuda.get_device_name(current_device))) else: logger.info("This model will run on CPU") # load pretrained checkpoint if self.args.pretrained: self.load_checkpoint(self.args.saved_checkpoint_file) # train self.train() self.writer.close() def train(self): for epoch in tqdm(range(self.current_epoch, self.epoch_num), desc="Total {} epochs".format( self.config.epoch_num)): self.current_epoch = epoch # self.scheduler.step(epoch) self.train_one_epoch() # validate PA, MPA, MIoU, FWIoU = self.validate() self.writer.add_scalar('PA', PA, self.current_epoch) self.writer.add_scalar('MPA', MPA, self.current_epoch) self.writer.add_scalar('MIoU', MIoU, self.current_epoch) self.writer.add_scalar('FWIoU', FWIoU, self.current_epoch) is_best = MIoU > self.best_MIou if is_best: self.best_MIou = MIoU self.save_checkpoint(is_best, self.args.store_checkpoint_name) # writer.add_scalar('PA', PA) # print(PA) def train_one_epoch(self): tqdm_epoch = tqdm(self.dataloader.train_loader, total=self.dataloader.train_iterations, desc="Train Epoch-{}-".format(self.current_epoch + 1)) logger.info("Training one epoch...") self.Eval.reset() # Set the model to be in training mode (for batchnorm and dropout) train_loss = [] preds = [] lab = [] self.model.train() # Initialize your average meters batch_idx = 0 for x, y, _ in tqdm_epoch: self.poly_lr_scheduler( optimizer=self.optimizer, init_lr=self.args.lr, iter=self.current_iter, max_iter=self.args.iter_max, power=self.config.poly_power, ) if self.current_iter >= self.args.iter_max: logger.info("iteration arrive {}!".format(self.args.iter_max)) break self.writer.add_scalar('learning_rate', self.optimizer.param_groups[0]["lr"], self.current_iter) self.writer.add_scalar('learning_rate_10x', self.optimizer.param_groups[1]["lr"], self.current_iter) # y.to(torch.long) if self.cuda: x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long) self.optimizer.zero_grad() # model pred = self.model(x) # logger.info("pre:{}".format(pred.data.cpu().numpy())) y = torch.squeeze(y, 1) # logger.info("y:{}".format(y.cpu().numpy())) # pred_s = F.softmax(pred, dim=1) # loss cur_loss = self.loss(pred, y) # optimizer cur_loss.backward() self.optimizer.step() train_loss.append(cur_loss.item()) if batch_idx % self.config.batch_save == 0: logger.info("The train loss of epoch{}-batch-{}:{}".format( self.current_epoch, batch_idx, cur_loss.item())) batch_idx += 1 self.current_iter += 1 # print(cur_loss) if np.isnan(float(cur_loss.item())): raise ValueError('Loss is nan during training...') pred = pred.data.cpu().numpy() label = y.cpu().numpy() argpred = np.argmax(pred, axis=1) self.Eval.add_batch(label, argpred) PA = self.Eval.Pixel_Accuracy() MPA = self.Eval.Mean_Pixel_Accuracy() MIoU = self.Eval.Mean_Intersection_over_Union() FWIoU = self.Eval.Frequency_Weighted_Intersection_over_Union() logger.info( 'Epoch:{}, train PA1:{}, MPA1:{}, MIoU1:{}, FWIoU1:{}'.format( self.current_epoch, PA, MPA, MIoU, FWIoU)) tr_loss = sum(train_loss) / len(train_loss) self.writer.add_scalar('train_loss', tr_loss, self.current_epoch) tqdm.write("The average loss of train epoch-{}-:{}".format( self.current_epoch, tr_loss)) tqdm_epoch.close() def validate(self): logger.info('validating one epoch...') self.Eval.reset() with torch.no_grad(): tqdm_batch = tqdm(self.dataloader.valid_loader, total=self.dataloader.valid_iterations, desc="Val Epoch-{}-".format(self.current_epoch + 1)) val_loss = [] preds = [] lab = [] self.model.eval() for x, y, id in tqdm_batch: # y.to(torch.long) if self.cuda: x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long) # model pred = self.model(x) y = torch.squeeze(y, 1) cur_loss = self.loss(pred, y) if np.isnan(float(cur_loss.item())): raise ValueError('Loss is nan during validating...') val_loss.append(cur_loss.item()) # if self.args.store_result == True and self.current_epoch == 20: # for i in range(len(id)): # result = Image.fromarray(np.asarray(argpred, dtype=np.uint8)[i], mode='P') # # logger.info("before:{}".format(result.mode)) # result = result.convert("RGB") # # logger.info("after:{}".format(result.mode)) # # logger.info("shape:{}".format(result.getpixel((1,1)))) # result.save(self.args.result_filepath + id[i] + '.png') pred = pred.data.cpu().numpy() label = y.cpu().numpy() argpred = np.argmax(pred, axis=1) self.Eval.add_batch(label, argpred) PA = self.Eval.Pixel_Accuracy() MPA = self.Eval.Mean_Pixel_Accuracy() MIoU = self.Eval.Mean_Intersection_over_Union() FWIoU = self.Eval.Frequency_Weighted_Intersection_over_Union() logger.info( 'Epoch:{}, validation PA1:{}, MPA1:{}, MIoU1:{}, FWIoU1:{}'. format(self.current_epoch, PA, MPA, MIoU, FWIoU)) v_loss = sum(val_loss) / len(val_loss) logger.info("The average loss of val loss:{}".format(v_loss)) self.writer.add_scalar('val_loss', v_loss, self.current_epoch) # logger.info(score) tqdm_batch.close() return PA, MPA, MIoU, FWIoU def save_checkpoint(self, is_best, filename=None): """ Save checkpoint if a new best is achieved :param state: :param is_best: :param filepath: :return: """ filename = os.path.join(self.args.checkpoint_dir, filename) state = { 'epoch': self.current_epoch + 1, 'iteration': self.current_iter, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_MIou': self.best_MIou } if is_best: logger.info("=>saving a new best checkpoint...") torch.save(state, filename) else: logger.info("=> The MIoU of val does't improve.") def load_checkpoint(self, filename): filename = os.path.join(self.args.checkpoint_dir, filename) try: logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) # self.current_epoch = checkpoint['epoch'] # self.current_iter = checkpoint['iteration'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_MIou = checkpoint['best_MIou'] logger.info( "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {},MIoU:{})\n" .format(self.args.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration'], checkpoint['best_MIou'])) except OSError as e: logger.info("No checkpoint exists from '{}'. Skipping...".format( self.args.checkpoint_dir)) logger.info("**First time to train**") def get_params(self, model, key): # For Dilated CNN if key == "1x": for m in model.named_modules(): if "Resnet101" in m[0]: if isinstance(m[1], nn.Conv2d): for p in m[1].parameters(): yield p # if key == "10x": for m in model.named_modules(): if "encoder" in m[0] or "decoder" in m[0]: if isinstance(m[1], nn.Conv2d): for p in m[1].parameters(): yield p def poly_lr_scheduler(self, optimizer, init_lr, iter, max_iter, power): new_lr = init_lr * (1 - float(iter) / max_iter)**power optimizer.param_groups[0]["lr"] = new_lr optimizer.param_groups[1]["lr"] = 10 * new_lr
class Trainer(): def __init__(self, args, cuda=None): self.args = args self.cuda = cuda and torch.cuda.is_available() self.device = torch.device('cuda' if self.cuda else 'cpu') self.current_MIoU = 0 self.best_MIou = 0 self.current_epoch = 0 self.current_iter = 0 self.batch_idx = 0 # set TensorboardX self.writer = SummaryWriter() # Metric definition self.Eval = Eval(self.args.num_classes) if self.args.loss == 'tanimoto': self.loss = tanimoto_loss() else: self.loss = nn.BCEWithLogitsLoss() self.loss.to(self.device) # model self.model = DeepLab(output_stride=self.args.output_stride, class_num=self.args.num_classes, num_input_channel=self.args.input_channels, pretrained=self.args.imagenet_pretrained and self.args.pretrained_ckpt_file is None, bn_eps=self.args.bn_eps, bn_momentum=self.args.bn_momentum, freeze_bn=self.args.freeze_bn) if torch.cuda.device_count() > 1: self.model = nn.DataParallel(self.model) patch_replication_callback(self.model) self.m = self.model.module else: self.m = self.model self.model.to(self.device) self.optimizer = torch.optim.SGD( params=[ { "params": self.get_params(self.m, key="1x"), "lr": self.args.lr, }, { "params": self.get_params(self.m, key="10x"), "lr": 10 * self.args.lr, }, ], momentum=self.args.momentum, weight_decay=self.args.weight_decay, ) self.dataloader = ISICDataLoader(self.args) self.epoch_num = ceil(self.args.iter_max / self.dataloader.train_iterations) if self.args.input_channels == 3: self.train_func = self.train_3ch if args.using_bb != 'none': if self.args.store_result: self.validate_func = self.validate_crop_store_result else: self.validate_func = self.validate_crop else: self.validate_func = self.validate_3ch else: self.train_func = self.train_4ch self.validate_func = self.validate_4ch if self.args.store_result: self.validate_one_epoch = self.validate_one_epoch_store_result def main(self): logger.info("Global configuration as follows:") for key, val in vars(self.args).items(): logger.info("{:16} {}".format(key, val)) if self.cuda: current_device = torch.cuda.current_device() logger.info("This model will run on {}".format( torch.cuda.get_device_name(current_device))) else: logger.info("This model will run on CPU") if self.args.pretrained_ckpt_file is not None: self.load_checkpoint(self.args.pretrained_ckpt_file) if self.args.validate: self.validate() else: self.train() self.writer.close() def train(self): for epoch in tqdm(range(self.current_epoch, self.epoch_num), desc="Total {} epochs".format(self.epoch_num)): self.current_epoch = epoch tqdm_epoch = tqdm( self.dataloader.train_loader, total=self.dataloader.train_iterations, desc="Train Epoch-{}-".format(self.current_epoch + 1)) logger.info("Training one epoch...") self.Eval.reset() self.train_loss = [] self.model.train() if self.args.freeze_bn: for m in self.model.modules(): if isinstance(m, SynchronizedBatchNorm2d): m.eval() # Initialize your average meters self.train_func(tqdm_epoch) MIoU_single_img, MIoU_thresh = self.Eval.Mean_Intersection_over_Union( ) logger.info('Epoch:{}, train MIoU1:{}'.format( self.current_epoch, MIoU_thresh)) tr_loss = sum(self.train_loss) / len(self.train_loss) self.writer.add_scalar('train_loss', tr_loss, self.current_epoch) tqdm.write("The average loss of train epoch-{}-:{}".format( self.current_epoch, tr_loss)) tqdm_epoch.close() if self.current_epoch % 10 == 0: state = { 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_MIou': self.current_MIoU } # logger.info("=>saving the final checkpoint...") torch.save(state, train_id + '_epoca_' + str(self.current_epoch)) # validate if self.args.validation: MIoU, MIoU_thresh = self.validate() self.writer.add_scalar('MIoU', MIoU_thresh, self.current_epoch) self.current_MIoU = MIoU_thresh is_best = MIoU_thresh > self.best_MIou if is_best: self.best_MIou = MIoU_thresh self.save_checkpoint(is_best, train_id + 'best.pth') state = { 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_MIou': self.current_MIoU } logger.info("=>saving the final checkpoint...") torch.save(state, train_id + 'final.pth') def train_3ch(self, tqdm_epoch): for x, y in tqdm_epoch: self.poly_lr_scheduler( optimizer=self.optimizer, init_lr=self.args.lr, iter=self.current_iter, max_iter=self.args.iter_max, power=self.args.poly_power, ) if self.current_iter >= self.args.iter_max: logger.info("iteration arrive {}!".format(self.args.iter_max)) break self.writer.add_scalar('learning_rate', self.optimizer.param_groups[0]["lr"], self.current_iter) self.writer.add_scalar('learning_rate_10x', self.optimizer.param_groups[1]["lr"], self.current_iter) self.train_one_epoch(x, y) def train_4ch(self, tqdm_epoch): for x, y, target in tqdm_epoch: self.poly_lr_scheduler( optimizer=self.optimizer, init_lr=self.args.lr, iter=self.current_iter, max_iter=self.args.iter_max, power=self.args.poly_power, ) if self.current_iter >= self.args.iter_max: logger.info("iteration arrive {}!".format(self.args.iter_max)) break self.writer.add_scalar('learning_rate', self.optimizer.param_groups[0]["lr"], self.current_iter) self.writer.add_scalar('learning_rate_10x', self.optimizer.param_groups[1]["lr"], self.current_iter) target = target.float() x = torch.cat((x, target), dim=1) self.train_one_epoch(x, y) def train_one_epoch(self, x, y): if self.cuda: x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long) y[y > 0] = 1. self.optimizer.zero_grad() # model pred = self.model(x) y = torch.squeeze(y, 1) if self.args.num_classes == 1: y = y.to(device=self.device, dtype=torch.float) pred = pred.squeeze() # loss cur_loss = self.loss(pred, y) # optimizer cur_loss.backward() self.optimizer.step() self.train_loss.append(cur_loss.item()) if self.batch_idx % 50 == 0: logger.info("The train loss of epoch{}-batch-{}:{}".format( self.current_epoch, self.batch_idx, cur_loss.item())) self.batch_idx += 1 self.current_iter += 1 # print(cur_loss) if np.isnan(float(cur_loss.item())): raise ValueError('Loss is nan during training...') def validate(self): logger.info('validating one epoch...') self.Eval.reset() self.iter = 0 with torch.no_grad(): tqdm_batch = tqdm(self.dataloader.valid_loader, total=self.dataloader.valid_iterations, desc="Val Epoch-{}-".format(self.current_epoch + 1)) self.val_loss = [] self.model.eval() self.validate_func(tqdm_batch) MIoU, MIoU_thresh = self.Eval.Mean_Intersection_over_Union() logger.info('validation MIoU1:{}'.format(MIoU)) v_loss = sum(self.val_loss) / len(self.val_loss) print('Miou: ' + str(MIoU) + ' MIoU_thresh: ' + str(MIoU_thresh)) self.writer.add_scalar('val_loss', v_loss, self.current_epoch) tqdm_batch.close() return MIoU, MIoU_thresh def validate_3ch(self, tqdm_batch): for x, y, w, h, name in tqdm_batch: self.validate_one_epoch(x, y, w, h, name) def validate_4ch(self, tqdm_batch): for x, y, target, w, h, name in tqdm_batch: target = target.float() x = torch.cat((x, target), dim=1) self.validate_one_epoch(x, y, w, h, name) def validate_crop(self, tqdm_batch): for i, (x, y, left, top, right, bottom, w, h, name) in enumerate(tqdm_batch): self.validate_one_epoch(x, y, w, h, name, left, top, right, bottom) def validate_crop_store_result(self, tqdm_batch): for i, (x, y, left, top, right, bottom, w, h, name) in enumerate(tqdm_batch): if self.cuda: x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long) # model pred = self.model(x) if self.args.loss == 'tanimoto': pred = (pred - pred.min()) / (pred.max() - pred.min()) else: pred = nn.Sigmoid()(pred) pred = pred.squeeze().data.cpu().numpy() for i, single_argpred in enumerate(pred): pil = Image.fromarray(single_argpred) pil = pil.resize((right[i] - left[i], bottom[i] - top[i])) img = np.array(pil) img_border = cv.copyMakeBorder(img, top[i].numpy(), h[i].numpy() - bottom[i].numpy(), left[i].numpy(), w[i].numpy() - right[i].numpy(), cv.BORDER_CONSTANT, value=[0, 0, 0]) if self.args.store_result: img_border *= 255 pil = Image.fromarray(img_border.astype('uint8')) pil.save(args.result_filepath + 'ISIC_{}.png'.format(name[i])) self.iter += 1 def validate_one_epoch_store_result(self, x, y, w, h, name): if self.cuda: x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long) # model pred = self.model(x) if self.args.loss == 'tanimoto': pred = (pred - pred.min()) / (pred.max() - pred.min()) else: pred = nn.Sigmoid()(pred) pred = pred.squeeze().data.cpu().numpy() for i, single_argpred in enumerate(pred): pil = Image.fromarray(single_argpred) pil = pil.resize((w[i], h[i])) img_border = np.array(pil) if self.args.store_result: img_border *= 255 pil = Image.fromarray(img_border.astype('uint8')) pil.save(args.result_filepath + 'ISIC_{}.png'.format(name[i])) self.iter += 1 # def validate_crop(self, tqdm_batch): # for i, (x, y, left, top, right, bottom, w, h, name) in enumerate(tqdm_batch): # if self.cuda: # x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long) # # pred = self.model(x) # y = torch.squeeze(y, 1) # if self.args.num_classes == 1: # y = y.to(device=self.device, dtype=torch.float) # pred = pred.squeeze() # # cur_loss = self.loss(pred, y) # if np.isnan(float(cur_loss.item())): # raise ValueError('Loss is nan during validating...') # self.val_loss.append(cur_loss.item()) # # pred = pred.data.cpu().numpy() # # pred[pred >= 0.5] = 1 # pred[pred < 0.5] = 0 # print('\n') # for i, single_pred in enumerate(pred): # gt = Image.open(self.args.data_root_path + "ground_truth/ISIC_" + name[i] + "_segmentation.png") # pil = Image.fromarray(single_pred.astype('uint8')) # pil = pil.resize((right[i] - left[i], bottom[i] - top[i])) # img = np.array(pil) # ground_border = np.array(gt) # ground_border[ground_border == 255] = 1 # img_border = cv.copyMakeBorder(img, top[i].numpy(), h[i].numpy() - bottom[i].numpy(), # left[i].numpy(), # w[i].numpy() - right[i].numpy(), cv.BORDER_CONSTANT, value=[0, 0, 0]) # # iou = self.Eval.iou_numpy(img_border, ground_border) # print(name[i] + ' iou: ' + str(iou)) # # if self.args.store_result: # img_border[img_border == 1] = 255 # pil = Image.fromarray(img_border) # pil.save(args.result_filepath + 'ISIC_{}.png'.format(name[i])) # # gt.save(args.result_filepath + 'ISIC_ground_{}.png'.format(name[i])) # # self.iter += 1 def validate_one_epoch(self, x, y, w, h, name, *ltrb): if self.cuda: x, y = x.to(self.device), y.to(device=self.device, dtype=torch.long) # model pred = self.model(x) y = torch.squeeze(y, 1) if self.args.num_classes == 1: y = y.to(device=self.device, dtype=torch.float) pred = pred.squeeze() cur_loss = self.loss(pred, y) if np.isnan(float(cur_loss.item())): raise ValueError('Loss is nan during validating...') self.val_loss.append(cur_loss.item()) pred = pred.data.cpu().numpy() pred[pred >= 0.5] = 1 pred[pred < 0.5] = 0 print('\n') for i, single_pred in enumerate(pred): gt = Image.open(self.args.data_root_path + "ground_truth/ISIC_" + name[i] + "_segmentation.png") pil = Image.fromarray(single_pred.astype('uint8')) if self.args.using_bb and self.args.input_channels == 3: pil = pil.resize( (ltrb[2][i] - ltrb[0][i], ltrb[3][i] - ltrb[1][i])) img = np.array(pil) img_border = cv.copyMakeBorder( img, ltrb[1][i].numpy(), h[i].numpy() - ltrb[3][i].numpy(), ltrb[0][i].numpy(), w[i].numpy() - ltrb[2][i].numpy(), cv.BORDER_CONSTANT, value=[0, 0, 0]) else: pil = pil.resize((w[i], h[i])) img_border = np.array(pil) ground_border = np.array(gt) ground_border[ground_border == 255] = 1 iou = self.Eval.IoU_one_class(img_border, ground_border) print(name[i] + ' iou: ' + str(iou)) if self.args.store_result: img_border[img_border == 1] = 255 pil = Image.fromarray(img_border) pil.save(args.result_filepath + 'ISIC_{}.png'.format(name[i])) # gt.save(args.result_filepath + 'ISIC_ground_{}.png'.format(name[i])) self.iter += 1 def save_checkpoint(self, is_best, filename=None): filename = os.path.join(self.args.checkpoint_dir, filename) state = { 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_MIou': self.best_MIou } if is_best: logger.info("=>saving a new best checkpoint...") torch.save(state, filename) else: logger.info("=> The MIoU of val does't improve.") def load_checkpoint(self, filename): try: logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) if 'module.Resnet101.bn1.weight' in checkpoint['state_dict']: checkpoint2 = collections.OrderedDict([ (k[7:], v) for k, v in checkpoint['state_dict'].items() ]) self.model.load_state_dict(checkpoint2) else: self.model.load_state_dict(checkpoint['state_dict']) if not self.args.freeze_bn: self.current_epoch = checkpoint['epoch'] self.current_iter = checkpoint['iteration'] self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_MIou = checkpoint['best_MIou'] print( "Checkpoint loaded successfully from '{}', MIoU:{})\n".format( self.args.checkpoint_dir, checkpoint['best_MIou'])) logger.info( "Checkpoint loaded successfully from '{}', MIoU:{})\n".format( self.args.checkpoint_dir, checkpoint['best_MIou'])) except OSError as e: logger.info("No checkpoint exists from '{}'. Skipping...".format( self.args.checkpoint_dir)) logger.info("**First time to train**") def get_params(self, model, key): # For Dilated CNN if key == "1x": for m in model.named_modules(): if "Resnet101" in m[0]: if isinstance(m[1], nn.Conv2d): for p in m[1].parameters(): yield p # if key == "10x": for m in model.named_modules(): if "encoder" in m[0] or "decoder" in m[0]: if isinstance(m[1], nn.Conv2d): for p in m[1].parameters(): yield p def poly_lr_scheduler(self, optimizer, init_lr, iter, max_iter, power): new_lr = init_lr * (1 - float(iter) / max_iter)**power optimizer.param_groups[0]["lr"] = new_lr optimizer.param_groups[1]["lr"] = 10 * new_lr
def __init__(self, args, cuda=None): self.args = args self.cuda = cuda and torch.cuda.is_available() self.device = torch.device('cuda' if self.cuda else 'cpu') self.current_MIoU = 0 self.best_MIou = 0 self.current_epoch = 0 self.current_iter = 0 self.batch_idx = 0 # set TensorboardX self.writer = SummaryWriter() # Metric definition self.Eval = Eval(self.args.num_classes) if self.args.loss == 'tanimoto': self.loss = tanimoto_loss() else: self.loss = nn.BCEWithLogitsLoss() self.loss.to(self.device) # model self.model = DeepLab(output_stride=self.args.output_stride, class_num=self.args.num_classes, num_input_channel=self.args.input_channels, pretrained=self.args.imagenet_pretrained and self.args.pretrained_ckpt_file is None, bn_eps=self.args.bn_eps, bn_momentum=self.args.bn_momentum, freeze_bn=self.args.freeze_bn) if torch.cuda.device_count() > 1: self.model = nn.DataParallel(self.model) patch_replication_callback(self.model) self.m = self.model.module else: self.m = self.model self.model.to(self.device) self.optimizer = torch.optim.SGD( params=[ { "params": self.get_params(self.m, key="1x"), "lr": self.args.lr, }, { "params": self.get_params(self.m, key="10x"), "lr": 10 * self.args.lr, }, ], momentum=self.args.momentum, weight_decay=self.args.weight_decay, ) self.dataloader = ISICDataLoader(self.args) self.epoch_num = ceil(self.args.iter_max / self.dataloader.train_iterations) if self.args.input_channels == 3: self.train_func = self.train_3ch if args.using_bb != 'none': if self.args.store_result: self.validate_func = self.validate_crop_store_result else: self.validate_func = self.validate_crop else: self.validate_func = self.validate_3ch else: self.train_func = self.train_4ch self.validate_func = self.validate_4ch if self.args.store_result: self.validate_one_epoch = self.validate_one_epoch_store_result
def __init__(self, args, cuda=None): self.args = args os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu self.cuda = cuda and torch.cuda.is_available() self.device = torch.device('cuda' if self.cuda else 'cpu') self.current_MIoU = 0 self.best_MIou = 0 self.current_epoch = 0 self.current_iter = 0 # set TensorboardX self.writer = SummaryWriter(log_dir=self.args.run_name) # Metric definition self.Eval = Eval(self.args.num_classes) # loss definition if self.args.loss_weight_file is not None: classes_weights_path = os.path.join(self.args.loss_weights_dir, self.args.loss_weight_file) print(classes_weights_path) if not os.path.isfile(classes_weights_path): logger.info('calculating class weights...') calculate_weigths_labels(self.args) class_weights = np.load(classes_weights_path) pprint.pprint(class_weights) weight = torch.from_numpy(class_weights.astype(np.float32)) logger.info('loading class weights successfully!') else: weight = None self.loss = nn.CrossEntropyLoss(weight=weight, ignore_index=255) self.loss.to(self.device) # model self.model = DeepLab(output_stride=self.args.output_stride, class_num=self.args.num_classes, pretrained=self.args.imagenet_pretrained and self.args.pretrained_ckpt_file==None, bn_momentum=self.args.bn_momentum, freeze_bn=self.args.freeze_bn) self.model = nn.DataParallel(self.model, device_ids=range(ceil(len(self.args.gpu)/2))) patch_replication_callback(self.model) self.model.to(self.device) self.optimizer = torch.optim.SGD( params=[ { "params": self.get_params(self.model.module, key="1x"), "lr": self.args.lr, }, { "params": self.get_params(self.model.module, key="10x"), "lr": 10 * self.args.lr, }, ], momentum=self.args.momentum, # dampening=self.args.dampening, weight_decay=self.args.weight_decay, # nesterov=self.args.nesterov ) # dataloader self.dataloader = VOCDataLoader(self.args) self.epoch_num = ceil(self.args.iter_max / self.dataloader.train_iterations)