示例#1
0
文件: train.py 项目: TianyiYe0322/_1
def train():
    model = Unet(1, 1).to(device)
    batch_size = BATCH_SIZE
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    liver_dataset = LiverDataset("./data/membrane/train/image",
                                 "./data/membrane/train/label",
                                 transform=x_transforms,
                                 target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=0)
    train_model(model, criterion, optimizer, dataloaders)
示例#2
0
文件: train.py 项目: TianyiYe0322/_1
def test():
    model = Unet(1, 1)
    #model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
    liver_dataset = LiverDataset("./data/membrane/test/image",
                                 "./data/membrane/test/predict",
                                 transform=x_transforms,
                                 target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        for x, _ in dataloaders:
            y = model(x)
            img_y = torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
示例#3
0
def get_resnet34(input_shape=(768, 768, 3),
                 loss=dice_coef_loss(),
                 n_class=1,
                 optimizer=Adam(lr=1e-4, decay=0.0),
                 use_loss_weights=False):
    from unet_model import Unet
    model = Unet(backbone_name='resnet34',
                 input_shape=input_shape,
                 encoder_weights='imagenet')

    model.compile(optimizer=optimizer,
                  loss=loss,
                  metrics=[
                      focal_loss(use_loss_weights=use_loss_weights,
                                 ignore_loss=True),
                      bce(use_loss_weights=use_loss_weights),
                      dice_coef_loss(use_loss_weights=use_loss_weights,
                                     ignore_loss=True),
                      dice_coef(use_loss_weights=use_loss_weights,
                                ignore_loss=True)
                  ])

    return model
示例#4
0
def train(args):
    #vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    #model = FCN8s(pretrained_net=vgg_model, n_class=1).to(device)
    model = Unet(1, 1).to(device)
    #model = R2AttU_Net().to(device)
    #model = AttU_Net().to(device)
    #model = R2U_Net().to(device)
    #model = U_Net().to(device)
    #model = ResNetUNet(1).to(device)
    batch_size = args.batch_size
    #criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters())
    criterion = nn.BCEWithLogitsLoss()
    #optimizer = optim.RMSprop(model.parameters(), lr=1e-4, momentum=0, weight_decay=1e-5)
    #optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    liver_dataset = LiverDataset(
        "/home/cvlab04/Desktop/Code/Medical/u_net_liver/data/train/",
        transform=x_transforms,
        target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=6)
    train_model(model, criterion, optimizer, dataloaders)
示例#5
0
def check(args):
    model = Unet(1, 1)
    #model = R2AttU_Net()
    #model = AttU_Net()
    #model = U_Net()
    #vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    #model = FCN8s(pretrained_net=vgg_model, n_class=1)
    model.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
    #liver_dataset = LiverDataset("/home/cvlab04/Desktop/Code/Medical/u_net_liver/data/val/", transform=x_transforms,target_transform=y_transforms)
    #dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    import PIL.Image as Image
    img = Image.open(
        '/home/cvlab04/Desktop/Code/Medical/u_net_liver/check/train/A001-2_instance-47.jpeg'
    )
    #img = Image.open('/home/cvlab04/Desktop/Code/Medical/u_net_liver/A001-23230277-27.jpeg').convert('RGB')
    img = x_transforms(img)
    img = img.view(1, 1, 512, 512)
    #img = img.view(1,1,64,64)
    #img = img.view(1,3,512,512)
    #img = img.to(device=device, dtype=torch.float32)
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        y = model(img)
        y = torch.sigmoid(y)
        y = y.squeeze(0)
        print(y)
        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(512),
            transforms.ToTensor()
        ])

        #y = tf(y.cpu())
        #print(y)
        img_y = y.squeeze().cpu().numpy()
        print(img_y)
        img_y = (img_y > 0.3).astype(
            np.uint8
        )  #0.3 / 0.01 for unet #3e-4 for r2attunet #3e-1 for unet-trans
        print(img_y)
        im = Image.fromarray((img_y * 255).astype(np.uint8))
        im.save(
            "/home/cvlab04/Desktop/Code/Medical/u_net_liver/check/result/Threshold03_U_Net_transpose_25_epoch_A001-2_instance-47.png"
        )

        plt.imshow(img_y, plt.cm.gray)
        plt.pause(2)
        plt.show()
