Esempio n. 1
0
def main():
    """
    input: outp1 concat with 4 modalities;
    target: difference between outp1 and the GT
    :return:
    """
    # init or load model
    print("init model with input shape", config["input_shape"])
    model = AttentionVNet(config=config)
    parameters = model.parameters()
    optimizer = optim.Adam(parameters, 
                           lr=config["initial_learning_rate"],
                           weight_decay=config["L2_norm"])
    start_epoch = 1
    if config["VAE_enable"]:
        loss_function = CombinedLoss(combine=config["combine"],
                                     k1=config["loss_k1_weight"], k2=config["loss_k2_weight"])
    else:
        loss_function = SoftDiceLoss(combine=config["combine"])

    with open('valid_list.txt', 'r') as f:
        val_list = f.read().splitlines()
    with open('train_list.txt', 'r') as f:
        tr_list = f.read().splitlines()

    config["training_patients"] = tr_list
    config["validation_patients"] = val_list

    preprocessor = stage2net_preprocessor(config, patch_size=patch_size)

    # data_generator
    print("data generating")
    training_data = PatchDataset(phase="train", config=config, preprocessor=preprocessor)
    valildation_data = PatchDataset(phase="validate", config=config, preprocessor=preprocessor)
    train_logger = Logger(model_name=config["model_name"] + '.h5',
                          header=['epoch', 'loss', 'wt-dice', 'tc-dice', 'et-dice', 'lr'])

    if not config["overwrite"] and config["saved_model_file"] is not None:
        if not os.path.exists(config["saved_model_file"]):
            raise Exception("Invalid model path!")
        model, start_epoch, optimizer_resume = load_old_model(model, optimizer, saved_model_path=config["saved_model_file"])
        parameters = model.parameters()
        optimizer = optim.Adam(parameters,
                               lr=optimizer_resume.param_groups[0]["lr"],
                               weight_decay=optimizer_resume.param_groups[0]["weight_decay"])

    if config["cuda_devices"] is not None:
        model = model.cuda()
        model = nn.DataParallel(model)    # multi-gpu training
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    scheduler = lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=poly_lr_scheduler_multi)
    # scheduler = lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=poly_lr_scheduler)

    max_val_TC_dice = 0.
    max_val_ET_dice = 0.
    max_val_AVG_dice = 0.
    for i in range(start_epoch, config["epochs"]):
        train_epoch(epoch=i, 
                    data_set=training_data, 
                    model=model,
                    criterion=loss_function,
                    optimizer=optimizer,
                    opt=config, 
                    logger=train_logger) 
        
        val_loss, WT_dice, TC_dice, ET_dice = val_epoch(epoch=i,
                                                        data_set=valildation_data,
                                                        model=model,
                                                        criterion=loss_function,
                                                        opt=config,
                                                        optimizer=optimizer,
                                                        logger=train_logger)

        scheduler.step()
        dices = np.array([WT_dice, TC_dice, ET_dice])
        AVG_dice = dices.mean()
        save_flag = False
        if config["checkpoint"] and TC_dice > max_val_TC_dice:
            max_val_TC_dice = TC_dice
            save_flag = True
        if config["checkpoint"] and ET_dice > max_val_ET_dice:
            max_val_ET_dice = ET_dice
            save_flag = True
        if config["checkpoint"] and AVG_dice > max_val_AVG_dice:
            max_val_AVG_dice = AVG_dice
            save_flag = True

        if save_flag:
            save_dir = config["result_path"]
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            save_states_path = os.path.join(save_dir,
                      'epoch_{0}_val_loss_{1:.4f}_TC_{2:.4f}_ET_{3:.4f}_AVG_{4:.4f}.pth'.format(i, val_loss,
                                                                                               TC_dice, ET_dice, AVG_dice))
            if config["cuda_devices"] is not None:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            states = {
                'epoch': i,
                'state_dict': state_dict,
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_states_path)
            save_model_path = os.path.join(save_dir, "best_model.pth")
            if os.path.exists(save_model_path):
                os.system("rm "+ save_model_path)
            torch.save(model, save_model_path)
        print("batch {0:d} finished, validation loss:{1:.4f}; TC:{2:.4f}, ET:{3:.4f}, AVG:{4:.4f}".format(i, val_loss,
                                                                                            TC_dice, ET_dice, AVG_dice))
