def eval(model, filenames, num_classes): model.eval() # recover h and w factor = config.problem_factor[params.problem_type] original_height, original_width = config.original_height, config.original_width h_start, w_start = config.h_start, config.w_start transform = Compose([ Resize(height=params.train_height, width=params.train_width, p=1), Normalize(p=1) ], p=1) dataloader = DataLoader(dataset=RobotSegDataset( filenames, transform=transform, mode='eval', problem_type=params.problem_type), shuffle=False, num_workers=params.num_workers, batch_size=params.batch_size, pin_memory=True) with torch.no_grad(): # init progress bar for each epoch tq = tqdm.tqdm(total=len(dataloader.dataset)) tq.set_description("Predict [{}]".format(params.model.__name__)) for batch_num, (filenames, inputs) in enumerate(dataloader): # no grad for targets inputs = inputs.cuda(non_blocking=True) outputs = model(inputs) for i, filename in enumerate(filenames): # binary if num_classes == 2: t_mask = ((outputs[i, 0] > 0).data.cpu().numpy() * factor).astype(np.uint8) # t_mask = (torch.sigmoid(outputs[i, 0]).data.cpu().numpy() * factor).astype(np.uint8) else: t_mask = (outputs[i].data.cpu().numpy().argmax(axis=0) * factor).astype(np.uint8) t_mask = cv2.resize(t_mask, dsize=(config.cropped_width, config.cropped_height), interpolation=cv2.INTER_AREA) # generate mask h, w = t_mask.shape # recover to original shape full_mask = np.zeros((original_height, original_width)) full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask # not recover # full_mask = t_mask[0] prediction_folder = Path(filenames[i]).parent.parent / \ 'prediction' / params.problem_type prediction_folder.mkdir(exist_ok=True, parents=True) cv2.imwrite( str(prediction_folder / (Path(filenames[i]).stem + '.png')), full_mask) tq.update(params.batch_size) tq.close()
def train_fold(fold, args): # loggers logging_logger = args.logging_logger if args.tb_log: tb_logger = args.tb_logger num_classes = utils.problem_class[args.problem_type] # init model model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False) model = nn.DataParallel(model, device_ids=args.device_ids).cuda() # transform for train/valid data train_transform, valid_transform = get_transform(args.model) # loss function loss_func = LossMulti(num_classes, args.jaccard_weight) if args.semi: loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight, args.semi_loss_alpha, args.semi_method) # train/valid filenames train_filenames, valid_filenames = utils.trainval_split( args.train_dir, fold) ckpt_dir = Path(args.ckpt_dir) ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0] res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename) # restore epoch engine.state.epoch = int(res.groups()[0]) # load model state dict model.load_state_dict(torch.load(str(ckpt_filename))) logging_logger.info('restore model [{}] from epoch {}.'.format( args.model, engine.state.epoch)) # DataLoader and Dataset args # train_shuffle = True # train_ds_kwargs = { # 'filenames': train_filenames, # 'problem_type': args.problem_type, # 'transform': train_transform, # 'model': args.model, # 'mode': 'train', # 'semi': args.semi, # } valid_num_workers = args.num_workers valid_batch_size = args.batch_size # if 'TAPNet' in args.model: # # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead # train_shuffle = False # train_ds_kwargs['batch_size'] = args.batch_size # train_ds_kwargs['mf'] = args.mf # if args.semi == True: # train_ds_kwargs['semi_method'] = args.semi_method # train_ds_kwargs['semi_percentage'] = args.semi_percentage # additional valid dataset kws valid_ds_kwargs = { 'filenames': valid_filenames, 'problem_type': args.problem_type, 'transform': valid_transform, 'model': args.model, 'mode': 'valid', } if 'TAPNet' in args.model: # in validation, num_workers should be set to 0 for sequences valid_num_workers = 0 # in validation, batch_size should be set to 1 for sequences valid_batch_size = 1 valid_ds_kwargs['mf'] = args.mf # # train dataloader # train_loader = DataLoader( # dataset=RobotSegDataset(**train_ds_kwargs), # shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle # num_workers=args.num_workers, # batch_size=args.batch_size, # pin_memory=True # ) # valid dataloader valid_loader = DataLoader( dataset=RobotSegDataset(**valid_ds_kwargs), shuffle=False, # in validation, no need to shuffle num_workers=valid_num_workers, batch_size= valid_batch_size, # in valid time. have to use one image by one pin_memory=True) # optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, # weight_decay=args.weight_decay, nesterov=True) # # ignite trainer process function # def train_step(engine, batch): # # set model to train # model.train() # # clear gradients # optimizer.zero_grad() # # additional params to feed into model # add_params = {} # inputs = batch['input'].cuda(non_blocking=True) # with torch.no_grad(): # targets = batch['target'].cuda(non_blocking=True) # # for TAPNet, add attention maps # if 'TAPNet' in args.model: # add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) # outputs = model(inputs, **add_params) # loss_kwargs = {} # if args.semi: # loss_kwargs['labeled'] = batch['labeled'] # if args.semi_method == 'rev_flow': # loss_kwargs['optflow'] = batch['optflow'] # loss = loss_func_semi(outputs, targets, **loss_kwargs) # else: # loss = loss_func(outputs, targets, **loss_kwargs) # loss.backward() # optimizer.step() # return_dict = { # 'output': outputs, # 'target': targets, # 'loss_kwargs': loss_kwargs, # 'loss': loss.item(), # } # # for TAPNet, update attention maps after each iteration # if 'TAPNet' in args.model: # # output_classes and target_classes: <b, h, w> # output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy() # # update attention maps # train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy()) # return_dict['attmap'] = add_params['attmap'] # return return_dict # # init trainer # trainer = engine.Engine(train_step) # # lr scheduler and handler # # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr) # # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler) # # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler) # step_scheduler = optim.lr_scheduler.StepLR(optimizer, # step_size=args.lr_decay_epochs, gamma=args.lr_decay) # lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler) # trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler) # @trainer.on(engine.Events.STARTED) # def trainer_start_callback(engine): # logging_logger.info('training fold {}, {} train / {} valid files'. \ # format(fold, len(train_filenames), len(valid_filenames))) # # resume training # if args.resume: # # ckpt for current fold fold_<fold>_model_<epoch>.pth # ckpt_dir = Path(args.ckpt_dir) # ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0] # res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename) # # restore epoch # engine.state.epoch = int(res.groups()[0]) # # load model state dict # model.load_state_dict(torch.load(str(ckpt_filename))) # logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch)) # else: # logging_logger.info('train model [{}] from scratch'.format(args.model)) # # record metrics history every epoch # engine.state.metrics_records = {} # @trainer.on(engine.Events.EPOCH_STARTED) # def trainer_epoch_start_callback(engine): # # log learning rate on pbar # train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \ # (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size)) # # for TAPNet, change dataset schedule to random after the first epoch # if 'TAPNet' in args.model and engine.state.epoch > 1: # train_loader.dataset.set_dataset_schedule("shuffle") # @trainer.on(engine.Events.ITERATION_COMPLETED) # def trainer_iter_comp_callback(engine): # # logging_logger.info(engine.state.metrics) # pass # # monitor loss # # running average loss # train_ra_loss = imetrics.RunningAverage(output_transform= # lambda x: x['loss'], alpha=0.98) # train_ra_loss.attach(trainer, 'train_ra_loss') # # monitor train loss over epoch # if args.semi: # train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs'])) # else: # train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) # train_loss.attach(trainer, 'train_loss') # # progress bar # train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) # train_metric_names = ['train_ra_loss'] # train_pbar.attach(trainer, metric_names=train_metric_names) # # tensorboardX: log train info # if args.tb_log: # tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'), # event_name=engine.Events.EPOCH_STARTED) # tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names), # event_name=engine.Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']), # event_name=engine.Events.EPOCH_COMPLETED) # tb_logger.attach(trainer, # log_handler=WeightsScalarHandler(model, reduction=torch.norm), # event_name=engine.Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, log_handler=tb_log_train_vars, # event_name=engine.Events.ITERATION_COMPLETED) # ignite validator process function def valid_step(engine, batch): with torch.no_grad(): model.eval() inputs = batch['input'].cuda(non_blocking=True) targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) # output logits outputs = model(inputs, **add_params) # loss loss = loss_func(outputs, targets) output_softmaxs = torch.softmax(outputs, dim=1) output_argmaxs = output_softmaxs.argmax(dim=1) # output_classes and target_classes: <b, h, w> output_classes = output_argmaxs.cpu().numpy() target_classes = targets.cpu().numpy() # record current batch metrics iou_mRecords = MetricRecord() dice_mRecords = MetricRecord() cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32) for output_class, target_class in zip(output_classes, target_classes): # calculate metrics for each frame # calculate using confusion matrix or dirctly using definition cm = calculate_confusion_matrix_from_arrays( output_class, target_class, num_classes) iou_mRecords.update_record(calculate_iou(cm)) dice_mRecords.update_record(calculate_dice(cm)) cm_b += cm ######## calculate directly using definition ########## # iou_mRecords.update_record(iou_multi_np(target_class, output_class)) # dice_mRecords.update_record(dice_multi_np(target_class, output_class)) # accumulate batch metrics to engine state engine.state.epoch_metrics['confusion_matrix'] += cm_b engine.state.epoch_metrics['iou'].merge(iou_mRecords) engine.state.epoch_metrics['dice'].merge(dice_mRecords) return_dict = { 'loss': loss.item(), 'output': outputs, 'output_argmax': output_argmaxs, 'target': targets, # for monitoring 'iou': iou_mRecords, 'dice': dice_mRecords, } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration valid_loader.dataset.update_attmaps( output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] # TODO: for TAPNet, return internal self-learned attention maps return return_dict
def process_fold(fold, args): num_classes = utils.problem_class[args.problem_type] factor = utils.problem_factor[args.problem_type] # inputs are RGB images (3 * h * w) # outputs are 2d multilabel segmentation maps (h * w) model = eval(args.model)(in_channels=3, num_classes=num_classes) # data parallel for multi-GPU model = nn.DataParallel(model, device_ids=args.device_ids).cuda() ckpt_dir = Path(args.ckpt_dir) #p = pathlib.Path(ckpt_dir) # ckpt for this fold fold_<fold>_model_<epoch>.pth print("ckpt_dir--> ", ckpt_dir) filenames = glob.glob(args.ckpt_dir + 'fold_%d_model_[0-99]*.pth' % fold) #filenames = glob.glob(args.ckpt_dir+'fold_%d_model_[0-99]*.pth') #filenames = ckpt_dir.glob(args.ckpt_dir+'fold_%d_model_[0-9]*.pth'%fold) print("Filename--> ", filenames) # if len(filenames) != 1: # raise ValueError('invalid model ckpt name. correct ckpt name should be \ # fold_<fold>_model_<epoch>.pth') ckpt_filename = filenames[0] # load state dict model.load_state_dict(torch.load(str(ckpt_filename))) logging.info('Restored model [{}] fold {}.'.format(args.model, fold)) # segmentation mask save directory mask_save_dir = Path(args.mask_save_dir) / ckpt_dir.name mask_save_dir.mkdir(exist_ok=True, parents=True) #print("mask_save_dir", mask_save_dir) eval_transform = Compose( [ Normalize(p=1), PadIfNeeded( min_height=args.input_height, min_width=args.input_width, p=1), # optional Resize(height=args.input_height, width=args.input_width, p=1), # CenterCrop(height=args.input_height, width=args.input_width, p=1) ], p=1) # train/valid filenames, # we evaluate and generate masks on validation set _, eval_filenames = utils.trainval_split(args.train_dir, fold) eval_num_workers = args.num_workers eval_batch_size = args.batch_size # additional ds args if 'TAPNet' in args.model: # in eval, num_workers should be set to 0 for sequences eval_num_workers = 0 # in eval, batch_size should be set to 1 for sequences eval_batch_size = 1 # additional eval dataset kws eval_ds_kwargs = { 'filenames': eval_filenames, 'problem_type': args.problem_type, 'transform': eval_transform, 'model': args.model, 'mode': 'eval', } # valid dataloader eval_loader = DataLoader( dataset=RobotSegDataset(**eval_ds_kwargs), shuffle=False, # in eval, no need to shuffle num_workers=eval_num_workers, batch_size= eval_batch_size, # in valid time. have to use one image by one pin_memory=True) # process function for ignite engine def eval_step(engine, batch): with torch.no_grad(): model.eval() #print("batch Keys-->", batch.keys()) inputs = batch['input'].cuda(non_blocking=True) #targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) outputs = model(inputs, **add_params) output_logsoftmax_np = torch.softmax(outputs, dim=1).cpu().numpy() # output_classes and target_classes: <b, h, w> output_classes = output_logsoftmax_np.argmax(axis=1) masks = (output_classes * factor).astype(np.uint8) #print(size(masks)) return_dict = { 'input_filename': batch['input_filename'], 'mask': masks } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration eval_loader.dataset.update_attmaps(output_logsoftmax_np, batch['idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] return return_dict
def train_fold(fold, args): # loggers logging_logger = args.logging_logger if args.tb_log: tb_logger = args.tb_logger num_classes = utils.problem_class[args.problem_type] # init model model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False) model = nn.DataParallel(model, device_ids=args.device_ids).cuda() # transform for train/valid data train_transform, valid_transform = get_transform(args.model) # loss function loss_func = LossMulti(num_classes, args.jaccard_weight) if args.semi: loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight, args.semi_loss_alpha, args.semi_method) # train/valid filenames train_filenames, valid_filenames = utils.trainval_split(args.train_dir, fold) # DataLoader and Dataset args train_shuffle = True train_ds_kwargs = { 'filenames': train_filenames, 'problem_type': args.problem_type, 'transform': train_transform, 'model': args.model, 'mode': 'train', 'semi': args.semi, } valid_num_workers = args.num_workers valid_batch_size = args.batch_size if 'TAPNet' in args.model: # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead train_shuffle = False train_ds_kwargs['batch_size'] = args.batch_size train_ds_kwargs['mf'] = args.mf if args.semi == True: train_ds_kwargs['semi_method'] = args.semi_method train_ds_kwargs['semi_percentage'] = args.semi_percentage # additional valid dataset kws valid_ds_kwargs = { 'filenames': valid_filenames, 'problem_type': args.problem_type, 'transform': valid_transform, 'model': args.model, 'mode': 'valid', } if 'TAPNet' in args.model: # in validation, num_workers should be set to 0 for sequences valid_num_workers = 0 # in validation, batch_size should be set to 1 for sequences valid_batch_size = 1 valid_ds_kwargs['mf'] = args.mf # train dataloader train_loader = DataLoader( dataset=RobotSegDataset(**train_ds_kwargs), shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle num_workers=args.num_workers, batch_size=args.batch_size, pin_memory=True ) # valid dataloader valid_loader = DataLoader( dataset=RobotSegDataset(**valid_ds_kwargs), shuffle=False, # in validation, no need to shuffle num_workers=valid_num_workers, batch_size=valid_batch_size, # in valid time. have to use one image by one pin_memory=True ) # optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, # weight_decay=args.weight_decay, nesterov=True) # ignite trainer process function def train_step(engine, batch): # set model to train model.train() # clear gradients optimizer.zero_grad() # additional params to feed into model add_params = {} inputs = batch['input'].cuda(non_blocking=True) with torch.no_grad(): targets = batch['target'].cuda(non_blocking=True) # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) outputs = model(inputs, **add_params) loss_kwargs = {} if args.semi: loss_kwargs['labeled'] = batch['labeled'] if args.semi_method == 'rev_flow': loss_kwargs['optflow'] = batch['optflow'] loss = loss_func_semi(outputs, targets, **loss_kwargs) else: loss = loss_func(outputs, targets, **loss_kwargs) loss.backward() optimizer.step() return_dict = { 'output': outputs, 'target': targets, 'loss_kwargs': loss_kwargs, 'loss': loss.item(), } # for TAPNet, update attention maps after each iteration if 'TAPNet' in args.model: # output_classes and target_classes: <b, h, w> output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy() # update attention maps train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy()) return_dict['attmap'] = add_params['attmap'] return return_dict # init trainer trainer = engine.Engine(train_step) # lr scheduler and handler # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr) # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler) # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler) step_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_epochs, gamma=args.lr_decay) lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler) trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler) @trainer.on(engine.Events.STARTED) def trainer_start_callback(engine): logging_logger.info('training fold {}, {} train / {} valid files'. \ format(fold, len(train_filenames), len(valid_filenames))) # resume training if args.resume: # ckpt for current fold fold_<fold>_model_<epoch>.pth ckpt_dir = Path(args.ckpt_dir) ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0] res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename) # restore epoch engine.state.epoch = int(res.groups()[0]) # load model state dict model.load_state_dict(torch.load(str(ckpt_filename))) logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch)) else: logging_logger.info('train model [{}] from scratch'.format(args.model)) # record metrics history every epoch engine.state.metrics_records = {} @trainer.on(engine.Events.EPOCH_STARTED) def trainer_epoch_start_callback(engine): # log learning rate on pbar train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \ (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size)) # for TAPNet, change dataset schedule to random after the first epoch if 'TAPNet' in args.model and engine.state.epoch > 1: train_loader.dataset.set_dataset_schedule("shuffle") @trainer.on(engine.Events.ITERATION_COMPLETED) def trainer_iter_comp_callback(engine): # logging_logger.info(engine.state.metrics) pass # monitor loss # running average loss train_ra_loss = imetrics.RunningAverage(output_transform= lambda x: x['loss'], alpha=0.98) train_ra_loss.attach(trainer, 'train_ra_loss') # monitor train loss over epoch if args.semi: train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs'])) else: train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) train_loss.attach(trainer, 'train_loss') # progress bar train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) train_metric_names = ['train_ra_loss'] train_pbar.attach(trainer, metric_names=train_metric_names) # tensorboardX: log train info if args.tb_log: tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'), event_name=engine.Events.EPOCH_STARTED) tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names), event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']), event_name=engine.Events.EPOCH_COMPLETED) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model, reduction=torch.norm), event_name=engine.Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, log_handler=tb_log_train_vars, # event_name=engine.Events.ITERATION_COMPLETED) # ignite validator process function def valid_step(engine, batch): with torch.no_grad(): model.eval() inputs = batch['input'].cuda(non_blocking=True) targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) # output logits outputs = model(inputs, **add_params) # loss loss = loss_func(outputs, targets) output_softmaxs = torch.softmax(outputs, dim=1) output_argmaxs = output_softmaxs.argmax(dim=1) # output_classes and target_classes: <b, h, w> output_classes = output_argmaxs.cpu().numpy() target_classes = targets.cpu().numpy() # record current batch metrics iou_mRecords = MetricRecord() dice_mRecords = MetricRecord() cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32) for output_class, target_class in zip(output_classes, target_classes): # calculate metrics for each frame # calculate using confusion matrix or dirctly using definition cm = calculate_confusion_matrix_from_arrays(output_class, target_class, num_classes) iou_mRecords.update_record(calculate_iou(cm)) dice_mRecords.update_record(calculate_dice(cm)) cm_b += cm ######## calculate directly using definition ########## # iou_mRecords.update_record(iou_multi_np(target_class, output_class)) # dice_mRecords.update_record(dice_multi_np(target_class, output_class)) # accumulate batch metrics to engine state engine.state.epoch_metrics['confusion_matrix'] += cm_b engine.state.epoch_metrics['iou'].merge(iou_mRecords) engine.state.epoch_metrics['dice'].merge(dice_mRecords) return_dict = { 'loss': loss.item(), 'output': outputs, 'output_argmax': output_argmaxs, 'target': targets, # for monitoring 'iou': iou_mRecords, 'dice': dice_mRecords, } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration valid_loader.dataset.update_attmaps(output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] # TODO: for TAPNet, return internal self-learned attention maps return return_dict # validator engine validator = engine.Engine(valid_step) # monitor loss valid_ra_loss = imetrics.RunningAverage(output_transform= lambda x: x['loss'], alpha=0.98) valid_ra_loss.attach(validator, 'valid_ra_loss') # monitor validation loss over epoch valid_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) valid_loss.attach(validator, 'valid_loss') # monitor <data> mean metrics valid_data_miou = imetrics.RunningAverage(output_transform= lambda x: x['iou'].data_mean()['mean'], alpha=0.98) valid_data_miou.attach(validator, 'mIoU') valid_data_mdice = imetrics.RunningAverage(output_transform= lambda x: x['dice'].data_mean()['mean'], alpha=0.98) valid_data_mdice.attach(validator, 'mDice') # show metrics on progress bar (after every iteration) valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice'] valid_pbar.attach(validator, metric_names=valid_metric_names) # ## monitor ignite IoU (the same as iou we are using) ### # cm = imetrics.ConfusionMatrix(num_classes, # output_transform=lambda x: (x['output'], x['target'])) # imetrics.IoU(cm, # ignore_index=0 # ).attach(validator, 'iou') # # monitor ignite mean iou (over all classes even not exist in gt) # mean_iou = imetrics.mIoU(cm, # ignore_index=0 # ).attach(validator, 'mean_iou') @validator.on(engine.Events.STARTED) def validator_start_callback(engine): pass @validator.on(engine.Events.EPOCH_STARTED) def validator_epoch_start_callback(engine): engine.state.epoch_metrics = { # directly use definition to calculate 'iou': MetricRecord(), 'dice': MetricRecord(), 'confusion_matrix': np.zeros((num_classes, num_classes), dtype=np.uint32), } # evaluate after iter finish @validator.on(engine.Events.ITERATION_COMPLETED) def validator_iter_comp_callback(engine): pass # evaluate after epoch finish @validator.on(engine.Events.EPOCH_COMPLETED) def validator_epoch_comp_callback(engine): # log ignite metrics # logging_logger.info(engine.state.metrics) # ious = engine.state.metrics['iou'] # msg = 'IoU: ' # for ins_id, iou in enumerate(ious): # msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou) # logging_logger.info(msg) # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean())) # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics ######### NOTICE: Two metrics are available but different ########## ### 1. mean metrics for all data calculated by confusion matrix #### ''' compared with using confusion_matrix[1:, 1:] in original code, we use the full confusion matrix and only present non-background result ''' confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) std_ious = np.std(list(ious.values())) std_dices = np.std(list(dices.values())) logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' % (mean_ious, std_ious, ious)) logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' % (mean_dices, std_dices, dices)) ### 2. mean metrics for all data calculated by definition ### iou_data_mean = epoch_metrics['iou'].data_mean() dice_data_mean = epoch_metrics['dice'].data_mean() logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' % (len(iou_data_mean['items']), iou_data_mean['mean'], iou_data_mean['std'])) logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' % (len(dice_data_mean['items']), dice_data_mean['mean'], dice_data_mean['std'])) # record metrics in trainer every epoch # trainer.state.metrics_records[trainer.state.epoch] = \ # {'miou': mean_ious, 'std_miou': std_ious, # 'mdice': mean_dices, 'std_mdice': std_dices} trainer.state.metrics_records[trainer.state.epoch] = \ {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'], 'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']} # log interal variables(attention maps, outputs, etc.) on validation def tb_log_valid_iter_vars(engine, logger, event_name): log_tag = 'valid_iter' output = engine.state.output batch_size = output['output'].shape[0] res_grid = tvutils.make_grid(torch.cat([ output['output_argmax'].unsqueeze(1), output['target'].unsqueeze(1), ]), padding=2, normalize=False, # show origin image nrow=batch_size).cpu() logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag), img_tensor=res_grid) if 'TAPNet' in args.model: # log attention maps and other internal values inter_vals_grid = tvutils.make_grid(torch.cat([ output['attmap'], ]), padding=2, normalize=True, nrow=batch_size).cpu() logger.writer.add_image(tag='%s internal vals' % (log_tag), img_tensor=inter_vals_grid) def tb_log_valid_epoch_vars(engine, logger, event_name): log_tag = 'valid_iter' # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch) logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch) if args.tb_log: # log internal values tb_logger.attach(validator, log_handler=tb_log_valid_iter_vars, event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=tb_log_valid_epoch_vars, event_name=engine.Events.EPOCH_COMPLETED) # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names), # event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=OutputHandler('valid_epoch', ['valid_loss']), event_name=engine.Events.EPOCH_COMPLETED) # score function for model saving ckpt_score_function = lambda engine: \ np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values())) # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean'] ckpt_filename_prefix = 'fold_%d' % fold # model saving handler model_ckpt_handler = handlers.ModelCheckpoint( dirname=args.model_save_dir, filename_prefix=ckpt_filename_prefix, score_function=ckpt_score_function, create_dir=True, require_empty=False, save_as_state_dict=True, atomic=True) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=model_ckpt_handler, to_save={ 'model': model, }) # early stop # trainer=trainer, but should be handled by validator early_stopping = handlers.EarlyStopping(patience=args.es_patience, score_function=ckpt_score_function, trainer=trainer ) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=early_stopping) # evaluate after epoch finish @trainer.on(engine.Events.EPOCH_COMPLETED) def trainer_epoch_comp_callback(engine): validator.run(valid_loader) trainer.run(train_loader, max_epochs=args.max_epochs) if args.tb_log: # close tb_logger tb_logger.close() return trainer.state.metrics_records
def process_fold(fold, args): num_classes = utils.problem_class[args.problem_type] factor = utils.problem_factor[args.problem_type] # inputs are RGB images (3 * h * w) # outputs are 2d multilabel segmentation maps (h * w) model = eval(args.model)(in_channels=3, num_classes=num_classes) # data parallel for multi-GPU model = nn.DataParallel(model, device_ids=args.device_ids).cuda() ckpt_dir = Path(args.ckpt_dir) #p = pathlib.Path(ckpt_dir) # ckpt for this fold fold_<fold>_model_<epoch>.pth print("ckpt_dir--> ", ckpt_dir) filenames = glob.glob(args.ckpt_dir + 'fold_%d_model_[0-99]*.pth' % fold) #filenames = glob.glob(args.ckpt_dir+'fold_%d_model_[0-99]*.pth') #filenames = ckpt_dir.glob(args.ckpt_dir+'fold_%d_model_[0-9]*.pth'%fold) print("Filename--> ", filenames) # if len(filenames) != 1: # raise ValueError('invalid model ckpt name. correct ckpt name should be \ # fold_<fold>_model_<epoch>.pth') ckpt_filename = filenames[0] # load state dict model.load_state_dict(torch.load(str(ckpt_filename))) logging.info('Restored model [{}] fold {}.'.format(args.model, fold)) # segmentation mask save directory mask_save_dir = Path(args.mask_save_dir) / ckpt_dir.name mask_save_dir.mkdir(exist_ok=True, parents=True) #print("mask_save_dir", mask_save_dir) eval_transform = Compose( [ Normalize(p=1), PadIfNeeded( min_height=args.input_height, min_width=args.input_width, p=1), # optional Resize(height=args.input_height, width=args.input_width, p=1), # CenterCrop(height=args.input_height, width=args.input_width, p=1) ], p=1) # train/valid filenames, # we evaluate and generate masks on validation set train_filenames, valid_filenames = utils.trainval_split( args.train_dir, fold) eval_num_workers = args.num_workers eval_batch_size = args.batch_size # additional ds args if 'TAPNet' in args.model: # in eval, num_workers should be set to 0 for sequences eval_num_workers = 0 # in eval, batch_size should be set to 1 for sequences eval_batch_size = 1 # additional eval dataset kws eval_ds_kwargs = { 'filenames': train_filenames, 'problem_type': args.problem_type, 'transform': eval_transform, 'model': args.model, 'mode': 'eval', } # valid dataloader eval_loader = DataLoader( dataset=RobotSegDataset(**eval_ds_kwargs), shuffle=False, # in eval, no need to shuffle num_workers=eval_num_workers, batch_size= eval_batch_size, # in valid time. have to use one image by one pin_memory=True) # process function for ignite engine def eval_step(engine, batch): with torch.no_grad(): model.eval() #print("batch Keys-->", batch.keys()) inputs = batch['input'].cuda(non_blocking=True) #targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) outputs = model(inputs, **add_params) output_logsoftmax_np = torch.softmax(outputs, dim=1).cpu().numpy() # output_classes and target_classes: <b, h, w> output_classes = output_logsoftmax_np.argmax(axis=1) masks = (output_classes * factor).astype(np.uint8) #print(size(masks)) return_dict = { 'input_filename': batch['input_filename'], 'mask': masks } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration eval_loader.dataset.update_attmaps(output_logsoftmax_np, batch['idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] return return_dict # eval engine evaluator = engine.Engine(eval_step) eval_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) #valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) eval_pbar.attach(evaluator) # evaluate after iter finish @evaluator.on(engine.Events.ITERATION_COMPLETED) def evaluator_epoch_comp_callback(engine): global Average_batch_IoU # save masks for each batch batch_output = engine.state.output input_filenames = batch_output['input_filename'] #print("Input_filenames--> ", input_filenames) masks = batch_output['mask'] iou = [] #Average_batch_IoU = [] for i, input_filename in enumerate(input_filenames): mask = cv2.resize(masks[i], dsize=(utils.cropped_width, utils.cropped_height), interpolation=cv2.INTER_AREA) # if pad: # h_start, w_start = utils.h_start, utils.w_start # h, w = mask.shape # # recover to original shape # full_mask = np.zeros((original_height, original_width)) # full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask # mask = full_mask #print("Input Filename-->", input_filename) #img = cv2.imread(input_filename) #instrument_folder_name = input_filename.parent.parent.name instrument_folder_name = os.path.basename( os.path.dirname(os.path.dirname(input_filename))) #print("instrument_folder_name-->", instrument_folder_name) binary_mask = Path(args.type_mask) gt_folder = os.path.dirname( os.path.dirname(input_filename)) / binary_mask #print("gt_folder-->", gt_folder) gt_filename = gt_folder / os.path.basename(input_filename) #print("gt_filename-->", gt_filename) # mask_folder/instrument_dataset_x/problem_type_masks/framexxx.png mask_folder = mask_save_dir / instrument_folder_name / utils.mask_folder[ args.problem_type] mask_folder.mkdir(exist_ok=True, parents=True) mask_filename = mask_folder / os.path.basename(input_filename) gt_mask = cv2.imread(str(gt_filename), cv2.CV_8UC1) #print("mask_filename-->", mask_filename) cv2.imwrite(str(mask_filename), mask) assert (mask.shape == gt_mask.shape) image_iou = get_iou(mask, gt_mask) if math.isnan(image_iou) == False: iou.append(image_iou) #print("IoU for image {} = {}".format(input_filename, iou[-1])) if 'TAPNet' in args.model: attmap = batch_output['attmap'][i] attmap_folder = mask_save_dir / instrument_folder_name / '_'.join( args.problem_type, 'attmaps') attmap_folder.mkdir(exist_ok=True, parents=True) attmap_filename = attmap_folder / os.path.basename( input_filename) cv2.imwrite(str(attmap_filename), attmap) #Average_batch_IoU.append(np.mean(iou)) #Average_batch_IoU = list(np.mean(iou)) Average_batch_IoU.append(np.nanmean(iou)) # evaluator.run(eval_loader) print("Average_batch_IoU-->", np.nanmean(Average_batch_IoU)) f.write(str(np.nanmean(Average_batch_IoU))) f.write('\n')
def main(fold): # check cuda available assert torch.cuda.is_available() == True # when the input dimension doesnot change, add this flag to speed up cudnn.benchmark = True num_classes = config.problem_class[params.problem_type] # input are RGB images in size 3 * h * w # output are binary model = params.model(in_channels=3, num_classes=num_classes) # data parallel model = nn.DataParallel(model, device_ids=params.device_ids).cuda() # loss function if num_classes == 2: loss = LossBinary(jaccard_weight=params.jaccard_weight) valid_metric = validation_binary else: loss = LossMulti(num_classes=num_classes, jaccard_weight=params.jaccard_weight) valid_metric = validation_multi # trainset transform train_transform = Compose([ Resize(height=params.train_height, width=params.train_width, p=1), Normalize(p=1) ], p=1) # validset transform valid_transform = Compose([ Resize(height=params.valid_height, width=params.valid_width, p=1), Normalize(p=1) ], p=1) # train/valid filenmaes train_filenames, valid_filenames = trainval_split(fold) print('num of train / validation files = {} / {}'.format(len(train_filenames), len(valid_filenames))) # train dataloader train_loader = DataLoader( dataset=RobotSegDataset(train_filenames, transform=train_transform), shuffle=True, num_workers=params.num_workers, batch_size=params.batch_size, pin_memory=True ) # valid dataloader valid_loader = DataLoader( dataset=RobotSegDataset(valid_filenames, transform=valid_transform), shuffle=True, num_workers=params.num_workers, batch_size=len(params.device_ids), # in valid time use one img for each dataset pin_memory=True ) train( model=model, loss_func=loss, train_loader=train_loader, valid_loader=valid_loader, valid_metric=valid_metric, fold=fold, num_classes=num_classes )
def main(fold): # check cuda available assert torch.cuda.is_available() == True # when the input dimension doesnot change, add this flag to speed up cudnn.benchmark = True num_classes = config.problem_class[params.problem_type] # input are RGB images in size 3 * h * w # output are binary model = params.model(in_channels=3, num_classes=num_classes) # data parallel model = nn.DataParallel(model, device_ids=params.device_ids).cuda() # loss function if num_classes == 2: loss = LossBinary(jaccard_weight=params.jaccard_weight) valid_metric = validation_binary else: loss = LossMulti(num_classes=num_classes, jaccard_weight=params.jaccard_weight) valid_metric = validation_multi # trainset transform train_transform = Compose([ Resize(height=params.train_height, width=params.train_width, p=1), Normalize(p=1), PadIfNeeded( min_height=params.train_height, min_width=params.train_width, p=1), ], p=1) # validset transform valid_transform = Compose([ PadIfNeeded( min_height=params.valid_height, min_width=params.valid_width, p=1), Resize(height=params.train_height, width=params.train_width, p=1), Normalize(p=1) ], p=1) # train/valid filenmaes train_filenames, valid_filenames = trainval_split(fold) print('fold {}, {} train / {} validation files'.format( fold, len(train_filenames), len(valid_filenames))) # train dataloader train_loader = DataLoader( dataset=RobotSegDataset(train_filenames, transform=train_transform, \ schedule="ordered", batch_size=params.batch_size, problem_type=params.problem_type, semi_percentage=params.semi_percentage), shuffle=False, # set to false to disable pytorch dataset shuffle num_workers=params.num_workers, batch_size=params.batch_size, pin_memory=True ) # valid dataloader valid_loader = DataLoader( dataset=RobotSegDataset(valid_filenames, transform=valid_transform, problem_type=params.problem_type, mode='valid'), shuffle=False, # set to false to disable pytorch dataset shuffle num_workers=0, # params.num_workers, batch_size=1, # in valid time. have to use one image by one pin_memory=True) train(model=model, loss_func=loss, train_loader=train_loader, valid_loader=valid_loader, valid_metric=valid_metric, fold=fold, num_classes=num_classes)
def eval(model, filenames, num_classes): # should set batch_size = 1 batch_size = 1 model.eval() # recover h and w factor = config.problem_factor[params.problem_type] original_height, original_width = config.original_height, config.original_width h_start, w_start = config.h_start, config.w_start transform = Compose([ Resize(height=params.train_height, width=params.train_width, p=1), Normalize(p=1), PadIfNeeded( min_height=params.train_height, min_width=params.train_width, p=1) ], p=1) dataloader = DataLoader(dataset=RobotSegDataset( filenames, transform=transform, mode='eval', problem_type=params.problem_type), shuffle=False, num_workers=0, batch_size=batch_size, pin_memory=True) with torch.no_grad(): # init progress bar for each epoch tq = tqdm.tqdm(total=len(dataloader.dataset)) tq.set_description("Predict [{}]".format(params.model.__name__)) for batch_num, (idxs, filenames, inputs, attmaps) in enumerate(dataloader): # no grad for targets inputs = inputs.cuda(non_blocking=True) attmaps = attmaps.cuda(non_blocking=True) outputs, am, am5, am4, am3, am2, am1 = model(inputs, attmaps) # update attention maps using prediction dataloader.dataset.update_attmaps(outputs.cpu(), idxs) for i, filename in enumerate(filenames): # binary if num_classes == 2: t_mask = ((outputs[i, 0] > 0).data.cpu().numpy() * factor).astype(np.uint8) else: t_mask = (outputs[i].data.cpu().numpy().argmax(axis=0) * factor).astype(np.uint8) t_mask = cv2.resize(t_mask, dsize=(config.cropped_width, config.cropped_height), interpolation=cv2.INTER_AREA) # generate mask h, w = t_mask.shape # recover to original shape full_mask = np.zeros((original_height, original_width)) full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask # not recover # full_mask = t_mask[0] prediction_folder = Path(filenames[i]).parent.parent / \ 'prediction' / params.problem_type prediction_folder.mkdir(exist_ok=True, parents=True) cv2.imwrite( str(prediction_folder / (Path(filenames[i]).stem + '.png')), full_mask) attmaps_folder = Path(filenames[i]).parent.parent / \ 'attmaps' / params.problem_type if num_classes == 2: sigmoid_output = ( torch.sigmoid(outputs[i, 0]).data.cpu().numpy() * 255).astype(np.uint8) else: sigmoid_output = ( (1 - outputs[i, 0].exp()).data.cpu().numpy() * 255).astype(np.uint8) attmaps_folder.mkdir(exist_ok=True, parents=True) cv2.imwrite( str(attmaps_folder / (Path(filenames[i]).stem + '_sig.png')), sigmoid_output) cv2.imwrite( str(attmaps_folder / (Path(filenames[i]).stem + '_am.png')), norm_attmap(attmaps[i].cpu().numpy().squeeze())) cv2.imwrite( str(attmaps_folder / (Path(filenames[i]).stem + '_am_.png')), norm_attmap(am[i].cpu().numpy().squeeze())) cv2.imwrite( str(attmaps_folder / (Path(filenames[i]).stem + '_am5.png')), norm_attmap(am5[i].cpu().numpy().squeeze())) cv2.imwrite( str(attmaps_folder / (Path(filenames[i]).stem + '_am4.png')), norm_attmap(am4[i].cpu().numpy().squeeze())) cv2.imwrite( str(attmaps_folder / (Path(filenames[i]).stem + '_am3.png')), norm_attmap(am3[i].cpu().numpy().squeeze())) cv2.imwrite( str(attmaps_folder / (Path(filenames[i]).stem + '_am2.png')), norm_attmap(am2[i].cpu().numpy().squeeze())) cv2.imwrite( str(attmaps_folder / (Path(filenames[i]).stem + '_am1.png')), norm_attmap(am1[i].cpu().numpy().squeeze())) # cv2.imwrite(str(attmaps_folder / (Path(filenames[i]).stem + '_am4.png')), cv2.equalizeHist((am4[i].cpu().numpy().squeeze() * 255).astype(np.uint8))) # cv2.imwrite(str(attmaps_folder / (Path(filenames[i]).stem + '_am3.png')), cv2.equalizeHist((am3[i].cpu().numpy().squeeze() * 255).astype(np.uint8))) # cv2.imwrite(str(attmaps_folder / (Path(filenames[i]).stem + '_am2.png')), cv2.equalizeHist((am2[i].cpu().numpy().squeeze() * 255).astype(np.uint8))) # cv2.imwrite(str(attmaps_folder / (Path(filenames[i]).stem + '_am1.png')), cv2.equalizeHist((am1[i].cpu().numpy().squeeze() * 255).astype(np.uint8))) tq.update(batch_size) tq.close()