示例#6
0
def test(args):
    model = Unet(1, 1)
    #model = R2AttU_Net()
    #model = U_Net()
    model.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
    #liver_dataset = LiverDataset("/home/cvlab04/Desktop/Code/Medical/u_net_liver/data/val/", transform=x_transforms,target_transform=y_transforms)
    liver_dataset = LiverDataset(
        "/home/cvlab04/Desktop/Code/Medical/u_net_liver/data/train/",
        transform=x_transforms,
        target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=6)
    model.eval()
    import matplotlib.pyplot as plt
    #plt.ion()
    count = 0
    count_sum = 0.
    dice_loss = 0.
    with torch.no_grad():
        for x, labels in dataloaders:
            count += 1
            print("batch:", count)
            y = model(x)
            img_y = torch.squeeze(y).numpy()
            img_y = (img_y > 0.3).astype(np.uint8)
            img_y = img_y.flatten()
            count_predict = np.count_nonzero(img_y > 0)
            #print("predict pixel:   ",count_predict)
            true = torch.squeeze(labels).numpy()
            true = true.flatten()
            count_true = np.count_nonzero(true > 0)
            #print("true pixel:   ",count_true)
            ans = 0
            '''
            for i in range(len(img_y)):
                for j in range(len(img_y)):
                    if img_y[i][j]>0 and true[i][j]>0:
                        ans+=1
            '''
            ans = np.count_nonzero(img_y * true > 0)
            dice_loss = (2 * ans + 0.0001) / (count_predict + count_true +
                                              0.0001)
            print("dice_loss:", dice_loss)

            count_sum += (dice_loss)

            #plt.imshow(img_y)
            #plt.pause(1)
        #plt.show()
        print("Final_Dice_Loss:", count_sum / count)
        return sigmoid_out

###############################################################################

if is_gpu_mode:
    ones_label = Variable(torch.ones(BATCH_SIZE).cuda())
    zeros_label = Variable(torch.zeros(BATCH_SIZE).cuda())
else:
    ones_label = Variable(torch.ones(BATCH_SIZE))
    zeros_label = Variable(torch.zeros(BATCH_SIZE))

if __name__ == "__main__":
    print 'main'

    gen_model_a = Unet()
    gen_model_b = Unet()
    disc_model_a = Discriminator()
    disc_model_b = Discriminator()

    if is_gpu_mode:
        gen_model_a.cuda()
        gen_model_b.cuda()
        disc_model_a.cuda()
        disc_model_b.cuda()
        # gen_model = torch.nn.DataParallel(gen_model).cuda()
        # disc_model = torch.nn.DataParallel(disc_model).cuda()

    if ENABLE_TRANSFER_LEARNING:
        # load the saved checkpoints for hair semantic segmentation
        gen_model_a.load_state_dict(torch.load('/home1/irteamsu/rklee/TheIllusionsLibraries/PyTorch-practice/tiramisu-fcdensenet103/models/tiramisu_lfw_added_zero_centr_lr_0_0002_iter_1870000.pt'))
                "./checkpoint/Unet_epoch{}_loss{:.4f}_retina.model".format(
                    str(epoch + 1).zfill(5), epoch_avg_loss))
    return net


if __name__ == "__main__":

    if not os.path.exists("./checkpoint"):
        os.mkdir("./checkpoint")
    if not os.path.exists("./datasets"):
        os.mkdir("./datasets")
    if not os.path.exists("./datasets/training"):
        os.mkdir("./datasets/training")

    # set parameters for training
    LR = 0.1
    EPOCHS = 500
    BATCH_SIZE = 8
    SAVE_EVERY = 20
    EVAL_EVERY = 30

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    unet_ins = Unet(img_ch=3, isDeconv=True, isBN=True)
    unet_ins.to(device)
    trained_unet = model_train(unet_ins,
                               batch_size=BATCH_SIZE,
                               lr=LR,
                               epochs=EPOCHS,
                               save_every=SAVE_EVERY,
                               eval_every=EVAL_EVERY)
            start_id = ite * batch_size
            bat_img = torch.Tensor(x_tensor[start_id : , :, :, :])
            bat_label = torch.Tensor(y_tensor[start_id : , 0: 1, :, :])
            #bat_mask_2ch = torch.Tensor(m_tensor[start_id : end_id, :, :, :])
            bat_mask = torch.Tensor(m_tensor[start_id : , 0: 1, :, :])
        bat_pred = net(bat_img)
        bat_pred_class = (bat_pred > 0.5).float() * bat_mask
        eval_print_metrics(bat_label, bat_pred, bat_mask)
        # plt.imshow(bat_pred[0,0,:,:].detach().numpy(), cmap='jet')#, vmin=0, vmax=1)
        # plt.colorbar()
        # plt.show()
        #bat_pred_class = bat_pred.detach() * bat_mask
        paste_and_save(bat_img, bat_label, bat_pred_class, batch_size, ite + 1)

    return


