Exemplo n.º 1
0
def train(opt):
    # model
    model = cnet(nb_res=opt.resnet_num, num_classes=opt.num_classes)
    model = model.to(device)

    transform_train = transforms.Compose([
        transforms.Resize((opt.input_size, opt.input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    # 数据文件通过voc.py生成
    dic_data = torch.load('data.pth')

    train_dataset = CTDataset(opt=opt, data=dic_data['train'], transform=transform_train)
    val_dataset = CTDataset(opt=opt, data=dic_data['val'], transform=transform_train)
    train_dl = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, shuffle=True)
    val_dl = DataLoader(dataset=val_dataset, batch_size=opt.batch_size, num_workers=opt.num_workers)

    cerition_hm = FocalLoss()
    cerition_wh = RegL1Loss()
    cerition_reg = RegL1Loss()

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    min_loss, best_epoch = 1e7, 1
    for epoch in range(1, opt.max_epoch + 1):
        train_loss = train_epoch(epoch, model, train_dl, optimizer, cerition_hm, cerition_wh, cerition_reg)
        val_loss = val_epoch(model, val_dl, cerition_hm, cerition_wh, cerition_reg)
        print("Epoch%02d train_loss:%0.3e val_loss:%0.3e min_loss:%0.3e(%02d)" % (
            epoch, train_loss, val_loss, min_loss, best_epoch))
        if min_loss > val_loss:
            min_loss, best_epoch = val_loss, epoch
            torch.save(model.state_dict(), opt.ckpt)
Exemplo n.º 2
0
def main():
    args = get_args()
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    device = args.device

    logger = setup_logger("brdnet", args.save_dir, 0)
    logger.info("Using device {}".format(device))
    logger.info(args)

    val_data = CTDataset(data_path=args.val_path,
                         patch_n=None,
                         patch_size=None)
    val_loader = DataLoader(dataset=val_data,
                            batch_size=1,
                            shuffle=False,
                            num_workers=args.num_workers)
    model = BRDNet()
    model.load_state_dict(torch.load(args.pretrained))
    model.to(device)

    loss_func = nn.MSELoss()
    v_loss = val(val_loader, model, loss_func, 0, args, logger)

    logger.info('done')
Exemplo n.º 3
0
def main():
    args = get_args()
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    device = args.device

    logger = setup_logger("brdnet", args.save_dir, 0)
    logger.info("Using device {}".format(device))
    logger.info(args)

    train_data = CTDataset(data_path=args.train_path, patch_n=args.patch_n, patch_size=args.patch_size)
    val_data = CTDataset(data_path=args.val_path, patch_n=None, patch_size=None)
    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    val_loader = DataLoader(dataset=val_data, batch_size=1, shuffle=False, num_workers=args.num_workers)
    model = BRDNet()
    if args.pretrained != '':
        model.load_state_dict(torch.load(args.pretrained))
    model.to(device)

    optimizer = optim.Adam(model.parameters(), args.lr)
    lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, verbose=True, patience=6, min_lr=1E-7)

    loss_func = nn.MSELoss()

    best_loss = float('inf')
    for epoch in range(1, args.num_epochs + 1):
        logger.info('epoch: {}/{}'.format(epoch, args.num_epochs))
        t_loss = train(train_loader, model, loss_func, optimizer, epoch, args, logger)
        v_loss = val(val_loader, model, loss_func, epoch, args, logger)
        lr_scheduler.step(v_loss)

        if v_loss < best_loss:
            best_loss = v_loss
            torch.save(model.state_dict(), '{}/model_best.pth'.format(args.save_dir))

        if (epoch) % args.save_interval == 0:
            torch.save(model.state_dict(), "{}/model_checkpoint_{}.pth".format(args.save_dir, epoch))
        #if epoch > 10: break

    logger.info('done')
Exemplo n.º 4
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epoch",
                        type=int,
                        default=0,
                        help="epoch to start training from")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=200,
                        help="number of epochs of training")
    parser.add_argument("--dataset_name",
                        type=str,
                        default="leftkidney_3d",
                        help="name of the dataset")
    parser.add_argument("--batch_size",
                        type=int,
                        default=1,
                        help="size of the batches")
    parser.add_argument("--glr",
                        type=float,
                        default=0.0002,
                        help="adam: generator learning rate")
    parser.add_argument("--dlr",
                        type=float,
                        default=0.0002,
                        help="adam: discriminator learning rate")
    parser.add_argument("--b1",
                        type=float,
                        default=0.5,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2",
                        type=float,
                        default=0.999,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--decay_epoch",
                        type=int,
                        default=100,
                        help="epoch from which to start lr decay")
    parser.add_argument(
        "--n_cpu",
        type=int,
        default=8,
        help="number of cpu threads to use during batch generation")  #8
    parser.add_argument("--img_height",
                        type=int,
                        default=128,
                        help="size of image height")
    parser.add_argument("--img_width",
                        type=int,
                        default=128,
                        help="size of image width")
    parser.add_argument("--img_depth",
                        type=int,
                        default=128,
                        help="size of image depth")
    parser.add_argument("--channels",
                        type=int,
                        default=1,
                        help="number of image channels")
    parser.add_argument("--disc_update",
                        type=int,
                        default=5,
                        help="only update discriminator every n iter")
    parser.add_argument("--d_threshold",
                        type=int,
                        default=.8,
                        help="discriminator threshold")
    parser.add_argument("--threshold",
                        type=int,
                        default=-1,
                        help="threshold during sampling, -1: No thresholding")
    parser.add_argument(
        "--sample_interval",
        type=int,
        default=1,
        help="interval between sampling of images from generators")
    parser.add_argument("--checkpoint_interval",
                        type=int,
                        default=50,
                        help="interval between model checkpoints")  #-1
    opt = parser.parse_args()
    print(opt)

    os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
    os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)

    cuda = True if torch.cuda.is_available() else False

    # Loss functions
    criterion_GAN = torch.nn.MSELoss()
    criterion_voxelwise = diceloss()

    # Loss weight of L1 voxel-wise loss between translated image and real image
    lambda_voxel = 100

    # Calculate output of image discriminator (PatchGAN)
    patch = (1, opt.img_height // 2**4, opt.img_width // 2**4,
             opt.img_depth // 2**4)

    # Initialize generator and discriminator
    generator = GeneratorUNet()
    discriminator = Discriminator()

    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        criterion_GAN.cuda()
        criterion_voxelwise.cuda()

    if opt.epoch != 0:
        # Load pretrained models
        generator.load_state_dict(
            torch.load("saved_models/%s/generator_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        discriminator.load_state_dict(
            torch.load("saved_models/%s/discriminator_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
    else:
        # Initialize weights
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.glr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.dlr,
                                   betas=(opt.b1, opt.b2))

    # Configure dataloaders
    transforms_ = transforms.Compose([
        # transforms.Resize((opt.img_height, opt.img_width, opt.img_depth), Image.BICUBIC),
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataloader = DataLoader(
        CTDataset("data/%s/train/" % opt.dataset_name,
                  transforms_=transforms_),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )

    val_dataloader = DataLoader(
        CTDataset("data/%s/test/" % opt.dataset_name, transforms_=transforms_),
        batch_size=1,
        shuffle=True,
        num_workers=1,
    )

    # Tensor type
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    def sample_voxel_volumes(epoch):
        """Saves a generated sample from the validation set"""
        imgs = next(iter(val_dataloader))
        real_A = Variable(imgs["A"].unsqueeze_(1).type(Tensor))
        real_B = Variable(imgs["B"].unsqueeze_(1).type(Tensor))
        fake_B = generator(real_A)

        # convert to numpy arrays
        real_A = real_A.cpu().detach().numpy()
        real_B = real_B.cpu().detach().numpy()
        fake_B = fake_B.cpu().detach().numpy()

        image_folder = "images/%s/epoch_%s_" % (opt.dataset_name, epoch)

        hf = h5py.File(image_folder + 'real_A.vox', 'w')
        hf.create_dataset('data', data=real_A, compression='gzip')

        hf1 = h5py.File(image_folder + 'real_B.vox', 'w')
        hf1.create_dataset('data', data=real_B, compression='gzip')

        hf2 = h5py.File(image_folder + 'fake_B.vox', 'w')
        hf2.create_dataset('data', data=fake_B, compression='gzip')

    # ----------
    #  Training
    # ----------

    prev_time = time.time()
    discriminator_update = 'False'
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Model inputs
            real_A = Variable(batch["A"].unsqueeze_(1).type(Tensor))
            real_B = Variable(batch["B"].unsqueeze_(1).type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), *patch))),
                             requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))),
                            requires_grad=False)

            # ---------------------
            #  Train Discriminator, only update every disc_update batches
            # ---------------------
            # Real loss
            fake_B = generator(real_A)
            pred_real = discriminator(real_B, real_A)
            loss_real = criterion_GAN(pred_real, valid)

            # Fake loss
            pred_fake = discriminator(fake_B.detach(), real_A)
            loss_fake = criterion_GAN(pred_fake, fake)
            # Total loss
            loss_D = 0.5 * (loss_real + loss_fake)

            d_real_acu = torch.ge(pred_real.squeeze(), 0.5).float()
            d_fake_acu = torch.le(pred_fake.squeeze(), 0.5).float()
            d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

            if d_total_acu <= opt.d_threshold:
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                discriminator_update = 'True'

            # ------------------
            #  Train Generators
            # ------------------
            optimizer_D.zero_grad()
            optimizer_G.zero_grad()

            # GAN loss
            fake_B = generator(real_A)
            pred_fake = discriminator(fake_B, real_A)
            loss_GAN = criterion_GAN(pred_fake, valid)
            # Voxel-wise loss
            loss_voxel = criterion_voxelwise(fake_B, real_B)

            # Total loss
            loss_G = loss_GAN + lambda_voxel * loss_voxel

            loss_G.backward()

            optimizer_G.step()

            batches_done = epoch * len(dataloader) + i

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f, D accuracy: %f, D update: %s] [G loss: %f, voxel: %f, adv: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    d_total_acu,
                    discriminator_update,
                    loss_G.item(),
                    loss_voxel.item(),
                    loss_GAN.item(),
                    time_left,
                ))
            # If at sample interval save image
            if batches_done % (opt.sample_interval * len(dataloader)) == 0:
                sample_voxel_volumes(epoch)
                print('*****volumes sampled*****')

            discriminator_update = 'False'

        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(
                generator.state_dict(),
                "saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                discriminator.state_dict(),
                "saved_models/%s/discriminator_%d.pth" %
                (opt.dataset_name, epoch))
