def main(): ############################################################################## if args.server == 'server_A': work_dir = os.path.join('/data1/JM/lung-seg-back-up', args.exp) print(work_dir) elif args.server == 'server_B': work_dir = os.path.join('/data1/workspace/JM_gen/lung-seg-back-up', args.exp) print(work_dir) elif args.server == 'server_D': work_dir = os.path.join( '/daintlab/home/woans0104/workspace/' 'lung-seg-back-up', args.exp) print(work_dir) ############################################################################## if not os.path.exists(work_dir): os.makedirs(work_dir) # copy this file to work dir to keep training configuration shutil.copy(__file__, os.path.join(work_dir, 'main.py')) with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f: pickle.dump(args, f) source_dataset, target_dataset1, target_dataset2 \ = loader.dataset_condition(args.source_dataset) # 1.load_dataset train_loader_source, test_loader_source \ = loader.get_loader(server=args.server, dataset=source_dataset, train_size=args.train_size, aug_mode=args.aug_mode, aug_range=args.aug_range, batch_size=args.batch_size, work_dir=work_dir) train_loader_target1, _ = loader.get_loader(server=args.server, dataset=target_dataset1, train_size=1, aug_mode=False, aug_range=args.aug_range, batch_size=1, work_dir=work_dir) train_loader_target2, _ = loader.get_loader(server=args.server, dataset=target_dataset2, train_size=1, aug_mode=False, aug_range=args.aug_range, batch_size=1, work_dir=work_dir) test_data_li = [ test_loader_source, train_loader_target1, train_loader_target2 ] trn_logger = Logger(os.path.join(work_dir, 'train.log')) trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log')) val_logger = Logger(os.path.join(work_dir, 'validation.log')) trn_logger_ae = Logger(os.path.join(work_dir, 'ae_train.log')) val_logger_ae = Logger(os.path.join(work_dir, 'ae_validation.log')) # 2.model_select model_seg = Unet2D(in_shape=(1, 256, 256)) model_seg = model_seg.cuda() model_ae = ae_lung(in_shape=(1, 256, 256)) model_ae = model_ae.cuda() cudnn.benchmark = True # 3.gpu select model_seg = nn.DataParallel(model_seg) model_ae = nn.DataParallel(model_ae) # 4.optim if args.optim == 'adam': optimizer_seg = torch.optim.Adam(model_seg.parameters(), betas=(args.adam_beta1, 0.999), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay) optimizer_ae = torch.optim.Adam(model_ae.parameters(), betas=(args.adam_beta1, 0.999), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay) elif args.optim == 'adamp': optimizer_seg = AdamP(model_seg.parameters(), betas=(args.adam_beta1, 0.999), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay) optimizer_ae = AdamP(model_ae.parameters(), betas=(args.adam_beta1, 0.999), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay) elif args.optim == 'sgd': optimizer_seg = torch.optim.SGD(model_seg.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_ae = torch.optim.SGD(model_ae.parameters(), lr=args.lr, weight_decay=args.weight_decay) # lr decay lr_schedule = args.lr_schedule lr_scheduler_seg = optim.lr_scheduler.MultiStepLR( optimizer_seg, milestones=lr_schedule[:-1], gamma=0.1) lr_scheduler_ae = optim.lr_scheduler.MultiStepLR( optimizer_ae, milestones=lr_schedule[:-1], gamma=0.1) # 5.loss criterion_seg = select_loss(args.seg_loss_function) criterion_ae = select_loss(args.ae_loss_function) criterion_embedding = select_loss(args.embedding_loss_function) ############################################################################ # train best_iou = 0 try: if args.train_mode: for epoch in range(lr_schedule[-1]): train(model_seg=model_seg, model_ae=model_ae, train_loader=train_loader_source, epoch=epoch, criterion_seg=criterion_seg, criterion_ae=criterion_ae, criterion_embedding=criterion_embedding, optimizer_seg=optimizer_seg, optimizer_ae=optimizer_ae, logger=trn_logger, sublogger=trn_raw_logger, logger_ae=trn_logger_ae) iou = validate(model_seg=model_seg, model_ae=model_ae, val_loader=test_loader_source, epoch=epoch, criterion_seg=criterion_seg, criterion_ae=criterion_ae, logger=val_logger, logger_ae=val_logger_ae) print('validation result ************************************') lr_scheduler_seg.step() lr_scheduler_ae.step() if args.val_size == 0: is_best = 1 else: is_best = iou > best_iou best_iou = max(iou, best_iou) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model_seg.state_dict(), 'optimizer': criterion_seg.state_dict() }, is_best, work_dir, filename='checkpoint.pth') print("train end") except RuntimeError as e: print( '#jm_private', '----------------------------------- error train : ' 'send to message JM ' '& Please send a kakao talk -------------------------- ' '\n error message : {}'.format(e)) import ipdb ipdb.set_trace() draw_curve(work_dir, trn_logger, val_logger) draw_curve(work_dir, trn_logger_ae, val_logger_ae, labelname='ae') # here is load model for last pth check_best_pth(work_dir) # validation if args.test_mode: print('Test mode ...') main_test(model=model_seg, test_loader=test_data_li, args=args)
def main(): # save input stats for later use print(args.work_dir, args.exp) work_dir = os.path.join(args.work_dir, args.exp) if not os.path.exists(work_dir): os.makedirs(work_dir) # copy this file to work dir to keep training configuration shutil.copy(__file__, os.path.join(work_dir, 'main.py')) with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f: pickle.dump(args, f) # transform transform1 = transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # 1.train_dataset train_path, test_path = loader.make_dataset(args.train_site, train_size=args.train_size, mode='train') np.save(os.path.join(work_dir, '{}_test_path.npy'.format(args.train_site)), test_path) train_image_path = train_path[0] train_label_path = train_path[1] test_image_path = test_path[0] test_label_path = test_path[1] train_dataset = loader.CustomDataset(train_image_path, train_label_path, args.train_site, args.input_size, transform1, arg_mode=args.arg_mode, arg_thres=args.arg_thres) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) val_dataset = loader.CustomDataset(test_image_path, test_label_path, args.train_site, args.input_size, transform1, arg_mode=False) val_loader = data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4) Train_test_dataset = loader.CustomDataset(test_image_path, test_label_path, args.train_site, args.input_size, transform1) Train_test_loader = data.DataLoader(Train_test_dataset, batch_size=1, shuffle=True, num_workers=4) trn_logger = Logger(os.path.join(work_dir, 'train.log')) trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log')) val_logger = Logger(os.path.join(work_dir, 'validation.log')) # 3.model_select my_net, model_name = model_select( args.arch, args.input_size, ) # 4.gpu select my_net = nn.DataParallel(my_net).cuda() cudnn.benchmark = True # 5.optim if args.optim == 'adam': gen_optimizer = torch.optim.Adam(my_net.parameters(), lr=args.initial_lr, eps=args.eps) elif args.optim == 'sgd': gen_optimizer = torch.optim.SGD(my_net.parameters(), lr=args.initial_lr, momentum=0.9, weight_decay=args.weight_decay) # lr decay lr_schedule = args.lr_schedule lr_scheduler = optim.lr_scheduler.MultiStepLR(gen_optimizer, milestones=lr_schedule[:-1], gamma=args.gamma) # 6.loss if args.loss_function == 'bce': criterion = nn.BCEWithLogitsLoss( pos_weight=torch.Tensor([args.bce_weight])).cuda() elif args.loss_function == 'mse': criterion = nn.MSELoss().cuda() ##################################################################################### # train send_slack_message(args.token, '#jm_private', '{} : starting_training'.format(args.exp)) best_iou = 0 try: if args.train_mode: for epoch in range(lr_schedule[-1]): train(my_net, train_loader, gen_optimizer, epoch, criterion, trn_logger, trn_raw_logger) iou = validate(val_loader, my_net, criterion, epoch, val_logger, save_fig=False, work_dir_name='jsrt_visualize_per_epoch') print( 'validation_iou **************************************************************' ) lr_scheduler.step() if args.val_size == 0: is_best = 1 else: is_best = iou > best_iou best_iou = max(iou, best_iou) checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format( epoch + 1) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': my_net.state_dict(), 'optimizer': gen_optimizer.state_dict() }, is_best, work_dir, filename='checkpoint.pth') print("train end") except RuntimeError as e: send_slack_message( args.token, '#jm_private', '----------------------------------- error train : send to message JM & Please send a kakao talk ----------------------------------------- \n error message : {}' .format(e)) import ipdb ipdb.set_trace() draw_curve(work_dir, trn_logger, val_logger) send_slack_message(args.token, '#jm_private', '{} : end_training'.format(args.exp)) if args.test_mode: print('Test mode ...') main_test(model=my_net, test_loader=test_data_list, args=args)
def main(): work_dir = os.path.join(args.work_dir, args.exp) if not os.path.exists(work_dir): os.makedirs(work_dir) # copy this file to work dir to keep training configuration shutil.copy(__file__, os.path.join(work_dir, 'main.py')) with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f: pickle.dump(args, f) #train trn_image_root = os.path.join(args.trn_root, 'images') exam_ids = os.listdir(trn_image_root) random.shuffle(exam_ids) train_exam_ids = exam_ids #train_exam_ids = exam_ids[:int(len(exam_ids)*0.8)] #val_exam_ids = exam_ids[int(len(exam_ids) * 0.8):] # train_dataset trn_dataset = DatasetTrain(args.trn_root, train_exam_ids, options=args, input_stats=[0.5, 0.5]) trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # save input stats for later use np.save(os.path.join(work_dir, 'input_stats.npy'), trn_dataset.input_stats) #val val_image_root = os.path.join(args.val_root, 'images') val_exam = os.listdir(val_image_root) random.shuffle(val_exam) val_exam_ids = val_exam # val_dataset val_dataset = DatasetVal(args.val_root, val_exam_ids, options=args, input_stats=trn_dataset.input_stats) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # make logger trn_logger = Logger(os.path.join(work_dir, 'train.log')) trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log')) val_logger = Logger(os.path.join(work_dir, 'validation.log')) # model_select if args.model == 'unet': net = UNet3D(1, 1, f_maps=args.f_maps, depth_stride=args.depth_stride, conv_layer_order=args.conv_layer_order, num_groups=args.num_groups) else: raise ValueError('Not supported network.') # loss_select if args.loss_function == 'bce': criterion = nn.BCEWithLogitsLoss( pos_weight=torch.Tensor([args.bce_weight])).cuda() elif args.loss_function == 'dice': criterion = DiceLoss().cuda() elif args.loss_function == 'weight_bce': criterion = nn.BCEWithLogitsLoss( pos_weight=torch.FloatTensor([5])).cuda() else: raise ValueError('{} loss is not supported yet.'.format( args.loss_function)) # optim_select if args.optim == 'sgd': optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False) elif args.optim == 'adam': optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) else: raise ValueError('{} optim is not supported yet.'.format(args.optim)) net = nn.DataParallel(net).cuda() cudnn.benchmark = True lr_schedule = args.lr_schedule lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_schedule[:-1], gamma=0.1) best_iou = 0 for epoch in range(lr_schedule[-1]): train(trn_loader, net, criterion, optimizer, epoch, trn_logger, trn_raw_logger) iou = validate(val_loader, net, criterion, epoch, val_logger) lr_scheduler.step() # save model parameter is_best = iou > best_iou best_iou = max(iou, best_iou) checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format(epoch + 1) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() }, is_best, work_dir, checkpoint_filename) # visualize curve draw_curve(work_dir, trn_logger, val_logger) if args.inplace_test: # calc overall performance and save figures print('Test mode ...') main_test(model=net, args=args)
def main(): if args.server == 'server_A': work_dir = os.path.join('/data1/JM/lung_segmentation', args.exp) print(work_dir) elif args.server == 'server_B': work_dir = os.path.join('/data1/workspace/JM_gen/lung_seg', args.exp) print(work_dir) if not os.path.exists(work_dir): os.makedirs(work_dir) # copy this file to work dir to keep training configuration shutil.copy(__file__, os.path.join(work_dir, 'main.py')) with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f: pickle.dump(args, f) # transform transform1 = transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # 1.train_dataset if args.val_size == 0: train_path, test_path = loader.make_dataset(args.server, args.train_dataset + '_dataset', train_size=args.train_size) np.save( os.path.join(work_dir, '{}_test_path.npy'.format(args.train_dataset)), test_path) train_image_path = train_path[0] train_label_path = train_path[1] test_image_path = test_path[0] test_label_path = test_path[1] train_dataset = loader.CustomDataset(train_image_path, train_label_path, transform1, arg_mode=args.arg_mode, arg_thres=args.arg_thres, arg_range=args.arg_range, dataset=args.train_dataset) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) # Organize images and labels differently. train_dataset_random = loader.CustomDataset(train_image_path, train_label_path, transform1, arg_mode=args.arg_mode, arg_thres=args.arg_thres, arg_range=args.arg_range, dataset=args.train_dataset) train_loader_random = data.DataLoader(train_dataset_random, batch_size=args.batch_size, shuffle=True, num_workers=4) val_dataset = loader.CustomDataset(test_image_path, test_label_path, transform1, arg_mode=False, dataset=args.train_dataset) val_loader = data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4) # 'JSRT' test_dataset Train_test_dataset = loader.CustomDataset(test_image_path, test_label_path, transform1, dataset=args.train_dataset) Train_test_loader = data.DataLoader(Train_test_dataset, batch_size=1, shuffle=True, num_workers=4) # 2.test_dataset_path # 'MC'test_dataset test_data1_path, _ = loader.make_dataset(args.server, args.test_dataset1 + '_dataset', train_size=1) test_data1_dataset = loader.CustomDataset(test_data1_path[0], test_data1_path[1], transform1, dataset=args.test_dataset1) test_data1_loader = data.DataLoader(test_data1_dataset, batch_size=1, shuffle=True, num_workers=4) # 'sh'test_dataset test_data2_path, _ = loader.make_dataset(args.server, args.test_dataset2 + '_dataset', train_size=1) test_data2_dataset = loader.CustomDataset(test_data2_path[0], test_data2_path[1], transform1, dataset=args.test_dataset2) test_data2_loader = data.DataLoader(test_data2_dataset, batch_size=1, shuffle=True, num_workers=0) test_data_list = [ Train_test_loader, test_data1_loader, test_data2_loader ] # np.save(os.path.join(work_dir, 'input_stats.npy'), train_dataset.input_stats) trn_logger = Logger(os.path.join(work_dir, 'train.log')) trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log')) val_logger = Logger(os.path.join(work_dir, 'validation.log')) # 3.model_select model_seg, model_name = model_select(args.arch_seg) model_ae, _ = model_select(args.arch_ae) # 4.gpu select model_seg = nn.DataParallel(model_seg).cuda() model_ae = nn.DataParallel(model_ae).cuda() #model_seg = model_seg.cuda() #model_ae = model_ae.cuda() cudnn.benchmark = True # 5.optim if args.optim == 'adam': optimizer_seg = torch.optim.Adam(model_seg.parameters(), lr=args.initial_lr) optimizer_ae = torch.optim.Adam(model_ae.parameters(), lr=args.initial_lr) elif args.optim == 'sgd': optimizer_seg = torch.optim.SGD(model_seg.parameters(), lr=args.initial_lr, weight_decay=args.weight_decay) optimizer_ae = torch.optim.SGD(model_ae.parameters(), lr=args.initial_lr, weight_decay=args.weight_decay) # if args.clip_grad : # # import torch.nn.utils as torch_utils # max_grad_norm = 1. # # torch_utils.clip_grad_norm_(model_seg.parameters(), # max_grad_norm # ) # torch_utils.clip_grad_norm_(model_ae.parameters(), # max_grad_norm # ) # lr decay lr_schedule = args.lr_schedule lr_scheduler_seg = optim.lr_scheduler.MultiStepLR( optimizer_seg, milestones=lr_schedule[:-1], gamma=args.gamma) lr_scheduler_ae = optim.lr_scheduler.MultiStepLR( optimizer_ae, milestones=lr_schedule[:-1], gamma=args.gamma) # 6.loss criterion_seg = loss_function_select(args.seg_loss_function) criterion_ae = loss_function_select(args.ae_loss_function) criterion_embedding = loss_function_select(args.embedding_loss_function) ##################################################################################### # train send_slack_message('#jm_private', '{} : starting_training'.format(args.exp)) best_iou = 0 try: if args.train_mode: for epoch in range(lr_schedule[-1]): train(model_seg=model_seg, model_ae=model_ae, train_loader=train_loader, train_loder_random=train_loader_random, optimizer_seg=optimizer_seg, optimizer_ae=optimizer_ae, criterion_seg=criterion_seg, criterion_ae=criterion_ae, criterion_embedding=criterion_embedding, epoch=epoch, logger=trn_logger, sublogger=trn_raw_logger) iou = validate(model=model_seg, val_loader=val_loader, criterion=criterion_seg, epoch=epoch, logger=val_logger, work_dir=work_dir, save_fig=False, work_dir_name='{}_visualize_per_epoch'.format( args.train_dataset)) print( 'validation result **************************************************************' ) lr_scheduler_seg.step() lr_scheduler_ae.step() if args.val_size == 0: is_best = 1 else: is_best = iou > best_iou best_iou = max(iou, best_iou) checkpoint_filename = 'model_checkpoint_{:0>3}.pth'.format( epoch + 1) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model_seg.state_dict(), 'optimizer': optimizer_seg.state_dict() }, is_best, work_dir, filename='checkpoint.pth') print("train end") except RuntimeError as e: send_slack_message( '#jm_private', '----------------------------------- error train : send to message JM & Please send a kakao talk ----------------------------------------- \n error message : {}' .format(e)) import ipdb ipdb.set_trace() draw_curve(work_dir, trn_logger, val_logger) send_slack_message('#jm_private', '{} : end_training'.format(args.exp)) #--------------------------------------------------------------------------------------------------------# #here is load model for last pth load_filename = os.path.join(work_dir, 'model_best.pth') checkpoint = torch.load(load_filename) ch_epoch = checkpoint['epoch'] save_check_txt = os.path.join(work_dir, str(ch_epoch)) f = open("{}_best_checkpoint.txt".format(save_check_txt), 'w') f.close() # --------------------------------------------------------------------------------------------------------# # validation if args.test_mode: print('Test mode ...') main_test(model=model_seg, test_loader=test_data_list, args=args)
def main(): # save input stats for later use if args.server == 'server_A': work_dir = os.path.join('/data1/JM/lung_segmentation', args.exp) print(work_dir) elif args.server == 'server_B': work_dir = os.path.join('/data1/workspace/JM_gen/lung-seg-back-up', args.exp) print(work_dir) elif args.server == 'server_D': work_dir = os.path.join( '/daintlab/home/woans0104/workspace/' 'lung-seg-back-up', args.exp) print(work_dir) if not os.path.exists(work_dir): os.makedirs(work_dir) # copy this file to work dir to keep training configuration shutil.copy(__file__, os.path.join(work_dir, 'main.py')) with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f: pickle.dump(args, f) source_dataset, target_dataset1, target_dataset2 \ = loader.dataset_condition(args.source_dataset) # 1.load_dataset train_loader_source,test_loader_source \ = loader.get_loader(server=args.server, dataset=source_dataset, train_size=args.train_size, aug_mode=args.aug_mode, aug_range=args.aug_range, batch_size=args.batch_size, work_dir=work_dir) train_loader_target1, _ = loader.get_loader(server=args.server, dataset=target_dataset1, train_size=1, aug_mode=False, aug_range=args.aug_range, batch_size=1, work_dir=work_dir) train_loader_target2, _ = loader.get_loader(server=args.server, dataset=target_dataset2, train_size=1, aug_mode=False, aug_range=args.aug_range, batch_size=1, work_dir=work_dir) test_data_li = [ test_loader_source, train_loader_target1, train_loader_target2 ] trn_logger = Logger(os.path.join(work_dir, 'train.log')) trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log')) val_logger = Logger(os.path.join(work_dir, 'validation.log')) # 2.model_select #model_seg = select_model(args.arch) if args.arch == 'unet': model_seg = Unet2D(in_shape=(1, 256, 256)) elif args.arch == 'unet_norm': model_seg = Unet2D_norm(in_shape=(1, 256, 256), nomalize_con=args.nomalize_con, affine=args.affine, group_channel=args.group_channel, weight_std=args.weight_std) else: raise ValueError('Not supported network.') model_seg = model_seg.cuda() # 3.gpu select model_seg = nn.DataParallel(model_seg) cudnn.benchmark = True # 4.optim if args.optim == 'adam': optimizer_seg = torch.optim.Adam(model_seg.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.eps) elif args.optim == 'adamp': optimizer_seg = AdamP(model_seg.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.eps) elif args.optim == 'sgd': optimizer_seg = torch.optim.SGD(model_seg.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) # lr decay lr_schedule = args.lr_schedule lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_seg, milestones=lr_schedule[:-1], gamma=0.1) # 5.loss if args.loss_function == 'bce': criterion = nn.BCELoss() elif args.loss_function == 'bce_logit': criterion = nn.BCEWithLogitsLoss() elif args.loss_function == 'dice': criterion = DiceLoss() elif args.loss_function == 'Cldice': bce = nn.BCEWithLogitsLoss().cuda() dice = DiceLoss().cuda() criterion = ClDice(bce, dice, alpha=1, beta=1) criterion = criterion.cuda() ############################################################################### # train best_iou = 0 try: if args.train_mode: for epoch in range(lr_schedule[-1]): train(model=model_seg, train_loader=train_loader_source, epoch=epoch, criterion=criterion, optimizer=optimizer_seg, logger=trn_logger, sublogger=trn_raw_logger) iou = validate(model=model_seg, val_loader=test_loader_source, epoch=epoch, criterion=criterion, logger=val_logger) print('validation_result ************************************') lr_scheduler.step() if args.val_size == 0: is_best = 1 else: is_best = iou > best_iou best_iou = max(iou, best_iou) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model_seg.state_dict(), 'optimizer': optimizer_seg.state_dict() }, is_best, work_dir, filename='checkpoint.pth') print("train end") except RuntimeError as e: print('error message : {}'.format(e)) import ipdb ipdb.set_trace() draw_curve(work_dir, trn_logger, val_logger) # here is load model for last pth check_best_pth(work_dir) # validation if args.test_mode: print('Test mode ...') main_test(model=model_seg, test_loader=test_data_li, args=args)