Exemplo n.º 1
0
def main():
    # init or load model
    print("init model with input shape",config["input_shape"])
    model = NvNet(config=config,input_shape=config["input_shape"], seg_outChans=config["n_labels"])
    parameters = model.parameters()
    optimizer = optim.Adam(parameters, 
                           lr=config["initial_learning_rate"],
                           weight_decay = config["L2_norm"])
    start_epoch = 1
    if config["VAE_enable"]:
        loss_function = CombinedLoss(k1=config["loss_k1_weight"], k2=config["loss_k2_weight"])
    else:
        loss_function = SoftDiceLoss()
    # data_generator
    print("data generating")
    training_data = BratsDataset(phase="train", config=config)
    train_loader = torch.utils.data.DataLoader(dataset=training_data, 
                                               batch_size=config["batch_size"], 
                                               shuffle=True, 
                                               pin_memory=True)
    valildation_data = BratsDataset(phase="validate", config=config)
    valildation_loader = torch.utils.data.DataLoader(dataset=valildation_data, 
                                               batch_size=config["batch_size"], 
                                               shuffle=True, 
                                               pin_memory=True)
    
    train_logger = Logger(model_name=config["model_file"],header=['epoch', 'loss', 'acc', 'lr'])

    if config["cuda_devices"] is not None:
        model = model.cuda()
        loss_function = loss_function.cuda()
        
    # if not config["overwrite"] and os.path.exists(config["model_file"]) or os.path.exists(config["saved_model_file"]):
    #    model, start_epoch, optimizer = load_old_model(model, optimizer, saved_model_path=config["saved_model_file"])
    
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["lr_decay"],patience=config["patience"])
    
    print("training on label:{}".format(config["labels"]))    
    for i in range(start_epoch,config["epochs"]):
        train_epoch(epoch=i, 
                    data_loader=train_loader, 
                    model=model,
                    model_name=config["model_file"], 
                    criterion=loss_function, 
                    optimizer=optimizer, 
                    opt=config, 
                    epoch_logger=train_logger) 
        
        val_loss = val_epoch(epoch=i, 
                  data_loader=valildation_loader, 
                  model=model, 
                  criterion=loss_function, 
                  opt=config,
                  optimizer=optimizer, 
                  logger=train_logger)
        scheduler.step(val_loss)
Exemplo n.º 2
0
def init_model_from_states(config):

    print("Init model...")
    model = NvNet(config=config)
    if config["cuda_devices"] is not None:
        if num_gpu > 0:
            model = torch.nn.DataParallel(model)  # multi-gpu inference
        model = model.cuda()
    checkpoint = torch.load(config['saved_model_path'], map_location='cpu')
    state_dict = checkpoint["state_dict"]
    if not config["load_from_data_parallel"]:
        model.load_state_dict(state_dict)
    else:
        from collections import OrderedDict  # Load state_dict from checkpoint model trained by multi-gpu
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if not "vae" in k:  # disable the vae path
                if "module." in k:
                    new_state_dict[k] = v
                # name = k[7:]
                else:
                    name = "module." + k  # fix the bug of missing keywords caused by data parallel
                    new_state_dict[name] = v

        model.load_state_dict(new_state_dict)

    return model
Exemplo n.º 3
0
def init_model_from_states(config):
    print("Init model...")
    model = NvNet(config=config)

    if config["cuda_devices"] is not None:
        # model = torch.nn.DataParallel(model)   # multi-gpu training
        model = model.cuda()
    checkpoint = torch.load(config["best_model_file"])  
    state_dict = checkpoint["state_dict"]
    if not config["load_from_data_parallel"]:
        model.load_state_dict(state_dict)
    else:
        from collections import OrderedDict     # Load state_dict from checkpoint model trained by multi-gpu 
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]
            new_state_dict[name] = v
            
        model.load_state_dict(new_state_dict)

    return model
