Example #1
0
def main():
    """
        Image Classification Prediction
    """
    cli_args = get_test_args(__author__, __version__)
    
    # Variables
    image_path = cli_args.input
    checkpoint_path = cli_args.checkpoint
    top_k = cli_args.top_k
    categories_names = cli_args.categories_names

    # LOAD THE PRE-TRAINED MODEL
    model = load_ckp(checkpoint_path, optimizer=None)
    # PREDICT THE TOP_K PROBABILITY AND ITS CORRESPONDING CLASS FROM WHICH IT IS BELONG
    probs, classes = predict(image_path, model, top_k)
    
    # Check the categories file
    if not os.path.isfile(categories_names):
        print(f'Categories file {categories_names} was not found.')
        exit(1)
    
    # Label mapping
    with open(categories_names, 'r') as f:
        cat_to_name = json.load(f)
        
    class_names = [cat_to_name[idx] for idx in classes]
    
    # Display prediction
    data = pd.DataFrame({' Classes': classes, '  Flower': class_names, 'Probability': probs })
    data = data.sort_values('Probability', ascending = False)
    print('The item identified in the image file is:')
    print(data)
Example #2
0
def predict_use(args):

    model_name = args.model_name
    patient_path = args.patient_path

    config_file = 'config.yaml'
    cfg = load_config(config_file)
    input_modalites = int(cfg['PARAMETERS']['input_modalites'])
    output_channels = int(cfg['PARAMETERS']['output_channels'])
    base_channels = int(cfg['PARAMETERS']['base_channels'])
    patience = int(cfg['PARAMETERS']['patience'])

    ROOT = cfg['PATH']['root']
    best_dir = cfg['PATH']['best_model_path']
    best_model_dir = os.path.join(ROOT, best_dir)
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    # load best trained model
    net = init_U_Net(input_modalites, output_channels, base_channels)
    net.to(device)

    optimizer = optim.Adam(net.parameters())
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)
    ckp_path = os.path.join(best_model_dir, model_name + '_best_model.pth.tar')
    net, _, _, _, _, _ = load_ckp(ckp_path, net, optimizer, scheduler)

    # predict
    predict(net, model_name, patient_path, ROOT, save_mask=True)
Example #3
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    test_image_dataset = image_preprocessing(opt.dataset, 'val')
    data_loader = DataLoader(test_image_dataset,
                             batch_size=opt.batch_size,
                             shuffle=False,
                             num_workers=opt.num_workers)

    G = Generator(ResidualBlock, layer_count=9)
    F = Generator(ResidualBlock, layer_count=9)

    if torch.cuda.is_available():
        G = nn.DataParallel(G)
        F = nn.DataParallel(F)

        G = G.cuda()
        F = F.cuda()

    G, F, _, _, _, _, _, _, _ = load_ckp(opt.model_path, G, F)
    G.eval()
    F.eval()

    if not os.path.exists(opt.save_path):
        os.mkdir(opt.save_path)

    for step, data in enumerate(tqdm(data_loader)):
        real_A = to_variable(data['A'])
        real_B = to_variable(data['B'])

        fake_B = G(real_A)
        fake_A = F(real_B)

        batch_image = torch.cat((torch.cat(
            (real_A, real_B), 3), torch.cat((fake_A, fake_B), 3)), 2)
        for i in range(batch_image.shape[0]):
            torchvision.utils.save_image(
                denorm(batch_image[i]),
                opt.save_path + '{result_name}_{step}.jpg'.format(
                    result_name=opt.result_name,
                    step=step * opt.batch_size + i))