def main():
    # init or load model
    print("init model with input shape", config["input_shape"])
    if config["attention"]:
        model = AttentionVNet(config=config)
    else:
        model = NvNet(config=config)
    parameters = model.parameters()
    optimizer = optim.Adam(parameters,
                           lr=config["initial_learning_rate"],
                           weight_decay=config["L2_norm"])
    start_epoch = 1
    if config["VAE_enable"]:
        loss_function = CombinedLoss(new_loss=config["new_SoftDiceLoss"],
                                     k1=config["loss_k1_weight"],
                                     k2=config["loss_k2_weight"],
                                     alpha=config["focal_alpha"],
                                     gamma=config["focal_gamma"],
                                     focal_enable=config["focal_enable"])
    else:
        loss_function = SoftDiceLoss(new_loss=config["new_SoftDiceLoss"])

    with open('valid_list_v2.txt', 'r') as f:
        val_list = f.read().splitlines()
    # with open('train_list.txt', 'r') as f:
    with open('train_list_v2.txt', 'r') as f:
        tr_list = f.read().splitlines()

    config["training_patients"] = tr_list
    config["validation_patients"] = val_list
    # data_generator
    print("data generating")
    training_data = BratsDataset(phase="train", config=config)
    # x = training_data[0] # for test
    valildation_data = BratsDataset(phase="validate", config=config)
    train_logger = Logger(
        model_name=config["model_name"] + '.h5',
        header=['epoch', 'loss', 'wt-dice', 'tc-dice', 'et-dice', 'lr'])

    if not config["overwrite"] and config["saved_model_file"] is not None:
        if not os.path.exists(config["saved_model_file"]):
            raise Exception("Invalid model path!")
        model, start_epoch, optimizer_resume = load_old_model(
            model, optimizer, saved_model_path=config["saved_model_file"])
        parameters = model.parameters()
        optimizer = optim.Adam(
            parameters,
            lr=optimizer_resume.param_groups[0]["lr"],
            weight_decay=optimizer_resume.param_groups[0]["weight_decay"])

    if config["cuda_devices"] is not None:
        model = model.cuda()
        loss_function = loss_function.cuda()
        model = nn.DataParallel(model)  # multi-gpu training
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["lr_decay"], patience=config["patience"])
    scheduler = lr_scheduler.LambdaLR(
        optimizer=optimizer,
        lr_lambda=poly_lr_scheduler)  # can't restore lr correctly

    max_val_WT_dice = 0.
    max_val_AVG_dice = 0.
    for i in range(start_epoch, config["epochs"]):
        train_epoch(epoch=i,
                    data_set=training_data,
                    model=model,
                    criterion=loss_function,
                    optimizer=optimizer,
                    opt=config,
                    logger=train_logger)

        val_loss, WT_dice, TC_dice, ET_dice = val_epoch(
            epoch=i,
            data_set=valildation_data,
            model=model,
            criterion=loss_function,
            opt=config,
            optimizer=optimizer,
            logger=train_logger)

        scheduler.step()
        # scheduler.step(val_loss)
        dices = np.array([WT_dice, TC_dice, ET_dice])
        AVG_dice = dices.mean()
        if config["checkpoint"] and (WT_dice > max_val_WT_dice
                                     or AVG_dice > max_val_AVG_dice
                                     or WT_dice >= 0.912):
            max_val_WT_dice = WT_dice
            max_val_AVG_dice = AVG_dice
            # save_dir = os.path.join(config["result_path"], config["model_file"].split("/")[-1].split(".h5")[0])
            save_dir = config["result_path"]
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            save_states_path = os.path.join(
                save_dir,
                'epoch_{0}_val_loss_{1:.4f}_WTdice_{2:.4f}_AVGDice:{3:.4f}.pth'
                .format(i, val_loss, WT_dice, AVG_dice))
            if config["cuda_devices"] is not None:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            states = {
                'epoch': i,
                'state_dict': state_dict,
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_states_path)
            save_model_path = os.path.join(save_dir, "best_model.pth")
            if os.path.exists(save_model_path):
                os.system("rm " + save_model_path)
            torch.save(model, save_model_path)
        print(
            "batch {0:d} finished, validation loss:{1:.4f}; WTDice:{2:.4f}; AVGDice:{3:.4f}"
            .format(i, val_loss, WT_dice, AVG_dice))
Esempio n. 3
0
def main():
    # Print config
    print(config)

    # init or load model
    model = NvNet(config=config)

    optimizer = optim.Adam(model.parameters(), 
                           lr=config["initial_learning_rate"],
                           weight_decay = config["L2_norm"])
    start_epoch = 1

    if config["VAE_enable"]:
        loss_function = CombinedLoss(k1=config["loss_k1_weight"], k2=config["loss_k2_weight"])
    else:
        loss_function = SoftDiceLoss()
    # data_generator
    print("Loading BraTS dataset...")
    training_data = BratsDataset(phase="train", config=config)
    validation_data = BratsDataset(phase="validate", config=config)

    train_logger = Logger(model_name=config["model_file"],header=['epoch', 'loss', 'acc', 'lr'])

    if config["cuda_devices"] is not None:
        #gpu_list = list(range(0, 2))
        #model = nn.DataParallel(model, gpu_list)  # multi-gpu training
        model = model.cuda()
        loss_function = loss_function.cuda() 