Exemplo n.º 4
0
def main():
    # convert input images into an hdf5 file
    if config["overwrite"] or not os.path.exists(config["data_file"]):
        training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)
        write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
                           subject_ids=subject_ids)

    # init or load model
    print("init model with input shape",config["input_shape"])
    model = NvNet(config=config)
    parameters = model.parameters()
    optimizer = optim.Adam(parameters, 
                           lr=config["initial_learning_rate"],
                           weight_decay = config["L2_norm"])
    start_epoch = 1
    if config["VAE_enable"]:
        loss_function = CombinedLoss(k1=config["loss_k1_weight"], k2=config["loss_k2_weight"])
    else:
        loss_function = SoftDiceLoss()
    # data_generator
    print("data generating")
    training_data = BratsDataset(phase="train", config=config)
    valildation_data = BratsDataset(phase="validate", config=config)

    
    train_logger = Logger(model_name=config["model_file"],header=['epoch', 'loss', 'acc', 'lr'])

    if config["cuda_devices"] is not None:
        # model = nn.DataParallel(model)  # multi-gpu training
        model = model.cuda()
        loss_function = loss_function.cuda()
        
    if not config["overwrite"] and config["saved_model_file"] is not None:
        if not os.path.exists(config["saved_model_file"]):
            raise Exception("Invalid model path!")
        model, start_epoch, optimizer = load_old_model(model, optimizer, saved_model_path=config["saved_model_file"])    
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["lr_decay"],patience=config["patience"])
    
    print("training on label:{}".format(config["labels"]))
    max_val_acc = 0.
    for i in range(start_epoch,config["epochs"]):
        train_epoch(epoch=i, 
                    data_set=training_data, 
                    model=model,
                    criterion=loss_function, 
                    optimizer=optimizer, 
                    opt=config, 
                    logger=train_logger) 
        
        val_loss, val_acc = val_epoch(epoch=i, 
                  data_set=valildation_data, 
                  model=model, 
                  criterion=loss_function, 
                  opt=config,
                  optimizer=optimizer, 
                  logger=train_logger)
        scheduler.step(val_loss)
        if config["checkpoint"] and val_acc > max_val_acc:
            max_val_acc = val_acc
            save_dir = os.path.join(config["result_path"], config["model_file"].split("/")[-1].split(".h5")[0])
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            save_states_path = os.path.join(save_dir,'epoch_{0}_val_loss_{1:.4f}_acc_{2:.4f}.pth'.format(i, val_loss, val_acc))
            states = {
                'epoch': i + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_states_path)
            save_model_path = os.path.join(save_dir, "best_model_file.pth")
            if os.path.exists(save_model_path):
                os.system("rm "+save_model_path)
            torch.save(model, save_model_path)
Exemplo n.º 5
0
def main():
    # init or load model
    print("init model with input shape", config["input_shape"])
    model = NvNet(config=config)
    #model = MiniNvNet(config=config)
    parameters = model.parameters()
    optimizer = optim.Adam(parameters,
                           lr=config["initial_learning_rate"],
                           weight_decay=config["L2_norm"])
    start_epoch = 1

    if config["VAE_enable"]:
        loss_function = CombinedLoss(k1=config["loss_k1_weight"],
                                     k2=config["loss_k2_weight"])
    else:
        loss_function = SoftDiceLoss()
    # data_generator
    print("data generating")
    training_data = StanfordDataset(phase="train", config=config)
    validation_data = StanfordDataset(phase="validate", config=config)
    #     training_data = StanfordSmallDataset(phase="train", config=config, limit=5)
    #     validation_data = StanfordSmallDataset(phase="validate", config=config, limit=1)

    train_logger = Logger(model_name=config["model_file"],
                          header=['epoch', 'loss', 'acc', 'lr'])

    if config["cuda_devices"] is not None:
        #gpu_list = list(range(0, 2))
        #model = nn.DataParallel(model, gpu_list)  # multi-gpu training
        model = model.cuda()
        loss_function = loss_function.cuda()


#         model = model.to(device=device)  # move the model parameters to CPU/GPU
#         loss_function = loss_function.to(device=device)

    if not config["overwrite"] and config["saved_model_file"] is not None:
        if not os.path.exists(config["saved_model_file"]):
            raise Exception("Invalid model path!")
        model, start_epoch, optimizer = load_old_model(
            model, optimizer, saved_model_path=config["saved_model_file"])
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               factor=config["lr_decay"],
                                               patience=config["patience"])

    #model = torch.load("checkpoint_models/run0/best_model_file_24.pth")

    print("training on label:{}".format(config["labels"]))
    max_val_acc = 0.
    for i in range(start_epoch, config["epochs"]):
        train_epoch(epoch=i,
                    data_set=training_data,
                    model=model,
                    criterion=loss_function,
                    optimizer=optimizer,
                    opt=config,
                    logger=train_logger)

        val_loss, val_acc = val_epoch(epoch=i,
                                      data_set=validation_data,
                                      model=model,
                                      criterion=loss_function,
                                      opt=config,
                                      optimizer=optimizer,
                                      logger=train_logger)
        scheduler.step(val_loss)
        if config["checkpoint"] and val_acc >= max_val_acc - 0.10:  #0.01:
            max_val_acc = val_acc
            save_dir = os.path.join(
                config["result_path"],
                config["model_file"].split("/")[-1].split(".h5")[0])
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            save_states_path = os.path.join(
                save_dir, 'epoch_{0}_val_loss_{1:.4f}_acc_{2:.4f}.pth'.format(
                    i, val_loss, val_acc))
            states = {
                'epoch': i + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_states_path)
            save_model_path = os.path.join(save_dir,
                                           "best_model_file_{0}.pth".format(i))
            if os.path.exists(save_model_path):
                os.system("rm " + save_model_path)
            torch.save(model, save_model_path)