def main(_config,_run):
    
    logger = _run
    SAVE_NAME = _config['SAVE_NAME']
    LOAD_SAVED_MODEL = _config['LOAD_SAVED_MODEL']
    MODEL_PATH_FINAL = _config['MODEL_PATH_FINAL']
    total_steps = 1000000

    params = common.HYPERPARAMS['gamePlay2']
    params['epsilon_frames'] *= 2
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", default=False, action="store_true", help="Enable cuda")
    args = parser.parse_args()

    env = gym.make(params['env_name'],glob_conf=_config,logger=logger)
    #env = ptan.common.wrappers.wrap_dqn(env)

    writer = SummaryWriter(comment="-" + params['run_name'] + "-rainbow-beta200")
    net = RainbowDQN(env.observation_space.shape, env.action_space.n).to(device)
    
    #net.load_state_dict(torch.load(  ))
    name_load = current_path +"/models" +MODEL_PATH_FINAL
    if _config['LOAD_SAVED_MODEL']:
        mdl, opt, lss = load_ckp(MODEL_PATH_FINAL, net, optimizer)
        net = mdl
        optimizer = opt

    tgt_net = ptan.agent.TargetNet(net)
    agent = ptan.agent.DQNAgent(lambda x: net.qvals(x), ptan.actions.ArgmaxActionSelector(), device=device)
    # change the step_counts to change multi step prediction
    exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=params['gamma'], steps_count=REWARD_STEPS)
    buffer = ptan.experience.PrioritizedReplayBuffer(exp_source, params['replay_size'], PRIO_REPLAY_ALPHA)
    optimizer = optim.Adam(net.parameters(), lr=params['learning_rate'])

    today = datetime.datetime.now()
    todays_date_full = str(today.year) + "_" + str(today.month) + "_" + str(today.day) + "_"
    todays_date_full += str(today.hour) + "_" + str(today.minute) + "_" + str(today.second)
    folder_name = todays_date_full +"_"+experiment_name
    results_dir = current_path + "/results/" + folder_name
    results_dir_weights = results_dir + "/weights"
    os.makedirs(results_dir)
    os.makedirs(results_dir_weights)

    frame_idx = 0
    beta = BETA_START
    best_mean_reward = 0.0
    eval_states = None
    with common.RewardTracker(writer, params['stop_reward']) as reward_tracker:
        while frame_idx < total_steps:
            frame_idx += 1
            buffer.populate(1)
            beta = min(1.0, BETA_START + frame_idx * (1.0 - BETA_START) / BETA_FRAMES)
            new_rewards = exp_source.pop_total_rewards()
            if new_rewards:
                # start saving the model after actual training begins
                if frame_idx > 100:
                    if best_mean_reward is None or best_mean_reward < reward_tracker.mean_reward:
                        torch.save(net.state_dict(),
                                   SAVE_NAME + "-best.dat")

                        if best_mean_reward is not None:
                            print("Best mean reward updated %.3f -> %.3f, model saved" % \
                                  (best_mean_reward, reward_tracker.mean_reward))
                        if not reward_tracker.mean_reward == 0:
                            best_mean_reward = reward_tracker.mean_reward

                if reward_tracker.reward(new_rewards[0], frame_idx):
                    break

            if len(buffer) < params['replay_initial']:
                continue
            if eval_states is None:
                eval_states, _, _ = buffer.sample(STATES_TO_EVALUATE, beta)
                eval_states = [np.array(transition.state, copy=False) for transition in eval_states]
                eval_states = np.array(eval_states, copy=False)

            optimizer.zero_grad()
            batch, batch_indices, batch_weights = buffer.sample(params['batch_size'], beta)
            loss_v, sample_prios_v = calc_loss(batch, batch_weights, net, tgt_net.target_model,
                                               params['gamma'] ** REWARD_STEPS, device=device)

            # if frame_idx % 10000 == 0:
            if frame_idx % 5000 == 0:
                checkpoint = ({
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'loss': loss_v,
                    'num_step': frame_idx
                })
                torch.save(checkpoint, results_dir_weights + "/rainbow" + str(frame_idx) + "step.dat")

                # Save network parameters as histogram
                for name, param in net.named_parameters():
                    writer.add_histogram(name, param.clone().cpu().data.numpy(), frame_idx)
            loss_v.backward()
            optimizer.step()
            buffer.update_priorities(batch_indices, sample_prios_v.data.cpu().numpy())

            if frame_idx % params['target_net_sync'] == 0:
                tgt_net.sync()

            if logger:
                loss_v.item()
                logger.log_scalar("loss", loss_v.item())
                logger.log_scalar("mean_reward", reward_tracker.mean_reward)
                
Example #5
0
def main(ckp_path=None):
    """ckp_path (str): checkpoint_path
    Train the model from scratch if ckp_path is None else
    Re-Train the model from previous checkpoint
    """
    cli_args = get_train_args(__author__, __version__)

    # Variables
    data_dir = cli_args.data_dir
    save_dir = cli_args.save_dir
    file_name = cli_args.file_name
    use_gpu = cli_args.use_gpu

    # LOAD DATA
    data_loaders = load_data(data_dir, config.IMG_SIZE, config.BATCH_SIZE)

    # BUILD MODEL
    if ckp_path == None:
        model = initialize_model(model_name=config.MODEL_NAME,
                                 num_classes=config.NO_OF_CLASSES,
                                 feature_extract=True,
                                 use_pretrained=True)
    else:
        model = load_ckp(ckp_path)

    # Device is available or not
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # If the user wants the gpu mode, check if cuda is available
    if (use_gpu == True) and (torch.cuda.is_available() == False):
        print("GPU mode is not available, using CPU...")
        use_gpu = False

    # MOVE MODEL TO AVAILBALE DEVICE
    model.to(device)

    # DEFINE OPTIMIZER
    optimizer = optimizer_fn(model_name=config.MODEL_NAME,
                             model=model,
                             lr_rate=config.LR_RATE)

    # DEFINE SCHEDULER
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="min",
                                                           patience=5,
                                                           factor=0.3,
                                                           verbose=True)

    # DEFINE LOSS FUNCTION
    criterion = loss_fn()

    # LOAD BEST MODEL'S WEIGHTS
    best_model_wts = copy.deepcopy(model.state_dict())

    # BEST VALIDATION SCORE
    if ckp_path == None:
        best_score = -1  # IF MODEL IS TRAIN FROM SCRATCH
    else:
        best_score = model.best_score  # IF MODEL IS RE-TRAIN

    # NO OF ITERATION
    no_epochs = config.EPOCHS
    # KEEP TRACK OF LOSS AND ACCURACY IN EACH EPOCH
    stats = {
        'train_losses': [],
        'valid_losses': [],
        'train_accuracies': [],
        'valid_accuracies': []
    }

    print("Models's Training Start......")

    for epoch in range(1, no_epochs + 1):
        train_loss, train_score = train_fn(data_loaders,
                                           model,
                                           optimizer,
                                           criterion,
                                           device,
                                           phase='train')
        val_loss, val_score = eval_fn(data_loaders,
                                      model,
                                      criterion,
                                      device=config.DEVICE,
                                      phase='valid')
        scheduler.step(val_loss)

        # SAVE MODEL'S WEIGHTS IF MODEL' VALIDATION ACCURACY IS INCREASED
        if val_score > best_score:
            print(
                'Validation score increased ({:.6f} --> {:.6f}).  Saving model ...'
                .format(best_score, val_score))
            best_score = val_score
            best_model_wts = copy.deepcopy(
                model.state_dict())  #Saving the best model' weights

        # MAKE A RECORD OF AVERAGE LOSSES AND ACCURACY IN EACH EPOCH FOR PLOTING
        stats['train_losses'].append(train_loss)
        stats['valid_losses'].append(val_loss)
        stats['train_accuracies'].append(train_score)
        stats['valid_accuracies'].append(val_score)

        # PRINT TRAINING AND VALIDATION LOOS/ACCURACIES AFTER EACH EPOCH
        epoch_len = len(str(no_epochs))
        print_msg = (f'[{epoch:>{epoch_len}}/{no_epochs:>{epoch_len}}] ' +
                     '\t' + f'train_loss: {train_loss:.5f} ' + '\t' +
                     f'train_score: {train_score:.5f} ' + '\t' +
                     f'valid_loss: {val_loss:.5f} ' + '\t' +
                     f'valid_score: {val_score:.5f}')
        print(print_msg)

    # load best model weights
    model.load_state_dict(best_model_wts)

    # create checkpoint variable and add important data
    model.class_to_idx = data_loaders['train'].dataset.class_to_idx
    model.best_score = best_score
    model.model_name = config.MODEL_NAME
    checkpoint = {
        'epoch': no_epochs,
        'lr_rate': config.LR_RATE,
        'model_name': config.MODEL_NAME,
        'batch_size': config.BATCH_SIZE,
        'valid_score': best_score,
        'optimizer': optimizer.state_dict(),
        'state_dict': model.state_dict(),
        'class_to_idx': model.class_to_idx
    }

    # SAVE CHECKPOINT
    save_ckp(checkpoint, save_dir, file_name)

    print("Models's Training is Successfull......")

    return model