from utils import save_check_point, load_check_point, save_history, load_history
from tqdm import tqdm
from visualize import plot_3d
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = './model/Best_model_14_0.0798.pth.tar'
batch_size = 1

# Load model checkpoint that is to be evaluated
model, _, _ = load_check_point(checkpoint)
model.double().to(device)
loss_function = dice_loss
model.eval()
# build data loader
data = CTDataset('./data/DCM_Test.json')
test_loader = CTDataLoader(data, 1, batch_size=batch_size, mode="testing")

# evaluate
result = []
area = []
with torch.no_grad():
    for i, (data, target,
            address) in enumerate(tqdm(test_loader, desc='Evaluating')):
        output = model(data.to(device))
        p, t = np.round(np.array(output.cpu())), np.round(
            np.array(target.cpu()))
        loss = loss_function(output.to(device), target.to(device))
        plot_3d(p.reshape(p.shape[2:]), t.reshape(t.shape[2:]), address,
                loss.item())
        # print(np.sum(p),np.sum(t))
import torch.optim as optim
from dataset import CTDataset, CTDataLoader
from flexible_model import dice_loss, Unet, Flex_Unet
from utils import save_check_point, load_check_point, save_history, load_history
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# parameters
best_loss = 100
early_stop = 5
not_improve = 0
batch_size = 1
down_scale = 3
first_ch = 6

