def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): # import ipdb; ipdb.set_trace() if not os.path.exists('./checkpoint'): os.makedirs('checkpoint') model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) checkpoint_path = './checkpoint/chkpoint_colab_{}.pt'.format(epoch) lr_scheduler = None if epoch == 0: warmup_factor = 1. / 1000 warmup_iters = min(1000, len(data_loader) - 1) lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) total_loss = 0.0 for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] loss_dict = model(images, targets) # import ipdb; ipdb.set_trace() losses = sum(loss for loss in loss_dict.values()) #option total_loss = total_loss + losses # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) loss_value = losses_reduced.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() losses.backward() optimizer.step() if lr_scheduler is not None: lr_scheduler.step() metric_logger.update(loss=losses_reduced, **loss_dict_reduced) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) tb.add_scalar('Train Loss', total_loss, epoch) checkpoint = { 'epoch': epoch + 1, 'train_loss_min': losses, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } utils.save_ckp(checkpoint, False, checkpoint_path, None)
def train_model(model, dataloaders, criterion, optimizer, scheduler, dataset_sizes, checkpoint_path, num_epochs=25): print(f"saving to {checkpoint_path}") since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0 loss_p = {'train':[],'val':[]} acc_p = {'train':[],'val':[]} for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 running_corrects = 0 # Iterate over data. tk = tqdm(dataloaders[phase], total=len(dataloaders[phase])) for inputs, labels in tk: inputs = inputs.to(config.DEVICE) labels = labels.to(config.DEVICE) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # backward + optimize only if in training phase if phase == 'train': # loss.backward() # optimizer.step() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # torch.cuda.empty_cache() # statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) # print("running loss ",running_loss) if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] loss_p[phase].append(epoch_loss) acc_p[phase].append(epoch_acc) print('{} Loss: {:.4f} Acc: {:.4f}'.format( phase, epoch_loss, epoch_acc)) # deep copy the model if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) checkpoint = { 'epoch': epoch, 'valid_acc': best_acc, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } # checkpoint_path = "/content/drive/MyDrive/competitions/mosaic-r1/weights/res18.pt" print(f"saving to {checkpoint_path}") save_ckp(checkpoint, checkpoint_path) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) # load best model weights model.load_state_dict(best_model_wts) plot(loss_p,acc_p,num_epochs) return model, best_acc
def main(ckp_path=None): """ckp_path (str): checkpoint_path Train the model from scratch if ckp_path is None else Re-Train the model from previous checkpoint """ cli_args = get_train_args(__author__, __version__) # Variables data_dir = cli_args.data_dir save_dir = cli_args.save_dir file_name = cli_args.file_name use_gpu = cli_args.use_gpu # LOAD DATA data_loaders = load_data(data_dir, config.IMG_SIZE, config.BATCH_SIZE) # BUILD MODEL if ckp_path == None: model = initialize_model(model_name=config.MODEL_NAME, num_classes=config.NO_OF_CLASSES, feature_extract=True, use_pretrained=True) else: model = load_ckp(ckp_path) # Device is available or not device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # If the user wants the gpu mode, check if cuda is available if (use_gpu == True) and (torch.cuda.is_available() == False): print("GPU mode is not available, using CPU...") use_gpu = False # MOVE MODEL TO AVAILBALE DEVICE model.to(device) # DEFINE OPTIMIZER optimizer = optimizer_fn(model_name=config.MODEL_NAME, model=model, lr_rate=config.LR_RATE) # DEFINE SCHEDULER scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=5, factor=0.3, verbose=True) # DEFINE LOSS FUNCTION criterion = loss_fn() # LOAD BEST MODEL'S WEIGHTS best_model_wts = copy.deepcopy(model.state_dict()) # BEST VALIDATION SCORE if ckp_path == None: best_score = -1 # IF MODEL IS TRAIN FROM SCRATCH else: best_score = model.best_score # IF MODEL IS RE-TRAIN # NO OF ITERATION no_epochs = config.EPOCHS # KEEP TRACK OF LOSS AND ACCURACY IN EACH EPOCH stats = { 'train_losses': [], 'valid_losses': [], 'train_accuracies': [], 'valid_accuracies': [] } print("Models's Training Start......") for epoch in range(1, no_epochs + 1): train_loss, train_score = train_fn(data_loaders, model, optimizer, criterion, device, phase='train') val_loss, val_score = eval_fn(data_loaders, model, criterion, device=config.DEVICE, phase='valid') scheduler.step(val_loss) # SAVE MODEL'S WEIGHTS IF MODEL' VALIDATION ACCURACY IS INCREASED if val_score > best_score: print( 'Validation score increased ({:.6f} --> {:.6f}). Saving model ...' .format(best_score, val_score)) best_score = val_score best_model_wts = copy.deepcopy( model.state_dict()) #Saving the best model' weights # MAKE A RECORD OF AVERAGE LOSSES AND ACCURACY IN EACH EPOCH FOR PLOTING stats['train_losses'].append(train_loss) stats['valid_losses'].append(val_loss) stats['train_accuracies'].append(train_score) stats['valid_accuracies'].append(val_score) # PRINT TRAINING AND VALIDATION LOOS/ACCURACIES AFTER EACH EPOCH epoch_len = len(str(no_epochs)) print_msg = (f'[{epoch:>{epoch_len}}/{no_epochs:>{epoch_len}}] ' + '\t' + f'train_loss: {train_loss:.5f} ' + '\t' + f'train_score: {train_score:.5f} ' + '\t' + f'valid_loss: {val_loss:.5f} ' + '\t' + f'valid_score: {val_score:.5f}') print(print_msg) # load best model weights model.load_state_dict(best_model_wts) # create checkpoint variable and add important data model.class_to_idx = data_loaders['train'].dataset.class_to_idx model.best_score = best_score model.model_name = config.MODEL_NAME checkpoint = { 'epoch': no_epochs, 'lr_rate': config.LR_RATE, 'model_name': config.MODEL_NAME, 'batch_size': config.BATCH_SIZE, 'valid_score': best_score, 'optimizer': optimizer.state_dict(), 'state_dict': model.state_dict(), 'class_to_idx': model.class_to_idx } # SAVE CHECKPOINT save_ckp(checkpoint, save_dir, file_name) print("Models's Training is Successfull......") return model
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined parameters model_name = args.model_name model_type = args.model_type lstm_backbone = args.lstmbase unet_backbone = args.unetbase layer_num = args.layer_num nb_shortcut = args.nb_shortcut loss_fn = args.loss_fn world_size = args.world_size rank = args.rank base_channel = args.base_channels crop_size = args.crop_size ignore_idx = args.ignore_idx return_sequence = args.return_sequence variant = args.LSTM_variant epochs = args.epoch is_pretrain = args.is_pretrain # system setup parameters config_file = 'config.yaml' config = load_config(config_file) labels = config['PARAMETERS']['labels'] root_path = config['PATH']['model_root'] model_dir = config['PATH']['save_ckp'] best_dir = config['PATH']['save_best_model'] input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) batch_size = int(config['PARAMETERS']['batch_size']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) time_step = int(config['PARAMETERS']['time_step']) num_workers = int(config['PARAMETERS']['num_workers']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) lr = int(config['PARAMETERS']['lr']) optimizer = config['PARAMETERS']['optimizer'] connect = config['PARAMETERS']['connect'] conv_type = config['PARAMETERS']['lstm_convtype'] # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_newdata', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('labels {} are ignored'.format(ignore_idx)) logger.info('Dataset is loading ...') writer = SummaryWriter('ProcessVisu/%s' % model_name) # load training set and validation set data_class = data_split() train, val, test = data_construction(data_class) train_dict = time_parser(train, time_patch=time_step) val_dict = time_parser(val, time_patch=time_step) # LSTM initilization if model_type == 'LSTM': net = LSTMSegNet(lstm_backbone=lstm_backbone, input_dim=input_modalites, output_dim=output_channels, hidden_dim=base_channel, kernel_size=3, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence) elif model_type == 'UNet_LSTM': if variant == 'back': net = BackLSTM(input_dim=input_modalites, hidden_dim=base_channel, output_dim=output_channels, kernel_size=3, num_layers=layer_num, conv_type=conv_type, lstm_backbone=lstm_backbone, unet_module=unet_backbone, base_channel=base_channel, return_sequence=return_sequence, is_pretrain=is_pretrain) logger.info( 'the pretrained status of backbone is {}'.format(is_pretrain)) elif variant == 'center': net = CenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'bicenter': net = BiCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'directcenter': net = DirectCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'bidirectcenter': net = BiDirectCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'rescenter': net = ResCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'birescenter': net = BiResCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'shortcut': net = ShortcutLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, num_connects=nb_shortcut, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) else: raise NotImplementedError() # loss and optimizer setup if loss_fn == 'Dice': criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx) elif loss_fn == 'GDice': criterion = GneralizedDiceLoss(labels=labels) elif loss_fn == 'WCE': criterion = WeightedCrossEntropyLoss(labels=labels) else: raise NotImplementedError() if optimizer == 'adam': optimizer = optim.Adam(net.parameters(), lr=0.001) # optimizer = optim.Adam(net.parameters()) elif optimizer == 'sgd': optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) # device setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # net, optimizer = amp.initialize(net, optimizer, opt_level="O1") if torch.cuda.device_count() > 1: torch.distributed.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:38366', rank=rank, world_size=world_size) if distributed_is_initialized(): print('distributed is initialized') net.to(device) net = nn.parallel.DistributedDataParallel(net, find_unused_parameters=True) else: print('data parallel') net = nn.DataParallel(net) net.to(device) min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'label_0_acc': [], 'label_1_acc': [], 'label_2_acc': [], 'label_3_acc': [], 'label_4_acc': [] } if is_resume: try: # open previous check points ckp_path = os.path.join(model_path, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) logger.info( 'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}' .format(train_info['label_0_acc'][-1], train_info['label_1_acc'][-1], train_info['label_2_acc'][-1], train_info['label_3_acc'][-1], train_info['label_4_acc'][-1])) except: logger.warning( 'No checkpoint available, strat training from scratch') for epoch in range(start_epoch, epochs): train_set = data_loader(train_dict, batch_size=batch_size, key='train', num_works=num_workers, time_step=time_step, patch=crop_size, model_type='RNN') n_train = len(train_set) val_set = data_loader(val_dict, batch_size=batch_size, key='val', num_works=num_workers, time_step=time_step, patch=crop_size, model_type='CNN') n_val = len(val_set) logger.info('Dataset loading finished!') nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) train_loader = train_set.load() # setup to train mode net.train() running_loss = 0 dice_score_label_0 = 0 dice_score_label_1 = 0 dice_score_label_2 = 0 dice_score_label_3 = 0 dice_score_label_4 = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): # i : patient images, segs = data['image'].to(device), data['seg'].to(device) outputs = net(images) loss = criterion(outputs, segs) loss.backward() optimizer.step() # if i == 0: # in_images = images.detach().cpu().numpy()[0] # in_segs = segs.detach().cpu().numpy()[0] # in_pred = outputs.detach().cpu().numpy()[0] # heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch+1, is_train=True) running_loss += loss.detach().item() outputs = outputs.view(-1, outputs.shape[-4], outputs.shape[-3], outputs.shape[-2], outputs.shape[-1]) segs = segs.view(-1, segs.shape[-3], segs.shape[-2], segs.shape[-1]) _, preds = torch.max(outputs.data, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx=None) dice_score_label_0 += dice_score['bg'] dice_score_label_1 += dice_score['csf'] dice_score_label_2 += dice_score['gm'] dice_score_label_3 += dice_score['wm'] dice_score_label_4 += dice_score['tm'] # show progress bar pbar.set_postfix( **{ 'training loss': loss.detach().item(), 'Training accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) global_step += 1 if global_step % nb_batches == 0: net.eval() val_loss, val_acc, val_info = validation(net, val_set, criterion, device, batch_size, ignore_idx=None, name=model_name, epoch=epoch + 1) net.train() train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['label_0_acc'].append(dice_score_label_0 / nb_batches) train_info['label_1_acc'].append(dice_score_label_1 / nb_batches) train_info['label_2_acc'].append(dice_score_label_2 / nb_batches) train_info['label_3_acc'].append(dice_score_label_3 / nb_batches) train_info['label_4_acc'].append(dice_score_label_4 / nb_batches) # save bast trained model scheduler.step(running_loss / nb_batches) logger.info('Epoch: {}, LR: {}'.format( epoch + 1, optimizer.param_groups[0]['lr'])) if min_loss > running_loss / nb_batches: min_loss = running_loss / nb_batches is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) # summarize the training results of this epoch logger.info('The average training loss for this epoch is {}'.format( running_loss / nb_batches)) logger.info('The best training loss till now is {}'.format(min_loss)) logger.info( 'Validation dice loss: {}; Validation (avg) accuracy of the last timestep: {}' .format(val_loss, val_acc)) # save the training info every epoch logger.info('Writing the training info into file ...') val_info_file = os.path.join(intermidiate_data_save, '{}_val_info.json'.format(model_name)) with open(train_info_file, 'w') as fp: json.dump(train_info, fp) with open(val_info_file, 'w') as fp: json.dump(val_info, fp) for name, layer in net.named_parameters(): if layer.requires_grad: writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), epoch) writer.add_histogram(name + '_data', layer.cpu().data.numpy(), epoch) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break loss_plot(train_info_file, name=model_name) logger.info('finish training!') return
def train(pretrained=True): train_df, val_df = utils.process_csv(train_csv) train_set = utils.Wheatset(train_df, train_dir, phase='train') val_set = utils.Wheatset(val_df, train_dir, phase='validation') # batching def collate_fn(batch): return tuple(zip(*batch)) train_data_loader = DataLoader(train_set, batch_size=8, shuffle=False, num_workers=2, collate_fn=collate_fn) valid_data_loader = DataLoader(val_set, batch_size=8, shuffle=False, num_workers=2, collate_fn=collate_fn) # images, targets, ids = next(iter(train_data_loader)) # images = list(image.to(device) for image in images) # targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # construct fasterrcnn network model = models.construct_models() if pretrained: WEIGHTS_FILE = '/checkpoints/bestmodel_may28.pt' weights = torch.load(WEIGHTS_FILE) model.load_state_dict(weights['state_dict']) model.to(device) params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005) #train num_epochs = 5 train_loss_min = 0.9 total_train_loss = [] checkpoint_path = '/checkpoints/chkpoint_' best_model_path = '/checkpoints/bestmodel_may28.pt' for epoch in range(num_epochs): print(f'Epoch :{epoch + 1}') start_time = time.time() train_loss = [] model.train() for images, targets, image_ids in train_data_loader: images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) train_loss.append(losses.item()) optimizer.zero_grad() losses.backward() optimizer.step() # train_loss/len(train_data_loader.dataset) epoch_train_loss = np.mean(train_loss) total_train_loss.append(epoch_train_loss) print(f'Epoch train loss is {epoch_train_loss}') # if lr_scheduler is not None: # lr_scheduler.step() # create checkpoint variable and add important data checkpoint = { 'epoch': epoch + 1, 'train_loss_min': epoch_train_loss, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } # save checkpoint utils.save_ckp(checkpoint, False, checkpoint_path, best_model_path) ## TODO: save the model if validation loss has decreased if epoch_train_loss <= train_loss_min: print( 'Train loss decreased ({:.6f} --> {:.6f}). Saving model ...'. format(train_loss_min, epoch_train_loss)) # save checkpoint as best model utils.save_ckp(checkpoint, True, checkpoint_path, best_model_path) train_loss_min = epoch_train_loss time_elapsed = time.time() - start_time print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined model_name = args.model_name model_loss_fn = args.loss_fn config_file = 'config.yaml' config = load_config(config_file) data_root = config['PATH']['data_root'] labels = config['PARAMETERS']['labels'] root_path = config['PATH']['root'] model_dir = config['PATH']['model_path'] best_dir = config['PATH']['best_model_path'] data_class = config['PATH']['data_class'] input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) base_channel = int(config['PARAMETERS']['base_channels']) crop_size = int(config['PARAMETERS']['crop_size']) batch_size = int(config['PARAMETERS']['batch_size']) epochs = int(config['PARAMETERS']['epoch']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) ignore_idx = int(config['PARAMETERS']['ignore_index']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_data', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('Dataset is loading ...') # split dataset dir_ = os.path.join(data_root, data_class) data_content = train_split(dir_) # load training set and validation set train_set = data_loader(data_content=data_content, key='train', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) n_train = len(train_set) train_loader = train_set.load() val_set = data_loader(data_content=data_content, key='val', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) logger.info('Dataset loading finished!') n_val = len(val_set) nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) net = init_U_Net(input_modalites, output_channels, base_channel) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' if torch.cuda.device_count() > 1: logger.info('{} GPUs available.'.format(torch.cuda.device_count())) net = nn.DataParallel(net) net.to(device) if model_loss_fn == 'Dice': criterion = DiceLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'CrossEntropy': criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'FocalLoss': criterion = FocalLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_CE': criterion = Dice_CE(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_FL': criterion = Dice_FL(labels=labels, ignore_index=ignore_idx) else: raise NotImplementedError() optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) net, optimizer = amp.initialize(net, optimizer, opt_level='O1') min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'BG_acc': [], 'NET_acc': [], 'ED_acc': [], 'ET_acc': [] } if is_resume: try: ckp_path = os.path.join(model_dir, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) except: logger.warning( 'No checkpoint available, strat training from scratch') # start training for epoch in range(start_epoch, epochs): # setup to train mode net.train() running_loss = 0 dice_coeff_bg = 0 dice_coeff_net = 0 dice_coeff_ed = 0 dice_coeff_et = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): images, segs = data['image'].to(device), data['seg'].to(device) # zero the parameter gradients optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, segs) # loss.backward() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() # save the output at the begining of each epoch to visulize it if i == 0: in_images = images.detach().cpu().numpy()[:, 0, ...] in_segs = segs.detach().cpu().numpy() in_pred = outputs.detach().cpu().numpy() heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch + 1) running_loss += loss.detach().item() dice_score = dice_coe(outputs.detach().cpu(), segs.detach().cpu()) dice_coeff_bg += dice_score['BG'] dice_coeff_ed += dice_score['ED'] dice_coeff_et += dice_score['ET'] dice_coeff_net += dice_score['NET'] # show progress bar pbar.set_postfix( **{ 'Training loss': loss.detach().item(), 'Training (avg) accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) global_step += 1 if global_step % nb_batches == 0: # validate net.eval() val_loss, val_acc = validation(net, val_set, criterion, device, batch_size) train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['BG_acc'].append(dice_coeff_bg / nb_batches) train_info['NET_acc'].append(dice_coeff_net / nb_batches) train_info['ED_acc'].append(dice_coeff_ed / nb_batches) train_info['ET_acc'].append(dice_coeff_et / nb_batches) # save bast trained model scheduler.step(running_loss / nb_batches) if min_loss > val_loss: min_loss = val_loss is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) logger.info('The average training loss for this epoch is {}'.format( running_loss / (np.ceil(n_train / batch_size)))) logger.info( 'Validation dice loss: {}; Validation (avg) accuracy: {}'.format( val_loss, val_acc)) logger.info('The best validation loss till now is {}'.format(min_loss)) # save the training info every epoch logger.info('Writing the training info into file ...') with open(train_info_file, 'w') as fp: json.dump(train_info, fp) loss_plot(train_info_file, name=model_name) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break logger.info('finish training!')
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined model_name = args.model_name model_type = args.model_type loss_func = args.loss world_size = args.world_size rank = args.rank base_channel = args.base_channels crop_size = args.crop_size ignore_idx = args.ignore_idx epochs = args.epoch # system setup config_file = 'config.yaml' config = load_config(config_file) labels = config['PARAMETERS']['labels'] root_path = config['PATH']['model_root'] model_dir = config['PATH']['save_ckp'] best_dir = config['PATH']['save_best_model'] output_channels = int(config['PARAMETERS']['output_channels']) batch_size = int(config['PARAMETERS']['batch_size']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) time_step = int(config['PARAMETERS']['time_step']) num_workers = int(config['PARAMETERS']['num_workers']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) pad_method = config['PARAMETERS']['pad_method'] lr = int(config['PARAMETERS']['lr']) optimizer = config['PARAMETERS']['optimizer'] softmax = True modalities = ['flair', 't1', 't1gd', 't2'] input_modalites = len(modalities) # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_newdata', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('Dataset is loading ...') writer = SummaryWriter('ProcessVisu/%s' % model_name) logger.info('patch size: {}'.format(crop_size)) # load training set and validation set data_class = data_split() train, val, test = data_construction(data_class) train_dict = time_parser(train, time_patch=time_step) val_dict = time_parser(val, time_patch=time_step) # groups = 4 if model_type == 'UNet': net = init_U_Net(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'ResUNet': net = ResUNet(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'DResUNet': net = DResUNet(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'direct_concat': net = U_Net_direct_concat(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'Inception': net = Inception_UNet(input_modalites, output_channels, base_channel, softmax) elif model_type == 'Simple_Inception': net = Simplified_Inception_UNet(input_modalites, output_channels, base_channel, softmax) # device setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' net.to(device) # print model structure summary(net, input_size=(input_modalites, crop_size, crop_size, crop_size)) dummy_input = torch.rand(1, input_modalites, crop_size, crop_size, crop_size).to(device) writer.add_graph(net, (dummy_input, )) # loss and optimizer setup if loss_func == 'Dice' and softmax: criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx) elif loss_func == 'GDice' and softmax: criterion = GneralizedDiceLoss(labels=labels) elif loss_func == 'CrossEntropy': criterion = WeightedCrossEntropyLoss(labels=labels) if not softmax: criterion = nn.CrossEntropyLoss().cuda() else: raise NotImplementedError() if optimizer == 'adam': optimizer = optim.Adam(net.parameters()) elif optimizer == 'sgd': optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) # net, optimizer = amp.initialize(net, optimizer, opt_level='O1') if torch.cuda.device_count() > 1: logger.info('{} GPUs avaliable'.format(torch.cuda.device_count())) torch.distributed.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:38366', rank=rank, world_size=world_size) if distributed_is_initialized(): logger.info('distributed is initialized') net.to(device) net = nn.parallel.DistributedDataParallel(net) else: logger.info('data parallel') net = nn.DataParallel(net) net.to(device) min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'label_0_acc': [], 'label_1_acc': [], 'label_2_acc': [], 'label_3_acc': [], 'label_4_acc': [] } if is_resume: try: # open previous check points ckp_path = os.path.join(model_path, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) logger.info( 'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}' .format(train_info['label_0_acc'][-1], train_info['label_1_acc'][-1], train_info['label_2_acc'][-1], train_info['label_3_acc'][-1], train_info['label_4_acc'][-1])) # min_loss = float('Inf') except: logger.warning( 'No checkpoint available, strat training from scratch') # start training for epoch in range(start_epoch, epochs): # every epoch generate a new set of images train_set = data_loader(train_dict, batch_size=batch_size, key='train', num_works=num_workers, time_step=time_step, patch=crop_size, modalities=modalities, model_type='CNN') n_train = len(train_set) train_loader = train_set.load() val_set = data_loader(val_dict, batch_size=batch_size, key='val', num_works=num_workers, time_step=time_step, patch=crop_size, modalities=modalities, model_type='CNN') n_val = len(val_set) nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) logger.info('Dataset loading finished!') # setup to train mode net.train() running_loss = 0 dice_score_label_0 = 0 dice_score_label_1 = 0 dice_score_label_2 = 0 dice_score_label_3 = 0 dice_score_label_4 = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): images, segs = data['image'].to(device), data['seg'].to(device) if model_type == 'SkipDenseSeg' and not softmax: segs = segs.long() # combine the batch and time step batch, time, channel, z, y, x = images.shape images = images.view(-1, channel, z, y, x) segs = segs.view(-1, z, y, x) # zero the parameter gradients optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, segs) loss.backward() # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() optimizer.step() running_loss += loss.detach().item() _, preds = torch.max(outputs.data, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx=ignore_idx) dice_score_label_0 += dice_score['bg'] dice_score_label_1 += dice_score['csf'] dice_score_label_2 += dice_score['gm'] dice_score_label_3 += dice_score['wm'] dice_score_label_4 += dice_score['tm'] # show progress bar pbar.set_postfix( **{ 'Training loss': loss.detach().item(), 'Training accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) del images, segs global_step += 1 if global_step % nb_batches == 0: net.eval() val_loss, val_acc, val_info = validation( net, val_set, criterion, device, batch_size, model_type=model_type, softmax=softmax, ignore_idx=ignore_idx) train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['label_0_acc'].append(dice_score_label_0 / nb_batches) train_info['label_1_acc'].append(dice_score_label_1 / nb_batches) train_info['label_2_acc'].append(dice_score_label_2 / nb_batches) train_info['label_3_acc'].append(dice_score_label_3 / nb_batches) train_info['label_4_acc'].append(dice_score_label_4 / nb_batches) # save bast trained model if model_type == 'SkipDenseSeg': scheduler.step() else: scheduler.step(val_loss) # debug for param_group in optimizer.param_groups: logger.info('%0.6f | %6d ' % (param_group['lr'], epoch)) if min_loss > running_loss / nb_batches + 1e-2: min_loss = running_loss / nb_batches is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 # save the check point state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) # summarize the training results of this epoch logger.info('Average training loss of this epoch is {}'.format( running_loss / nb_batches)) logger.info('Best training loss till now is {}'.format(min_loss)) logger.info('Validation dice loss: {}; Validation accuracy: {}'.format( val_loss, val_acc)) # save the training info every epoch logger.info('Writing the training info into file ...') val_info_file = os.path.join(intermidiate_data_save, '{}_val_info.json'.format(model_name)) with open(train_info_file, 'w') as fp: json.dump(train_info, fp) with open(val_info_file, 'w') as fp: json.dump(val_info, fp) loss_plot(train_info_file, name=model_name) for name, layer in net.named_parameters(): writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), epoch) writer.add_histogram(name + '_data', layer.cpu().data.numpy(), epoch) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break writer.close() logger.info('finish training!')
optimizerI.step() print("Epoch"+str(epoch),"Step"+str(step),abs(err.item()),abs(t_c_loss.item())) if(step%200==0): checkpoint_dynamic = { 'epoch': epoch + 1, 'state_dict': net_dynamic.state_dict(), 'optimizer': optimizerD.state_dict(), } checkpoint_inpainter = { 'epoch': epoch + 1, 'state_dict': net_impainter.state_dict(), 'optimizer': optimizerI.state_dict(), } save_ckp(checkpoint_dynamic, checkpoint_dynamic_path+"checkpoint_"+str(epoch+1)+".pt") save_ckp(checkpoint_inpainter, checkpoint_inpainter_path+"checkpoint_"+str(epoch+1)+".pt") step+=1 # break checkpoint_dynamic = { 'epoch': epoch + 1, 'state_dict': net_dynamic.state_dict(), 'optimizer': optimizerD.state_dict(), } checkpoint_inpainter = { 'epoch': epoch + 1, 'state_dict': net_impainter.state_dict(),