from modules.models import get_model
from utils import load_ckp
import torch.onnx
import onnx
import torch

OPSET_VERSION = 8
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = get_model(cfg.model_name, embeddings_size=cfg.embeddings_size)
model = model.to(device)
model.effnet.set_swish(False)

print(f'Load checkpoint : {cfg.WEIGHTS_LOAD_PATH}')
model, _, _, _, _, _ = load_ckp('../' + cfg.WEIGHTS_LOAD_PATH,
                                model,
                                remove_module=True)
model.eval()

with torch.no_grad():
    # Input to the model
    x = torch.randn(10, 3, 48, 48).to(device).float()
    print('Start onnx conversion')
    print(f'Save model to : {cfg.WEIGHTS_SAVE_PATH}')
    torch.onnx.export(model,
                      x,
                      os.path.join('..', cfg.WEIGHTS_SAVE_PATH, 'model.onnx'),
                      opset_version=OPSET_VERSION,
                      verbose=False,
                      export_params=True,
                      input_names=['input'],
Example #7
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined parameters
    model_name = args.model_name
    model_type = args.model_type
    lstm_backbone = args.lstmbase
    unet_backbone = args.unetbase
    layer_num = args.layer_num
    nb_shortcut = args.nb_shortcut
    loss_fn = args.loss_fn
    world_size = args.world_size
    rank = args.rank
    base_channel = args.base_channels
    crop_size = args.crop_size
    ignore_idx = args.ignore_idx
    return_sequence = args.return_sequence
    variant = args.LSTM_variant
    epochs = args.epoch
    is_pretrain = args.is_pretrain

    # system setup parameters
    config_file = 'config.yaml'
    config = load_config(config_file)
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['model_root']
    model_dir = config['PATH']['save_ckp']
    best_dir = config['PATH']['save_best_model']

    input_modalites = int(config['PARAMETERS']['input_modalites'])
    output_channels = int(config['PARAMETERS']['output_channels'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    time_step = int(config['PARAMETERS']['time_step'])
    num_workers = int(config['PARAMETERS']['num_workers'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])
    lr = int(config['PARAMETERS']['lr'])
    optimizer = config['PARAMETERS']['optimizer']
    connect = config['PARAMETERS']['connect']
    conv_type = config['PARAMETERS']['lstm_convtype']

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_newdata',
                                          model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('labels {} are ignored'.format(ignore_idx))
    logger.info('Dataset is loading ...')
    writer = SummaryWriter('ProcessVisu/%s' % model_name)

    # load training set and validation set
    data_class = data_split()
    train, val, test = data_construction(data_class)
    train_dict = time_parser(train, time_patch=time_step)
    val_dict = time_parser(val, time_patch=time_step)

    # LSTM initilization

    if model_type == 'LSTM':
        net = LSTMSegNet(lstm_backbone=lstm_backbone,
                         input_dim=input_modalites,
                         output_dim=output_channels,
                         hidden_dim=base_channel,
                         kernel_size=3,
                         num_layers=layer_num,
                         conv_type=conv_type,
                         return_sequence=return_sequence)
    elif model_type == 'UNet_LSTM':
        if variant == 'back':
            net = BackLSTM(input_dim=input_modalites,
                           hidden_dim=base_channel,
                           output_dim=output_channels,
                           kernel_size=3,
                           num_layers=layer_num,
                           conv_type=conv_type,
                           lstm_backbone=lstm_backbone,
                           unet_module=unet_backbone,
                           base_channel=base_channel,
                           return_sequence=return_sequence,
                           is_pretrain=is_pretrain)
            logger.info(
                'the pretrained status of backbone is {}'.format(is_pretrain))
        elif variant == 'center':
            net = CenterLSTM(input_modalites=input_modalites,
                             output_channels=output_channels,
                             base_channel=base_channel,
                             num_layers=layer_num,
                             conv_type=conv_type,
                             return_sequence=return_sequence,
                             is_pretrain=is_pretrain)
        elif variant == 'bicenter':
            net = BiCenterLSTM(input_modalites=input_modalites,
                               output_channels=output_channels,
                               base_channel=base_channel,
                               num_layers=layer_num,
                               connect=connect,
                               conv_type=conv_type,
                               return_sequence=return_sequence,
                               is_pretrain=is_pretrain)
        elif variant == 'directcenter':
            net = DirectCenterLSTM(input_modalites=input_modalites,
                                   output_channels=output_channels,
                                   base_channel=base_channel,
                                   num_layers=layer_num,
                                   conv_type=conv_type,
                                   return_sequence=return_sequence,
                                   is_pretrain=is_pretrain)
        elif variant == 'bidirectcenter':
            net = BiDirectCenterLSTM(input_modalites=input_modalites,
                                     output_channels=output_channels,
                                     base_channel=base_channel,
                                     num_layers=layer_num,
                                     connect=connect,
                                     conv_type=conv_type,
                                     return_sequence=return_sequence,
                                     is_pretrain=is_pretrain)
        elif variant == 'rescenter':
            net = ResCenterLSTM(input_modalites=input_modalites,
                                output_channels=output_channels,
                                base_channel=base_channel,
                                num_layers=layer_num,
                                conv_type=conv_type,
                                return_sequence=return_sequence,
                                is_pretrain=is_pretrain)
        elif variant == 'birescenter':
            net = BiResCenterLSTM(input_modalites=input_modalites,
                                  output_channels=output_channels,
                                  base_channel=base_channel,
                                  num_layers=layer_num,
                                  connect=connect,
                                  conv_type=conv_type,
                                  return_sequence=return_sequence,
                                  is_pretrain=is_pretrain)
        elif variant == 'shortcut':
            net = ShortcutLSTM(input_modalites=input_modalites,
                               output_channels=output_channels,
                               base_channel=base_channel,
                               num_layers=layer_num,
                               num_connects=nb_shortcut,
                               conv_type=conv_type,
                               return_sequence=return_sequence,
                               is_pretrain=is_pretrain)
    else:
        raise NotImplementedError()

    # loss and optimizer setup
    if loss_fn == 'Dice':
        criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx)
    elif loss_fn == 'GDice':
        criterion = GneralizedDiceLoss(labels=labels)
    elif loss_fn == 'WCE':
        criterion = WeightedCrossEntropyLoss(labels=labels)
    else:
        raise NotImplementedError()

    if optimizer == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=0.001)
        # optimizer = optim.Adam(net.parameters())
    elif optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    # device setup
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    # net, optimizer = amp.initialize(net, optimizer, opt_level="O1")

    if torch.cuda.device_count() > 1:
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:38366',
            rank=rank,
            world_size=world_size)
    if distributed_is_initialized():
        print('distributed is initialized')
        net.to(device)
        net = nn.parallel.DistributedDataParallel(net,
                                                  find_unused_parameters=True)
    else:
        print('data parallel')
        net = nn.DataParallel(net)
        net.to(device)

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'label_0_acc': [],
        'label_1_acc': [],
        'label_2_acc': [],
        'label_3_acc': [],
        'label_4_acc': []
    }

    if is_resume:
        try:
            # open previous check points
            ckp_path = os.path.join(model_path,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)

            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))
            logger.info(
                'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}'
                .format(train_info['label_0_acc'][-1],
                        train_info['label_1_acc'][-1],
                        train_info['label_2_acc'][-1],
                        train_info['label_3_acc'][-1],
                        train_info['label_4_acc'][-1]))

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    for epoch in range(start_epoch, epochs):

        train_set = data_loader(train_dict,
                                batch_size=batch_size,
                                key='train',
                                num_works=num_workers,
                                time_step=time_step,
                                patch=crop_size,
                                model_type='RNN')
        n_train = len(train_set)

        val_set = data_loader(val_dict,
                              batch_size=batch_size,
                              key='val',
                              num_works=num_workers,
                              time_step=time_step,
                              patch=crop_size,
                              model_type='CNN')
        n_val = len(val_set)

        logger.info('Dataset loading finished!')

        nb_batches = np.ceil(n_train / batch_size)
        n_total = n_train + n_val
        logger.info(
            '{} images will be used in total, {} for trainning and {} for validation'
            .format(n_total, n_train, n_val))

        train_loader = train_set.load()

        # setup to train mode
        net.train()
        running_loss = 0
        dice_score_label_0 = 0
        dice_score_label_1 = 0
        dice_score_label_2 = 0
        dice_score_label_3 = 0
        dice_score_label_4 = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):

                # i : patient
                images, segs = data['image'].to(device), data['seg'].to(device)

                outputs = net(images)
                loss = criterion(outputs, segs)
                loss.backward()
                optimizer.step()

                # if i == 0:
                #     in_images = images.detach().cpu().numpy()[0]
                #     in_segs = segs.detach().cpu().numpy()[0]
                #     in_pred = outputs.detach().cpu().numpy()[0]
                #     heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch+1, is_train=True)

                running_loss += loss.detach().item()

                outputs = outputs.view(-1, outputs.shape[-4],
                                       outputs.shape[-3], outputs.shape[-2],
                                       outputs.shape[-1])
                segs = segs.view(-1, segs.shape[-3], segs.shape[-2],
                                 segs.shape[-1])
                _, preds = torch.max(outputs.data, 1)
                dice_score = dice(preds.data.cpu(),
                                  segs.data.cpu(),
                                  ignore_idx=None)

                dice_score_label_0 += dice_score['bg']
                dice_score_label_1 += dice_score['csf']
                dice_score_label_2 += dice_score['gm']
                dice_score_label_3 += dice_score['wm']
                dice_score_label_4 += dice_score['tm']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'training loss': loss.detach().item(),
                        'Training accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                global_step += 1
                if global_step % nb_batches == 0:
                    net.eval()
                    val_loss, val_acc, val_info = validation(net,
                                                             val_set,
                                                             criterion,
                                                             device,
                                                             batch_size,
                                                             ignore_idx=None,
                                                             name=model_name,
                                                             epoch=epoch + 1)
                    net.train()

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['label_0_acc'].append(dice_score_label_0 / nb_batches)
        train_info['label_1_acc'].append(dice_score_label_1 / nb_batches)
        train_info['label_2_acc'].append(dice_score_label_2 / nb_batches)
        train_info['label_3_acc'].append(dice_score_label_3 / nb_batches)
        train_info['label_4_acc'].append(dice_score_label_4 / nb_batches)

        # save bast trained model
        scheduler.step(running_loss / nb_batches)
        logger.info('Epoch: {}, LR: {}'.format(
            epoch + 1, optimizer.param_groups[0]['lr']))

        if min_loss > running_loss / nb_batches:
            min_loss = running_loss / nb_batches
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        # summarize the training results of this epoch
        logger.info('The average training loss for this epoch is {}'.format(
            running_loss / nb_batches))
        logger.info('The best training loss till now is {}'.format(min_loss))
        logger.info(
            'Validation dice loss: {}; Validation (avg) accuracy of the last timestep: {}'
            .format(val_loss, val_acc))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        val_info_file = os.path.join(intermidiate_data_save,
                                     '{}_val_info.json'.format(model_name))
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)
        with open(val_info_file, 'w') as fp:
            json.dump(val_info, fp)

        for name, layer in net.named_parameters():
            if layer.requires_grad:
                writer.add_histogram(name + '_grad',
                                     layer.grad.cpu().data.numpy(), epoch)
                writer.add_histogram(name + '_data',
                                     layer.cpu().data.numpy(), epoch)
        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    loss_plot(train_info_file, name=model_name)
    logger.info('finish training!')

    return
