def train_module(_opt): #def train_module(_train_path, _train_save, _resume_snapshot,_batchsize): #parser = argparse.ArgumentParser() #parser.add_argument('--epoch', type=int, default=10, help='epoch number') #parser.add_argument('--lr', type=float, default=3e-4, help='learning rate') #parser.add_argument('--batchsize', type=int, default=_batchsize, help='training batch size') #parser.add_argument('--trainsize', type=int, default=352, help='training dataset size') #parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') #parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') #parser.add_argument('--decay_epoch', type=int, default=50, help='every n epochs decay learning rate') #parser.add_argument('--train_path', type=str, default=_train_path) #parser.add_argument('--train_save', type=str, default=_train_save) #parser.add_argument('--resume_snapshot', type=str, default=_resume_snapshot) #opt = parser.parse_args() opt = _opt # ---- build models ---- torch.cuda.set_device(0) model = Network(channel=32, n_class=1).cuda() model.load_state_dict(torch.load(opt.resume_snapshot)) params = model.parameters() optimizer = torch.optim.Adam(params, opt.lr) image_root = '{}/Imgs/'.format(opt.train_path) gt_root = '{}/GT/'.format(opt.train_path) edge_root = '{}/Edge/'.format(opt.train_path) train_loader = get_loader(image_root, gt_root, edge_root, batchsize=opt.batchsize, trainsize=opt.trainsize) total_step = len(train_loader) print("#" * 20, "Start Training", "#" * 20) for epoch in range(1, opt.epoch): adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) trainer(train_loader=train_loader, model=model, optimizer=optimizer, epoch=epoch, opt=opt, total_step=total_step)
BCE = torch.nn.BCEWithLogitsLoss() params = model.parameters() optimizer = torch.optim.Adam(params, opt.lr) image_root = '{}/Imgs/'.format(opt.train_path) gt_root = '{}/GT/'.format(opt.train_path) edge_root = '{}/Edge/'.format(opt.train_path) train_loader = get_loader(image_root, gt_root, edge_root, batchsize=opt.batchsize, trainsize=opt.trainsize, num_workers=opt.num_workers) total_step = len(train_loader) # ---- start !! ----- print( "#" * 20, "\nStart Training (Inf-Net-{})\n{}\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung " "Infection Segmentation from CT Scans', 2020, TMI.\n" "----\nPlease cite the paper if you use this code and dataset. " "And any questions feel free to contact me " "via E-mail ([email protected])\n----\n".format(opt.backbone, opt), "#" * 20) for epoch in range(1, opt.epoch): adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) train(train_loader, model, optimizer, epoch, train_save)
def cross_validation(train_save, opt): image_root = '{}/Imgs/'.format(opt.all_path) gt_root = '{}/GT/'.format(opt.all_path) edge_root = '{}/Edge/'.format(opt.all_path) images = np.array( sorted([ image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png') ])) gts = np.array( sorted( [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')])) edges = np.array( sorted([ edge_root + f for f in os.listdir(edge_root) if f.endswith('.png') ])) k_folds = KFold(opt.folds) VALIDATION_EARLY_STOPPING = 6 for fold_index, (train_index, test_index) in enumerate(k_folds.split(images)): best_loss = 99999 current_validation_early_count = 0 random.seed(opt.seed) np.random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed(opt.seed) torch.random.manual_seed(opt.seed) model, optimizer = create_model(opt) train_dataset = IndicesDataset(images[train_index], gts[train_index], edges[train_index], opt.trainsize, opt.is_data_augment, opt.random_cutout) test_dataset = IndicesDataset(images[test_index], gts[test_index], None, opt.trainsize, opt.is_data_augment, opt.random_cutout, is_test=True) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=opt.batchsize, shuffle=True, num_workers=opt.num_workers, pin_memory=True, drop_last=False) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=opt.batchsize, shuffle=True, num_workers=opt.num_workers, pin_memory=True, drop_last=False) for epoch in range(1, opt.epoch): adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) average_test_loss = train(train_loader, test_loader, model, optimizer, epoch, train_save, opt.device, opt) if average_test_loss < best_loss: best_loss = average_test_loss current_validation_early_count = 0 else: current_validation_early_count += 1 if current_validation_early_count >= VALIDATION_EARLY_STOPPING: break metric_string = eval(test_loader, model, opt.device, None, opt.eval_threshold, opt) # write the metrics os.makedirs(os.path.join(opt.metric_path, opt.train_save), exist_ok=True) filename = os.path.join(opt.metric_path, opt.train_save, f"metrics_{fold_index}.txt") with open(f'{filename}', 'a') as f: f.write(metric_string)