if __name__ == "__main__":
    if not os.path.exists("./pred_imgs"):
        os.mkdir("./pred_imgs")
    if not os.path.exists("./datasets/test"):
        os.mkdir("./datasets/test")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    selected_model = glob("./checkpoint/Unet_epoch*.model")[-1]
    print("[*] Selected model for testing: {} ".format(selected_model))
    unet_ins = Unet(img_ch=3, isDeconv=True, isBN=True)
    unet_ins.load_state_dict(torch.load(selected_model))
    unet_ins.to(device)
    model_test(unet_ins, batch_size=2)
    
示例#10
0
def train_model(args, train_loader, val_loader, epochs, device, tol):
    # Number of channels in the training images. For color images this is 3
    nc = 1
    # Size of z latent vector (i.e. size of generator input)
    nz = 100
    # Size of feature maps in generator
    ngf = 64
    # create model, optimizer and criterion
    # R
    model_R = Unet(in_chans=1,
                   out_chans=1,
                   chans=64,
                   num_pool_layers=5,
                   drop_prob=0,
                  )
    optimizer_R = torch.optim.AdamW(model_R.parameters(), lr=args.learning_rate)
    # G
    model_G = Generator(nc, nz, ngf)
    optimizer_G = torch.optim.AdamW(model_G.parameters(), lr=1e-4)
    # use multiple GPUs
    if torch.cuda.device_count() >= 1:
        print("Let's use", torch.cuda.device_count(), "GPUs")
        model_R = nn.DataParallel(model_R)
        model_G = nn.DataParallel(model_G)
    model_R.to(device)
    model_G.to(device)
    # set objects for storing metrics
    tr_losses_R = []
    tr_losses_G = []
    val_losses = []
    tr_ssims = []
    val_ssims = []
    tr_psnrs = []
    val_psnrs = []
    tr_nmses = []
    val_nmses = []
    alpha_1 = args.alpha_1
    alpha_2 = args.alpha_2
    # track history of validation loss to perform early stopping
    # set number of epochs to track
    tol = tol
    # Train model
    for epoch in range(1, epochs+1):
        # training
        loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
        print(f'Epoch {epoch}:')
        print('Train:')
        tr_loss, nb_tr_steps, ssim_val, tot_ssim, tot_psnr, tot_nmse, tot_loss_G, tot_loss_R = \
            0, 0, 0, 0, 0, 0, 0, 0
        model_R.train()
        model_G.train()
        for batch_idx, sample in loop:
            data, target = sample[0].unsqueeze(1).to(device), sample[1].unsqueeze(1).to(device)
            ### Generator
            # Create batch of latent vectors that we will use to visualize
            # the progression of the generator
            optimizer_G.zero_grad()
            latent_noise = torch.randn(data.shape[0], nz, 1, 1, device=device) * 1e-3
            perturbation = model_G(latent_noise)
            output_p = model_R(data + perturbation)
            # calculate regularization hinge loss
            epsilon = args.epsilon * args.batch_size
            # hinge loss
            hinge_loss = torch.clamp(torch.norm(perturbation)**2 - epsilon, min=0)
            # get best perturbation
            loss_G = -(ssim_criterion(output_p, target) * alpha_1 + hinge_loss * alpha_2)
            loss_G.backward()
            optimizer_G.step()
            ### End
            ### Reconstructor
            optimizer_R.zero_grad()
            output = model_R(data)
            perturbation = model_G(latent_noise)
            output_p = model_R(data + perturbation.detach())
            # calcualte regular loss
            loss_1 = ssim_criterion(output, target)
            # calculate loss with perturbation
            loss_2 = ssim_criterion(output_p, target)
            loss_R = loss_1 + loss_2 * alpha_1
            # calculate gradients
            loss_R.backward()
            # optimizer/scheduler "step"
            optimizer_R.step()
            ### End
            pred = output
            # calculate performance metrics
            ssim_tr = ssim(pred.squeeze(1).detach().cpu().numpy(), target.squeeze(1).detach().cpu().numpy())
            psnr_tr = psnr(pred.squeeze(1).detach().cpu().numpy(), target.squeeze(1).detach().cpu().numpy())
            nmse_tr = nmse(pred.squeeze(1).detach().cpu().numpy(), target.squeeze(1).detach().cpu().numpy())
            #print(ssim_tr)
            tot_ssim += ssim_tr
            tot_psnr += psnr_tr
            tot_nmse += nmse_tr
            tr_losses_G.append(loss_G)
            tr_losses_R.append(loss_R)
            nb_tr_steps += 1
            # update progress bar
            loop.set_description(f'Epoch [{epoch}/{epochs}]')
            loop.set_postfix(loss_G = loss_G.item(), loss_R = loss_R.item(), ssim = ssim_tr, pert = \
                             (torch.norm(perturbation)**2).item() / args.batch_size, psnr = psnr_tr, nmse = nmse_tr)
        tr_ssim = tot_ssim / nb_tr_steps
        tr_psnr = tot_psnr / nb_tr_steps
        tr_nmse = tot_nmse / nb_tr_steps
        tr_ssims.append(tr_ssim)
        tr_psnrs.append(tr_psnr)
        tr_nmses.append(tr_nmse)
        print(f'Train SSIM: {tr_ssim}')
        print(f'Train PSNR: {tr_psnr}')
        print(f'Train NMSE: {tr_nmse}')
        # validation
        model_R.eval()
        val_loss, nb_val_steps, ssim_val, tot_ssim, tot_psnr, tot_nmse = 0, 0, 0, 0, 0, 0
        print('Validation:')
        with torch.no_grad():
            for sample in val_loader:
                data, target = sample[0].unsqueeze(1).to(device), sample[1].unsqueeze(1).to(device)
                output = model_R(data)
                loss = ssim_criterion(output, target)
                pred = output
                # calculate performance metrics
                ssim_val = ssim(pred.squeeze(1).detach().cpu().numpy(), target.squeeze(1).detach().cpu().numpy())
                psnr_val = psnr(pred.squeeze(1).detach().cpu().numpy(), target.squeeze(1).detach().cpu().numpy())
                nmse_val = nmse(pred.squeeze(1).detach().cpu().numpy(), target.squeeze(1).detach().cpu().numpy())
                tot_ssim += ssim_val
                tot_psnr += psnr_val
                tot_nmse += nmse_val
                val_loss += loss.item()
                nb_val_steps += 1
        val_ssim = tot_ssim / nb_val_steps
        val_psnr = tot_psnr / nb_val_steps
        val_nmse = tot_nmse / nb_val_steps
        val_loss = val_loss / nb_val_steps
        val_losses.append(val_loss)
        val_ssims.append(val_ssim)
        val_nmses.append(val_nmse)
        val_psnrs.append(val_psnr)
        print(f'Validation SSIM: {val_ssim}')
        print(f'Validation PSNR: {val_psnr}')
        print(f'Validation NMSE: {val_nmse}')
        print(f'Validation Loss: {val_loss}')
        # check validation loss history for early stopping
        if len(val_losses) > tol:
            losses_diff_hist = []
            tracked_loss = val_losses[len(val_losses)-tol-1]
            # get last 'tol' tolerance index and calculate loss difference history
            for i in range(1, tol+1):
                losses_diff_hist.append(val_losses[len(val_losses)-i] - tracked_loss)
            print(losses_diff_hist)
            # if all histories are larger than or equal previous tracked loss, stop training
            # larger than 0 means the losses are not decreasing
            if sum([loss_diff >= 0 for loss_diff in losses_diff_hist]) == tol:
                print(sum([loss_diff >= 0 for loss_diff in losses_diff_hist]))
                break
    # save model
    torch.save(model_R.state_dict(), "unet_model_R.pt")
    torch.save(model_G.state_dict(), "unet_model_G.pt")
    return tr_ssims, tr_psnrs, tr_nmses, val_ssims, val_psnrs, val_nmses, tr_losses_G, tr_losses_R
