def __init__(self, input_modalites, output_channels, base_channel, is_pretrain=False): super(CenterLSTMDecoder, self).__init__(input_modalites, output_channels, base_channel) if is_pretrain: backbone = init_U_Net(input_modalites, output_channels, base_channel, softmax=False) ckp_path = 'best_newdata/UNet-p64-b4-newdata-oriinput_best_model.pth.tar' backbone = WrappedModel(backbone) checkpoint = torch.load(ckp_path, map_location=torch.device('cpu')) for param in backbone.parameters(): param.requires_grad = False backbone.load_state_dict(checkpoint['model_state_dict']) self.up_sample_1 = backbone.module.up_sample_1 self.up_sample_2 = backbone.module.up_sample_2 self.up_sample_3 = backbone.module.up_sample_3 self.up_conv1 = backbone.module.up_conv1 self.up_conv2 = backbone.module.up_conv2 self.up_conv3 = backbone.module.up_conv3 self.out = backbone.module.out for param in self.out.parameters(): param.requires_grad = True nn.init.kaiming_normal_(self.out.weight, mode='fan_out', nonlinearity='leaky_relu')
def __init__(self, input_modalites, output_channels, base_channel, is_pretrain=True): super(CenterLSTMEncoder, self).__init__(input_modalites, output_channels, base_channel) if is_pretrain: backbone = init_U_Net(input_modalites, output_channels, base_channel, softmax=False) ckp_path = 'best_newdata/UNet-p64-b4-newdata-oriinput_best_model.pth.tar' backbone = WrappedModel(backbone) checkpoint = torch.load(ckp_path, map_location=torch.device('cpu')) for param in backbone.parameters(): param.requires_grad = False backbone.load_state_dict(checkpoint['model_state_dict']) self.down_conv1 = backbone.module.down_conv1 self.down_conv2 = backbone.module.down_conv2 self.down_conv3 = backbone.module.down_conv3 self.down_sample_1 = backbone.module.down_sample_1 self.down_sample_2 = backbone.module.down_sample_2 self.down_sample_3 = backbone.module.down_sample_3
def predict_use(args): model_name = args.model_name patient_path = args.patient_path config_file = 'config.yaml' cfg = load_config(config_file) input_modalites = int(cfg['PARAMETERS']['input_modalites']) output_channels = int(cfg['PARAMETERS']['output_channels']) base_channels = int(cfg['PARAMETERS']['base_channels']) patience = int(cfg['PARAMETERS']['patience']) ROOT = cfg['PATH']['root'] best_dir = cfg['PATH']['best_model_path'] best_model_dir = os.path.join(ROOT, best_dir) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # load best trained model net = init_U_Net(input_modalites, output_channels, base_channels) net.to(device) optimizer = optim.Adam(net.parameters()) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) ckp_path = os.path.join(best_model_dir, model_name + '_best_model.pth.tar') net, _, _, _, _, _ = load_ckp(ckp_path, net, optimizer, scheduler) # predict predict(net, model_name, patient_path, ROOT, save_mask=True)
def __init__(self, input_modalites, output_channels, base_channel, num_layers, num_connects, pad_method='pad', conv_type='plain', softmax=True, is_pretrain=True): super(ShortcutLSTMBody, self).__init__(input_modalites, output_channels, base_channel, pad_method, softmax) self.input_modalites = input_modalites self.output_channels = output_channels self.base_channel = base_channel self.pad_method = pad_method self.softmax = softmax self.num_layers = num_layers self.conv_type = conv_type self.num_connects = num_connects if is_pretrain: backbone = init_U_Net(input_modalites, output_channels, base_channel, softmax=False) ckp_path = 'best_newdata/UNet-p64-b4-newdata-oriinput_best_model.pth.tar' backbone = WrappedModel(backbone) checkpoint = torch.load(ckp_path, map_location=torch.device('cpu')) for param in backbone.parameters(): param.requires_grad = False backbone.load_state_dict(checkpoint['model_state_dict']) self.down_conv1 = backbone.module.down_conv1 self.down_conv2 = backbone.module.down_conv2 self.down_conv3 = backbone.module.down_conv3 self.down_sample_1 = backbone.module.down_sample_1 self.down_sample_2 = backbone.module.down_sample_2 self.down_sample_3 = backbone.module.down_sample_3 self.bridge = backbone.module.bridge self.up_sample_1 = backbone.module.up_sample_1 self.up_sample_2 = backbone.module.up_sample_2 self.up_sample_3 = backbone.module.up_sample_3 self.up_conv1 = backbone.module.up_conv1 self.up_conv2 = backbone.module.up_conv2 self.up_conv3 = backbone.module.up_conv3 # self.out = backbone.module.out self.up_conv1 = nn.Sequential(*list(self.up_conv1.block)[:3]) self.up_conv2 = nn.Sequential(*list(self.up_conv2.block)[:3]) self.up_conv3 = nn.Sequential(*list(self.up_conv3.block)[:3])
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined model_name = args.model_name model_loss_fn = args.loss_fn config_file = 'config.yaml' config = load_config(config_file) data_root = config['PATH']['data_root'] labels = config['PARAMETERS']['labels'] root_path = config['PATH']['root'] model_dir = config['PATH']['model_path'] best_dir = config['PATH']['best_model_path'] data_class = config['PATH']['data_class'] input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) base_channel = int(config['PARAMETERS']['base_channels']) crop_size = int(config['PARAMETERS']['crop_size']) batch_size = int(config['PARAMETERS']['batch_size']) epochs = int(config['PARAMETERS']['epoch']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) ignore_idx = int(config['PARAMETERS']['ignore_index']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_data', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('Dataset is loading ...') # split dataset dir_ = os.path.join(data_root, data_class) data_content = train_split(dir_) # load training set and validation set train_set = data_loader(data_content=data_content, key='train', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) n_train = len(train_set) train_loader = train_set.load() val_set = data_loader(data_content=data_content, key='val', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) logger.info('Dataset loading finished!') n_val = len(val_set) nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) net = init_U_Net(input_modalites, output_channels, base_channel) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' if torch.cuda.device_count() > 1: logger.info('{} GPUs available.'.format(torch.cuda.device_count())) net = nn.DataParallel(net) net.to(device) if model_loss_fn == 'Dice': criterion = DiceLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'CrossEntropy': criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'FocalLoss': criterion = FocalLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_CE': criterion = Dice_CE(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_FL': criterion = Dice_FL(labels=labels, ignore_index=ignore_idx) else: raise NotImplementedError() optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) net, optimizer = amp.initialize(net, optimizer, opt_level='O1') min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'BG_acc': [], 'NET_acc': [], 'ED_acc': [], 'ET_acc': [] } if is_resume: try: ckp_path = os.path.join(model_dir, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) except: logger.warning( 'No checkpoint available, strat training from scratch') # start training for epoch in range(start_epoch, epochs): # setup to train mode net.train() running_loss = 0 dice_coeff_bg = 0 dice_coeff_net = 0 dice_coeff_ed = 0 dice_coeff_et = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): images, segs = data['image'].to(device), data['seg'].to(device) # zero the parameter gradients optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, segs) # loss.backward() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() # save the output at the begining of each epoch to visulize it if i == 0: in_images = images.detach().cpu().numpy()[:, 0, ...] in_segs = segs.detach().cpu().numpy() in_pred = outputs.detach().cpu().numpy() heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch + 1) running_loss += loss.detach().item() dice_score = dice_coe(outputs.detach().cpu(), segs.detach().cpu()) dice_coeff_bg += dice_score['BG'] dice_coeff_ed += dice_score['ED'] dice_coeff_et += dice_score['ET'] dice_coeff_net += dice_score['NET'] # show progress bar pbar.set_postfix( **{ 'Training loss': loss.detach().item(), 'Training (avg) accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) global_step += 1 if global_step % nb_batches == 0: # validate net.eval() val_loss, val_acc = validation(net, val_set, criterion, device, batch_size) train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['BG_acc'].append(dice_coeff_bg / nb_batches) train_info['NET_acc'].append(dice_coeff_net / nb_batches) train_info['ED_acc'].append(dice_coeff_ed / nb_batches) train_info['ET_acc'].append(dice_coeff_et / nb_batches) # save bast trained model scheduler.step(running_loss / nb_batches) if min_loss > val_loss: min_loss = val_loss is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) logger.info('The average training loss for this epoch is {}'.format( running_loss / (np.ceil(n_train / batch_size)))) logger.info( 'Validation dice loss: {}; Validation (avg) accuracy: {}'.format( val_loss, val_acc)) logger.info('The best validation loss till now is {}'.format(min_loss)) # save the training info every epoch logger.info('Writing the training info into file ...') with open(train_info_file, 'w') as fp: json.dump(train_info, fp) loss_plot(train_info_file, name=model_name) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break logger.info('finish training!')