Example #8
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    start_epoch = 0
    train_image_dataset = image_preprocessing(opt.dataset, 'train')
    data_loader = DataLoader(train_image_dataset, batch_size=opt.batch_size,
                            shuffle=True, num_workers=opt.num_workers)
    criterion = least_squares
    euclidean_l1 = nn.L1Loss()

    G = Generator(ResidualBlock, layer_count=9)
    F = Generator(ResidualBlock, layer_count=9)
    Dx = Discriminator()
    Dy = Discriminator()

    G_optimizer = optim.Adam(G.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    F_optimizer = optim.Adam(F.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    Dx_optimizer = optim.Adam(Dx.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
    Dy_optimizer = optim.Adam(Dy.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))

    if torch.cuda.is_available():
        G = nn.DataParallel(G)
        F = nn.DataParallel(F)
        Dx = nn.DataParallel(Dx)
        Dy = nn.DataParallel(Dy)

        G = G.cuda()
        F = F.cuda()
        Dx = Dx.cuda()
        Dy = Dy.cuda()
    
    if opt.checkpoint is not None:
        G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer, start_epoch = load_ckp(opt.checkpoint, G, F, Dx, Dy, G_optimizer, F_optimizer, Dx_optimizer, Dy_optimizer)

    print('[Start] : Cycle GAN Training')

    logger = Logger(opt.epochs, len(data_loader), image_step=10)

    for epoch in range(opt.epochs):
        epoch = epoch + start_epoch + 1
        print("Epoch[{epoch}] : Start".format(epoch=epoch))
        
        for step, data in enumerate(data_loader):
            real_A = to_variable(data['A'])
            real_B = to_variable(data['B'])

            fake_B = G(real_A)
            fake_A = F(real_B)

            # Train Dx
            Dx_optimizer.zero_grad()

            Dx_real = Dx(real_A)
            Dx_fake = Dx(fake_A)

            Dx_loss = patch_loss(criterion, Dx_real, True) + patch_loss(criterion, Dx_fake, 0)

            Dx_loss.backward(retain_graph=True)
            Dx_optimizer.step()

            # Train Dy
            Dy_optimizer.zero_grad()

            Dy_real = Dy(real_B)
            Dy_fake = Dy(fake_B)

            Dy_loss = patch_loss(criterion, Dy_real, True) + patch_loss(criterion, Dy_fake, 0)

            Dy_loss.backward(retain_graph=True)
            Dy_optimizer.step()

            # Train G
            G_optimizer.zero_grad()

            Dy_fake = Dy(fake_B)

            G_loss = patch_loss(criterion, Dy_fake, True)

            # Train F
            F_optimizer.zero_grad()

            Dx_fake = Dx(fake_A)

            F_loss = patch_loss(criterion, Dx_fake, True)

            # identity loss
            loss_identity = euclidean_l1(real_A, fake_A) + euclidean_l1(real_B, fake_B)

            # cycle consistency
            loss_cycle = euclidean_l1(F(fake_B), real_A) + euclidean_l1(G(fake_A), real_B)

            # Optimize G & F
            loss = G_loss + F_loss + opt.lamda * loss_cycle + opt.lamda * loss_identity * (0.5)

            loss.backward()
            G_optimizer.step()
            F_optimizer.step()

            if (step + 1 ) % opt.save_step == 0:
                print("Epoch[{epoch}]| Step [{now}/{total}]| Dx Loss: {Dx_loss}, Dy_Loss: {Dy_loss}, G_Loss: {G_loss}, F_Loss: {F_loss}".format(
                    epoch=epoch, now=step + 1, total=len(data_loader), Dx_loss=Dx_loss.item(), Dy_loss=Dy_loss,
                    G_loss=G_loss.item(), F_loss=F_loss.item()))
                batch_image = torch.cat((torch.cat((real_A, real_B), 3), torch.cat((fake_A, fake_B), 3)), 2)

                torchvision.utils.save_image(denorm(batch_image[0]), opt.training_result + 'result_{result_name}_ep{epoch}_{step}.jpg'.format(result_name=opt.result_name,epoch=epoch, step=(step + 1) * opt.batch_size))
            
            # http://localhost:8097
            logger.log(
                losses={
                    'loss_G': G_loss,
                    'loss_F': F_loss,
                    'loss_identity': loss_identity,
                    'loss_cycle': loss_cycle,
                    'total_G_loss': loss,
                    'loss_Dx': Dx_loss,
                    'loss_Dy': Dy_loss,
                    'total_D_loss': (Dx_loss + Dy_loss),
                },
                images={
                    'real_A': real_A,
                    'real_B': real_B,
                    'fake_A': fake_A,
                    'fake_ B': fake_B,
                },
            )


        torch.save({
            'epoch': epoch,
            'G_model': G.state_dict(),
            'G_optimizer': G_optimizer.state_dict(),
            'F_model': F.state_dict(),
            'F_optimizer': F_optimizer.state_dict(),
            'Dx_model': Dx.state_dict(),
            'Dx_optimizer': Dx_optimizer.state_dict(),
            'Dy_model': Dy.state_dict(),
            'Dy_optimizer': Dy_optimizer.state_dict(),
        }, opt.save_model + 'model_{result_name}_CycleGAN_ep{epoch}.ckp'.format(result_name=opt.result_name, epoch=epoch))
Example #9
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined
    model_name = args.model_name
    model_loss_fn = args.loss_fn

    config_file = 'config.yaml'

    config = load_config(config_file)
    data_root = config['PATH']['data_root']
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['root']
    model_dir = config['PATH']['model_path']
    best_dir = config['PATH']['best_model_path']

    data_class = config['PATH']['data_class']
    input_modalites = int(config['PARAMETERS']['input_modalites'])
    output_channels = int(config['PARAMETERS']['output_channels'])
    base_channel = int(config['PARAMETERS']['base_channels'])
    crop_size = int(config['PARAMETERS']['crop_size'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    epochs = int(config['PARAMETERS']['epoch'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    ignore_idx = int(config['PARAMETERS']['ignore_index'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_data', model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('Dataset is loading ...')
    # split dataset
    dir_ = os.path.join(data_root, data_class)
    data_content = train_split(dir_)

    # load training set and validation set
    train_set = data_loader(data_content=data_content,
                            key='train',
                            form='LGG',
                            crop_size=crop_size,
                            batch_size=batch_size,
                            num_works=8)
    n_train = len(train_set)
    train_loader = train_set.load()

    val_set = data_loader(data_content=data_content,
                          key='val',
                          form='LGG',
                          crop_size=crop_size,
                          batch_size=batch_size,
                          num_works=8)

    logger.info('Dataset loading finished!')

    n_val = len(val_set)
    nb_batches = np.ceil(n_train / batch_size)
    n_total = n_train + n_val
    logger.info(
        '{} images will be used in total, {} for trainning and {} for validation'
        .format(n_total, n_train, n_val))

    net = init_U_Net(input_modalites, output_channels, base_channel)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    if torch.cuda.device_count() > 1:
        logger.info('{} GPUs available.'.format(torch.cuda.device_count()))
        net = nn.DataParallel(net)

    net.to(device)

    if model_loss_fn == 'Dice':
        criterion = DiceLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'CrossEntropy':
        criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'FocalLoss':
        criterion = FocalLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_CE':
        criterion = Dice_CE(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_FL':
        criterion = Dice_FL(labels=labels, ignore_index=ignore_idx)
    else:
        raise NotImplementedError()

    optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'BG_acc': [],
        'NET_acc': [],
        'ED_acc': [],
        'ET_acc': []
    }

    if is_resume:
        try:
            ckp_path = os.path.join(model_dir,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)

            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    # start training
    for epoch in range(start_epoch, epochs):

        # setup to train mode
        net.train()
        running_loss = 0
        dice_coeff_bg = 0
        dice_coeff_net = 0
        dice_coeff_ed = 0
        dice_coeff_et = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):
                images, segs = data['image'].to(device), data['seg'].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                outputs = net(images)

                loss = criterion(outputs, segs)
                # loss.backward()
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                optimizer.step()

                # save the output at the begining of each epoch to visulize it
                if i == 0:
                    in_images = images.detach().cpu().numpy()[:, 0, ...]
                    in_segs = segs.detach().cpu().numpy()
                    in_pred = outputs.detach().cpu().numpy()
                    heatmap_plot(image=in_images,
                                 mask=in_segs,
                                 pred=in_pred,
                                 name=model_name,
                                 epoch=epoch + 1)

                running_loss += loss.detach().item()
                dice_score = dice_coe(outputs.detach().cpu(),
                                      segs.detach().cpu())
                dice_coeff_bg += dice_score['BG']
                dice_coeff_ed += dice_score['ED']
                dice_coeff_et += dice_score['ET']
                dice_coeff_net += dice_score['NET']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'Training loss': loss.detach().item(),
                        'Training (avg) accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                global_step += 1
                if global_step % nb_batches == 0:
                    # validate
                    net.eval()
                    val_loss, val_acc = validation(net, val_set, criterion,
                                                   device, batch_size)

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['BG_acc'].append(dice_coeff_bg / nb_batches)
        train_info['NET_acc'].append(dice_coeff_net / nb_batches)
        train_info['ED_acc'].append(dice_coeff_ed / nb_batches)
        train_info['ET_acc'].append(dice_coeff_et / nb_batches)

        # save bast trained model
        scheduler.step(running_loss / nb_batches)

        if min_loss > val_loss:
            min_loss = val_loss
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        logger.info('The average training loss for this epoch is {}'.format(
            running_loss / (np.ceil(n_train / batch_size))))
        logger.info(
            'Validation dice loss: {}; Validation (avg) accuracy: {}'.format(
                val_loss, val_acc))
        logger.info('The best validation loss till now is {}'.format(min_loss))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)

        loss_plot(train_info_file, name=model_name)

        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    logger.info('finish training!')
Example #10
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined
    model_name = args.model_name
    model_type = args.model_type
    loss_func = args.loss
    world_size = args.world_size
    rank = args.rank
    base_channel = args.base_channels
    crop_size = args.crop_size
    ignore_idx = args.ignore_idx
    epochs = args.epoch

    # system setup
    config_file = 'config.yaml'
    config = load_config(config_file)
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['model_root']
    model_dir = config['PATH']['save_ckp']
    best_dir = config['PATH']['save_best_model']

    output_channels = int(config['PARAMETERS']['output_channels'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    time_step = int(config['PARAMETERS']['time_step'])
    num_workers = int(config['PARAMETERS']['num_workers'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])
    pad_method = config['PARAMETERS']['pad_method']
    lr = int(config['PARAMETERS']['lr'])
    optimizer = config['PARAMETERS']['optimizer']
    softmax = True
    modalities = ['flair', 't1', 't1gd', 't2']
    input_modalites = len(modalities)

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_newdata',
                                          model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('Dataset is loading ...')
    writer = SummaryWriter('ProcessVisu/%s' % model_name)

    logger.info('patch size: {}'.format(crop_size))

    # load training set and validation set
    data_class = data_split()
    train, val, test = data_construction(data_class)
    train_dict = time_parser(train, time_patch=time_step)
    val_dict = time_parser(val, time_patch=time_step)

    # groups = 4
    if model_type == 'UNet':
        net = init_U_Net(input_modalites, output_channels, base_channel,
                         pad_method, softmax)
    elif model_type == 'ResUNet':
        net = ResUNet(input_modalites, output_channels, base_channel,
                      pad_method, softmax)
    elif model_type == 'DResUNet':
        net = DResUNet(input_modalites, output_channels, base_channel,
                       pad_method, softmax)
    elif model_type == 'direct_concat':
        net = U_Net_direct_concat(input_modalites, output_channels,
                                  base_channel, pad_method, softmax)
    elif model_type == 'Inception':
        net = Inception_UNet(input_modalites, output_channels, base_channel,
                             softmax)
    elif model_type == 'Simple_Inception':
        net = Simplified_Inception_UNet(input_modalites, output_channels,
                                        base_channel, softmax)

    # device setup
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    net.to(device)

    # print model structure
    summary(net, input_size=(input_modalites, crop_size, crop_size, crop_size))
    dummy_input = torch.rand(1, input_modalites, crop_size, crop_size,
                             crop_size).to(device)
    writer.add_graph(net, (dummy_input, ))

    # loss and optimizer setup
    if loss_func == 'Dice' and softmax:
        criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx)
    elif loss_func == 'GDice' and softmax:
        criterion = GneralizedDiceLoss(labels=labels)
    elif loss_func == 'CrossEntropy':
        criterion = WeightedCrossEntropyLoss(labels=labels)
        if not softmax:
            criterion = nn.CrossEntropyLoss().cuda()
    else:
        raise NotImplementedError()

    if optimizer == 'adam':
        optimizer = optim.Adam(net.parameters())
    elif optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              momentum=0.9,
                              lr=lr,
                              weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    # net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

    if torch.cuda.device_count() > 1:
        logger.info('{} GPUs avaliable'.format(torch.cuda.device_count()))
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:38366',
            rank=rank,
            world_size=world_size)
    if distributed_is_initialized():
        logger.info('distributed is initialized')
        net.to(device)
        net = nn.parallel.DistributedDataParallel(net)
    else:
        logger.info('data parallel')
        net = nn.DataParallel(net)
        net.to(device)

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'label_0_acc': [],
        'label_1_acc': [],
        'label_2_acc': [],
        'label_3_acc': [],
        'label_4_acc': []
    }

    if is_resume:
        try:
            # open previous check points
            ckp_path = os.path.join(model_path,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)
            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))
            logger.info(
                'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}'
                .format(train_info['label_0_acc'][-1],
                        train_info['label_1_acc'][-1],
                        train_info['label_2_acc'][-1],
                        train_info['label_3_acc'][-1],
                        train_info['label_4_acc'][-1]))
            # min_loss = float('Inf')

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    # start training
    for epoch in range(start_epoch, epochs):

        # every epoch generate a new set of images
        train_set = data_loader(train_dict,
                                batch_size=batch_size,
                                key='train',
                                num_works=num_workers,
                                time_step=time_step,
                                patch=crop_size,
                                modalities=modalities,
                                model_type='CNN')
        n_train = len(train_set)
        train_loader = train_set.load()

        val_set = data_loader(val_dict,
                              batch_size=batch_size,
                              key='val',
                              num_works=num_workers,
                              time_step=time_step,
                              patch=crop_size,
                              modalities=modalities,
                              model_type='CNN')
        n_val = len(val_set)

        nb_batches = np.ceil(n_train / batch_size)
        n_total = n_train + n_val
        logger.info(
            '{} images will be used in total, {} for trainning and {} for validation'
            .format(n_total, n_train, n_val))
        logger.info('Dataset loading finished!')

        # setup to train mode
        net.train()
        running_loss = 0
        dice_score_label_0 = 0
        dice_score_label_1 = 0
        dice_score_label_2 = 0
        dice_score_label_3 = 0
        dice_score_label_4 = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):
                images, segs = data['image'].to(device), data['seg'].to(device)

                if model_type == 'SkipDenseSeg' and not softmax:
                    segs = segs.long()

                # combine the batch and time step
                batch, time, channel, z, y, x = images.shape
                images = images.view(-1, channel, z, y, x)
                segs = segs.view(-1, z, y, x)

                # zero the parameter gradients
                optimizer.zero_grad()
                outputs = net(images)

                loss = criterion(outputs, segs)
                loss.backward()
                # with amp.scale_loss(loss, optimizer) as scaled_loss:
                #     scaled_loss.backward()
                optimizer.step()

                running_loss += loss.detach().item()
                _, preds = torch.max(outputs.data, 1)
                dice_score = dice(preds.data.cpu(),
                                  segs.data.cpu(),
                                  ignore_idx=ignore_idx)

                dice_score_label_0 += dice_score['bg']
                dice_score_label_1 += dice_score['csf']
                dice_score_label_2 += dice_score['gm']
                dice_score_label_3 += dice_score['wm']
                dice_score_label_4 += dice_score['tm']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'Training loss': loss.detach().item(),
                        'Training accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                del images, segs

                global_step += 1
                if global_step % nb_batches == 0:
                    net.eval()
                    val_loss, val_acc, val_info = validation(
                        net,
                        val_set,
                        criterion,
                        device,
                        batch_size,
                        model_type=model_type,
                        softmax=softmax,
                        ignore_idx=ignore_idx)

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['label_0_acc'].append(dice_score_label_0 / nb_batches)
        train_info['label_1_acc'].append(dice_score_label_1 / nb_batches)
        train_info['label_2_acc'].append(dice_score_label_2 / nb_batches)
        train_info['label_3_acc'].append(dice_score_label_3 / nb_batches)
        train_info['label_4_acc'].append(dice_score_label_4 / nb_batches)

        # save bast trained model
        if model_type == 'SkipDenseSeg':
            scheduler.step()
        else:
            scheduler.step(val_loss)
        # debug
        for param_group in optimizer.param_groups:
            logger.info('%0.6f | %6d ' % (param_group['lr'], epoch))

        if min_loss > running_loss / nb_batches + 1e-2:
            min_loss = running_loss / nb_batches
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        # save the check point
        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        # summarize the training results of this epoch
        logger.info('Average training loss of this epoch is {}'.format(
            running_loss / nb_batches))
        logger.info('Best training loss till now is {}'.format(min_loss))
        logger.info('Validation dice loss: {}; Validation accuracy: {}'.format(
            val_loss, val_acc))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        val_info_file = os.path.join(intermidiate_data_save,
                                     '{}_val_info.json'.format(model_name))
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)
        with open(val_info_file, 'w') as fp:
            json.dump(val_info, fp)

        loss_plot(train_info_file, name=model_name)
        for name, layer in net.named_parameters():
            writer.add_histogram(name + '_grad',
                                 layer.grad.cpu().data.numpy(), epoch)
            writer.add_histogram(name + '_data',
                                 layer.cpu().data.numpy(), epoch)

        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    writer.close()
    logger.info('finish training!')
Example #11
0
    pretrained_dynamic = a1[0]
# flow_dataset = FlowDataset(transform = transforms.Compose([ToTensor(),Rescale((cnvrt_size,cnvrt_size))]))
flow_dataset = FlowDataset(transform = transforms.Compose([ToTensor()]))

dataloader = DataLoader(flow_dataset, batch_size=batch_size,shuffle=True, num_workers=workers)

net_dynamic = createDeepLabv3().to(device)
net_dynamic.apply(weights_init)

net_impainter = Inpainter(ngpu=1).to(device) 
# net_impainter.apply(weights_init)
optimizerD = optim.Adam(net_dynamic.parameters(), lr=lr, betas=(beta1, beta2))
optimizerI = optim.Adam(net_impainter.parameters(), lr=lr, betas=(beta1, beta2))

if(pretrained_dynamic!=None):
  net_dynamic, optimizerD, start_epoch = load_ckp(checkpoint_dynamic_path+pretrained_dynamic, net_dynamic, optimizerD)
  print("Loaded pretrained: " + pretrained_dynamic)

if(pretrained_inpainter!=None):
  net_impainter, optimizerI, start_epoch = load_ckp(checkpoint_inpainter_path+pretrained_inpainter, net_impainter, optimizerI)
  print("Loaded pretrained: " + pretrained_inpainter)

loss_l1 = nn.L1Loss()
loss_l2 = nn.MSELoss()

I_losses = []
D_losses = []
iters = 0

print("Starting Training Loop... from" + str(start_epoch))
net_dynamic.train()