# data loader
data = CTDataset('./data/Train.json')
train_loader, val_loader = CTDataLoader(data,
                                        0.75,
                                        batch_size=batch_size,
                                        mode="training")

# model, loss function, optimizer initialization
model = Flex_Unet(down_scale, first_ch)
optimizer = optim.Adam(model.parameters())

# checkpoint_path = ''
# model, optimizer, start_epoch = load_check_point(checkpoint_path)
print("Number of paras:",
      sum(p.numel() for p in model.parameters() if p.requires_grad))
model.double().to(device)
loss_function = dice_loss
Exemplo n.º 7
0
def test():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epoch",
                        type=int,
                        default=200,
                        help="epoch to start training from")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=200,
                        help="number of epochs of training")
    parser.add_argument("--dataset_name",
                        type=str,
                        default="leftkidney_3d",
                        help="name of the dataset")
    parser.add_argument("--batch_size",
                        type=int,
                        default=1,
                        help="size of the batches")
    parser.add_argument("--glr",
                        type=float,
                        default=0.0002,
                        help="adam: generator learning rate")
    parser.add_argument("--dlr",
                        type=float,
                        default=0.0002,
                        help="adam: discriminator learning rate")
    parser.add_argument("--b1",
                        type=float,
                        default=0.5,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2",
                        type=float,
                        default=0.999,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--decay_epoch",
                        type=int,
                        default=100,
                        help="epoch from which to start lr decay")
    parser.add_argument(
        "--n_cpu",
        type=int,
        default=8,
        help="number of cpu threads to use during batch generation")  #8
    parser.add_argument("--img_height",
                        type=int,
                        default=128,
                        help="size of image height")
    parser.add_argument("--img_width",
                        type=int,
                        default=128,
                        help="size of image width")
    parser.add_argument("--img_depth",
                        type=int,
                        default=128,
                        help="size of image depth")
    parser.add_argument("--channels",
                        type=int,
                        default=1,
                        help="number of image channels")
    parser.add_argument("--disc_update",
                        type=int,
                        default=5,
                        help="only update discriminator every n iter")
    parser.add_argument("--d_threshold",
                        type=int,
                        default=.8,
                        help="discriminator threshold")
    parser.add_argument("--threshold",
                        type=int,
                        default=-1,
                        help="threshold during sampling, -1: No thresholding")
    parser.add_argument(
        "--sample_interval",
        type=int,
        default=1,
        help="interval between sampling of images from generators")
    parser.add_argument("--checkpoint_interval",
                        type=int,
                        default=50,
                        help="interval between model checkpoints")  #-1
    opt = parser.parse_args()
    print(opt)

    #os.makedirs("output/%s" % opt.dataset_name, exist_ok=True)
    #os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)

    cuda = True if torch.cuda.is_available() else False

    # Loss functions
    criterion_GAN = torch.nn.MSELoss()
    criterion_voxelwise = diceloss()

    # Loss weight of L1 voxel-wise loss between translated image and real image
    lambda_voxel = 100

    # Calculate output of image discriminator (PatchGAN)
    patch = (1, opt.img_height // 2**4, opt.img_width // 2**4,
             opt.img_depth // 2**4)

    # Initialize generator and discriminator
    generator = GeneratorUNet()
    discriminator = Discriminator()

    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        criterion_GAN.cuda()
        criterion_voxelwise.cuda()

    #if opt.epoch != 0:
    # Load pretrained models
    #generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))
    #discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))
    generator.load_state_dict(
        torch.load("model/generator_" + str(opt.epoch) + ".pth"))
    discriminator.load_state_dict(
        torch.load("model/discriminator_" + str(opt.epoch) + ".pth"))
    #else:
    # Initialize weights
    #generator.apply(weights_init_normal)
    #discriminator.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.glr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.dlr,
                                   betas=(opt.b1, opt.b2))

    # Configure dataloaders
    transforms_ = transforms.Compose([
        # transforms.Resize((opt.img_height, opt.img_width, opt.img_depth), Image.BICUBIC),
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    '''
    dataloader = DataLoader(
        CTDataset("../../data/%s/train/" % opt.dataset_name, transforms_=transforms_),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )
    '''

    val_dataloader = DataLoader(
        CTDataset("input/", transforms_=transforms_, isTest=True),
        batch_size=1,
        shuffle=False,
        num_workers=0,
    )

    # Tensor type
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    def write_binvox(data, path):
        data = np.rint(data).astype(np.uint8)
        dims = (opt.img_width, opt.img_height, opt.img_depth)  #data.shape
        translate = [0, 0, 0]
        scale = 1.0
        axis_order = 'xzy'
        v = binvox_rw.Voxels(data, dims, translate, scale, axis_order)

        with open(path, 'bw') as f:
            v.write(f)

    dataiter = iter(val_dataloader)

    def sample_voxel_volumes(index):
        imgs = dataiter.next()
        """Saves a generated sample from the validation set"""
        real_A = Variable(imgs["A"].unsqueeze_(1).type(Tensor))
        #real_B = Variable(imgs["B"].unsqueeze_(1).type(Tensor))
        fake_B = generator(real_A)

        # convert to numpy arrays
        real_A = real_A.cpu().detach().numpy()
        #real_B = real_B.cpu().detach().numpy()
        fake_B = fake_B.cpu().detach().numpy()

        image_folder = "output"  #/%s_%s_" % (opt.dataset_name, index)

        #write_binvox(real_A, image_folder + 'real_A.binvox')

        write_binvox(
            fake_B,
            os.path.join(image_folder,
                         os.path.basename(imgs["url"][0]) + "_fake.binvox"))

    for i, batch in enumerate(val_dataloader):
        sample_voxel_volumes(i)
        print('*****volume ' + str(i + 1) + '/' + str(len(val_dataloader)) +
              ' sampled*****')