def main(): """ Image Classification Prediction """ cli_args = get_test_args(__author__, __version__) # Variables image_path = cli_args.input checkpoint_path = cli_args.checkpoint top_k = cli_args.top_k categories_names = cli_args.categories_names # LOAD THE PRE-TRAINED MODEL model = load_ckp(checkpoint_path, optimizer=None) # PREDICT THE TOP_K PROBABILITY AND ITS CORRESPONDING CLASS FROM WHICH IT IS BELONG probs, classes = predict(image_path, model, top_k) # Check the categories file if not os.path.isfile(categories_names): print(f'Categories file {categories_names} was not found.') exit(1) # Label mapping with open(categories_names, 'r') as f: cat_to_name = json.load(f) class_names = [cat_to_name[idx] for idx in classes] # Display prediction data = pd.DataFrame({' Classes': classes, ' Flower': class_names, 'Probability': probs }) data = data.sort_values('Probability', ascending = False) print('The item identified in the image file is:') print(data)
def predict_use(args): model_name = args.model_name patient_path = args.patient_path config_file = 'config.yaml' cfg = load_config(config_file) input_modalites = int(cfg['PARAMETERS']['input_modalites']) output_channels = int(cfg['PARAMETERS']['output_channels']) base_channels = int(cfg['PARAMETERS']['base_channels']) patience = int(cfg['PARAMETERS']['patience']) ROOT = cfg['PATH']['root'] best_dir = cfg['PATH']['best_model_path'] best_model_dir = os.path.join(ROOT, best_dir) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # load best trained model net = init_U_Net(input_modalites, output_channels, base_channels) net.to(device) optimizer = optim.Adam(net.parameters()) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) ckp_path = os.path.join(best_model_dir, model_name + '_best_model.pth.tar') net, _, _, _, _, _ = load_ckp(ckp_path, net, optimizer, scheduler) # predict predict(net, model_name, patient_path, ROOT, save_mask=True)
def main(): os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus test_image_dataset = image_preprocessing(opt.dataset, 'val') data_loader = DataLoader(test_image_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers) G = Generator(ResidualBlock, layer_count=9) F = Generator(ResidualBlock, layer_count=9) if torch.cuda.is_available(): G = nn.DataParallel(G) F = nn.DataParallel(F) G = G.cuda() F = F.cuda() G, F, _, _, _, _, _, _, _ = load_ckp(opt.model_path, G, F) G.eval() F.eval() if not os.path.exists(opt.save_path): os.mkdir(opt.save_path) for step, data in enumerate(tqdm(data_loader)): real_A = to_variable(data['A']) real_B = to_variable(data['B']) fake_B = G(real_A) fake_A = F(real_B) batch_image = torch.cat((torch.cat( (real_A, real_B), 3), torch.cat((fake_A, fake_B), 3)), 2) for i in range(batch_image.shape[0]): torchvision.utils.save_image( denorm(batch_image[i]), opt.save_path + '{result_name}_{step}.jpg'.format( result_name=opt.result_name, step=step * opt.batch_size + i))
def main(_config,_run): logger = _run SAVE_NAME = _config['SAVE_NAME'] LOAD_SAVED_MODEL = _config['LOAD_SAVED_MODEL'] MODEL_PATH_FINAL = _config['MODEL_PATH_FINAL'] total_steps = 1000000 params = common.HYPERPARAMS['gamePlay2'] params['epsilon_frames'] *= 2 parser = argparse.ArgumentParser() parser.add_argument("--cuda", default=False, action="store_true", help="Enable cuda") args = parser.parse_args() env = gym.make(params['env_name'],glob_conf=_config,logger=logger) #env = ptan.common.wrappers.wrap_dqn(env) writer = SummaryWriter(comment="-" + params['run_name'] + "-rainbow-beta200") net = RainbowDQN(env.observation_space.shape, env.action_space.n).to(device) #net.load_state_dict(torch.load( )) name_load = current_path +"/models" +MODEL_PATH_FINAL if _config['LOAD_SAVED_MODEL']: mdl, opt, lss = load_ckp(MODEL_PATH_FINAL, net, optimizer) net = mdl optimizer = opt tgt_net = ptan.agent.TargetNet(net) agent = ptan.agent.DQNAgent(lambda x: net.qvals(x), ptan.actions.ArgmaxActionSelector(), device=device) # change the step_counts to change multi step prediction exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=params['gamma'], steps_count=REWARD_STEPS) buffer = ptan.experience.PrioritizedReplayBuffer(exp_source, params['replay_size'], PRIO_REPLAY_ALPHA) optimizer = optim.Adam(net.parameters(), lr=params['learning_rate']) today = datetime.datetime.now() todays_date_full = str(today.year) + "_" + str(today.month) + "_" + str(today.day) + "_" todays_date_full += str(today.hour) + "_" + str(today.minute) + "_" + str(today.second) folder_name = todays_date_full +"_"+experiment_name results_dir = current_path + "/results/" + folder_name results_dir_weights = results_dir + "/weights" os.makedirs(results_dir) os.makedirs(results_dir_weights) frame_idx = 0 beta = BETA_START best_mean_reward = 0.0 eval_states = None with common.RewardTracker(writer, params['stop_reward']) as reward_tracker: while frame_idx < total_steps: frame_idx += 1 buffer.populate(1) beta = min(1.0, BETA_START + frame_idx * (1.0 - BETA_START) / BETA_FRAMES) new_rewards = exp_source.pop_total_rewards() if new_rewards: # start saving the model after actual training begins if frame_idx > 100: if best_mean_reward is None or best_mean_reward < reward_tracker.mean_reward: torch.save(net.state_dict(), SAVE_NAME + "-best.dat") if best_mean_reward is not None: print("Best mean reward updated %.3f -> %.3f, model saved" % \ (best_mean_reward, reward_tracker.mean_reward)) if not reward_tracker.mean_reward == 0: best_mean_reward = reward_tracker.mean_reward if reward_tracker.reward(new_rewards[0], frame_idx): break if len(buffer) < params['replay_initial']: continue if eval_states is None: eval_states, _, _ = buffer.sample(STATES_TO_EVALUATE, beta) eval_states = [np.array(transition.state, copy=False) for transition in eval_states] eval_states = np.array(eval_states, copy=False) optimizer.zero_grad() batch, batch_indices, batch_weights = buffer.sample(params['batch_size'], beta) loss_v, sample_prios_v = calc_loss(batch, batch_weights, net, tgt_net.target_model, params['gamma'] ** REWARD_STEPS, device=device) # if frame_idx % 10000 == 0: if frame_idx % 5000 == 0: checkpoint = ({ 'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': loss_v, 'num_step': frame_idx }) torch.save(checkpoint, results_dir_weights + "/rainbow" + str(frame_idx) + "step.dat") # Save network parameters as histogram for name, param in net.named_parameters(): writer.add_histogram(name, param.clone().cpu().data.numpy(), frame_idx) loss_v.backward() optimizer.step() buffer.update_priorities(batch_indices, sample_prios_v.data.cpu().numpy()) if frame_idx % params['target_net_sync'] == 0: tgt_net.sync() if logger: loss_v.item() logger.log_scalar("loss", loss_v.item()) logger.log_scalar("mean_reward", reward_tracker.mean_reward)
def main(ckp_path=None): """ckp_path (str): checkpoint_path Train the model from scratch if ckp_path is None else Re-Train the model from previous checkpoint """ cli_args = get_train_args(__author__, __version__) # Variables data_dir = cli_args.data_dir save_dir = cli_args.save_dir file_name = cli_args.file_name use_gpu = cli_args.use_gpu # LOAD DATA data_loaders = load_data(data_dir, config.IMG_SIZE, config.BATCH_SIZE) # BUILD MODEL if ckp_path == None: model = initialize_model(model_name=config.MODEL_NAME, num_classes=config.NO_OF_CLASSES, feature_extract=True, use_pretrained=True) else: model = load_ckp(ckp_path) # Device is available or not device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # If the user wants the gpu mode, check if cuda is available if (use_gpu == True) and (torch.cuda.is_available() == False): print("GPU mode is not available, using CPU...") use_gpu = False # MOVE MODEL TO AVAILBALE DEVICE model.to(device) # DEFINE OPTIMIZER optimizer = optimizer_fn(model_name=config.MODEL_NAME, model=model, lr_rate=config.LR_RATE) # DEFINE SCHEDULER scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=5, factor=0.3, verbose=True) # DEFINE LOSS FUNCTION criterion = loss_fn() # LOAD BEST MODEL'S WEIGHTS best_model_wts = copy.deepcopy(model.state_dict()) # BEST VALIDATION SCORE if ckp_path == None: best_score = -1 # IF MODEL IS TRAIN FROM SCRATCH else: best_score = model.best_score # IF MODEL IS RE-TRAIN # NO OF ITERATION no_epochs = config.EPOCHS # KEEP TRACK OF LOSS AND ACCURACY IN EACH EPOCH stats = { 'train_losses': [], 'valid_losses': [], 'train_accuracies': [], 'valid_accuracies': [] } print("Models's Training Start......") for epoch in range(1, no_epochs + 1): train_loss, train_score = train_fn(data_loaders, model, optimizer, criterion, device, phase='train') val_loss, val_score = eval_fn(data_loaders, model, criterion, device=config.DEVICE, phase='valid') scheduler.step(val_loss) # SAVE MODEL'S WEIGHTS IF MODEL' VALIDATION ACCURACY IS INCREASED if val_score > best_score: print( 'Validation score increased ({:.6f} --> {:.6f}). Saving model ...' .format(best_score, val_score)) best_score = val_score best_model_wts = copy.deepcopy( model.state_dict()) #Saving the best model' weights # MAKE A RECORD OF AVERAGE LOSSES AND ACCURACY IN EACH EPOCH FOR PLOTING stats['train_losses'].append(train_loss) stats['valid_losses'].append(val_loss) stats['train_accuracies'].append(train_score) stats['valid_accuracies'].append(val_score) # PRINT TRAINING AND VALIDATION LOOS/ACCURACIES AFTER EACH EPOCH epoch_len = len(str(no_epochs)) print_msg = (f'[{epoch:>{epoch_len}}/{no_epochs:>{epoch_len}}] ' + '\t' + f'train_loss: {train_loss:.5f} ' + '\t' + f'train_score: {train_score:.5f} ' + '\t' + f'valid_loss: {val_loss:.5f} ' + '\t' + f'valid_score: {val_score:.5f}') print(print_msg) # load best model weights model.load_state_dict(best_model_wts) # create checkpoint variable and add important data model.class_to_idx = data_loaders['train'].dataset.class_to_idx model.best_score = best_score model.model_name = config.MODEL_NAME checkpoint = { 'epoch': no_epochs, 'lr_rate': config.LR_RATE, 'model_name': config.MODEL_NAME, 'batch_size': config.BATCH_SIZE, 'valid_score': best_score, 'optimizer': optimizer.state_dict(), 'state_dict': model.state_dict(), 'class_to_idx': model.class_to_idx } # SAVE CHECKPOINT save_ckp(checkpoint, save_dir, file_name) print("Models's Training is Successfull......") return model
from modules.models import get_model from utils import load_ckp import torch.onnx import onnx import torch OPSET_VERSION = 8 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = get_model(cfg.model_name, embeddings_size=cfg.embeddings_size) model = model.to(device) model.effnet.set_swish(False) print(f'Load checkpoint : {cfg.WEIGHTS_LOAD_PATH}') model, _, _, _, _, _ = load_ckp('../' + cfg.WEIGHTS_LOAD_PATH, model, remove_module=True) model.eval() with torch.no_grad(): # Input to the model x = torch.randn(10, 3, 48, 48).to(device).float() print('Start onnx conversion') print(f'Save model to : {cfg.WEIGHTS_SAVE_PATH}') torch.onnx.export(model, x, os.path.join('..', cfg.WEIGHTS_SAVE_PATH, 'model.onnx'), opset_version=OPSET_VERSION, verbose=False, export_params=True, input_names=['input'],
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined parameters model_name = args.model_name model_type = args.model_type lstm_backbone = args.lstmbase unet_backbone = args.unetbase layer_num = args.layer_num nb_shortcut = args.nb_shortcut loss_fn = args.loss_fn world_size = args.world_size rank = args.rank base_channel = args.base_channels crop_size = args.crop_size ignore_idx = args.ignore_idx return_sequence = args.return_sequence variant = args.LSTM_variant epochs = args.epoch is_pretrain = args.is_pretrain # system setup parameters config_file = 'config.yaml' config = load_config(config_file) labels = config['PARAMETERS']['labels'] root_path = config['PATH']['model_root'] model_dir = config['PATH']['save_ckp'] best_dir = config['PATH']['save_best_model'] input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) batch_size = int(config['PARAMETERS']['batch_size']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) time_step = int(config['PARAMETERS']['time_step']) num_workers = int(config['PARAMETERS']['num_workers']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) lr = int(config['PARAMETERS']['lr']) optimizer = config['PARAMETERS']['optimizer'] connect = config['PARAMETERS']['connect'] conv_type = config['PARAMETERS']['lstm_convtype'] # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_newdata', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('labels {} are ignored'.format(ignore_idx)) logger.info('Dataset is loading ...') writer = SummaryWriter('ProcessVisu/%s' % model_name) # load training set and validation set data_class = data_split() train, val, test = data_construction(data_class) train_dict = time_parser(train, time_patch=time_step) val_dict = time_parser(val, time_patch=time_step) # LSTM initilization if model_type == 'LSTM': net = LSTMSegNet(lstm_backbone=lstm_backbone, input_dim=input_modalites, output_dim=output_channels, hidden_dim=base_channel, kernel_size=3, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence) elif model_type == 'UNet_LSTM': if variant == 'back': net = BackLSTM(input_dim=input_modalites, hidden_dim=base_channel, output_dim=output_channels, kernel_size=3, num_layers=layer_num, conv_type=conv_type, lstm_backbone=lstm_backbone, unet_module=unet_backbone, base_channel=base_channel, return_sequence=return_sequence, is_pretrain=is_pretrain) logger.info( 'the pretrained status of backbone is {}'.format(is_pretrain)) elif variant == 'center': net = CenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'bicenter': net = BiCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'directcenter': net = DirectCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'bidirectcenter': net = BiDirectCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'rescenter': net = ResCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'birescenter': net = BiResCenterLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, connect=connect, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) elif variant == 'shortcut': net = ShortcutLSTM(input_modalites=input_modalites, output_channels=output_channels, base_channel=base_channel, num_layers=layer_num, num_connects=nb_shortcut, conv_type=conv_type, return_sequence=return_sequence, is_pretrain=is_pretrain) else: raise NotImplementedError() # loss and optimizer setup if loss_fn == 'Dice': criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx) elif loss_fn == 'GDice': criterion = GneralizedDiceLoss(labels=labels) elif loss_fn == 'WCE': criterion = WeightedCrossEntropyLoss(labels=labels) else: raise NotImplementedError() if optimizer == 'adam': optimizer = optim.Adam(net.parameters(), lr=0.001) # optimizer = optim.Adam(net.parameters()) elif optimizer == 'sgd': optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) # device setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # net, optimizer = amp.initialize(net, optimizer, opt_level="O1") if torch.cuda.device_count() > 1: torch.distributed.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:38366', rank=rank, world_size=world_size) if distributed_is_initialized(): print('distributed is initialized') net.to(device) net = nn.parallel.DistributedDataParallel(net, find_unused_parameters=True) else: print('data parallel') net = nn.DataParallel(net) net.to(device) min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'label_0_acc': [], 'label_1_acc': [], 'label_2_acc': [], 'label_3_acc': [], 'label_4_acc': [] } if is_resume: try: # open previous check points ckp_path = os.path.join(model_path, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) logger.info( 'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}' .format(train_info['label_0_acc'][-1], train_info['label_1_acc'][-1], train_info['label_2_acc'][-1], train_info['label_3_acc'][-1], train_info['label_4_acc'][-1])) except: logger.warning( 'No checkpoint available, strat training from scratch') for epoch in range(start_epoch, epochs): train_set = data_loader(train_dict, batch_size=batch_size, key='train', num_works=num_workers, time_step=time_step, patch=crop_size, model_type='RNN') n_train = len(train_set) val_set = data_loader(val_dict, batch_size=batch_size, key='val', num_works=num_workers, time_step=time_step, patch=crop_size, model_type='CNN') n_val = len(val_set) logger.info('Dataset loading finished!') nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) train_loader = train_set.load() # setup to train mode net.train() running_loss = 0 dice_score_label_0 = 0 dice_score_label_1 = 0 dice_score_label_2 = 0 dice_score_label_3 = 0 dice_score_label_4 = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): # i : patient images, segs = data['image'].to(device), data['seg'].to(device) outputs = net(images) loss = criterion(outputs, segs) loss.backward() optimizer.step() # if i == 0: # in_images = images.detach().cpu().numpy()[0] # in_segs = segs.detach().cpu().numpy()[0] # in_pred = outputs.detach().cpu().numpy()[0] # heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch+1, is_train=True) running_loss += loss.detach().item() outputs = outputs.view(-1, outputs.shape[-4], outputs.shape[-3], outputs.shape[-2], outputs.shape[-1]) segs = segs.view(-1, segs.shape[-3], segs.shape[-2], segs.shape[-1]) _, preds = torch.max(outputs.data, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx=None) dice_score_label_0 += dice_score['bg'] dice_score_label_1 += dice_score['csf'] dice_score_label_2 += dice_score['gm'] dice_score_label_3 += dice_score['wm'] dice_score_label_4 += dice_score['tm'] # show progress bar pbar.set_postfix( **{ 'training loss': loss.detach().item(), 'Training accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) global_step += 1 if global_step % nb_batches == 0: net.eval() val_loss, val_acc, val_info = validation(net, val_set, criterion, device, batch_size, ignore_idx=None, name=model_name, epoch=epoch + 1) net.train() train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['label_0_acc'].append(dice_score_label_0 / nb_batches) train_info['label_1_acc'].append(dice_score_label_1 / nb_batches) train_info['label_2_acc'].append(dice_score_label_2 / nb_batches) train_info['label_3_acc'].append(dice_score_label_3 / nb_batches) train_info['label_4_acc'].append(dice_score_label_4 / nb_batches) # save bast trained model scheduler.step(running_loss / nb_batches) logger.info('Epoch: {}, LR: {}'.format( epoch + 1, optimizer.param_groups[0]['lr'])) if min_loss > running_loss / nb_batches: min_loss = running_loss / nb_batches is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) # summarize the training results of this epoch logger.info('The average training loss for this epoch is {}'.format( running_loss / nb_batches)) logger.info('The best training loss till now is {}'.format(min_loss)) logger.info( 'Validation dice loss: {}; Validation (avg) accuracy of the last timestep: {}' .format(val_loss, val_acc)) # save the training info every epoch logger.info('Writing the training info into file ...') val_info_file = os.path.join(intermidiate_data_save, '{}_val_info.json'.format(model_name)) with open(train_info_file, 'w') as fp: json.dump(train_info, fp) with open(val_info_file, 'w') as fp: json.dump(val_info, fp) for name, layer in net.named_parameters(): if layer.requires_grad: writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), epoch) writer.add_histogram(name + '_data', layer.cpu().data.numpy(), epoch) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break loss_plot(train_info_file, name=model_name) logger.info('finish training!') return
def main(): os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus start_epoch = 0 train_image_dataset = image_preprocessing(opt.dataset, 'train') data_loader = DataLoader(train_image_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) criterion = least_squares euclidean_l1 = nn.L1Loss() G = Generator(ResidualBlock, layer_count=9) F = Generator(ResidualBlock, layer_count=9) Dx = Discriminator() Dy = Discriminator() G_optimizer = optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) F_optimizer = optim.Adam(F.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) Dx_optimizer = optim.Adam(Dx.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) Dy_optimizer = optim.Adam(Dy.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) if torch.cuda.is_available(): G = nn.DataParallel(G) F = nn.DataParallel(F) Dx = nn.DataParallel(Dx) Dy = nn.DataParallel(Dy) G = G.cuda() F = F.cuda() Dx = Dx.cuda() Dy = Dy.cuda() if opt.checkpoint is not None: G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer, start_epoch = load_ckp(opt.checkpoint, G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer) print('[Start] : Cycle GAN Training') logger = Logger(opt.epochs, len(data_loader), image_step=10) for epoch in range(opt.epochs): epoch = epoch + start_epoch + 1 print("Epoch[{epoch}] : Start".format(epoch=epoch)) for step, data in enumerate(data_loader): real_A = to_variable(data['A']) real_B = to_variable(data['B']) fake_B = G(real_A) fake_A = F(real_B) # Train Dx Dx_optimizer.zero_grad() Dx_real = Dx(real_A) Dx_fake = Dx(fake_A) Dx_loss = patch_loss(criterion, Dx_real, True) + patch_loss(criterion, Dx_fake, 0) Dx_loss.backward(retain_graph=True) Dx_optimizer.step() # Train Dy Dy_optimizer.zero_grad() Dy_real = Dy(real_B) Dy_fake = Dy(fake_B) Dy_loss = patch_loss(criterion, Dy_real, True) + patch_loss(criterion, Dy_fake, 0) Dy_loss.backward(retain_graph=True) Dy_optimizer.step() # Train G G_optimizer.zero_grad() Dy_fake = Dy(fake_B) G_loss = patch_loss(criterion, Dy_fake, True) # Train F F_optimizer.zero_grad() Dx_fake = Dx(fake_A) F_loss = patch_loss(criterion, Dx_fake, True) # identity loss loss_identity = euclidean_l1(real_A, fake_A) + euclidean_l1(real_B, fake_B) # cycle consistency loss_cycle = euclidean_l1(F(fake_B), real_A) + euclidean_l1(G(fake_A), real_B) # Optimize G & F loss = G_loss + F_loss + opt.lamda * loss_cycle + opt.lamda * loss_identity * (0.5) loss.backward() G_optimizer.step() F_optimizer.step() if (step + 1 ) % opt.save_step == 0: print("Epoch[{epoch}]| Step [{now}/{total}]| Dx Loss: {Dx_loss}, Dy_Loss: {Dy_loss}, G_Loss: {G_loss}, F_Loss: {F_loss}".format( epoch=epoch, now=step + 1, total=len(data_loader), Dx_loss=Dx_loss.item(), Dy_loss=Dy_loss, G_loss=G_loss.item(), F_loss=F_loss.item())) batch_image = torch.cat((torch.cat((real_A, real_B), 3), torch.cat((fake_A, fake_B), 3)), 2) torchvision.utils.save_image(denorm(batch_image[0]), opt.training_result + 'result_{result_name}_ep{epoch}_{step}.jpg'.format(result_name=opt.result_name,epoch=epoch, step=(step + 1) * opt.batch_size)) # http://localhost:8097 logger.log( losses={ 'loss_G': G_loss, 'loss_F': F_loss, 'loss_identity': loss_identity, 'loss_cycle': loss_cycle, 'total_G_loss': loss, 'loss_Dx': Dx_loss, 'loss_Dy': Dy_loss, 'total_D_loss': (Dx_loss + Dy_loss), }, images={ 'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_ B': fake_B, }, ) torch.save({ 'epoch': epoch, 'G_model': G.state_dict(), 'G_optimizer': G_optimizer.state_dict(), 'F_model': F.state_dict(), 'F_optimizer': F_optimizer.state_dict(), 'Dx_model': Dx.state_dict(), 'Dx_optimizer': Dx_optimizer.state_dict(), 'Dy_model': Dy.state_dict(), 'Dy_optimizer': Dy_optimizer.state_dict(), }, opt.save_model + 'model_{result_name}_CycleGAN_ep{epoch}.ckp'.format(result_name=opt.result_name, epoch=epoch))
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined model_name = args.model_name model_loss_fn = args.loss_fn config_file = 'config.yaml' config = load_config(config_file) data_root = config['PATH']['data_root'] labels = config['PARAMETERS']['labels'] root_path = config['PATH']['root'] model_dir = config['PATH']['model_path'] best_dir = config['PATH']['best_model_path'] data_class = config['PATH']['data_class'] input_modalites = int(config['PARAMETERS']['input_modalites']) output_channels = int(config['PARAMETERS']['output_channels']) base_channel = int(config['PARAMETERS']['base_channels']) crop_size = int(config['PARAMETERS']['crop_size']) batch_size = int(config['PARAMETERS']['batch_size']) epochs = int(config['PARAMETERS']['epoch']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) ignore_idx = int(config['PARAMETERS']['ignore_index']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_data', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('Dataset is loading ...') # split dataset dir_ = os.path.join(data_root, data_class) data_content = train_split(dir_) # load training set and validation set train_set = data_loader(data_content=data_content, key='train', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) n_train = len(train_set) train_loader = train_set.load() val_set = data_loader(data_content=data_content, key='val', form='LGG', crop_size=crop_size, batch_size=batch_size, num_works=8) logger.info('Dataset loading finished!') n_val = len(val_set) nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) net = init_U_Net(input_modalites, output_channels, base_channel) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' if torch.cuda.device_count() > 1: logger.info('{} GPUs available.'.format(torch.cuda.device_count())) net = nn.DataParallel(net) net.to(device) if model_loss_fn == 'Dice': criterion = DiceLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'CrossEntropy': criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'FocalLoss': criterion = FocalLoss(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_CE': criterion = Dice_CE(labels=labels, ignore_index=ignore_idx) elif model_loss_fn == 'Dice_FL': criterion = Dice_FL(labels=labels, ignore_index=ignore_idx) else: raise NotImplementedError() optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) net, optimizer = amp.initialize(net, optimizer, opt_level='O1') min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'BG_acc': [], 'NET_acc': [], 'ED_acc': [], 'ET_acc': [] } if is_resume: try: ckp_path = os.path.join(model_dir, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) except: logger.warning( 'No checkpoint available, strat training from scratch') # start training for epoch in range(start_epoch, epochs): # setup to train mode net.train() running_loss = 0 dice_coeff_bg = 0 dice_coeff_net = 0 dice_coeff_ed = 0 dice_coeff_et = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): images, segs = data['image'].to(device), data['seg'].to(device) # zero the parameter gradients optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, segs) # loss.backward() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() # save the output at the begining of each epoch to visulize it if i == 0: in_images = images.detach().cpu().numpy()[:, 0, ...] in_segs = segs.detach().cpu().numpy() in_pred = outputs.detach().cpu().numpy() heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch + 1) running_loss += loss.detach().item() dice_score = dice_coe(outputs.detach().cpu(), segs.detach().cpu()) dice_coeff_bg += dice_score['BG'] dice_coeff_ed += dice_score['ED'] dice_coeff_et += dice_score['ET'] dice_coeff_net += dice_score['NET'] # show progress bar pbar.set_postfix( **{ 'Training loss': loss.detach().item(), 'Training (avg) accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) global_step += 1 if global_step % nb_batches == 0: # validate net.eval() val_loss, val_acc = validation(net, val_set, criterion, device, batch_size) train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['BG_acc'].append(dice_coeff_bg / nb_batches) train_info['NET_acc'].append(dice_coeff_net / nb_batches) train_info['ED_acc'].append(dice_coeff_ed / nb_batches) train_info['ET_acc'].append(dice_coeff_et / nb_batches) # save bast trained model scheduler.step(running_loss / nb_batches) if min_loss > val_loss: min_loss = val_loss is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) logger.info('The average training loss for this epoch is {}'.format( running_loss / (np.ceil(n_train / batch_size)))) logger.info( 'Validation dice loss: {}; Validation (avg) accuracy: {}'.format( val_loss, val_acc)) logger.info('The best validation loss till now is {}'.format(min_loss)) # save the training info every epoch logger.info('Writing the training info into file ...') with open(train_info_file, 'w') as fp: json.dump(train_info, fp) loss_plot(train_info_file, name=model_name) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break logger.info('finish training!')
def train(args): torch.cuda.manual_seed(1) torch.manual_seed(1) # user defined model_name = args.model_name model_type = args.model_type loss_func = args.loss world_size = args.world_size rank = args.rank base_channel = args.base_channels crop_size = args.crop_size ignore_idx = args.ignore_idx epochs = args.epoch # system setup config_file = 'config.yaml' config = load_config(config_file) labels = config['PARAMETERS']['labels'] root_path = config['PATH']['model_root'] model_dir = config['PATH']['save_ckp'] best_dir = config['PATH']['save_best_model'] output_channels = int(config['PARAMETERS']['output_channels']) batch_size = int(config['PARAMETERS']['batch_size']) is_best = bool(config['PARAMETERS']['is_best']) is_resume = bool(config['PARAMETERS']['resume']) patience = int(config['PARAMETERS']['patience']) time_step = int(config['PARAMETERS']['time_step']) num_workers = int(config['PARAMETERS']['num_workers']) early_stop_patience = int(config['PARAMETERS']['early_stop_patience']) pad_method = config['PARAMETERS']['pad_method'] lr = int(config['PARAMETERS']['lr']) optimizer = config['PARAMETERS']['optimizer'] softmax = True modalities = ['flair', 't1', 't1gd', 't2'] input_modalites = len(modalities) # build up dirs model_path = os.path.join(root_path, model_dir) best_path = os.path.join(root_path, best_dir) intermidiate_data_save = os.path.join(root_path, 'train_newdata', model_name) train_info_file = os.path.join(intermidiate_data_save, '{}_train_info.json'.format(model_name)) log_path = os.path.join(root_path, 'logfiles') if not os.path.exists(model_path): os.mkdir(model_path) if not os.path.exists(best_path): os.mkdir(best_path) if not os.path.exists(intermidiate_data_save): os.makedirs(intermidiate_data_save) if not os.path.exists(log_path): os.mkdir(log_path) log_name = model_name + '_' + config['PATH']['log_file'] logger = logfile(os.path.join(log_path, log_name)) logger.info('Dataset is loading ...') writer = SummaryWriter('ProcessVisu/%s' % model_name) logger.info('patch size: {}'.format(crop_size)) # load training set and validation set data_class = data_split() train, val, test = data_construction(data_class) train_dict = time_parser(train, time_patch=time_step) val_dict = time_parser(val, time_patch=time_step) # groups = 4 if model_type == 'UNet': net = init_U_Net(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'ResUNet': net = ResUNet(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'DResUNet': net = DResUNet(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'direct_concat': net = U_Net_direct_concat(input_modalites, output_channels, base_channel, pad_method, softmax) elif model_type == 'Inception': net = Inception_UNet(input_modalites, output_channels, base_channel, softmax) elif model_type == 'Simple_Inception': net = Simplified_Inception_UNet(input_modalites, output_channels, base_channel, softmax) # device setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' net.to(device) # print model structure summary(net, input_size=(input_modalites, crop_size, crop_size, crop_size)) dummy_input = torch.rand(1, input_modalites, crop_size, crop_size, crop_size).to(device) writer.add_graph(net, (dummy_input, )) # loss and optimizer setup if loss_func == 'Dice' and softmax: criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx) elif loss_func == 'GDice' and softmax: criterion = GneralizedDiceLoss(labels=labels) elif loss_func == 'CrossEntropy': criterion = WeightedCrossEntropyLoss(labels=labels) if not softmax: criterion = nn.CrossEntropyLoss().cuda() else: raise NotImplementedError() if optimizer == 'adam': optimizer = optim.Adam(net.parameters()) elif optimizer == 'sgd': optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr, weight_decay=1e-5) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=patience) # net, optimizer = amp.initialize(net, optimizer, opt_level='O1') if torch.cuda.device_count() > 1: logger.info('{} GPUs avaliable'.format(torch.cuda.device_count())) torch.distributed.init_process_group( backend='nccl', init_method='tcp://127.0.0.1:38366', rank=rank, world_size=world_size) if distributed_is_initialized(): logger.info('distributed is initialized') net.to(device) net = nn.parallel.DistributedDataParallel(net) else: logger.info('data parallel') net = nn.DataParallel(net) net.to(device) min_loss = float('Inf') early_stop_count = 0 global_step = 0 start_epoch = 0 start_loss = 0 train_info = { 'train_loss': [], 'val_loss': [], 'label_0_acc': [], 'label_1_acc': [], 'label_2_acc': [], 'label_3_acc': [], 'label_4_acc': [] } if is_resume: try: # open previous check points ckp_path = os.path.join(model_path, '{}_model_ckp.pth.tar'.format(model_name)) net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp( ckp_path, net, optimizer, scheduler) # open previous training records with open(train_info_file) as f: train_info = json.load(f) logger.info( 'Training loss from last time is {}'.format(start_loss) + '\n' + 'Mininum training loss from last time is {}'.format(min_loss)) logger.info( 'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}' .format(train_info['label_0_acc'][-1], train_info['label_1_acc'][-1], train_info['label_2_acc'][-1], train_info['label_3_acc'][-1], train_info['label_4_acc'][-1])) # min_loss = float('Inf') except: logger.warning( 'No checkpoint available, strat training from scratch') # start training for epoch in range(start_epoch, epochs): # every epoch generate a new set of images train_set = data_loader(train_dict, batch_size=batch_size, key='train', num_works=num_workers, time_step=time_step, patch=crop_size, modalities=modalities, model_type='CNN') n_train = len(train_set) train_loader = train_set.load() val_set = data_loader(val_dict, batch_size=batch_size, key='val', num_works=num_workers, time_step=time_step, patch=crop_size, modalities=modalities, model_type='CNN') n_val = len(val_set) nb_batches = np.ceil(n_train / batch_size) n_total = n_train + n_val logger.info( '{} images will be used in total, {} for trainning and {} for validation' .format(n_total, n_train, n_val)) logger.info('Dataset loading finished!') # setup to train mode net.train() running_loss = 0 dice_score_label_0 = 0 dice_score_label_1 = 0 dice_score_label_2 = 0 dice_score_label_3 = 0 dice_score_label_4 = 0 logger.info('Training epoch {} will begin'.format(epoch + 1)) with tqdm(total=n_train, desc=f'Epoch {epoch+1}/{epochs}', unit='patch') as pbar: for i, data in enumerate(train_loader, 0): images, segs = data['image'].to(device), data['seg'].to(device) if model_type == 'SkipDenseSeg' and not softmax: segs = segs.long() # combine the batch and time step batch, time, channel, z, y, x = images.shape images = images.view(-1, channel, z, y, x) segs = segs.view(-1, z, y, x) # zero the parameter gradients optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, segs) loss.backward() # with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() optimizer.step() running_loss += loss.detach().item() _, preds = torch.max(outputs.data, 1) dice_score = dice(preds.data.cpu(), segs.data.cpu(), ignore_idx=ignore_idx) dice_score_label_0 += dice_score['bg'] dice_score_label_1 += dice_score['csf'] dice_score_label_2 += dice_score['gm'] dice_score_label_3 += dice_score['wm'] dice_score_label_4 += dice_score['tm'] # show progress bar pbar.set_postfix( **{ 'Training loss': loss.detach().item(), 'Training accuracy': dice_score['avg'] }) pbar.update(images.shape[0]) del images, segs global_step += 1 if global_step % nb_batches == 0: net.eval() val_loss, val_acc, val_info = validation( net, val_set, criterion, device, batch_size, model_type=model_type, softmax=softmax, ignore_idx=ignore_idx) train_info['train_loss'].append(running_loss / nb_batches) train_info['val_loss'].append(val_loss) train_info['label_0_acc'].append(dice_score_label_0 / nb_batches) train_info['label_1_acc'].append(dice_score_label_1 / nb_batches) train_info['label_2_acc'].append(dice_score_label_2 / nb_batches) train_info['label_3_acc'].append(dice_score_label_3 / nb_batches) train_info['label_4_acc'].append(dice_score_label_4 / nb_batches) # save bast trained model if model_type == 'SkipDenseSeg': scheduler.step() else: scheduler.step(val_loss) # debug for param_group in optimizer.param_groups: logger.info('%0.6f | %6d ' % (param_group['lr'], epoch)) if min_loss > running_loss / nb_batches + 1e-2: min_loss = running_loss / nb_batches is_best = True early_stop_count = 0 else: is_best = False early_stop_count += 1 # save the check point state = { 'epoch': epoch + 1, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': running_loss / nb_batches, 'min_loss': min_loss } verbose = save_ckp(state, is_best, early_stop_count=early_stop_count, early_stop_patience=early_stop_patience, save_model_dir=model_path, best_dir=best_path, name=model_name) # summarize the training results of this epoch logger.info('Average training loss of this epoch is {}'.format( running_loss / nb_batches)) logger.info('Best training loss till now is {}'.format(min_loss)) logger.info('Validation dice loss: {}; Validation accuracy: {}'.format( val_loss, val_acc)) # save the training info every epoch logger.info('Writing the training info into file ...') val_info_file = os.path.join(intermidiate_data_save, '{}_val_info.json'.format(model_name)) with open(train_info_file, 'w') as fp: json.dump(train_info, fp) with open(val_info_file, 'w') as fp: json.dump(val_info, fp) loss_plot(train_info_file, name=model_name) for name, layer in net.named_parameters(): writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(), epoch) writer.add_histogram(name + '_data', layer.cpu().data.numpy(), epoch) if verbose: logger.info( 'The validation loss has not improved for {} epochs, training will stop here.' .format(early_stop_patience)) break writer.close() logger.info('finish training!')
pretrained_dynamic = a1[0] # flow_dataset = FlowDataset(transform = transforms.Compose([ToTensor(),Rescale((cnvrt_size,cnvrt_size))])) flow_dataset = FlowDataset(transform = transforms.Compose([ToTensor()])) dataloader = DataLoader(flow_dataset, batch_size=batch_size,shuffle=True, num_workers=workers) net_dynamic = createDeepLabv3().to(device) net_dynamic.apply(weights_init) net_impainter = Inpainter(ngpu=1).to(device) # net_impainter.apply(weights_init) optimizerD = optim.Adam(net_dynamic.parameters(), lr=lr, betas=(beta1, beta2)) optimizerI = optim.Adam(net_impainter.parameters(), lr=lr, betas=(beta1, beta2)) if(pretrained_dynamic!=None): net_dynamic, optimizerD, start_epoch = load_ckp(checkpoint_dynamic_path+pretrained_dynamic, net_dynamic, optimizerD) print("Loaded pretrained: " + pretrained_dynamic) if(pretrained_inpainter!=None): net_impainter, optimizerI, start_epoch = load_ckp(checkpoint_inpainter_path+pretrained_inpainter, net_impainter, optimizerI) print("Loaded pretrained: " + pretrained_inpainter) loss_l1 = nn.L1Loss() loss_l2 = nn.MSELoss() I_losses = [] D_losses = [] iters = 0 print("Starting Training Loop... from" + str(start_epoch)) net_dynamic.train()