class Solver(object): def __init__(self, model, exp_name, device, num_class, optim=torch.optim.SGD, optim_args={}, loss_func=losses.DiceLoss(), model_name='OneShotSegmentor', labels=None, num_epochs=10, log_nth=5, lr_scheduler_step_size=5, lr_scheduler_gamma=0.5, use_last_checkpoint=True, exp_dir='experiments', log_dir='logs'): self.device = device self.model = model self.model_name = model_name self.labels = labels self.num_epochs = num_epochs if torch.cuda.is_available(): self.loss_func = loss_func.cuda(device) else: self.loss_func = loss_func self.optim = optim([{ 'params': model.squeeze_conv_bn.parameters() }, { 'params': model.squeeze_conv_d1.parameters() }, { 'params': model.squeeze_conv_d2.parameters() }, { 'params': model.squeeze_conv_d3.parameters() }, { 'params': model.conditioner.parameters(), 'lr': 1e-2, 'momentum': 0.95, 'weight_decay': 0.001 }, { 'params': model.segmentor.parameters(), 'lr': 1e-2, 'momentum': 0.95, 'weight_decay': 0.001 }], **optim_args) self.scheduler = lr_scheduler.StepLR(self.optim, step_size=2, gamma=0.1) exp_dir_path = os.path.join(exp_dir, exp_name) common_utils.create_if_not(exp_dir_path) common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) self.exp_dir_path = exp_dir_path self.log_nth = log_nth self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels) self.use_last_checkpoint = use_last_checkpoint self.start_epoch = 1 self.start_iteration = 1 self.best_ds_mean = 0 self.best_ds_mean_epoch = 0 if use_last_checkpoint: self.load_checkpoint() def train(self, train_loader, test_loader): """ Train a given model with the provided data. Inputs: - train_loader: train data in torch.utils.data.DataLoader - val_loader: val data in torch.utils.data.DataLoader """ model, optim, scheduler = self.model, self.optim, self.scheduler data_loader = {'train': train_loader, 'val': test_loader} if torch.cuda.is_available(): torch.cuda.empty_cache() model.cuda(self.device) self.logWriter.log( 'START TRAINING. : model name = %s, device = %s' % (self.model_name, torch.cuda.get_device_name(self.device))) current_iteration = self.start_iteration warm_up_epoch = 5 val_old = 0 change_model = False current_model = 'seg' for epoch in range(self.start_epoch, self.num_epochs + 1): self.logWriter.log( 'train', "\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs)) for phase in ['train', 'val']: self.logWriter.log("<<<= Phase: %s =>>>" % phase) loss_arr = [] input_img_list = [] y_list = [] out_list = [] condition_input_img_list = [] condition_y_list = [] if phase == 'train': model.train() scheduler.step() else: model.eval() for i_batch, sampled_batch in enumerate(data_loader[phase]): X = sampled_batch[0].type(torch.FloatTensor) y = sampled_batch[1].type(torch.LongTensor) w = sampled_batch[2].type(torch.FloatTensor) query_label = data_loader[phase].batch_sampler.query_label input1, input2, y1, y2 = split_batch( X, y, int(query_label)) condition_input = torch.mul(input1, y1.unsqueeze(1)) query_input = input2 if model.is_cuda: condition_input, query_input, y2 = condition_input.cuda( self.device, non_blocking=True), query_input.cuda( self.device, non_blocking=True), y2.cuda(self.device, non_blocking=True) output = model(condition_input, query_input) # TODO: add weights loss = self.loss_func(output, y2) optim.zero_grad() loss.backward() if phase == 'train': optim.step() if i_batch % self.log_nth == 0: self.logWriter.loss_per_iter( loss.item(), i_batch, current_iteration) current_iteration += 1 loss_arr.append(loss.item()) # batch_output = output > 0.5 _, batch_output = torch.max(F.softmax(output, dim=1), dim=1) out_list.append(batch_output.cpu()) input_img_list.append(input2.cpu()) y_list.append(y2.cpu()) condition_input_img_list.append(input1.cpu()) condition_y_list.append(y1) del X, y, w, output, batch_output, loss, input1, input2, y2 torch.cuda.empty_cache() if phase == 'val': if i_batch != len(data_loader[phase]) - 1: print("#", end='', flush=True) else: print("100%", flush=True) if phase == 'train': self.logWriter.log('saving checkpoint ....') self.save_checkpoint( { 'epoch': epoch + 1, 'start_iteration': current_iteration + 1, 'arch': self.model_name, 'state_dict': model.state_dict(), 'optimizer': optim.state_dict(), 'scheduler': scheduler.state_dict(), }, os.path.join( self.exp_dir_path, CHECKPOINT_DIR, 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) with torch.no_grad(): input_img_arr = torch.cat(input_img_list) y_arr = torch.cat(y_list) out_arr = torch.cat(out_list) condition_input_img_arr = torch.cat( condition_input_img_list) condition_y_arr = torch.cat(condition_y_list) current_loss = self.logWriter.loss_per_epoch( loss_arr, phase, epoch) if phase == 'val': if epoch > warm_up_epoch: self.logWriter.log("Diff : " + str(current_loss - val_old)) change_model = (current_loss - val_old) > 0.001 if change_model and current_model == 'seg': self.logWriter.log("Setting to con") current_model = 'con' elif change_model and current_model == 'con': self.logWriter.log("Setting to seg") current_model = 'seg' val_old = current_loss index = np.random.choice(len(out_arr), 3, replace=False) self.logWriter.image_per_epoch( out_arr[index], y_arr[index], phase, epoch, additional_image=(input_img_arr[index], condition_input_img_arr[index], condition_y_arr[index])) self.logWriter.dice_score_per_epoch( phase, out_arr, y_arr, epoch) self.logWriter.log("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====") self.logWriter.log('FINISH.') self.logWriter.close() def save_checkpoint(self, state, filename): torch.save(state, filename) def load_checkpoint(self): checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION) list_of_files = glob.glob(checkpoint_path) if len(list_of_files) > 0: latest_file = max(list_of_files, key=os.path.getctime) self.logWriter.log( "=> loading checkpoint '{}'".format(latest_file)) checkpoint = torch.load(latest_file) self.start_epoch = checkpoint['epoch'] self.start_iteration = checkpoint['start_iteration'] self.model.load_state_dict(checkpoint['state_dict']) self.optim.load_state_dict(checkpoint['optimizer']) for state in self.optim.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(self.device) self.scheduler.load_state_dict(checkpoint['scheduler']) self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format( latest_file, checkpoint['epoch'])) else: self.logWriter.log("=> no checkpoint found at '{}' folder".format( os.path.join(self.exp_dir_path, CHECKPOINT_DIR)))
class Solver(object): def __init__(self, model, exp_name, device, num_class, optim=torch.optim.SGD, optim_args={}, loss_func=additional_losses.CombinedLoss(), model_name='quicknat', labels=None, num_epochs=10, log_nth=5, lr_scheduler_step_size=5, lr_scheduler_gamma=0.5, use_last_checkpoint=True, exp_dir='experiments', log_dir='logs', arch_file_path=None): self.device = device self.model = model # self.swa_model = torch.optim.swa_utils.AveragedModel(self.model) self.model_name = model_name self.labels = labels self.num_epochs = num_epochs if torch.cuda.is_available(): self.loss_func = loss_func.cuda(device) else: self.loss_func = loss_func self.optim = optim(model.parameters(), **optim_args) # self.scheduler = lr_scheduler.StepLR(self.optim, step_size=lr_scheduler_step_size, # gamma=lr_scheduler_gamma) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim, T_max=100) # self.swa_start = -1 #int(np.round(self.num_epochs*0.75)) # print(self.swa_start) # self.swa_scheduler = torch.optim.swa_utils.SWALR(self.optim, swa_lr=0.05) exp_dir_path = os.path.join(exp_dir, exp_name) common_utils.create_if_not(exp_dir_path) common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) self.exp_dir_path = exp_dir_path self.save_architectural_files(arch_file_path) self.log_nth = log_nth self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels) # self.wandb = wandb self.use_last_checkpoint = use_last_checkpoint self.start_epoch = 1 self.start_iteration = 1 self.best_ds_mean = 0 self.best_ds_mean_epoch = 0 if use_last_checkpoint: self.load_checkpoint() print(self.best_ds_mean, self.best_ds_mean_epoch, self.start_epoch) # TODO:Need to correct the CM and dice score calculation. def train(self, train_loader, val_loader): """ Train a given model with the provided data. Inputs: - train_loader: train data in torch.utils.data.DataLoader - val_loader: val data in torch.utils.data.DataLoader """ model, optim, scheduler = self.model, self.optim, self.scheduler # self.wandb.watch(model) # swa_model, swa_scheduler, swa_start = self.swa_model, self.swa_scheduler, self.swa_start dataloaders = { 'train': train_loader, 'val': val_loader } if torch.cuda.is_available(): torch.cuda.empty_cache() model.cuda(self.device) print('START TRAINING. : model name = %s, device = %s' % ( self.model_name, torch.cuda.get_device_name(self.device))) current_iteration = self.start_iteration for epoch in range(self.start_epoch, self.num_epochs + 1): print("\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs)) for phase in ['train', 'val']: print("<<<= Phase: %s =>>>" % phase) loss_arr = [] out_list = [] y_list = [] if phase == 'train': model.train() scheduler.step() else: model.eval() for i_batch, sample_batched in enumerate(dataloaders[phase]): X = sample_batched[0].type(torch.FloatTensor) y = sample_batched[1].type(torch.LongTensor) w = sample_batched[3].type(torch.FloatTensor) wd = sample_batched[2].type(torch.FloatTensor) if model.is_cuda: X, y, w, wd = X.cuda(self.device, non_blocking=True), y.cuda(self.device, non_blocking=True), \ w.cuda(self.device, non_blocking=True), wd.cuda(self.device, non_blocking=True) output = model(X) if phase == 'val': pass loss = self.loss_func(output, y, wd, None) if phase == 'train': optim.zero_grad() loss.backward() optim.step() #scheduler.step(epoch) # if epoch > swa_start: # swa_model.update_parameters(model) # swa_scheduler.step() # else: # scheduler.step() if i_batch % self.log_nth == 0: self.logWriter.loss_per_iter(loss.item(), i_batch, current_iteration) current_iteration += 1 loss_arr.append(loss.item()) _, batch_output = torch.max(output, dim=1) out_list.append(batch_output.cpu()) y_list.append(y.cpu()) del X, y, output, batch_output, loss, wd, w torch.cuda.empty_cache() if phase == 'val': if i_batch != len(dataloaders[phase]) - 1: print("#", end='', flush=True) else: print("100%", flush=True) with torch.no_grad(): out_arr, y_arr = torch.cat(out_list), torch.cat(y_list) self.logWriter.loss_per_epoch(loss_arr, phase, epoch) index = np.random.choice(len(dataloaders[phase].dataset.X), 3, replace=False) print("index", index) val_imgs, val_labels = dataloaders[phase].dataset.getItem(index) predicted_imgs = model.predict(val_imgs, self.device) if val_imgs.shape[1] > 1: mid_slice = val_imgs.shape[1]//2 val_imgs = val_imgs[:, mid_slice, :, :] self.logWriter.image_per_epoch(val_imgs, predicted_imgs, val_labels, phase, epoch) self.logWriter.cm_per_epoch(phase, out_arr, y_arr, epoch) ds_mean = self.logWriter.dice_score_per_epoch(phase, out_arr, y_arr, epoch) if phase == 'val': if (ds_mean > self.best_ds_mean): self.best_ds_mean = ds_mean self.best_ds_mean_epoch = epoch print(out_arr.shape, epoch, ds_mean, self.best_ds_mean, self.best_ds_mean_epoch) print("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====") self.save_checkpoint({ 'epoch': epoch + 1, 'start_iteration': current_iteration + 1, 'arch': self.model_name, 'best_ds_mean': self.best_ds_mean, 'best_ds_mean_epoch': self.best_ds_mean_epoch, 'state_dict': model.state_dict(), 'optimizer': optim.state_dict(), 'scheduler': scheduler.state_dict() }, os.path.join(self.exp_dir_path, CHECKPOINT_DIR, 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) # torch.optim.swa_utils.update_bn(dataloaders['train'], swa_model) # self.model = swa_model self.model = model print('FINISH.') self.logWriter.close() def save_architectural_files(self, arch_file_paths): if arch_file_paths is not None: arch_file_path, setting_path = arch_file_paths destination = os.path.join(self.exp_dir_path, ARCHITECTURE_DIR) common_utils.create_if_not(destination) arch_base = "/".join(arch_file_path.split('/')[:-1]) print(arch_file_path, arch_base, setting_path, destination+'/model.py') shutil.copy(arch_file_path, destination+'/model.py') shutil.copy(f'{arch_base}/run.py', f'{destination}/run.py') shutil.copy(f'{arch_base}/solver.py', f'{destination}/solver.py') shutil.copy(f'{arch_base}/utils/evaluator.py', f'{destination}/utils-evaluator.py') shutil.copy(f'{arch_base}/nn_common_modules/losses.py', f'{destination}/nn_common_modules-losses.py') shutil.copy(f'{arch_base}/nn_common_modules/modules.py', f'{destination}/nn_common_modules-modules.py') shutil.copy(f'{setting_path}', f'{destination}/settings.ini') else: print('No Architectural file!!!') def save_best_model(self, path): """ Save model with its parameters to the given path. Conventionally the path should end with "*.model". Inputs: - path: path string """ print('Saving model... %s' % path) print('Best Model at Epoch: ' + str(self.best_ds_mean_epoch)) print('Best Model with val Dice Score: ' + str(self.best_ds_mean)) self.load_checkpoint(self.best_ds_mean_epoch) torch.save(self.model, path) def save_checkpoint(self, state, filename): torch.save(state, filename) def load_checkpoint(self, epoch=None): if epoch is not None: checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION) self._load_checkpoint_file(checkpoint_path) else: all_files_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION) list_of_files = glob.glob(all_files_path) if len(list_of_files) > 0: checkpoint_path = max(list_of_files, key=os.path.getctime) self._load_checkpoint_file(checkpoint_path) else: self.logWriter.log( "=> no checkpoint found at '{}' folder".format(os.path.join(self.exp_dir_path, CHECKPOINT_DIR))) def _load_checkpoint_file(self, file_path): self.logWriter.log("=> loading checkpoint '{}'".format(file_path)) checkpoint = torch.load(file_path) self.start_epoch = checkpoint['epoch'] if 'best_ds_mean' in checkpoint.keys(): self.best_ds_mean = checkpoint['best_ds_mean'] self.best_ds_mean_epoch = checkpoint['best_ds_mean_epoch'] self.start_iteration = checkpoint['start_iteration'] self.model.load_state_dict(checkpoint['state_dict']) self.optim.load_state_dict(checkpoint['optimizer']) for state in self.optim.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(self.device) self.scheduler.load_state_dict(checkpoint['scheduler']) self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format(file_path, checkpoint['epoch']))
class Solver(object): def __init__(self, model, exp_name, device, num_class, optim=torch.optim.SGD, optim_args={}, loss_func=losses.CombinedLoss(), model_name='segmentor', labels=None, num_epochs=10, log_nth=5, lr_scheduler_step_size=5, lr_scheduler_gamma=0.5, use_last_checkpoint=True, exp_dir='experiments', log_dir='logs'): self.device = device self.model = model self.model_name = model_name self.labels = labels self.num_epochs = num_epochs if torch.cuda.is_available(): self.loss_func = loss_func.cuda(device) else: self.loss_func = loss_func self.optim = optim(model.parameters(), **optim_args) self.scheduler = lr_scheduler.StepLR(self.optim, step_size=lr_scheduler_step_size, gamma=lr_scheduler_gamma) exp_dir_path = os.path.join(exp_dir, exp_name) common_utils.create_if_not(exp_dir_path) common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) self.exp_dir_path = exp_dir_path self.log_nth = log_nth self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels) self.use_last_checkpoint = use_last_checkpoint self.start_epoch = 1 self.start_iteration = 1 self.best_ds_mean = 0 self.best_ds_mean_epoch = 0 if use_last_checkpoint: self.load_checkpoint() # TODO:Need to correct the CM and dice score calculation. def train(self, train_loader, val_loader): """ Train a given model with the provided data. Inputs: - train_loader: train data in torch.utils.data.DataLoader - val_loader: val data in torch.utils.data.DataLoader """ model, optim, scheduler = self.model, self.optim, self.scheduler dataloaders = {'train': train_loader, 'val': val_loader} if torch.cuda.is_available(): torch.cuda.empty_cache() model.cuda(self.device) print('START TRAINING. : model name = %s, device = %s' % (self.model_name, torch.cuda.get_device_name(self.device))) current_iteration = self.start_iteration for epoch in range(self.start_epoch, self.num_epochs + 1): print("\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs)) for phase in ['train', 'val']: print("<<<= Phase: %s =>>>" % phase) loss_arr = [] out_list = [] y_list = [] if phase == 'train': model.train() scheduler.step() else: model.eval() for i_batch, sample_batched in enumerate(dataloaders[phase]): X = sample_batched[0].type(torch.FloatTensor) y = sample_batched[1].type(torch.LongTensor) w = sample_batched[2].type(torch.FloatTensor) if model.is_cuda: X, y, w = X.cuda( self.device, non_blocking=True), y.cuda( self.device, non_blocking=True), w.cuda(self.device, non_blocking=True) output = model(X) loss = self.loss_func(output, y, w) if phase == 'train': optim.zero_grad() loss.backward() optim.step() if i_batch % self.log_nth == 0: self.logWriter.loss_per_iter( loss.item(), i_batch, current_iteration) current_iteration += 1 loss_arr.append(loss.item()) _, batch_output = torch.max(output, dim=1) out_list.append(batch_output.cpu()) y_list.append(y.cpu()) del X, y, w, output, batch_output, loss torch.cuda.empty_cache() if phase == 'val': if i_batch != len(dataloaders[phase]) - 1: print("#", end='', flush=True) else: print("100%", flush=True) with torch.no_grad(): out_arr, y_arr = torch.cat(out_list), torch.cat(y_list) self.logWriter.loss_per_epoch(loss_arr, phase, epoch) index = np.random.choice(len(dataloaders[phase].dataset.X), 3, replace=False) self.logWriter.image_per_epoch( model.predict(dataloaders[phase].dataset.X[index], self.device), dataloaders[phase].dataset.y[index], phase, epoch, dataloaders[phase].dataset.X[index]) self.logWriter.cm_per_epoch(phase, out_arr, y_arr, epoch) ds_mean = self.logWriter.dice_score_per_epoch_segmentor( phase, out_arr, y_arr, epoch) if phase == 'val': if ds_mean > self.best_ds_mean: self.best_ds_mean = ds_mean self.best_ds_mean_epoch = epoch print("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====") self.save_checkpoint( { 'epoch': epoch + 1, 'start_iteration': current_iteration + 1, 'arch': self.model_name, 'state_dict': model.state_dict(), 'optimizer': optim.state_dict(), 'best_ds_mean_epoch': self.best_ds_mean_epoch, 'scheduler': scheduler.state_dict() }, os.path.join( self.exp_dir_path, CHECKPOINT_DIR, 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) print('FINISH.') self.logWriter.close() def save_best_model(self, path): """ Save model with its parameters to the given path. Conventionally the path should end with "*.model". Inputs: - path: path string """ print('Saving model... %s' % path) print("Best Epoch... " + str(self.best_ds_mean_epoch)) self.load_checkpoint(self.best_ds_mean_epoch) torch.save(self.model, path) def save_checkpoint(self, state, filename): torch.save(state, filename) def load_checkpoint(self, epoch=None): if epoch is not None: checkpoint_path = os.path.join( self.exp_dir_path, CHECKPOINT_DIR, 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION) self._load_checkpoint_file(checkpoint_path) else: all_files_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION) list_of_files = glob.glob(all_files_path) if len(list_of_files) > 0: checkpoint_path = max(list_of_files, key=os.path.getctime) self._load_checkpoint_file(checkpoint_path) else: self.logWriter.log( "=> no checkpoint found at '{}' folder".format( os.path.join(self.exp_dir_path, CHECKPOINT_DIR))) def _load_checkpoint_file(self, file_path): self.logWriter.log("=> loading checkpoint '{}'".format(file_path)) # checkpoint = torch.load(file_path) # checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION) # list_of_files = glob.glob(checkpoint_path) # if len(list_of_files) > 0: # latest_file = max(list_of_files, key=os.path.getctime) # self.logWriter.log("=> loading checkpoint '{}'".format(latest_file)) checkpoint = torch.load(file_path) self.start_epoch = checkpoint['epoch'] self.start_iteration = checkpoint['start_iteration'] self.best_ds_mean_epoch = checkpoint['best_ds_mean_epoch'] self.model.load_state_dict(checkpoint['state_dict']) self.optim.load_state_dict(checkpoint['optimizer']) for state in self.optim.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(self.device) self.scheduler.load_state_dict(checkpoint['scheduler']) self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format( file_path, checkpoint['epoch']))
class Solver(object): def __init__(self, model, exp_name, device, num_class, optim=torch.optim.SGD, optim_args={}, loss_func=nn.BCELoss(), model_name='OneShotSegmentor', labels=None, num_epochs=10, log_nth=5, lr_scheduler_step_size=5, lr_scheduler_gamma=0.5, use_last_checkpoint=True, exp_dir='experiments', log_dir='logs'): self.device = device self.model = model self.model_name = model_name self.labels = labels self.num_epochs = num_epochs if torch.cuda.is_available(): self.loss_func = loss_func.cuda(device) else: self.loss_func = loss_func # self.optim = optim(model.parameters(), **optim_args) self.optim_c = optim([{ 'params': model.conditioner.parameters(), 'lr': 1e-4, 'momentum': 0.99, 'weight_decay': 0.0001 }], **optim_args) self.optim_s = optim([{ 'params': model.segmentor.parameters(), 'lr': 1e-4, 'momentum': 0.99, 'weight_decay': 0.0001 }], **optim_args) # self.scheduler = lr_scheduler.StepLR(self.optim, step_size=5, # gamma=0.1) self.scheduler_s = lr_scheduler.StepLR(self.optim_s, step_size=10, gamma=0.1) self.scheduler_c = lr_scheduler.StepLR(self.optim_c, step_size=10, gamma=0.001) exp_dir_path = os.path.join(exp_dir, exp_name) common_utils.create_if_not(exp_dir_path) common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) self.exp_dir_path = exp_dir_path self.log_nth = log_nth self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels) self.use_last_checkpoint = use_last_checkpoint self.start_epoch = 1 self.start_iteration = 1 self.best_ds_mean = 0 self.best_ds_mean_epoch = 0 if use_last_checkpoint: self.load_checkpoint() def train(self, train_loader, test_loader): """ Train a given model with the provided data. Inputs: - train_loader: train data in torch.utils.data.DataLoader - val_loader: val data in torch.utils.data.DataLoader """ model, optim_c, optim_s, scheduler_c, scheduler_s = self.model, self.optim_c, self.optim_s, self.scheduler_c, self.scheduler_s data_loader = {'train': train_loader, 'val': test_loader} if torch.cuda.is_available(): torch.cuda.empty_cache() model.cuda(self.device) self.logWriter.log( 'START TRAINING. : model name = %s, device = %s' % (self.model_name, torch.cuda.get_device_name(self.device))) current_iteration = self.start_iteration warm_up_epoch = 15 val_old = 0 change_model = False current_model = 'seg' for epoch in range(self.start_epoch, self.num_epochs + 1): self.logWriter.log( 'train', "\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs)) if epoch > warm_up_epoch: if current_model == 'seg': self.logWriter.log("Optimizing Segmentor") optim = optim_s elif current_model == 'con': optim = optim_c self.logWriter.log("Optimizing Conditioner") for phase in ['train', 'val']: self.logWriter.log("<<<= Phase: %s =>>>" % phase) loss_arr = [] input_img_list = [] y_list = [] out_list = [] condition_input_img_list = [] condition_y_list = [] if phase == 'train': model.train() scheduler_c.step() scheduler_s.step() else: model.eval() for i_batch, sampled_batch in enumerate(data_loader[phase]): X = sampled_batch[0].type(torch.FloatTensor) y = sampled_batch[1].type(torch.LongTensor) w = sampled_batch[2].type(torch.FloatTensor) query_label = data_loader[phase].batch_sampler.query_label input1, input2, y1, y2 = split_batch( X, y, int(query_label)) condition_input = torch.cat((input1, y1.unsqueeze(1)), dim=1) query_input = input2 y1 = y1.type(torch.LongTensor) # TODO: Only for shaban baseline y2 = y2.type(torch.FloatTensor) if model.is_cuda: condition_input, query_input, y2, y1 = condition_input.cuda( self.device, non_blocking=True), query_input.cuda( self.device, non_blocking=True), y2.cuda( self.device, non_blocking=True), y1.cuda( self.device, non_blocking=True) weights = model.conditioner(condition_input) output = model.segmentor(query_input, weights) # TODO: add weights # loss_weights = (1, 0) if epoch < 5 else (0.5, 0.5) # loss = self.loss_func(F.softmax(output, dim=1), y2) loss = self.loss_func(torch.sigmoid(output), y2) optim_s.zero_grad() optim_c.zero_grad() loss.backward() if phase == 'train': if epoch <= warm_up_epoch: optim_s.step() optim_c.step() elif epoch > warm_up_epoch and change_model: optim.step() # # TODO: value needs to be optimized, Gradient Clipping (Optional) # if epoch > 1: # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.0001) if i_batch % self.log_nth == 0: self.logWriter.loss_per_iter( loss.item(), i_batch, current_iteration) current_iteration += 1 loss_arr.append(loss.item()) batch_output = output.squeeze() > 0.5 # _, batch_output = torch.max(F.softmax(output, dim=1), dim=1) batch_output.cpu() batch_output.type(torch.FloatTensor) out_list.append(batch_output) input_img_list.append(input2.cpu()) y_list.append(y2.cpu()) condition_input_img_list.append(input1.cpu()) condition_y_list.append(y1) del X, y, w, output, batch_output, loss, input1, input2, y2 torch.cuda.empty_cache() if phase == 'val': if i_batch != len(data_loader[phase]) - 1: print("#", end='', flush=True) else: print("100%", flush=True) if phase == 'train': self.logWriter.log('saving checkpoint ....') self.save_checkpoint( { 'epoch': epoch + 1, 'start_iteration': current_iteration + 1, 'arch': self.model_name, 'state_dict': model.state_dict(), 'optimizer_c': optim_c.state_dict(), 'scheduler_c': scheduler_c.state_dict(), 'optimizer_s': optim_s.state_dict(), 'best_ds_mean_epoch': self.best_ds_mean_epoch, 'scheduler_s': scheduler_s.state_dict() }, os.path.join( self.exp_dir_path, CHECKPOINT_DIR, 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) with torch.no_grad(): input_img_arr = torch.cat(input_img_list) y_arr = torch.cat(y_list) out_arr = torch.cat(out_list) condition_input_img_arr = torch.cat( condition_input_img_list) condition_y_arr = torch.cat(condition_y_list) current_loss = self.logWriter.loss_per_epoch( loss_arr, phase, epoch) if phase == 'val': if epoch > warm_up_epoch: self.logWriter.log("Diff : " + str(current_loss - val_old)) change_model = (current_loss - val_old) > 0.001 if change_model and current_model == 'seg': self.logWriter.log("Setting to con") current_model = 'con' elif change_model and current_model == 'con': self.logWriter.log("Setting to seg") current_model = 'seg' val_old = current_loss index = np.random.choice(len(out_arr), 3, replace=False) self.logWriter.image_per_epoch( out_arr[index], y_arr[index], phase, epoch, additional_image=(input_img_arr[index], condition_input_img_arr[index], condition_y_arr[index])) ds_mean = self.logWriter.dice_score_per_epoch( phase, out_arr, y_arr, epoch) if phase == 'val': if ds_mean > self.best_ds_mean: self.best_ds_mean = ds_mean self.best_ds_mean_epoch = epoch self.logWriter.log("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====") self.logWriter.log('FINISH.') self.logWriter.close() def save_checkpoint(self, state, filename): torch.save(state, filename) def save_best_model(self, path): """ Save model with its parameters to the given path. Conventionally the path should end with "*.model". Inputs: - path: path string """ print('Saving model... %s' % path) print("Best Epoch... " + str(self.best_ds_mean_epoch)) self.load_checkpoint(self.best_ds_mean_epoch) torch.save(self.model, path) def load_checkpoint(self, epoch=None): if epoch is not None: checkpoint_path = os.path.join( self.exp_dir_path, CHECKPOINT_DIR, 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION) self._load_checkpoint_file(checkpoint_path) else: all_files_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION) list_of_files = glob.glob(all_files_path) if len(list_of_files) > 0: checkpoint_path = max(list_of_files, key=os.path.getctime) self._load_checkpoint_file(checkpoint_path) else: self.logWriter.log( "=> no checkpoint found at '{}' folder".format( os.path.join(self.exp_dir_path, CHECKPOINT_DIR))) # def _load_checkpoint_file(self, file_path): # self.logWriter.log("=> loading checkpoint '{}'".format(file_path)) # checkpoint = torch.load(file_path) # self.start_epoch = checkpoint['epoch'] # self.start_iteration = checkpoint['start_iteration'] # self.model.load_state_dict(checkpoint['state_dict']) # self.optim.load_state_dict(checkpoint['optimizer']) # # for state in self.optim.state.values(): # for k, v in state.items(): # if torch.is_tensor(v): # state[k] = v.to(self.device) # # self.scheduler.load_state_dict(checkpoint['scheduler']) # self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format(file_path, checkpoint['epoch'])) def _load_checkpoint_file(self, file_path): self.logWriter.log("=> loading checkpoint '{}'".format(file_path)) # checkpoint = torch.load(file_path) # checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION) # list_of_files = glob.glob(checkpoint_path) # if len(list_of_files) > 0: # latest_file = max(list_of_files, key=os.path.getctime) # self.logWriter.log("=> loading checkpoint '{}'".format(latest_file)) checkpoint = torch.load(file_path) self.start_epoch = checkpoint['epoch'] self.start_iteration = checkpoint['start_iteration'] self.best_ds_mean_epoch = checkpoint['best_ds_mean_epoch'] self.model.load_state_dict(checkpoint['state_dict']) self.optim_c.load_state_dict(checkpoint['optimizer_c']) self.optim_s.load_state_dict(checkpoint['optimizer_s']) for state in self.optim_c.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(self.device) for state in self.optim_s.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(self.device) self.scheduler_c.load_state_dict(checkpoint['scheduler_c']) self.scheduler_s.load_state_dict(checkpoint['scheduler_s']) self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format( file_path, checkpoint['epoch']))