# MODEL_SAVING_DIRECTORY = '/home/illusion/PycharmProjects/TheIllusionsLibraries/PyTorch-practice/GANs/models/'
# RESULT_IMAGE_DIRECTORY = '/home/illusion/PycharmProjects/TheIllusionsLibraries/PyTorch-practice/GANs/generate_imgs_simple_cyclegan/'

###################################################################################

if is_gpu_mode:
    ones_label = Variable(torch.ones(BATCH_SIZE).cuda())
    zeros_label = Variable(torch.zeros(BATCH_SIZE).cuda())
else:
    ones_label = Variable(torch.ones(BATCH_SIZE))
    zeros_label = Variable(torch.zeros(BATCH_SIZE))

if __name__ == "__main__":
    print 'main'

    gen_model_a = Unet()
    gen_model_b = Unet()
    disc_model_a = Discriminator()
    disc_model_b = Discriminator()

    if is_gpu_mode:
        gen_model_a.cuda()
        gen_model_b.cuda()
        disc_model_a.cuda()
        disc_model_b.cuda()
        '''
        gen_model_a = torch.nn.DataParallel(gen_model_a).cuda()
        gen_model_b = torch.nn.DataParallel(gen_model_b).cuda()
        disc_model_a = torch.nn.DataParallel(disc_model_a).cuda()
        disc_model_b = torch.nn.DataParallel(disc_model_b).cuda()
        '''