def predict(self, X, y, query_label, device=0, enable_dropout=False): """ Predicts the outout after the model is trained. Inputs: - X: Volume to be predicted """ self.eval() input1, input2, y2 = split_batch(X, y, query_label) input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) if enable_dropout: self.enable_test_dropout() with torch.no_grad(): out = self.forward(input1, input2) # max_val, idx = torch.max(out, 1) idx = out > 0.5 idx = idx.data.cpu().numpy() prediction = np.squeeze(idx) del X, out, idx return prediction
def train(self, train_loader, test_loader): """ Train a given model with the provided data. Inputs: - train_loader: train data in torch.utils.data.DataLoader - val_loader: val data in torch.utils.data.DataLoader """ model, optim, scheduler = self.model, self.optim, self.scheduler data_loader = {'train': train_loader, 'val': test_loader} if torch.cuda.is_available(): torch.cuda.empty_cache() model.cuda(self.device) self.logWriter.log( 'START TRAINING. : model name = %s, device = %s' % (self.model_name, torch.cuda.get_device_name(self.device))) current_iteration = self.start_iteration warm_up_epoch = 5 val_old = 0 change_model = False current_model = 'seg' for epoch in range(self.start_epoch, self.num_epochs + 1): self.logWriter.log( 'train', "\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs)) for phase in ['train', 'val']: self.logWriter.log("<<<= Phase: %s =>>>" % phase) loss_arr = [] input_img_list = [] y_list = [] out_list = [] condition_input_img_list = [] condition_y_list = [] if phase == 'train': model.train() scheduler.step() else: model.eval() for i_batch, sampled_batch in enumerate(data_loader[phase]): X = sampled_batch[0].type(torch.FloatTensor) y = sampled_batch[1].type(torch.LongTensor) w = sampled_batch[2].type(torch.FloatTensor) query_label = data_loader[phase].batch_sampler.query_label input1, input2, y1, y2 = split_batch( X, y, int(query_label)) condition_input = torch.mul(input1, y1.unsqueeze(1)) query_input = input2 if model.is_cuda: condition_input, query_input, y2 = condition_input.cuda( self.device, non_blocking=True), query_input.cuda( self.device, non_blocking=True), y2.cuda(self.device, non_blocking=True) output = model(condition_input, query_input) # TODO: add weights loss = self.loss_func(output, y2) optim.zero_grad() loss.backward() if phase == 'train': optim.step() if i_batch % self.log_nth == 0: self.logWriter.loss_per_iter( loss.item(), i_batch, current_iteration) current_iteration += 1 loss_arr.append(loss.item()) # batch_output = output > 0.5 _, batch_output = torch.max(F.softmax(output, dim=1), dim=1) out_list.append(batch_output.cpu()) input_img_list.append(input2.cpu()) y_list.append(y2.cpu()) condition_input_img_list.append(input1.cpu()) condition_y_list.append(y1) del X, y, w, output, batch_output, loss, input1, input2, y2 torch.cuda.empty_cache() if phase == 'val': if i_batch != len(data_loader[phase]) - 1: print("#", end='', flush=True) else: print("100%", flush=True) if phase == 'train': self.logWriter.log('saving checkpoint ....') self.save_checkpoint( { 'epoch': epoch + 1, 'start_iteration': current_iteration + 1, 'arch': self.model_name, 'state_dict': model.state_dict(), 'optimizer': optim.state_dict(), 'scheduler': scheduler.state_dict(), }, os.path.join( self.exp_dir_path, CHECKPOINT_DIR, 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) with torch.no_grad(): input_img_arr = torch.cat(input_img_list) y_arr = torch.cat(y_list) out_arr = torch.cat(out_list) condition_input_img_arr = torch.cat( condition_input_img_list) condition_y_arr = torch.cat(condition_y_list) current_loss = self.logWriter.loss_per_epoch( loss_arr, phase, epoch) if phase == 'val': if epoch > warm_up_epoch: self.logWriter.log("Diff : " + str(current_loss - val_old)) change_model = (current_loss - val_old) > 0.001 if change_model and current_model == 'seg': self.logWriter.log("Setting to con") current_model = 'con' elif change_model and current_model == 'con': self.logWriter.log("Setting to seg") current_model = 'seg' val_old = current_loss index = np.random.choice(len(out_arr), 3, replace=False) self.logWriter.image_per_epoch( out_arr[index], y_arr[index], phase, epoch, additional_image=(input_img_arr[index], condition_input_img_arr[index], condition_y_arr[index])) self.logWriter.dice_score_per_epoch( phase, out_arr, y_arr, epoch) self.logWriter.log("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====") self.logWriter.log('FINISH.') self.logWriter.close()
def train(self, train_loader, test_loader): """ Train a given model with the provided data. Inputs: - train_loader: train data in torch.utils.data.DataLoader - val_loader: val data in torch.utils.data.DataLoader """ model, optim_c, optim_s, scheduler_c, scheduler_s = self.model, self.optim_c, self.optim_s, self.scheduler_c, self.scheduler_s data_loader = { 'train': train_loader, 'val': test_loader } if torch.cuda.is_available(): torch.cuda.empty_cache() model.cuda(self.device) print('START TRAINING. : model name = %s, device = %s' % ( self.model_name, torch.cuda.get_device_name(self.device))) current_iteration = self.start_iteration for epoch in range(self.start_epoch, self.num_epochs + 1): print("\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs)) for phase in ['train', 'val']: print("<<<= Phase: %s =>>>" % phase) loss_arr = [] input_img_list = [] y_list = [] out_list = [] condition_input_img_list = [] condition_y_list = [] if phase == 'train': model.train() scheduler_c.step() scheduler_s.step() else: model.eval() for i_batch, sampled_batch in enumerate(data_loader[phase]): X = sampled_batch[0].type(torch.FloatTensor) y = sampled_batch[1].type(torch.LongTensor) w = sampled_batch[2].type(torch.FloatTensor) query_label = data_loader[phase].batch_sampler.query_label input1, input2, y1, y2 = split_batch(X, y, int(query_label)) condition_input = torch.mul(input1, y1.unsqueeze(1)) if model.is_cuda: condition_input, input2, y2 = condition_input.cuda(self.device, non_blocking=True), input2.cuda( self.device, non_blocking=True), y2.cuda( self.device, non_blocking=True) weights = model.conditioner(condition_input) output = model.segmentor(input2, weights) # TODO: add weights loss = self.loss_func(output, y2) optim_s.zero_grad() optim_c.zero_grad() loss.backward() if phase == 'train': if epoch <= 1: optim_s.step() optim_c.step() elif epoch in [2, 3, 6, 7, 10]: optim_s.step() elif epoch in [4, 5, 8, 9]: optim_c.step() # # TODO: value needs to be optimized, Gradient Clipping (Optional) # if epoch > 1: # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.0001) if i_batch % self.log_nth == 0: self.logWriter.loss_per_iter(loss.item(), i_batch, current_iteration) current_iteration += 1 loss_arr.append(loss.item()) # batch_output = output > 0.5 _, batch_output = torch.max(output, dim=1) out_list.append(batch_output.cpu()) input_img_list.append(input2.cpu()) y_list.append(y2.cpu()) condition_input_img_list.append(input1.cpu()) condition_y_list.append(y1) del X, y, w, output, batch_output, loss, input1, input2, y2 torch.cuda.empty_cache() if phase == 'val': if i_batch != len(data_loader[phase]) - 1: print("#", end='', flush=True) else: print("100%", flush=True) if phase == 'train': print('saving checkpoint ....') self.save_checkpoint({ 'epoch': epoch + 1, 'start_iteration': current_iteration + 1, 'arch': self.model_name, 'state_dict': model.state_dict(), 'optimizer_c': optim_c.state_dict(), 'scheduler_c': scheduler_c.state_dict(), 'optimizer_s': optim_s.state_dict(), 'scheduler_s': scheduler_s.state_dict() }, os.path.join(self.exp_dir_path, CHECKPOINT_DIR, 'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) with torch.no_grad(): input_img_arr = torch.cat(input_img_list) y_arr = torch.cat(y_list) out_arr = torch.cat(out_list) condition_input_img_arr = torch.cat(condition_input_img_list) condition_y_arr = torch.cat(condition_y_list) self.logWriter.loss_per_epoch(loss_arr, phase, epoch) index = np.random.choice(len(out_arr), 3, replace=False) self.logWriter.image_per_epoch(out_arr[index], y_arr[index], phase, epoch, additional_image=( input_img_arr[index], condition_input_img_arr[index], condition_y_arr[index])) self.logWriter.dice_score_per_epoch(phase, out_arr, y_arr, epoch) print("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====") print('FINISH.') self.logWriter.close()