def train_fold(fold, args): # loggers logging_logger = args.logging_logger if args.tb_log: tb_logger = args.tb_logger num_classes = utils.problem_class[args.problem_type] # init model model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False) model = nn.DataParallel(model, device_ids=args.device_ids).cuda() # transform for train/valid data train_transform, valid_transform = get_transform(args.model) # loss function loss_func = LossMulti(num_classes, args.jaccard_weight) if args.semi: loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight, args.semi_loss_alpha, args.semi_method) # train/valid filenames train_filenames, valid_filenames = utils.trainval_split( args.train_dir, fold) ckpt_dir = Path(args.ckpt_dir) ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0] res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename) # restore epoch engine.state.epoch = int(res.groups()[0]) # load model state dict model.load_state_dict(torch.load(str(ckpt_filename))) logging_logger.info('restore model [{}] from epoch {}.'.format( args.model, engine.state.epoch)) # DataLoader and Dataset args # train_shuffle = True # train_ds_kwargs = { # 'filenames': train_filenames, # 'problem_type': args.problem_type, # 'transform': train_transform, # 'model': args.model, # 'mode': 'train', # 'semi': args.semi, # } valid_num_workers = args.num_workers valid_batch_size = args.batch_size # if 'TAPNet' in args.model: # # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead # train_shuffle = False # train_ds_kwargs['batch_size'] = args.batch_size # train_ds_kwargs['mf'] = args.mf # if args.semi == True: # train_ds_kwargs['semi_method'] = args.semi_method # train_ds_kwargs['semi_percentage'] = args.semi_percentage # additional valid dataset kws valid_ds_kwargs = { 'filenames': valid_filenames, 'problem_type': args.problem_type, 'transform': valid_transform, 'model': args.model, 'mode': 'valid', } if 'TAPNet' in args.model: # in validation, num_workers should be set to 0 for sequences valid_num_workers = 0 # in validation, batch_size should be set to 1 for sequences valid_batch_size = 1 valid_ds_kwargs['mf'] = args.mf # # train dataloader # train_loader = DataLoader( # dataset=RobotSegDataset(**train_ds_kwargs), # shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle # num_workers=args.num_workers, # batch_size=args.batch_size, # pin_memory=True # ) # valid dataloader valid_loader = DataLoader( dataset=RobotSegDataset(**valid_ds_kwargs), shuffle=False, # in validation, no need to shuffle num_workers=valid_num_workers, batch_size= valid_batch_size, # in valid time. have to use one image by one pin_memory=True) # optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, # weight_decay=args.weight_decay, nesterov=True) # # ignite trainer process function # def train_step(engine, batch): # # set model to train # model.train() # # clear gradients # optimizer.zero_grad() # # additional params to feed into model # add_params = {} # inputs = batch['input'].cuda(non_blocking=True) # with torch.no_grad(): # targets = batch['target'].cuda(non_blocking=True) # # for TAPNet, add attention maps # if 'TAPNet' in args.model: # add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) # outputs = model(inputs, **add_params) # loss_kwargs = {} # if args.semi: # loss_kwargs['labeled'] = batch['labeled'] # if args.semi_method == 'rev_flow': # loss_kwargs['optflow'] = batch['optflow'] # loss = loss_func_semi(outputs, targets, **loss_kwargs) # else: # loss = loss_func(outputs, targets, **loss_kwargs) # loss.backward() # optimizer.step() # return_dict = { # 'output': outputs, # 'target': targets, # 'loss_kwargs': loss_kwargs, # 'loss': loss.item(), # } # # for TAPNet, update attention maps after each iteration # if 'TAPNet' in args.model: # # output_classes and target_classes: <b, h, w> # output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy() # # update attention maps # train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy()) # return_dict['attmap'] = add_params['attmap'] # return return_dict # # init trainer # trainer = engine.Engine(train_step) # # lr scheduler and handler # # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr) # # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler) # # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler) # step_scheduler = optim.lr_scheduler.StepLR(optimizer, # step_size=args.lr_decay_epochs, gamma=args.lr_decay) # lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler) # trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler) # @trainer.on(engine.Events.STARTED) # def trainer_start_callback(engine): # logging_logger.info('training fold {}, {} train / {} valid files'. \ # format(fold, len(train_filenames), len(valid_filenames))) # # resume training # if args.resume: # # ckpt for current fold fold_<fold>_model_<epoch>.pth # ckpt_dir = Path(args.ckpt_dir) # ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0] # res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename) # # restore epoch # engine.state.epoch = int(res.groups()[0]) # # load model state dict # model.load_state_dict(torch.load(str(ckpt_filename))) # logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch)) # else: # logging_logger.info('train model [{}] from scratch'.format(args.model)) # # record metrics history every epoch # engine.state.metrics_records = {} # @trainer.on(engine.Events.EPOCH_STARTED) # def trainer_epoch_start_callback(engine): # # log learning rate on pbar # train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \ # (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size)) # # for TAPNet, change dataset schedule to random after the first epoch # if 'TAPNet' in args.model and engine.state.epoch > 1: # train_loader.dataset.set_dataset_schedule("shuffle") # @trainer.on(engine.Events.ITERATION_COMPLETED) # def trainer_iter_comp_callback(engine): # # logging_logger.info(engine.state.metrics) # pass # # monitor loss # # running average loss # train_ra_loss = imetrics.RunningAverage(output_transform= # lambda x: x['loss'], alpha=0.98) # train_ra_loss.attach(trainer, 'train_ra_loss') # # monitor train loss over epoch # if args.semi: # train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs'])) # else: # train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) # train_loss.attach(trainer, 'train_loss') # # progress bar # train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) # train_metric_names = ['train_ra_loss'] # train_pbar.attach(trainer, metric_names=train_metric_names) # # tensorboardX: log train info # if args.tb_log: # tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'), # event_name=engine.Events.EPOCH_STARTED) # tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names), # event_name=engine.Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']), # event_name=engine.Events.EPOCH_COMPLETED) # tb_logger.attach(trainer, # log_handler=WeightsScalarHandler(model, reduction=torch.norm), # event_name=engine.Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, log_handler=tb_log_train_vars, # event_name=engine.Events.ITERATION_COMPLETED) # ignite validator process function def valid_step(engine, batch): with torch.no_grad(): model.eval() inputs = batch['input'].cuda(non_blocking=True) targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) # output logits outputs = model(inputs, **add_params) # loss loss = loss_func(outputs, targets) output_softmaxs = torch.softmax(outputs, dim=1) output_argmaxs = output_softmaxs.argmax(dim=1) # output_classes and target_classes: <b, h, w> output_classes = output_argmaxs.cpu().numpy() target_classes = targets.cpu().numpy() # record current batch metrics iou_mRecords = MetricRecord() dice_mRecords = MetricRecord() cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32) for output_class, target_class in zip(output_classes, target_classes): # calculate metrics for each frame # calculate using confusion matrix or dirctly using definition cm = calculate_confusion_matrix_from_arrays( output_class, target_class, num_classes) iou_mRecords.update_record(calculate_iou(cm)) dice_mRecords.update_record(calculate_dice(cm)) cm_b += cm ######## calculate directly using definition ########## # iou_mRecords.update_record(iou_multi_np(target_class, output_class)) # dice_mRecords.update_record(dice_multi_np(target_class, output_class)) # accumulate batch metrics to engine state engine.state.epoch_metrics['confusion_matrix'] += cm_b engine.state.epoch_metrics['iou'].merge(iou_mRecords) engine.state.epoch_metrics['dice'].merge(dice_mRecords) return_dict = { 'loss': loss.item(), 'output': outputs, 'output_argmax': output_argmaxs, 'target': targets, # for monitoring 'iou': iou_mRecords, 'dice': dice_mRecords, } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration valid_loader.dataset.update_attmaps( output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] # TODO: for TAPNet, return internal self-learned attention maps return return_dict
def train_fold(fold, args): # loggers logging_logger = args.logging_logger if args.tb_log: tb_logger = args.tb_logger num_classes = utils.problem_class[args.problem_type] # init model model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False) model = nn.DataParallel(model, device_ids=args.device_ids).cuda() # transform for train/valid data train_transform, valid_transform = get_transform(args.model) # loss function loss_func = LossMulti(num_classes, args.jaccard_weight) if args.semi: loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight, args.semi_loss_alpha, args.semi_method) # train/valid filenames train_filenames, valid_filenames = utils.trainval_split(args.train_dir, fold) # DataLoader and Dataset args train_shuffle = True train_ds_kwargs = { 'filenames': train_filenames, 'problem_type': args.problem_type, 'transform': train_transform, 'model': args.model, 'mode': 'train', 'semi': args.semi, } valid_num_workers = args.num_workers valid_batch_size = args.batch_size if 'TAPNet' in args.model: # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead train_shuffle = False train_ds_kwargs['batch_size'] = args.batch_size train_ds_kwargs['mf'] = args.mf if args.semi == True: train_ds_kwargs['semi_method'] = args.semi_method train_ds_kwargs['semi_percentage'] = args.semi_percentage # additional valid dataset kws valid_ds_kwargs = { 'filenames': valid_filenames, 'problem_type': args.problem_type, 'transform': valid_transform, 'model': args.model, 'mode': 'valid', } if 'TAPNet' in args.model: # in validation, num_workers should be set to 0 for sequences valid_num_workers = 0 # in validation, batch_size should be set to 1 for sequences valid_batch_size = 1 valid_ds_kwargs['mf'] = args.mf # train dataloader train_loader = DataLoader( dataset=RobotSegDataset(**train_ds_kwargs), shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle num_workers=args.num_workers, batch_size=args.batch_size, pin_memory=True ) # valid dataloader valid_loader = DataLoader( dataset=RobotSegDataset(**valid_ds_kwargs), shuffle=False, # in validation, no need to shuffle num_workers=valid_num_workers, batch_size=valid_batch_size, # in valid time. have to use one image by one pin_memory=True ) # optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, # weight_decay=args.weight_decay, nesterov=True) # ignite trainer process function def train_step(engine, batch): # set model to train model.train() # clear gradients optimizer.zero_grad() # additional params to feed into model add_params = {} inputs = batch['input'].cuda(non_blocking=True) with torch.no_grad(): targets = batch['target'].cuda(non_blocking=True) # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) outputs = model(inputs, **add_params) loss_kwargs = {} if args.semi: loss_kwargs['labeled'] = batch['labeled'] if args.semi_method == 'rev_flow': loss_kwargs['optflow'] = batch['optflow'] loss = loss_func_semi(outputs, targets, **loss_kwargs) else: loss = loss_func(outputs, targets, **loss_kwargs) loss.backward() optimizer.step() return_dict = { 'output': outputs, 'target': targets, 'loss_kwargs': loss_kwargs, 'loss': loss.item(), } # for TAPNet, update attention maps after each iteration if 'TAPNet' in args.model: # output_classes and target_classes: <b, h, w> output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy() # update attention maps train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy()) return_dict['attmap'] = add_params['attmap'] return return_dict # init trainer trainer = engine.Engine(train_step) # lr scheduler and handler # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr) # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler) # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler) step_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_epochs, gamma=args.lr_decay) lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler) trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler) @trainer.on(engine.Events.STARTED) def trainer_start_callback(engine): logging_logger.info('training fold {}, {} train / {} valid files'. \ format(fold, len(train_filenames), len(valid_filenames))) # resume training if args.resume: # ckpt for current fold fold_<fold>_model_<epoch>.pth ckpt_dir = Path(args.ckpt_dir) ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0] res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename) # restore epoch engine.state.epoch = int(res.groups()[0]) # load model state dict model.load_state_dict(torch.load(str(ckpt_filename))) logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch)) else: logging_logger.info('train model [{}] from scratch'.format(args.model)) # record metrics history every epoch engine.state.metrics_records = {} @trainer.on(engine.Events.EPOCH_STARTED) def trainer_epoch_start_callback(engine): # log learning rate on pbar train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \ (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size)) # for TAPNet, change dataset schedule to random after the first epoch if 'TAPNet' in args.model and engine.state.epoch > 1: train_loader.dataset.set_dataset_schedule("shuffle") @trainer.on(engine.Events.ITERATION_COMPLETED) def trainer_iter_comp_callback(engine): # logging_logger.info(engine.state.metrics) pass # monitor loss # running average loss train_ra_loss = imetrics.RunningAverage(output_transform= lambda x: x['loss'], alpha=0.98) train_ra_loss.attach(trainer, 'train_ra_loss') # monitor train loss over epoch if args.semi: train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs'])) else: train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) train_loss.attach(trainer, 'train_loss') # progress bar train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) train_metric_names = ['train_ra_loss'] train_pbar.attach(trainer, metric_names=train_metric_names) # tensorboardX: log train info if args.tb_log: tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'), event_name=engine.Events.EPOCH_STARTED) tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names), event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']), event_name=engine.Events.EPOCH_COMPLETED) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model, reduction=torch.norm), event_name=engine.Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, log_handler=tb_log_train_vars, # event_name=engine.Events.ITERATION_COMPLETED) # ignite validator process function def valid_step(engine, batch): with torch.no_grad(): model.eval() inputs = batch['input'].cuda(non_blocking=True) targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) # output logits outputs = model(inputs, **add_params) # loss loss = loss_func(outputs, targets) output_softmaxs = torch.softmax(outputs, dim=1) output_argmaxs = output_softmaxs.argmax(dim=1) # output_classes and target_classes: <b, h, w> output_classes = output_argmaxs.cpu().numpy() target_classes = targets.cpu().numpy() # record current batch metrics iou_mRecords = MetricRecord() dice_mRecords = MetricRecord() cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32) for output_class, target_class in zip(output_classes, target_classes): # calculate metrics for each frame # calculate using confusion matrix or dirctly using definition cm = calculate_confusion_matrix_from_arrays(output_class, target_class, num_classes) iou_mRecords.update_record(calculate_iou(cm)) dice_mRecords.update_record(calculate_dice(cm)) cm_b += cm ######## calculate directly using definition ########## # iou_mRecords.update_record(iou_multi_np(target_class, output_class)) # dice_mRecords.update_record(dice_multi_np(target_class, output_class)) # accumulate batch metrics to engine state engine.state.epoch_metrics['confusion_matrix'] += cm_b engine.state.epoch_metrics['iou'].merge(iou_mRecords) engine.state.epoch_metrics['dice'].merge(dice_mRecords) return_dict = { 'loss': loss.item(), 'output': outputs, 'output_argmax': output_argmaxs, 'target': targets, # for monitoring 'iou': iou_mRecords, 'dice': dice_mRecords, } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration valid_loader.dataset.update_attmaps(output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] # TODO: for TAPNet, return internal self-learned attention maps return return_dict # validator engine validator = engine.Engine(valid_step) # monitor loss valid_ra_loss = imetrics.RunningAverage(output_transform= lambda x: x['loss'], alpha=0.98) valid_ra_loss.attach(validator, 'valid_ra_loss') # monitor validation loss over epoch valid_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) valid_loss.attach(validator, 'valid_loss') # monitor <data> mean metrics valid_data_miou = imetrics.RunningAverage(output_transform= lambda x: x['iou'].data_mean()['mean'], alpha=0.98) valid_data_miou.attach(validator, 'mIoU') valid_data_mdice = imetrics.RunningAverage(output_transform= lambda x: x['dice'].data_mean()['mean'], alpha=0.98) valid_data_mdice.attach(validator, 'mDice') # show metrics on progress bar (after every iteration) valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice'] valid_pbar.attach(validator, metric_names=valid_metric_names) # ## monitor ignite IoU (the same as iou we are using) ### # cm = imetrics.ConfusionMatrix(num_classes, # output_transform=lambda x: (x['output'], x['target'])) # imetrics.IoU(cm, # ignore_index=0 # ).attach(validator, 'iou') # # monitor ignite mean iou (over all classes even not exist in gt) # mean_iou = imetrics.mIoU(cm, # ignore_index=0 # ).attach(validator, 'mean_iou') @validator.on(engine.Events.STARTED) def validator_start_callback(engine): pass @validator.on(engine.Events.EPOCH_STARTED) def validator_epoch_start_callback(engine): engine.state.epoch_metrics = { # directly use definition to calculate 'iou': MetricRecord(), 'dice': MetricRecord(), 'confusion_matrix': np.zeros((num_classes, num_classes), dtype=np.uint32), } # evaluate after iter finish @validator.on(engine.Events.ITERATION_COMPLETED) def validator_iter_comp_callback(engine): pass # evaluate after epoch finish @validator.on(engine.Events.EPOCH_COMPLETED) def validator_epoch_comp_callback(engine): # log ignite metrics # logging_logger.info(engine.state.metrics) # ious = engine.state.metrics['iou'] # msg = 'IoU: ' # for ins_id, iou in enumerate(ious): # msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou) # logging_logger.info(msg) # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean())) # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics ######### NOTICE: Two metrics are available but different ########## ### 1. mean metrics for all data calculated by confusion matrix #### ''' compared with using confusion_matrix[1:, 1:] in original code, we use the full confusion matrix and only present non-background result ''' confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) std_ious = np.std(list(ious.values())) std_dices = np.std(list(dices.values())) logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' % (mean_ious, std_ious, ious)) logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' % (mean_dices, std_dices, dices)) ### 2. mean metrics for all data calculated by definition ### iou_data_mean = epoch_metrics['iou'].data_mean() dice_data_mean = epoch_metrics['dice'].data_mean() logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' % (len(iou_data_mean['items']), iou_data_mean['mean'], iou_data_mean['std'])) logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' % (len(dice_data_mean['items']), dice_data_mean['mean'], dice_data_mean['std'])) # record metrics in trainer every epoch # trainer.state.metrics_records[trainer.state.epoch] = \ # {'miou': mean_ious, 'std_miou': std_ious, # 'mdice': mean_dices, 'std_mdice': std_dices} trainer.state.metrics_records[trainer.state.epoch] = \ {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'], 'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']} # log interal variables(attention maps, outputs, etc.) on validation def tb_log_valid_iter_vars(engine, logger, event_name): log_tag = 'valid_iter' output = engine.state.output batch_size = output['output'].shape[0] res_grid = tvutils.make_grid(torch.cat([ output['output_argmax'].unsqueeze(1), output['target'].unsqueeze(1), ]), padding=2, normalize=False, # show origin image nrow=batch_size).cpu() logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag), img_tensor=res_grid) if 'TAPNet' in args.model: # log attention maps and other internal values inter_vals_grid = tvutils.make_grid(torch.cat([ output['attmap'], ]), padding=2, normalize=True, nrow=batch_size).cpu() logger.writer.add_image(tag='%s internal vals' % (log_tag), img_tensor=inter_vals_grid) def tb_log_valid_epoch_vars(engine, logger, event_name): log_tag = 'valid_iter' # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch) logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch) if args.tb_log: # log internal values tb_logger.attach(validator, log_handler=tb_log_valid_iter_vars, event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=tb_log_valid_epoch_vars, event_name=engine.Events.EPOCH_COMPLETED) # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names), # event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=OutputHandler('valid_epoch', ['valid_loss']), event_name=engine.Events.EPOCH_COMPLETED) # score function for model saving ckpt_score_function = lambda engine: \ np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values())) # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean'] ckpt_filename_prefix = 'fold_%d' % fold # model saving handler model_ckpt_handler = handlers.ModelCheckpoint( dirname=args.model_save_dir, filename_prefix=ckpt_filename_prefix, score_function=ckpt_score_function, create_dir=True, require_empty=False, save_as_state_dict=True, atomic=True) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=model_ckpt_handler, to_save={ 'model': model, }) # early stop # trainer=trainer, but should be handled by validator early_stopping = handlers.EarlyStopping(patience=args.es_patience, score_function=ckpt_score_function, trainer=trainer ) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=early_stopping) # evaluate after epoch finish @trainer.on(engine.Events.EPOCH_COMPLETED) def trainer_epoch_comp_callback(engine): validator.run(valid_loader) trainer.run(train_loader, max_epochs=args.max_epochs) if args.tb_log: # close tb_logger tb_logger.close() return trainer.state.metrics_records
def main(): parser = argparse.ArgumentParser() arg = parser.add_argument arg('--jaccard-weight', default=1, type=float) arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs') arg('--fold', type=int, help='fold', default=0) arg('--root', default='runs/debug', help='checkpoint root') arg('--batch-size', type=int, default=1) arg('--n-epochs', type=int, default=10) arg('--lr', type=float, default=0.0002) arg('--workers', type=int, default=10) arg('--type', type=str, default='binary', choices=['binary', 'parts', 'instruments']) arg('--model', type=str, default='DLinkNet', choices=['UNet', 'UNet11', 'LinkNet34', 'DLinkNet']) args = parser.parse_args() root = Path(args.root) root.mkdir(exist_ok=True, parents=True) if args.type == 'parts': num_classes = 4 elif args.type == 'instruments': num_classes = 8 else: num_classes = 1 if args.model == 'UNet': model = UNet(num_classes=num_classes) elif args.model == 'UNet11': model = UNet11(num_classes=num_classes, pretrained='vgg') elif args.model == 'UNet16': model = UNet16(num_classes=num_classes, pretrained='vgg') elif args.model == 'LinkNet34': model = LinkNet34(num_classes=num_classes, pretrained=True) elif args.model == 'DLinkNet': model = D_LinkNet34(num_classes=num_classes, pretrained=True) else: model = UNet(num_classes=num_classes, input_channels=3) if torch.cuda.is_available(): if args.device_ids: device_ids = list(map(int, args.device_ids.split(','))) else: device_ids = None model = nn.DataParallel(model, device_ids=device_ids).cuda() if args.type == 'binary': # loss = LossBinary(jaccard_weight=args.jaccard_weight) loss = LossBCE_DICE() else: loss = LossMulti(num_classes=num_classes, jaccard_weight=args.jaccard_weight) cudnn.benchmark = True def make_loader(file_names, shuffle=False, transform=None, problem_type='binary'): return DataLoader(dataset=RoboticsDataset(file_names, transform=transform, problem_type=problem_type), shuffle=shuffle, num_workers=args.workers, batch_size=args.batch_size, pin_memory=torch.cuda.is_available()) # train_file_names, val_file_names = get_split(args.fold) train_file_names, val_file_names = get_train_val_files() print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names))) train_transform = DualCompose( [HorizontalFlip(), VerticalFlip(), ImageOnly(Normalize())]) val_transform = DualCompose([ImageOnly(Normalize())]) train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform, problem_type=args.type) valid_loader = make_loader(val_file_names, transform=val_transform, problem_type=args.type) root.joinpath('params.json').write_text( json.dumps(vars(args), indent=True, sort_keys=True)) if args.type == 'binary': valid = validation_binary else: valid = validation_multi utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr), args=args, model=model, criterion=loss, train_loader=train_loader, valid_loader=valid_loader, validation=valid, fold=args.fold, num_classes=num_classes)
def main(): parser = argparse.ArgumentParser() arg = parser.add_argument arg('--jaccard-weight', default=1, type=float) arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs') arg('--fold', type=int, help='fold', default=0) arg('--root', default='runs/debug', help='checkpoint root') arg('--batch-size', type=int, default=8) arg('--n-epochs', type=int, default=14) arg('--lr', type=float, default=0.000001) arg('--workers', type=int, default=8) arg('--type', type=str, default='binary', choices=['binary', 'parts', 'instruments']) arg('--model', type=str, default='TernausNet', choices=['UNet', 'UNet11', 'LinkNet34', 'TernausNet']) args = parser.parse_args() root = Path(args.root) root.mkdir(exist_ok=True, parents=True) if args.type == 'parts': num_classes = 3 elif args.type == 'instruments': num_classes = 8 else: num_classes = 1 if args.model == 'TernausNet': model = TernausNet34(num_classes=num_classes) else: model = TernausNet34(num_classes=num_classes) if torch.cuda.is_available(): if args.device_ids: device_ids = list(map(int, args.device_ids.split(','))) else: device_ids = None model = nn.DataParallel(model, device_ids=device_ids).cuda() if args.type == 'binary': loss = LossBinary(jaccard_weight=args.jaccard_weight) else: loss = LossMulti(num_classes=num_classes, jaccard_weight=args.jaccard_weight) cudnn.benchmark = True def make_loader(file_names, shuffle=False, transform=None, mode='train', problem_type='binary'): return DataLoader(dataset=MapDataset(file_names, transform=transform, problem_type=problem_type, mode=mode), shuffle=shuffle, num_workers=args.workers, batch_size=args.batch_size, pin_memory=torch.cuda.is_available()) # labels = pd.read_csv('data/stage1_train_labels.csv') # labels = os.listdir('data/stage1_train_') # train_file_names, val_file_names = train_test_split(labels, test_size=0.2, random_state=42) # print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names))) # train_transform = DualCompose([ # HorizontalFlip(), # VerticalFlip(), # RandomCrop([256, 256]), # RandomRotate90(), # ShiftScaleRotate(), # ImageOnly(RandomHueSaturationValue()), # ImageOnly(RandomBrightness()), # ImageOnly(RandomContrast()), # ImageOnly(Normalize()) # ]) train_transform = DualCompose([ OneOrOther(*(OneOf([ Distort1(distort_limit=0.05, shift_limit=0.05), Distort2(num_steps=2, distort_limit=0.05) ]), ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.10, rotate_limit=45)), prob=0.5), RandomRotate90(), RandomCrop([256, 256]), RandomFlip(prob=0.5), Transpose(prob=0.5), ImageOnly(RandomContrast(limit=0.2, prob=0.5)), ImageOnly(RandomFilter(limit=0.5, prob=0.2)), ImageOnly(RandomHueSaturationValue(prob=0.2)), ImageOnly(RandomBrightness()), ImageOnly(Normalize()) ]) val_transform = DualCompose([ # RandomCrop([256, 256]), Rescale([256, 256]), ImageOnly(Normalize()) ]) train_loader = make_loader(TRAIN_ANNOTATIONS_PATH, shuffle=True, transform=train_transform, problem_type=args.type) valid_loader = make_loader(VAL_ANNOTATIONS_PATH, transform=val_transform, mode='valid', problem_type=args.type) root.joinpath('params.json').write_text( json.dumps(vars(args), indent=True, sort_keys=True)) if args.type == 'binary': valid = validation_binary else: valid = validation_multi utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr), args=args, model=model, criterion=loss, train_loader=train_loader, valid_loader=valid_loader, validation=valid, fold=args.fold, num_classes=num_classes)
def main(): parser = argparse.ArgumentParser() arg = parser.add_argument arg('--jaccard-weight', default=0.5, type=float) arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs') arg('--fold', type=int, help='fold', default=0) arg('--root', default='runs/debug', help='checkpoint root') arg('--batch-size', type=int, default=1) arg('--n-epochs', type=int, default=100) arg('--lr', type=float, default=0.0001) arg('--workers', type=int, default=12) arg('--train_crop_height', type=int, default=1024) arg('--train_crop_width', type=int, default=1280) arg('--val_crop_height', type=int, default=1024) arg('--val_crop_width', type=int, default=1280) arg('--type', type=str, default='binary', choices=['binary', 'parts', 'instruments']) arg('--model', type=str, default='UNet', choices=moddel_list.keys()) args = parser.parse_args() root = Path(args.root) root.mkdir(exist_ok=True, parents=True) if not utils.check_crop_size(args.train_crop_height, args.train_crop_width): print('Input image sizes should be divisible by 32, but train ' 'crop sizes ({train_crop_height} and {train_crop_width}) ' 'are not.'.format(train_crop_height=args.train_crop_height, train_crop_width=args.train_crop_width)) sys.exit(0) if not utils.check_crop_size(args.val_crop_height, args.val_crop_width): print('Input image sizes should be divisible by 32, but validation ' 'crop sizes ({val_crop_height} and {val_crop_width}) ' 'are not.'.format(val_crop_height=args.val_crop_height, val_crop_width=args.val_crop_width)) sys.exit(0) if args.type == 'parts': num_classes = 4 elif args.type == 'instruments': num_classes = 8 else: num_classes = 1 if args.model == 'UNet': model = UNet(num_classes=num_classes) else: model_name = moddel_list[args.model] model = model_name(num_classes=num_classes, pretrained=True) if torch.cuda.is_available(): if args.device_ids: device_ids = list(map(int, args.device_ids.split(','))) else: device_ids = None model = nn.DataParallel(model, device_ids=device_ids).cuda() else: raise SystemError('GPU device not found') if args.type == 'binary': loss = LossBinary(jaccard_weight=args.jaccard_weight) else: loss = LossMulti(num_classes=num_classes, jaccard_weight=args.jaccard_weight) cudnn.benchmark = True def make_loader(file_names, shuffle=False, transform=None, problem_type='binary', batch_size=1): return DataLoader(dataset=RoboticsDataset(file_names, transform=transform, problem_type=problem_type), shuffle=shuffle, num_workers=args.workers, batch_size=batch_size, pin_memory=torch.cuda.is_available()) #print('sfsdgsdhsfffffffffff',args.fold) train_file_names, val_file_names = get_split(args.fold) print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names))) def train_transform(p=1): return Compose([ PadIfNeeded(min_height=args.train_crop_height, min_width=args.train_crop_width, p=1), RandomCrop(height=args.train_crop_height, width=args.train_crop_width, p=1), VerticalFlip(p=0.5), HorizontalFlip(p=0.5), Normalize(p=1) ], p=p) def val_transform(p=1): return Compose([ PadIfNeeded(min_height=args.val_crop_height, min_width=args.val_crop_width, p=1), CenterCrop( height=args.val_crop_height, width=args.val_crop_width, p=1), Normalize(p=1) ], p=p) train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform(p=1), problem_type=args.type, batch_size=args.batch_size) valid_loader = make_loader(val_file_names, transform=val_transform(p=1), problem_type=args.type, batch_size=len(device_ids)) root.joinpath('params.json').write_text( json.dumps(vars(args), indent=True, sort_keys=True)) if args.type == 'binary': valid = validation_binary else: valid = validation_multi print(model.parameters()) utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr), args=args, model=model, criterion=loss, train_loader=train_loader, valid_loader=valid_loader, validation=valid, fold=args.fold, num_classes=num_classes)
def main(): parser = argparse.ArgumentParser() arg = parser.add_argument arg('--jaccard-weight', default=0.5, type=float) arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs') arg('--filepath', type=str, help='folder with images and annotation masks') arg('--root', default='runs/debug', help='checkpoint root') arg('--batch-size', type=int, default=32) arg('--n-epochs', type=int, default=100) arg('--lr', type=float, default=0.0001) arg('--workers', type=int, default=12) arg('--train_crop_height', type=int, default=416) arg('--train_crop_width', type=int, default=416) arg('--val_crop_height', type=int, default=416) arg('--val_crop_width', type=int, default=416) arg('--type', type=str, default='binary', choices=['binary', 'multi']) arg('--model', type=str, default='UNet', choices=model_list.keys()) arg('--datatype', type=str, default='buildings', choices=['buildings', 'roads', 'combined']) arg('--pretrained', action='store_true', help='use pretrained network for initialisation') arg('--num_classes', type=int, default=1) args = parser.parse_args() timestr = time.strftime("%Y%m%d-%H%M%S") root = Path(args.root) root = Path(os.path.join(root, timestr)) root.mkdir(exist_ok=True, parents=True) # dataset_type = args.filepath.split("/")[-3] dataset_type = args.datatype print('log', root, dataset_type) if not utils.check_crop_size(args.train_crop_height, args.train_crop_width): print('Input image sizes should be divisible by 32, but train ' 'crop sizes ({train_crop_height} and {train_crop_width}) ' 'are not.'.format(train_crop_height=args.train_crop_height, train_crop_width=args.train_crop_width)) sys.exit(0) if not utils.check_crop_size(args.val_crop_height, args.val_crop_width): print('Input image sizes should be divisible by 32, but validation ' 'crop sizes ({val_crop_height} and {val_crop_width}) ' 'are not.'.format(val_crop_height=args.val_crop_height, val_crop_width=args.val_crop_width)) sys.exit(0) num_classes = args.num_classes if args.model == 'UNet': model = UNet(num_classes=num_classes) else: model_name = model_list[args.model] model = model_name(num_classes=num_classes, pretrained=args.pretrained) if torch.cuda.is_available(): if args.device_ids: device_ids = list(map(int, args.device_ids.split(','))) else: device_ids = None model = nn.DataParallel(model, device_ids=device_ids).cuda() else: raise SystemError('GPU device not found') if args.type == 'binary': loss = LossBinary(jaccard_weight=args.jaccard_weight) elif args.num_classes == 2: labelweights = [89371542, 7083233] labelweights = np.sum(labelweights) / \ (np.multiply(num_classes, labelweights)) loss = LossMulti(num_classes=num_classes, jaccard_weight=args.jaccard_weight, class_weights=labelweights) else: #labelweights = [30740321,3046555,1554577] #labelweights = labelweights / np.sum(labelweights) #labelweights = 1 / np.log(1.2 + labelweights) labelweights = [89371542, 29703049, 7083233] labelweights = np.sum(labelweights) / \ (np.multiply(num_classes, labelweights)) loss = LossMulti(num_classes=num_classes, jaccard_weight=args.jaccard_weight, class_weights=labelweights) cudnn.benchmark = True train_filename = os.path.join(args.filepath, 'trainval.txt') val_filename = os.path.join(args.filepath, 'test.txt') def train_transform(p=1): return Compose([ PadIfNeeded(min_height=args.train_crop_height, min_width=args.train_crop_width, p=1), RandomCrop(height=args.train_crop_height, width=args.train_crop_width, p=1), VerticalFlip(p=0.5), HorizontalFlip(p=0.5), Normalize(p=1) ], p=p) def val_transform(p=1): return Compose([ PadIfNeeded(min_height=args.val_crop_height, min_width=args.val_crop_width, p=1), CenterCrop(height=args.val_crop_height, width=args.val_crop_width, p=1), Normalize(p=1) ], p=p) train_loader = make_loader(train_filename, shuffle=True, transform=train_transform( p=1), problem_type=args.type, batch_size=args.batch_size, datatype=args.datatype) valid_loader = make_loader(val_filename, transform=val_transform(p=1), problem_type=args.type, batch_size=len(device_ids), datatype=args.datatype) root.joinpath('params.json').write_text( json.dumps(vars(args), indent=True, sort_keys=True)) args.root = root if args.type == 'binary': valid = validation_binary else: valid = validation_multi utils.train( init_optimizer=lambda lr: Adam(model.parameters(), lr=lr), args=args, model=model, criterion=loss, train_loader=train_loader, valid_loader=valid_loader, validation=valid, num_classes=num_classes, model_name=args.model, dataset_type=dataset_type )
save = lambda ep: torch.save({ 'model': model.state_dict(), 'epoch': ep, 'step': step, }, str(model_path)) report_each = 10 valid_each = 4 log = root.joinpath('train_{fold}.log'.format(fold=fold)).open('at', encoding='utf8') valid_losses = [] if(add_log == False): criterion = MultiDiceLoss(num_classes=11) else: criterion = LossMulti(num_classes=11, jaccard_weight=0.5) class_color_table = read_json(json_file_name) first_time = True for epoch in range(epoch, n_epochs + 1): model.train() random.seed() tq = tqdm.tqdm(total=(len(train_loader) * batch_size)) tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) losses = [] try: mean_loss = 0 for i, (inputs, targets) in enumerate(train_loader): # images = inputs.data.cpu().numpy() # targets = targets.data.cpu().numpy() # print(targets.shape) # images = np.moveaxis(images, [0, 1, 2, 3], [0, 3, 1, 2])
def main(): parser = argparse.ArgumentParser() arg = parser.add_argument arg('--jaccard-weight', default=0.5, type=float) arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs') arg('--fold', type=int, help='fold', default=0) arg('--root', default='runs/debug', help='checkpoint root') arg('--batch-size', type=int, default=1) arg('--n-epochs', type=int, default=100) arg('--lr', type=float, default=0.0001) arg('--workers', type=int, default=12) arg('--type', type=str, default='binary', choices=['binary', 'parts', 'instruments']) arg('--model', type=str, default='UNet', choices=['UNet', 'UNet11', 'LinkNet34', 'AlbuNet']) args = parser.parse_args() root = Path(args.root) root.mkdir(exist_ok=True, parents=True) if args.type == 'parts': num_classes = 4 elif args.type == 'instruments': num_classes = 8 else: num_classes = 1 if args.model == 'UNet': model = UNet(num_classes=num_classes) elif args.model == 'UNet11': model = UNet11(num_classes=num_classes, pretrained=True) elif args.model == 'UNet16': model = UNet16(num_classes=num_classes, pretrained=True) elif args.model == 'LinkNet34': model = LinkNet34(num_classes=num_classes, pretrained=True) elif args.model == 'AlbuNet': model = AlbuNet(num_classes=num_classes, pretrained=True) else: model = UNet(num_classes=num_classes, input_channels=3) if torch.cuda.is_available(): if args.device_ids: device_ids = list(map(int, args.device_ids.split(','))) else: device_ids = None model = nn.DataParallel(model, device_ids=device_ids).cuda() if args.type == 'binary': loss = LossBinary(jaccard_weight=args.jaccard_weight) else: loss = LossMulti(num_classes=num_classes, jaccard_weight=args.jaccard_weight) cudnn.benchmark = True def make_loader(file_names, shuffle=False, transform=None, problem_type='binary', batch_size=1): return DataLoader(dataset=CustomDataset(file_names, transform=transform), shuffle=shuffle, num_workers=args.workers, batch_size=batch_size, pin_memory=torch.cuda.is_available()) train_file_names, val_file_names = get_split() print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names))) def train_transform(p=1): return Compose( [ # Rescale(SIZE), RandomCrop(SIZE), RandomBrightness(0.2), OneOf([ IAAAdditiveGaussianNoise(), GaussNoise(), ], p=0.15), # OneOf([ # OpticalDistortion(p=0.3), # GridDistortion(p=.1), # IAAPiecewiseAffine(p=0.3), # ], p=0.1), # OneOf([ # IAASharpen(), # IAAEmboss(), # RandomContrast(), # RandomBrightness(), # ], p=0.15), HueSaturationValue(p=0.15), HorizontalFlip(p=0.5), Normalize(p=1), ], p=p) def val_transform(p=1): return Compose( [ # Rescale(256), RandomCrop(SIZE), Normalize(p=1) ], p=p) train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform(p=1), problem_type=args.type, batch_size=args.batch_size) valid_loader = make_loader(val_file_names, transform=val_transform(p=1), problem_type=args.type, batch_size=len(device_ids)) root.joinpath('params.json').write_text( json.dumps(vars(args), indent=True, sort_keys=True)) if args.type == 'binary': valid = validation_binary else: valid = validation_multi utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr), args=args, model=model, criterion=loss, train_loader=train_loader, valid_loader=valid_loader, validation=valid, fold=args.fold, num_classes=num_classes)