def train(args): model = UNet(3, 3).to(device) batch_size = args.batch_size criterion = nn.BCEWithLogitsLoss() # criterion = DiceLoss() optimizer = optim.Adam(model.parameters()) verse_data = DatasetVerse(dir_img, dir_mask, transform=x_transform, target_transform=y_transform) dataloader = DataLoader(verse_data, batch_size=batch_size, shuffle=True, num_workers=4) train_model(model, criterion, optimizer, dataloader)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device) #GPU能否使用 model = UNet(1, 1).to(device) #model = NestedUNet(1, 1).to(device) # model.load_state_dict(torch.load("pretrain/weights_80.pth",map_location='cpu')) # 设置随机数种子,保证复现能力 torch.backends.cudnn.deterministic = True random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed(1) np.random.seed(1) batch_size = 2 learning_rate = 0.001 criterion = torch.nn.BCELoss() optimizer = optim.Adam([{ 'params': model.parameters(), 'initial_lr': learning_rate }], lr=learning_rate) scheduler = lr_scheduler.StepLR( optimizer, step_size=10, gamma=0.8, last_epoch=0) # 每10个epoch衰减0.8,注意last_epoch的设置!! x_transform = T.Compose([T.ToTensor(), T.Normalize([0.5], [0.5])]) y_transform = T.ToTensor() cell_dataset = CellDataset1('dataset/dataset1/train/', 'dataset/dataset1/train_GT/SEG/', transform=x_transform, target_transform=y_transform) #cell_dataset = CellDataset2('dataset/dataset2/train/', 'dataset/dataset2/train_GT/SEG/', transform=x_transform, # target_transform=y_transform) #对应数据集2 train_size = int(0.8 * len(cell_dataset)) #划分训练集和验证集,8:2
def main(): global args, best_prec1 args = parser.parse_args() print(args) if args.saveTest == 'True': args.saveTest = True elif args.saveTest == 'False': args.saveTest = False # Check if the save directory exists or not if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) cudnn.benchmark = True data_transforms = { 'train': transforms.Compose([ transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST), transforms.TenCrop(args.resizedImageSize), transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), #transforms.Lambda(lambda normalized: torch.stack([transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])(crop) for crop in normalized])) #transforms.RandomResizedCrop(224, interpolation=Image.NEAREST), #transforms.RandomHorizontalFlip(), #transforms.RandomVerticalFlip(), #transforms.ToTensor(), ]), 'test': transforms.Compose([ transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST), transforms.ToTensor(), #transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182]) ]), } # Data Loading data_dir = 'datasets/miccaiSegRefined' # json path for class definitions json_path = 'datasets/miccaiSegClasses.json' image_datasets = {x: miccaiSegDataset(os.path.join(data_dir, x), data_transforms[x], json_path) for x in ['train', 'test']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batchSize, shuffle=True, num_workers=args.workers) for x in ['train', 'test']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']} # Get the dictionary for the id and RGB value pairs for the dataset classes = image_datasets['train'].classes key = utils.disentangleKey(classes) num_classes = len(key) # Initialize the model model = UNet(num_classes) # # Optionally resume from a checkpoint # if args.resume: # if os.path.isfile(args.resume): # print("=> loading checkpoint '{}'".format(args.resume)) # checkpoint = torch.load(args.resume) # #args.start_epoch = checkpoint['epoch'] # pretrained_dict = checkpoint['state_dict'] # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()} # model.state_dict().update(pretrained_dict) # model.load_state_dict(model.state_dict()) # print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) # else: # print("=> no checkpoint found at '{}'".format(args.resume)) # # # # Freeze the encoder weights # # for param in model.encoder.parameters(): # # param.requires_grad = False # # optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd) # else: optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd) # Load the saved model if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) print(model) # Define loss function (criterion) criterion = nn.CrossEntropyLoss() # Use a learning rate scheduler scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) if use_gpu: model.cuda() criterion.cuda() # Initialize an evaluation Object evaluator = utils.Evaluate(key, use_gpu) for epoch in range(args.start_epoch, args.epochs): #adjust_learning_rate(optimizer, epoch) # Train for one epoch print('>>>>>>>>>>>>>>>>>>>>>>>Training<<<<<<<<<<<<<<<<<<<<<<<') train(dataloaders['train'], model, criterion, optimizer, scheduler, epoch, key) # Evaulate on validation set print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<') validate(dataloaders['test'], model, criterion, epoch, key, evaluator) # Calculate the metrics print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<') IoU = evaluator.getIoU() print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU)) PRF1 = evaluator.getPRF1() precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2] print('Mean Precision: {}, Class-wise Precision: {}'.format(torch.mean(precision), precision)) print('Mean Recall: {}, Class-wise Recall: {}'.format(torch.mean(recall), recall)) print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1)) evaluator.reset() save_checkpoint({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), }, filename=os.path.join(args.save_dir, 'checkpoint_{}.tar'.format(epoch)))
def train(args): print("Traning") print("Prepaing data") masks = pd.read_csv(os.path.join(args.dataset_dir, args.train_masks)) unique_img_ids = get_unique_img_ids(masks, args) train_df, valid_df = get_balanced_train_valid(masks, unique_img_ids, args) if args.stage == 0: train_shape = (256, 256) batch_size = args.stage0_batch_size extra_epoch = args.stage0_epochs elif args.stage == 1: train_shape = (384, 384) batch_size = args.stage1_batch_size extra_epoch = args.stage1_epochs elif args.stage == 2: train_shape = (512, 512) batch_size = args.stage2_batch_size extra_epoch = args.stage2_epochs elif args.stage == 3: train_shape = (768, 768) batch_size = args.stage3_batch_size extra_epoch = args.stage3_epochs print("Stage {}".format(args.stage)) train_transform = DualCompose([ Resize(train_shape), HorizontalFlip(), VerticalFlip(), RandomRotate90(), Shift(), Transpose(), # ImageOnly(RandomBrightness()), # ImageOnly(RandomContrast()), ]) val_transform = DualCompose([ Resize(train_shape), ]) train_dataloader = make_dataloader(train_df, args, batch_size, args.shuffle, transform=train_transform) val_dataloader = make_dataloader(valid_df, args, batch_size // 2, args.shuffle, transform=val_transform) # Build model model = UNet() optimizer = Adam(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=args.decay_fr, gamma=0.1) if args.gpu and torch.cuda.is_available(): model = model.cuda() # Restore model ... run_id = 4 model_path = Path('model_{run_id}.pt'.format(run_id=run_id)) if not model_path.exists() and args.stage > 0: raise ValueError( 'model_{run_id}.pt does not exist, initial train first.'.format( run_id=run_id)) if model_path.exists(): state = torch.load(str(model_path)) last_epoch = state['epoch'] step = state['step'] model.load_state_dict(state['model']) print('Restore model, epoch {}, step {:,}'.format(last_epoch, step)) else: last_epoch = 1 step = 0 log_file = open('train_{run_id}.log'.format(run_id=run_id), 'at', encoding='utf8') loss_fn = LossBinary(jaccard_weight=args.iou_weight) valid_losses = [] print("Start training ...") for _ in range(last_epoch): scheduler.step() for epoch in range(last_epoch, last_epoch + extra_epoch): scheduler.step() model.train() random.seed() tq = tqdm(total=len(train_dataloader) * batch_size) tq.set_description('Run Id {}, Epoch {} of {}, lr {}'.format( run_id, epoch, last_epoch + extra_epoch, args.lr * (0.1**(epoch // args.decay_fr)))) losses = [] try: mean_loss = 0. for i, (inputs, targets) in enumerate(train_dataloader): inputs, targets = torch.tensor(inputs), torch.tensor(targets) if args.gpu and torch.cuda.is_available(): inputs = inputs.cuda() targets = targets.cuda() outputs = model(inputs) loss = loss_fn(outputs, targets) loss.backward() optimizer.step() step += 1 tq.update(batch_size) losses.append(loss.item()) mean_loss = np.mean(losses[-args.log_fr:]) tq.set_postfix(loss="{:.5f}".format(mean_loss)) if i and (i % args.log_fr) == 0: write_event(log_file, step, loss=mean_loss) write_event(log_file, step, loss=mean_loss) tq.close() save_model(model, epoch, step, model_path) valid_metrics = validation(args, model, loss_fn, val_dataloader) write_event(log_file, step, **valid_metrics) valid_loss = valid_metrics['valid_loss'] valid_losses.append(valid_loss) except KeyboardInterrupt: tq.close() print('Ctrl+C, saving snapshot') save_model(model, epoch, step, model_path) print('Terminated.') print('Done.')
def main(): args = get_args() # set GPU device os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu # default: '0' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # set model model = UNet(n_channels=1, n_classes=1).to(device) if len(args.gpu) > 1: # if multi-gpu model = torch.nn.DataParallel(model) """set img size - UNet type architecture require input img size be divisible by 2^N, - Where N is the number of the Max Pooling layers (in the Vanila UNet N = 5) """ img_size = args.img_size #default: 512 # set transforms for dataset import torchvision.transforms as transforms from my_transforms import RandomHorizontalFlip, RandomVerticalFlip, ColorJitter, GrayScale, Resize, ToTensor train_transforms = transforms.Compose([ #Data Augmentations RandomHorizontalFlip(), RandomVerticalFlip(), ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), #shear #rotation #scale #transformations to fit in Network GrayScale(), Resize(img_size), ToTensor(), ]) eval_transforms = transforms.Compose( [GrayScale(), Resize(img_size), ToTensor()]) # set Dataset and DataLoader train_dataset = LungSegDataset(transforms=train_transforms) val_dataset = LungSegDataset(split='val', transforms=eval_transforms) test_dataset = LungSegDataset(split='test', transforms=eval_transforms) from torch.utils.data import DataLoader dataloader = { 'train': DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.n_workers, shuffle=True), 'val': DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=args.n_workers), 'test': DataLoader(dataset=test_dataset, batch_size=args.batch_size, num_workers=args.n_workers) } # checkpoint dir checkpoint_dir = os.path.join(os.getcwd(), 'checkpoint') if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) checkpoint_path = args.load_model # set optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) # learning rate scheduler from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau # scheduler = StepLR(optimizer, step_size = 3 , gamma = 0.8) ## option 2. scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3) # # set criterion # if model.n_classes > 1: # criterion = nn.CrossEntropyLoss() # else: # criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss() train_and_validate(net=model, criterion=criterion, optimizer=optimizer, dataloader=dataloader, device=device, epochs=args.epochs, scheduler=scheduler, load_model=checkpoint_path)
def main(args): dataset_kwargs = { 'transforms': {}, 'max_length': None, 'sensor_resolution': None, 'preload_events': False, 'num_bins': 16, 'voxel_method': { 'method': 'random_k_events', 'k': 60000, 't': 0.5, 'sliding_window_w': 500, 'sliding_window_t': 0.1 } } unet_kwargs = { 'base_num_channels': 32, # written as '64' in EVFlowNet tf code 'num_encoders': 4, 'num_residual_blocks': 2, # transition 'num_output_channels': 2, # (x, y) displacement 'skip_type': 'concat', 'norm': None, 'use_upsample_conv': True, 'kernel_size': 3, 'channel_multiplier': 2, 'num_bins': 16 } torch.autograd.set_detect_anomaly(True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ev_loader = EventDataLoader(args.h5_file_path, batch_size=1, num_workers=6, shuffle=True, pin_memory=True, dataset_kwargs=dataset_kwargs) H, W = ev_loader.H, ev_loader.W model = UNet(unet_kwargs) model = model.to(device) model.train() crop = CropParameters(W, H, 4) print("=== Let's use", torch.cuda.device_count(), "GPUs!") if torch.cuda.device_count() > 1: model = nn.DataParallel(model) optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999)) # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01) # raise # tmp_voxel = crop.pad(torch.randn(1, 9, H, W).to(device)) # F, P = profile(model, inputs=(tmp_voxel, )) for idx in range(10): # for i, item in enumerate(tqdm(ev_loader)): for i, item in enumerate(ev_loader): events = item['events'] voxel = item['voxel'].to(device) voxel = crop.pad(voxel) model.zero_grad() optimizer.zero_grad() flow = model(voxel) * 10 flow = torch.clamp(flow, min=-40, max=40) loss = compute_loss(events, flow) loss.backward() # cvshow_voxel_grid(voxel.squeeze()[0:2].cpu().numpy()) # raise optimizer.step() if i % 10 == 0: print( idx, i, '\t', "{0:.2f}".format(loss.data.item()), "{0:.2f}".format(torch.max(flow[0, 0]).item()), "{0:.2f}".format(torch.min(flow[0, 0]).item()), "{0:.2f}".format(torch.max(flow[0, 1]).item()), "{0:.2f}".format(torch.min(flow[0, 1]).item()), ) xs, ys, ts, ps = events print_voxel = voxel[0].sum(axis=0).cpu().numpy() print_flow = flow[0].clone().detach().cpu().numpy() print_co = warp_events_with_flow_torch( (xs[0][ps[0] == 1], ys[0][ps[0] == 1], ts[0][ps[0] == 1], ps[0][ps[0] == 1]), flow[0].clone().detach(), sensor_size=(H, W)) print_co = crop.pad(print_co) print_co = print_co.cpu().numpy() cvshow_all(idx=idx * 10000 + i, voxel=print_voxel, flow=flow[0].clone().detach().cpu().numpy(), frame=None, compensated=print_co)
import time num_epoches = 400 batch_size = 12 data_dir = "/userhome/Unet/unet/data/" device = 'cuda' if torch.cuda.is_available() else 'cpu' train_dataloader, val_dataloader = create_dataset(data_dir, repeat=1, train_batch_size=12, augment=True) model = UNet(1, 2).to(device) criterion = CrossEntropyWithLogits().to(device) optimizer = Adam(model.parameters(), lr=0.0001, weight_decay=0.0005, eps=1e-08) save_step = 200 ##test data load time # print("get-100-epoch") # load_s = time.time() # for i in range(2): # for sample in train_dataloader: # print(sample["image"].shape) # print(sample["mask"].shape) # load_e = time.time() # print("load data time: ", load_e - load_s) # TODO: Initialization the params val_loss = -1
def train(): # prepare the dataloader device = torch.device(args.devices if torch.cuda.is_available() else "cpu") #dataset = Training_Dataset(args.image_dir, (args.image_size, args.image_size), (args.noise,args.noise_param)) # dataset = HongZhang_Dataset("/data_1/data/Noise2Noise/shenqingbiao/0202", "/data_1/data/Noise2Noise/hongzhang") # dataset = HongZhang_Dataset2("/data_1/data/红章图片", (256, 256)) dataset = HongZhang_Dataset3("/data_1/data/红章图片/6_12", (256, 256)) dataset_length = len(dataset) train_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) # choose the model if args.model == "unet": model = UNet(in_channels=args.image_channels, out_channels=args.image_channels) elif args.model == "srresnet": model = SRResnet(args.image_channels, args.image_channels) elif args.model == "eesp": model = EESPNet_Seg(args.image_channels, 2) else: model = UNet(in_channels=args.image_channels, out_channels=args.image_channels) model = model.to(device) # choose the loss type if args.loss == "l2": criterion = nn.MSELoss() elif args.loss == "l1": criterion = nn.L1Loss() elif args.loss == "ssim": criterion = SSIM() # resume the mode if needed if args.resume_model: resume_model(model, args.resume_model) optim = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=True) #scheduler = lr_scheduler.StepLR(optim, step_size=args.scheduler_step, gamma=0.5) scheduler = lr_scheduler.MultiStepLR(optim, milestones=[20, 40], gamma=0.1) model.train() print(model) # start to train print("Starting Training Loop...") since = time.time() for epoch in range(args.epochs): print('Epoch {}/{}'.format(epoch, args.epochs - 1)) print('-' * 10) running_loss = 0.0 scheduler.step() for batch_idx, (target, source) in enumerate(train_loader): source = source.to(device) target = target.to(device) denoised_source = model(source) if args.loss == "ssim": loss = 1 - criterion(denoised_source, Variable(target)) else: loss = criterion(denoised_source, Variable(target)) optim.zero_grad() loss.backward() optim.step() running_loss += loss.item() * source.size(0) if batch_idx % args.steps_show == 0: print('{}/{} Current loss {}'.format(batch_idx, len(train_loader), loss.item())) epoch_loss = running_loss / dataset_length print('{} Loss: {:.4f}'.format('current ' + str(epoch), epoch_loss)) if (epoch + 1) % args.save_per_epoch == 0: save_model(model, epoch + 1) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60))
batch_size=batch_size,num_workers=num_workers,pin_memory=True,shuffle=True) val_loader = DataLoader(MRBrainSDataset(defualt_path, split='val', is_transform=True, \ img_norm=True, augmentations=Compose([Scale(224)])), \ batch_size=1,num_workers=num_workers,pin_memory=True,shuffle=False) # Setup Model and summary model = UNet().to(device) summary(model, (3, 224, 224), batch_size) # summary 网络参数 # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # 需要学习的参数 # base_learning_list = list(filter(lambda p: p.requires_grad, model.base_net.parameters())) # learning_list = model.parameters() # 优化器以及学习率设置 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) # learning rate调节器 scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[ int(0.2 * end_epoch), int(0.6 * end_epoch), int(0.9 * end_epoch) ], gamma=0.01) # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',patience=10, verbose=True) criterion = cross_entropy2d # criterion = BCEDiceLoss()
class BaseModel: losses = {'train': [], 'val': []} acces = {'train': [], 'val': []} scores = {'train': [], 'val': []} pred = {'train': [], 'val': []} true = {'train': [], 'val': []} def __init__(self, args): self.args = args self.net = None print(args.model_name) if args.model_name == 'UNet': self.net = UNet(args.in_channels, args.num_classes) self.net.apply(weights_init) elif args.model_name == 'UNetResNet34': self.net = UNetResNet34(args.num_classes, dropout_2d=0.2) elif args.model_name == 'UNetResNet152': self.net = UNetResNet152(args.num_classes, dropout_2d=0.2) elif args.model_name == 'UNet11': self.net = UNet11(args.num_classes, pretrained=True) elif args.model_name == 'UNetVGG16': self.net = UNetVGG16(args.num_classes, pretrained=True, dropout_2d=0.0, is_deconv=True) elif args.model_name == 'deeplab50_v2': if args.ms: raise NotImplemented else: self.net = deeplab50_v2(args.num_classes, pretrained=args.pretrained) elif args.model_name == 'deeplab_v2': if args.ms: self.net = ms_deeplab_v2(args.num_classes, pretrained=args.pretrained, scales=args.ms_scales) else: self.net = deeplab_v2(args.num_classes, pretrained=args.pretrained) elif args.model_name == 'deeplab_v3': if args.ms: self.net = ms_deeplab_v3(args.num_classes, out_stride=args.out_stride, pretrained=args.pretrained, scales=args.ms_scales) else: self.net = deeplab_v3(args.num_classes, out_stride=args.out_stride, pretrained=args.pretrained) elif args.model_name == 'deeplab_v3_plus': if args.ms: self.net = ms_deeplab_v3_plus(args.num_classes, out_stride=args.out_stride, pretrained=args.pretrained, scales=args.ms_scales) else: self.net = deeplab_v3_plus(args.num_classes, out_stride=args.out_stride, pretrained=args.pretrained) self.interp = nn.Upsample(size=args.size, mode='bilinear') self.iterations = args.epochs self.lr_current = args.lr self.cuda = args.cuda self.phase = args.phase self.lr_policy = args.lr_policy self.cyclic_m = args.cyclic_m if self.lr_policy == 'cyclic': print('using cyclic') assert self.iterations % self.cyclic_m == 0 if args.loss == 'CELoss': self.criterion = nn.CrossEntropyLoss(size_average=True) elif args.loss == 'DiceLoss': self.criterion = DiceLoss(num_classes=args.num_classes) elif args.loss == 'MixLoss': self.criterion = MixLoss(args.num_classes, weights=args.loss_weights) elif args.loss == 'LovaszLoss': self.criterion = LovaszSoftmax(per_image=args.loss_per_img) elif args.loss == 'FocalLoss': self.criterion = FocalLoss(args.num_classes, alpha=None, gamma=2) else: raise RuntimeError('must define loss') if 'deeplab' in args.model_name: self.optimizer = optim.SGD( [{ 'params': get_1x_lr_params_NOscale(self.net), 'lr': args.lr }, { 'params': get_10x_lr_params(self.net), 'lr': 10 * args.lr }], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: self.optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.net.parameters()), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.iters = 0 self.best_val = 0.0 self.count = 0 def init_model(self): if self.args.resume_model: saved_state_dict = torch.load( self.args.resume_model, map_location=lambda storage, loc: storage) if self.args.ms: new_params = self.net.Scale.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not (not i_parts[0] == 'layer5') and (not i_parts[0] == 'decoder'): new_params[i] = saved_state_dict[i] self.net.Scale.load_state_dict(new_params) else: new_params = self.net.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if (not i_parts[0] == 'layer5') and (not i_parts[0] == 'decoder'): # if not i_parts[0] == 'layer5': new_params[i] = saved_state_dict[i] self.net.load_state_dict(new_params) print('Resuming training, image net loading {}...'.format( self.args.resume_model)) # self.load_weights(self.net, self.args.resume_model) if self.args.mGPUs: self.net = nn.DataParallel(self.net) if self.args.cuda: self.net = self.net.cuda() cudnn.benchmark = True def _adjust_learning_rate(self, epoch): """Sets the learning rate to the initial LR decayed by 10 at every specified step # Adapted from PyTorch Imagenet example: # https://github.com/pytorch/examples/blob/master/imagenet/main.py """ if epoch < int(self.iterations * 0.5): self.lr_current = max(self.lr_current * self.args.gamma, 1e-4) elif epoch < int(self.iterations * 0.85): self.lr_current = max(self.lr_current * self.args.gamma, 1e-5) else: self.lr_current = max(self.lr_current * self.args.gamma, 1e-6) self.optimizer.param_groups[0]['lr'] = self.lr_current self.optimizer.param_groups[1]['lr'] = self.lr_current * 10 def save_network(self, net, net_name, epoch, label=''): save_fname = '%s_%s_%s.pth' % (epoch, net_name, label) save_path = os.path.join(self.args.save_folder, self.args.exp_name, save_fname) torch.save(net.state_dict(), save_path) def load_weights(self, net, base_file): other, ext = os.path.splitext(base_file) if ext == '.pkl' or '.pth': print('Loading weights into state dict...') net.load_state_dict( torch.load(base_file, map_location=lambda storage, loc: storage)) print('Finished!') else: print('Sorry only .pth and .pkl files supported.') def load_trained_model(self): path = os.path.join(self.args.save_folder, self.args.exp_name, self.args.trained_model) print('eval cls, image net loading {}...'.format(path)) if self.args.ms: self.load_weights(self.net.Scale, path) else: self.load_weights(self.net, path) def eval(self, dataloader): assert self.phase == 'test', "Command arg phase should be 'test'. " from tqdm import tqdm self.net.eval() output = [] for i, image in tqdm(enumerate(dataloader)): if self.cuda: image = Variable(image.cuda(), volatile=True) else: image = Variable(image, volatile=True) # cls forward out = self.net(image) if isinstance(out, list): out_max = out[-1] if out_max.size(2) != image.size(2): out = self.interp(out_max) else: if out.size(2) != image.size(2): out = self.interp(out) # out [bs * num_tta, c, h, w] if self.args.use_tta: num_tta = len(tta_config) # out = F.softmax(out, dim=1) out = detta_score( out.view(num_tta, -1, self.args.num_classes, out.size(2), out.size(3))) # [num_tta, bs, nclass, H, W] out = out.mean(dim=0) # [bs, nclass, H, W] out = F.softmax(out) output.extend([ resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out ]) return np.array(output) def tta(self, dataloaders): results = np.zeros(shape=(len(dataloaders[0].dataset), self.args.num_classes)) for dataloader in dataloaders: output = self.eval(dataloader) results += output return np.argmax(results, 1) def tta_output(self, dataloaders): results = np.zeros(shape=(len(dataloaders[0].dataset), self.args.num_classes)) for dataloader in dataloaders: output = self.eval(dataloader) results += output return results def test_val(self, dataloader): assert self.phase == 'test', "Command arg phase should be 'test'. " from tqdm import tqdm self.net.eval() predict = [] true = [] t1 = time.time() for i, (image, mask) in tqdm(enumerate(dataloader)): if self.cuda: image = Variable(image.cuda(), volatile=True) label_image = Variable(mask.cuda(), volatile=True) else: image = Variable(image, volatile=True) label_image = Variable(mask, volatile=True) # cls forward out = self.net(image) if isinstance(out, list): out_max = out[-1] if out_max.size(2) != label_image.size(2): out = self.interp(out_max) else: if out.size(2) != image.size(2): out = self.interp(out) # out [bs * num_tta, c, h, w] if self.args.use_tta: num_tta = len(tta_config) # out = F.softmax(out, dim=1) out = detta_score( out.view(num_tta, -1, self.args.num_classes, out.size(2), out.size(3))) # [num_tta, bs, nclass, H, W] out = out.mean(dim=0) # [bs, nclass, H, W] out = F.softmax(out) if self.args.aug == 'heng': out = out[:, :, 11:11 + 202, 11:11 + 202] predict.extend([ resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out ]) # predict.extend([pred[1, :101, :101].data.cpu().numpy() for pred in out]) # pred.extend(out.data.cpu().numpy()) true.extend(label_image.data.cpu().numpy()) # pred_all = np.argmax(np.array(pred), 1) for t in np.arange(0.05, 0.51, 0.01): pred_all = np.array(predict) > t true_all = np.array(true).astype(np.int) # new_iou = intersection_over_union(true_all, pred_all) # new_iou_t = intersection_over_union_thresholds(true_all, pred_all) mean_iou, iou_t = mIoU(true_all, pred_all) print('threshold : {:.4f}'.format(t)) print('mean IoU : {:.4f}, IoU threshold : {:.4f}'.format( mean_iou, iou_t)) return predict, true def run_epoch(self, dataloader, writer, epoch, train=True, metrics=True): if train: self.net.train() flag = 'train' else: self.net.eval() flag = 'val' t2 = time.time() for image, mask in dataloader: if train and self.lr_policy != 'step': adjust_learning_rate(self.args.lr, self.optimizer, self.iters, self.iterations * len(dataloader), 0.9, self.cyclic_m, self.lr_policy) self.iters += 1 if self.cuda: image = Variable(image.cuda(), volatile=(not train)) label_image = Variable(mask.cuda(), volatile=(not train)) else: image = Variable(image, volatile=(not train)) label_image = Variable(mask, volatile=(not train)) # cls forward out = self.net(image) if isinstance(out, list): out_max = None loss = 0.0 for i, out_scale in enumerate(out): if out_scale.size(2) != label_image.size(2): out_scale = self.interp(out_scale) if i == (len(out) - 1): out_max = out_scale loss += self.criterion(out_scale, label_image) label_image_np = label_image.data.cpu().numpy() sig_out_np = out_max.data.cpu().numpy() acc = accuracy(label_image_np, np.argmax(sig_out_np, 1)) self.pred[flag].extend(sig_out_np) self.true[flag].extend(label_image_np) self.losses[flag].append(loss.data[0]) self.acces[flag].append(acc) else: if out.size(-1) != label_image.size(-1): out = self.interp(out) loss = self.criterion(out, label_image) label_image_np = label_image.data.cpu().numpy() sig_out_np = out.data.cpu().numpy() acc = accuracy(label_image_np, np.argmax(sig_out_np, 1)) self.pred[flag].extend(sig_out_np) self.true[flag].extend(label_image_np) self.losses[flag].append(loss.data[0]) self.acces[flag].append(acc) if train: self.optimizer.zero_grad() loss.backward() self.optimizer.step() if metrics: n = len(self.losses[flag]) loss = sum(self.losses[flag]) / n scalars = [ loss, ] names = [ 'loss', ] write_scalars(writer, scalars, names, epoch, tag=flag + '_loss') all_acc = sum(self.acces[flag]) / n scalars = [ all_acc, ] names = [ 'all_acc', ] write_scalars(writer, scalars, names, epoch, tag=flag + '_acc') # all_score = sum(self.scores[flag]) / n # scalars = [all_score, ] # names = ['all_score', ] # write_scalars(writer, scalars, names, epoch, tag=flag + '_score') pred_all = np.argmax(np.array(self.pred[flag]), 1) true_all = np.array(self.true[flag]).astype(np.int) mean_iou, iou_t = mIoU(true_all, pred_all) # new_iou = intersection_over_union(true_all, pred_all) # new_iou_t = intersection_over_union_thresholds(true_all, pred_all) scalars = [ mean_iou, iou_t, ] names = [ 'mIoU', 'mIoU_threshold', ] write_scalars(writer, scalars, names, epoch, tag=flag + '_IoU') scalars = [ self.optimizer.param_groups[0]['lr'], ] names = [ 'learning_rate', ] write_scalars(writer, scalars, names, epoch, tag=flag + '_lr') print( '{} loss: {:.4f} | acc: {:.4f} | mIoU: {:.4f} | mIoU_threshold: {:.4f} | n_iter: {} | learning_rate: {} | time: {:.2f}' .format(flag, loss, all_acc, mean_iou, iou_t, epoch, self.optimizer.param_groups[0]['lr'], time.time() - t2)) self.losses[flag] = [] self.pred[flag] = [] self.true[flag] = [] self.acces[flag] = [] self.scores[flag] = [] if (not train) and (iou_t >= self.best_val): if self.args.ms: if self.args.mGPUs: self.save_network(self.net.module.Scale, self.args.model_name, epoch=epoch, label='best') else: self.save_network(self.net.Scale, self.args.model_name, epoch=epoch, label='best') else: if self.args.mGPUs: self.save_network(self.net.module, self.args.model_name, epoch=epoch, label='best') else: self.save_network(self.net, self.args.model_name, epoch=epoch, label='best') print( 'val improve from {:.4f} to {:.4f} saving in best val_iteration {}' .format(self.best_val, iou_t, epoch)) self.best_val = iou_t self.count = 0 if (not train) and (self.best_val - iou_t > 0.003) and ( self.count < 10) and (self.lr_policy == 'step'): self.count += 1 if (not train) and (self.count >= 10) and (self.lr_policy == 'step'): self._adjust_learning_rate(epoch) self.count = 0 def train_val(self, dataloader_train, dataloader_val, writer): val_epoch = 0 for epoch in range(self.iterations): if (self.lr_policy == 'cyclic') and ( epoch % int(self.iterations / self.cyclic_m) == 0): print('-------start cycle {}------------'.format( epoch // int(self.iterations / self.cyclic_m))) self.best_val = 0.0 self.run_epoch(dataloader_train, writer, epoch, train=True, metrics=True) self.run_epoch(dataloader_val, writer, val_epoch, train=False, metrics=True) val_epoch += 1 if (epoch + 1) % self.args.save_freq == 0: if self.args.ms: if self.args.mGPUs: self.save_network( self.net.module.Scale, self.args.model_name, epoch=val_epoch, ) else: self.save_network( self.net.Scale, self.args.model_name, epoch=val_epoch, ) else: if self.args.mGPUs: self.save_network( self.net.module, self.args.model_name, epoch=val_epoch, ) else: self.save_network( self.net, self.args.model_name, epoch=val_epoch, ) print('saving in val_iteration {}'.format(val_epoch))
def train(datafile): # model = ResUNet(n_classes=2) model = UNet(n_channels=3, n_classes=2) if torch.cuda.is_available(): model.cuda() # criterion = SoftDiceLoss(batch_dice=True) criterion_CE = nn.CrossEntropyLoss() criterion_SD = SoftDiceLoss() optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9) vis = PytorchVisdomLogger(name="GIANA", port=8080) giana_transform, giana_train_loader, giana_valid_loader = giana_data_pipeline( datafile) for epoch in range(EPOCHS): iteration = 0 for iteration, (images, labels) in enumerate(giana_train_loader): # print('TRAIN', images.shape, labels.shape) images, labels = giana_transform.apply_transform([images, labels]) labels_onehot = make_one_hot(labels, 2) # for images, labels in giana_pool.imap_unordered(giana_transform.apply_transform, giana_iter): if torch.cuda.is_available(): images = images.cuda() labels_onehot = labels_onehot.cuda() optimizer.zero_grad() model.train() predictions = model(images) predictions_softmax = F.softmax(predictions, dim=1) # loss = 0.75 * criterion_CE(predictions, labels.squeeze().cuda().long()) + 0.25 * criterion_SD(predictions_softmax, labels_onehot) loss = criterion_CE(predictions, labels.squeeze().cuda().long()) # loss = criterion_SD(predictions_softmax, labels_onehot) loss.backward() optimizer.step() # iteration += 1 if iteration % PRINT_AFTER_ITERATIONS == 0: # print('Epoch: {0}, Iteration: {1}, Loss: {2}, Valid dice score: {3}'.format(epoch, iteration, loss, score)) print('Epoch: {0}, Iteration: {1}, Loss: {2}'.format( epoch, iteration, loss)) image_args = {'normalize': True, 'range': (0, 1)} # viz.show_image_grid(images=images.cpu()[:, 0, ].unsqueeze(1), name='Images_train', image_args=image_args) vis.show_image_grid( images=predictions_softmax.cpu()[:, 0, ].unsqueeze(1), name='Predictions_1', image_args=image_args) vis.show_image_grid( images=predictions_softmax.cpu()[:, 1, ].unsqueeze(1), name='Predictions_2', image_args=image_args) vis.show_image_grid(images=labels.cpu(), name='Ground truth') vis.show_value(value=loss.item(), name='Train_Loss', label='Loss', counter=epoch + (iteration / MAX_ITERATIONS)) if iteration == MAX_ITERATIONS: break score = model.predict(giana_valid_loader, SCORE_TYPE, MAX_VALIDATION_ITERATIONS, vis) vis.show_value(value=np.asarray([score]), name='TestDiceScore', label='Dice', counter=epoch) print( '\n--------------------------------------------------\nEpoch: {0}, Score: {1}, Loss: {2}\n--------------------------------------------------\n' .format(epoch, score, loss))
def train(train_sources, eval_source): path = sys.argv[1] dr = DataReader(path, train_sources) dr.read() print(len(dr.train.x)) batch_size = 8 device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') dataset_s_train = MultiDomainDataset(dr.train.x, dr.train.y, dr.train.vendor, device, DomainAugmentation()) dataset_s_dev = MultiDomainDataset(dr.dev.x, dr.dev.y, dr.dev.vendor, device) dataset_s_test = MultiDomainDataset(dr.test.x, dr.test.y, dr.test.vendor, device) loader_s_train = DataLoader(dataset_s_train, batch_size, shuffle=True) dr_eval = DataReader(path, [eval_source]) dr_eval.read() dataset_eval_dev = MultiDomainDataset(dr_eval.dev.x, dr_eval.dev.y, dr_eval.dev.vendor, device) dataset_eval_test = MultiDomainDataset(dr_eval.test.x, dr_eval.test.y, dr_eval.test.vendor, device) dataset_da_train = MultiDomainDataset(dr.train.x+dr_eval.train.x, dr.train.y+dr_eval.train.y, dr.train.vendor+dr_eval.train.vendor, device, DomainAugmentation()) loader_da_train = DataLoader(dataset_da_train, batch_size, shuffle=True) segmentator = UNet() discriminator = Discriminator(n_domains=len(train_sources)) discriminator.to(device) segmentator.to(device) sigmoid = nn.Sigmoid() selector = Selector() s_criterion = nn.BCELoss() d_criterion = nn.CrossEntropyLoss() s_optimizer = optim.AdamW(segmentator.parameters(), lr=0.0001, weight_decay=0.01) d_optimizer = optim.AdamW(discriminator.parameters(), lr=0.001, weight_decay=0.01) a_optimizer = optim.AdamW(segmentator.encoder.parameters(), lr=0.001, weight_decay=0.01) lmbd = 1/150 s_train_losses = [] s_dev_losses = [] d_train_losses = [] eval_domain_losses = [] train_dices = [] dev_dices = [] eval_dices = [] epochs = 3 da_loader_iter = iter(loader_da_train) for epoch in tqdm(range(epochs)): s_train_loss = 0.0 d_train_loss = 0.0 for index, sample in enumerate(loader_s_train): img = sample['image'] target_mask = sample['target'] da_sample = next(da_loader_iter, None) if epoch == 100: s_optimizer.defaults['lr'] = 0.001 d_optimizer.defaults['lr'] = 0.0001 if da_sample is None: da_loader_iter = iter(loader_da_train) da_sample = next(da_loader_iter, None) if epoch < 50 or epoch >= 100: # Training step of segmentator predicted_activations, inner_repr = segmentator(img) predicted_mask = sigmoid(predicted_activations) s_loss = s_criterion(predicted_mask, target_mask) s_optimizer.zero_grad() s_loss.backward() s_optimizer.step() s_train_loss += s_loss.cpu().detach().numpy() if epoch >= 50: # Training step of discriminator predicted_activations, inner_repr = segmentator(da_sample['image']) predicted_activations = predicted_activations.clone().detach() inner_repr = inner_repr.clone().detach() predicted_vendor = discriminator(predicted_activations, inner_repr) d_loss = d_criterion(predicted_vendor, da_sample['vendor']) d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() d_train_loss += d_loss.cpu().detach().numpy() if epoch >= 100: # adversarial training step predicted_mask, inner_repr = segmentator(da_sample['image']) predicted_vendor = discriminator(predicted_mask, inner_repr) a_loss = -1 * lmbd * d_criterion(predicted_vendor, da_sample['vendor']) a_optimizer.zero_grad() a_loss.backward() a_optimizer.step() lmbd += 1/150 inference_model = nn.Sequential(segmentator, selector, sigmoid) inference_model.to(device) inference_model.eval() d_train_losses.append(d_train_loss / len(loader_s_train)) s_train_losses.append(s_train_loss / len(loader_s_train)) s_dev_losses.append(calculate_loss(dataset_s_dev, inference_model, s_criterion, batch_size)) eval_domain_losses.append(calculate_loss(dataset_eval_dev, inference_model, s_criterion, batch_size)) train_dices.append(calculate_dice(inference_model, dataset_s_train)) dev_dices.append(calculate_dice(inference_model, dataset_s_dev)) eval_dices.append(calculate_dice(inference_model, dataset_eval_dev)) segmentator.train() date_time = datetime.now().strftime("%m%d%Y_%H%M%S") model_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "model", "weights", "segmentator"+str(date_time)+".pth") torch.save(segmentator.state_dict(), model_path) util.plot_data([(s_train_losses, 'train_losses'), (s_dev_losses, 'dev_losses'), (d_train_losses, 'discriminator_losses'), (eval_domain_losses, 'eval_domain_losses')], 'losses.png') util.plot_dice([(train_dices, 'train_dice'), (dev_dices, 'dev_dice'), (eval_dices, 'eval_dice')], 'dices.png') inference_model = nn.Sequential(segmentator, selector, sigmoid) inference_model.to(device) inference_model.eval() print('Dice on annotated: ', calculate_dice(inference_model, dataset_s_test)) print('Dice on unannotated: ', calculate_dice(inference_model, dataset_eval_test))
def train(args): result_path = 'result/%s/'%args.model if not os.path.exists(result_path): os.makedirs(result_path) os.makedirs('%simage'%result_path) os.makedirs('%scheckpoint'%result_path) train_set = MyDataset('train', args.label_type, 512) train_loader = DataLoader( train_set, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers) # device = 'cuda:0' device = 'cuda:6' if torch.cuda.device_count()>1 else 'cuda:0' out_channels = 1 if args.label_type=='msk' else 2 print(out_channels) if args.model=='unet': print('using unet as model!') model = UNet(out_channels=out_channels) elif args.model=='deeplab': print('using deeplab as model!') model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False) else: print('no model!') model = model.to(device) # model = nn.DataParallel(model) img_show = train_set.__getitem__(0)['x'] img_show = torch.tensor(img_show).to(device).float()[None, :] model.train() criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) loss_list = [] loss_best = 10 for epo in tqdm(range(1, args.epochs+1), ascii=True): epo_loss = [] for idx, item in enumerate(train_loader): x = item['x'].to(device, dtype=torch.float) y = item['y'].to(device, dtype=torch.float) optimizer.zero_grad() if args.model=='unet': pred = model(x) elif args.model=='deeplab': pred = model(x)['out'][:,0][:,None] # print(y.shape, pred.shape) loss = criterion(pred, y) # print(loss.item()) epo_loss.append(loss.data.item()) loss.backward() optimizer.step() epo_loss_mean = np.array(epo_loss).mean() # print(epo_loss_mean) loss_list.append(epo_loss_mean) plot_loss(loss_list, '%simage/loss.png'%result_path) with torch.no_grad(): if args.model=='unet': pred = model(img_show.clone()) elif args.model=='deeplab': pred = model(img_show.clone())['out'][:,0][:,None] # y = model(img_show) # print(img_show.shape) if args.label_type=='msk': x = img_show[0].cpu().detach().numpy().transpose((1,2,0)) y = pred[0, 0].cpu().detach().numpy() elif args.label_type=='flow': x = img_show[0].cpu().detach().numpy().transpose((1,2,0)) y = pred[0].cpu().detach().numpy().transpose((1,2,0)) plt.subplot(121) plt.imshow(x*255) plt.subplot(122) plt.imshow(y[:,:,0]) plt.savefig('%simage/%d.png'%(result_path, epo)) plt.clf() #loss if epo % 3 ==0: torch.save(model, '%scheckpoint/%d.pt'%(result_path, epo)) if epo_loss_mean < loss_best: loss_best = epo_loss_mean torch.save(model, '%scheckpoint/best.pt'%(result_path)) np.save('%sloss.npy'%result_path, np.array(loss_list))
import torch import torch.nn as nn from model.unet import UNet if __name__ == '__main__': device = torch.device('cuda:0') LEARNING_RATE = 1e-3 LR_DECAY_STEP = 2 LR_DECAY_FACTOR = 0.5 WEIGHT_DECAY = 5e-4 BATCH_SIZE = 4 MAX_EPOCHS = 30 MODEL = UNet(1, 2).to(device) OPTIMIZER = torch.optim.Adam(MODEL.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) LR_SCHEDULER = torch.optim.lr_scheduler.StepLR(OPTIMIZER, step_size=LR_DECAY_STEP, gamma=LR_DECAY_FACTOR) CRITERION = nn.CrossEntropyLoss().to(device) tr_path_raw = 'data/tr/raw' tr_path_label = 'data/tr/label' ts_path_raw = 'data/ts/raw' ts_path_label = 'data/ts/label' checkpoints_dir = 'checkpoints' checkpoint_frequency = 1000 dataloaders = make_dataloaders(tr_path_raw, tr_path_label, ts_path_raw, ts_path_label, BATCH_SIZE, n_workers=4) comment = 'liver_segmentation_U-Net_on_LITS_dataset_' verbose_train = 1 verbose_val = 500
def main(): args = get_args() # set GPU device os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu # default: '0' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # set model model = UNet(n_channels=1, n_classes=1).to(device) if len(args.gpu) > 1: # if multi-gpu model = torch.nn.DataParallel(model) img_size = args.img_size # default: 512 # set transforms for dataset import torchvision.transforms as transforms from my_transforms import GrayScale, Resize, ToTensor, histogram_equalize, gamma_correction custom_transforms = transforms.Compose([ GrayScale(), Resize(img_size), histogram_equalize(), gamma_correction(0.5), ToTensor(), ]) # set Dataset and DataLoader chn_train = chn_dataset(split='train', transforms=custom_transforms) chn_val = chn_dataset(split='val', transforms=custom_transforms) mcu_train = mcu_dataset(split='train', transforms=custom_transforms) mcu_val = mcu_dataset(split='val', transforms=custom_transforms) from torch.utils.data import DataLoader dataloader = { 'train': { 'chn': DataLoader(dataset=chn_train, batch_size=args.batch_size, num_workers=args.n_workers, shuffle=True), 'mcu': DataLoader(dataset=mcu_train, batch_size=args.batch_size, num_workers=args.n_workers, shuffle=True) }, 'val': { 'chn': DataLoader(dataset=chn_val, batch_size=args.batch_size, num_workers=args.n_workers), 'mcu': DataLoader(dataset=mcu_val, batch_size=args.batch_size, num_workers=args.n_workers) } } # checkpoint dir checkpoint_dir = os.path.join(os.getcwd(), 'checkpoint') if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) checkpoint_path = args.load_model # set optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) # learning rate scheduler # from torch.optim.lr_scheduler import StepLR # scheduler = StepLR(optimizer, step_size = 3 , gamma = 0.8) # option 2. from torch.optim.lr_scheduler import ReduceLROnPlateau scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=5) criterion = nn.BCEWithLogitsLoss() train_and_validate(net=model, criterion=criterion, optimizer=optimizer, dataloader=dataloader, device=device, epochs=args.epochs, scheduler=scheduler, load_model=checkpoint_path)