コード例 #1
0
ファイル: main.py プロジェクト: manuelminca/ISMI-Frikandel
def main():
    # patches_file = "50_patches_dataset.h5"
    # hf = h5py.File(patches_file, 'r')
    # We obtain a list with all the IDs of the patches
    # all_groups = list(hf)
    # Dividing the dataset into train and validation. Shuffle has to be false otherwise the model might be trained
    # on what was previously validation set and validated on what was previously train set.
    # X_train, X_validation = train_test_split(all_groups, test_size=0.2, shuffle=False)
    # print(X_train, X_validation)

    # for testing
    datapath = "Data/"
    train_file = datapath + "patches_dataset_test.h5"
    val_file = datapath + "val250.h5"

    # Loader Parameters
    params = {'batch_size': 2, 'shuffle': False, 'num_workers': 0}

    train_dataset = PatchDataset(train_file, n_classes=3)
    print(len(train_dataset))
    val_dataset = PatchDataset(val_file, n_classes=3)
    print(len(val_dataset))

    train_loader = DataLoader(train_dataset, **params)
    val_loader = DataLoader(val_dataset, **params)

    loaders = {'train': train_loader, 'val': val_loader}

    # Model and param
    model = Modified3DUNet(in_channels=1, n_classes=3)
    optimizer = optim.Adam(model.parameters())
    max_epochs = 10

    # Median foreground percentage = 0.2 (= class 1,2)
    # Median cancer percentage = 0.01 (= class 2)
    # Median pancreas percentage = 0.2 - 0.01 = 0.19 (= class 1)
    # Median background percentage = 1-0.2 = 99.8 (=class 0)
    # [99.8, 0.19, 0.01] => corresponding class weights = [1, 525, 9980]
    # class_weights = torch.tensor([1., 525., 9980.])
    # loss_criterion = GeneralizedDiceLoss(weight=class_weights)
    # loss_criterion = WeightedCrossEntropyLoss(weight=class_weights)

    weights = [1, 100, 500]
    class_weights = torch.FloatTensor(weights)

    loss_criterion = nn.CrossEntropyLoss(weight=class_weights)

    # trainer = UNetTrainer(model, optimizer, loaders, max_epochs, loss_criterion=loss_criterion)
    # trainer.train()

    # Load from last epoch
    checkpoint_trainer = UNetTrainer.load_checkpoint(
        "WCEL_1_10_50_last_model",
        model,
        optimizer,
        loaders,
        max_epochs,
        loss_criterion=loss_criterion)
    pred = checkpoint_trainer.single_image_forward(val_dataset[0][0])
コード例 #2
0
def main():
    # setup environments and seeds
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    # setup networks
    #Network = getattr(models, args.net)
    #model = Network(**args.net_params)

    model = Modified3DUNet(in_channels=1, n_classes=2, base_n_filter=16)
    model = model.cuda()
    '''optimizer = getattr(torch.optim, args.opt)(
            model.parameters(), **args.opt_params)'''
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.001,
                                 weight_decay=0.0001)
    #optimizer = torch.optim.SGD(model.parameters(),lr = 0.1,momentum=0.9)

    criterion = getattr(criterions, args.criterion)
    msg = '-------------- New training session -----------------'
    msg += '\n' + str(args)
    logging.info(msg)
    num_gpus = len(args.gpu.split(','))
    args.batch_size *= num_gpus
    args.workers *= num_gpus
    args.opt_params['lr'] *= num_gpus
    # create dataloaders
    #Dataset = getattr(datasets, args.dataset)
    dset = cell_training('/home/tom/Modified-3D-UNet-Pytorch/PNAS/')
    train_loader = DataLoader(dset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)
    file_name_best = os.path.join(ckpts, 'cell/model_best.tar')
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer, epoch)
        # remember best lost and save checkpoint
        ckpt = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
            'train_loss': train_loss,
        }
        file_name = os.path.join(ckpts, 'model_last.tar')
        torch.save(ckpt, file_name)
        msg = 'Epoch: {:02d} Train loss {:.4f}'.format(epoch + 1, train_loss)
        logging.info(msg)