#         model = model.to(device=device)  # move the model parameters to CPU/GPU
#         loss_function = loss_function.to(device=device)
        
    if not config["overwrite"] and config["saved_model_file"] is not None:
        if not os.path.exists(config["saved_model_file"]):
            raise Exception("Invalid model path!")
        model, start_epoch, optimizer = load_old_model(model, optimizer, saved_model_path=config["saved_model_file"])    
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["lr_decay"], patience=config["patience"])
    
    #model = torch.load("checkpoint_models/run0/best_model_file_24.pth")
    
    print("training on label:{}".format(config["labels"]))
    max_val_acc = 0.
    for i in range(start_epoch,config["epochs"]):
        train_epoch(epoch=i, 
                    data_set=training_data, 
                    model=model,
                    criterion=loss_function, 
                    optimizer=optimizer, 
                    opt=config, 
                    logger=train_logger) 
        
        val_loss, val_acc = val_epoch(epoch=i, 
                  data_set=validation_data, 
                  model=model, 
                  criterion=loss_function, 
                  opt=config,
                  optimizer=optimizer, 
                  logger=train_logger)
        scheduler.step(val_loss)
        if config["checkpoint"] and val_acc >= max_val_acc - 0.10:
            max_val_acc = val_acc
            save_dir = os.path.join(config["result_path"], config["model_file"].split("/")[-1].split(".h5")[0])
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            save_states_path = os.path.join(save_dir,'epoch_{0}_val_loss_{1:.4f}_acc_{2:.4f}.pth'.format(i, val_loss, val_acc))
            states = {
                'epoch': i + 1,
                # 'state_dict': model.state_dict(),
                'encoder': model.encoder.state_dict(),
                'decoder': model.decoder.state_dict(),
                'vae': model.vae.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_states_path)
            save_model_path = os.path.join(save_dir, "best_model_file_{0}.pth".format(i))
            if os.path.exists(save_model_path):
                os.system("rm "+save_model_path)
            # torch.save(model, save_model_path)
            torch.save(states, save_model_path)
Esempio n. 4
0
def main():
    # convert input images into an hdf5 file
    if config["overwrite"] or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)
        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    # init or load model
    print("init model with input shape",config["input_shape"])
    model = NvNet(config=config)
    parameters = model.parameters()
    optimizer = optim.Adam(parameters, 
                           lr=config["initial_learning_rate"],
                           weight_decay = config["L2_norm"])
    start_epoch = 1
    if config["VAE_enable"]:
        loss_function = CombinedLoss(k1=config["loss_k1_weight"], k2=config["loss_k2_weight"])
    else:
        loss_function = SoftDiceLoss()
    # data_generator
    print("data generating")
    training_data = BratsDataset(phase="train", config=config)
    valildation_data = BratsDataset(phase="validate", config=config)

    
    train_logger = Logger(model_name=config["model_file"],header=['epoch', 'loss', 'acc', 'lr'])

    if config["cuda_devices"] is not None:
        # model = nn.DataParallel(model)  # multi-gpu training
        model = model.cuda()
        loss_function = loss_function.cuda()
        
    if not config["overwrite"] and config["saved_model_file"] is not None:
        if not os.path.exists(config["saved_model_file"]):
            raise Exception("Invalid model path!")
        model, start_epoch, optimizer = load_old_model(model, optimizer, saved_model_path=config["saved_model_file"])    
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["lr_decay"],patience=config["patience"])
    
    print("training on label:{}".format(config["labels"]))
    max_val_acc = 0.
    for i in range(start_epoch,config["epochs"]):
        train_epoch(epoch=i, 
                    data_set=training_data, 
                    model=model,
                    criterion=loss_function, 
                    optimizer=optimizer, 
                    opt=config, 
                    logger=train_logger) 
        
        val_loss, val_acc = val_epoch(epoch=i, 
                  data_set=valildation_data, 
                  model=model, 
                  criterion=loss_function, 
                  opt=config,
                  optimizer=optimizer, 
                  logger=train_logger)
        scheduler.step(val_loss)
        if config["checkpoint"] and val_acc > max_val_acc:
            max_val_acc = val_acc
            save_dir = os.path.join(config["result_path"], config["model_file"].split("/")[-1].split(".h5")[0])
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            save_states_path = os.path.join(save_dir,'epoch_{0}_val_loss_{1:.4f}_acc_{2:.4f}.pth'.format(i, val_loss, val_acc))
            states = {
                'epoch': i + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_states_path)
            save_model_path = os.path.join(save_dir, "best_model_file.pth")
            if os.path.exists(save_model_path):
                os.system("rm "+save_model_path)
            torch.save(model, save_model_path)