Exemplo n.º 1
0
        loss_D_A_self.backward()

        optimizer_D_A_self.step()

        real_A = real_A * 0.5 + 0.5
        real_B = real_B * 0.5 + 0.5
        fake_B = fake_B * 0.5 + 0.5

        if i % 10 == 0:
            # logger.log({'loss_G': loss_G, 'loss_G_GAN': (loss_GAN_A2B ),\
            #             'loss_style':loss_style,'loss_content':(content_loss_A), 'loss_D_self':(loss_D_A_self),
            #              'loss_D': (loss_D_A )},
            #             images={'real_A': real_A, 'real_B': real_B,  'fake_B': fake_B},
            #            heatmaps={'heatmap': diff})
            VIS.img(name="real_A", img_=real_A)
            VIS.img(name="real_B", img_=real_B)
            VIS.img(name="fake", img_=fake_B)

    # Save models checkpoints and early stop

    if loss_GAN_A2B < loss_best:
        loss_best = loss_GAN_A2B
        early_stop = 0
    elif loss_GAN_A2B > loss_best:
        early_stop += 1

    if early_stop >= opt.threshold:
        break

    torch.save(netG_A2B.state_dict(),
Exemplo n.º 2
0
def vessel_main():
    SHAPE = (448, 448)
    # ROOT = 'dataset/RIM-ONE/'
    ROOT = './dataset/DRIVE'
    NAME = 'log01_dink34-UNet' + ROOT.split('/')[-1]
    BATCHSIZE_PER_CARD = 8

    # net = UNet(n_channels=3, n_classes=2)

    viz = Visualizer(env="Vessel_Unet_from_scratch")

    solver = MyFrame(UNet, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

    dataset = ImageFolder(root_path=ROOT, datasets='DRIVE')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()
    no_optim = 0
    total_epoch = 300
    train_epoch_best_loss = 10000.
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0

        index = 0

        for img, mask in data_loader_iter:
            solver.set_input(img, mask)

            train_loss, pred = solver.optimize()

            train_epoch_loss += train_loss

            index = index + 1

            # if index % 10 == 0:
            #     # train_epoch_loss /= index
            #     # viz.plot(name='loss', y=train_epoch_loss)
            #     show_image = (img + 1.6) / 3.2 * 255.
            #     viz.img(name='images', img_=show_image[0, :, :, :])
            #     viz.img(name='labels', img_=mask[0, :, :, :])
            #     viz.img(name='prediction', img_=pred[0, :, :, :])

        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', SHAPE)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', SHAPE)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save('./weights/' + NAME + '.th')
        if no_optim > 20:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > 15:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()
def CE_Net_Train():
    NAME = 'CE-Net' + Constants.ROOT.split('/')[-1]

    # run the Visdom
    viz = Visualizer(env=NAME)

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    print("count", Constants.BATCHSIZE_PER_CARD)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD
    print("batchsize", batchsize)

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    dataset = ImageFolder(root_path=Constants.ROOT, datasets='Cell')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    dataset_val = ImageFolder(root_path='./test_data/DRIVE_dot_dash_training',
                              datasets='Cell')
    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size=8,
                                                  shuffle=True,
                                                  num_workers=4)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        index = 0

        for img, mask in data_loader_iter:
            # solver.load('./weights/' + NAME + '.th')
            # print("iterating the dataloader")
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # show the original images, predication and ground truth on the visdom.
        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        torchvision.utils.save_image(img[0, :, :, :],
                                     "images/image_" + str(epoch) + ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)
        torchvision.utils.save_image(mask[0, :, :, :],
                                     "images/mask_" + str(epoch) + ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)
        torchvision.utils.save_image(pred[0, :, :, :],
                                     "images/pred_" + str(epoch) + ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)

        # x = torch.tensor([[1,2,3],[4,5,6]], dtype = torch.uint8)
        # x = show_image[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'RGB')(x)
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_' + str(epoch) +  '.jpg')

        # x = mask[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'L')(x)
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_' + str(epoch) +  '.jpg')

        # x = pred[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'HSV')(x.detach().cpu().numpy())
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_' + str(epoch) +  '.jpg')
        # (x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_" + str(epoch) + ".png")
        # cv2.imwrite('imagename.jpg', x.detach().numpy().astype('uint8')).transpose(2,1,0)
        # x = mask[0,:,:,:]
        # # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_" + str(epoch) + ".png")
        # x = pred[0,:,:,:]
        # print(x.shape)
        # cv2.imwrite('imagename2.jpg', x.detach().numpy().astype('uint8'))

        # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_" + str(epoch) + ".png")
        print("saving images")
        print("Train_loss_for_all ", train_epoch_loss)
        print("length of (data_loader_iter) ", len(data_loader_iter))
        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', Constants.Image_size)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', Constants.Image_size)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            print("Saving the Weights")
            solver.save('./weights/' + NAME + '.th')
            if epoch % 100 == 0:
                solver.save('./weights/' + NAME + str(epoch) + '.th')
        if no_optim > Constants.NUM_EARLY_STOP:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

        if (epoch % 1 == 0):
            # validation save image
            print('in VALIDATION')
            # for
            data_loader_iter_val = iter(data_loader_val)
            train_epoch_loss = 0
            index = 0

            for img, mask in data_loader_iter_val:
                # solver.load('./weights/' + NAME + '.th')
                solver.set_input(img, mask)

                train_loss, pred = solver.optimize_test()
                train_epoch_loss += train_loss
                index = index + 1
                # torchvision.utils.save_image(img[0, :, :, :], "test_data/results2/image_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0)
                # torchvision.utils.save_image(mask[0, :, :, :], "test_data/results2/mask_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0)
                # torchvision.utils.save_image(pred[0, :, :, :], "test_data/results2/pred_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0)
            print("Train_loss_for_all ", train_epoch_loss)
            print("length of (data_loader_iter_val) ",
                  len(data_loader_iter_val))
            print(train_epoch_loss / len(data_loader_iter_val))
            print('++++++++++++++++++++++++++++++++++')
            # show the original images, predication and ground truth on the visdom.
            # show_image = (img + 1.6) / 3.2 * 255.
            # viz.img(name='images', img_=show_image[0, :, :, :])
            # viz.img(name='labels', img_=mask[0, :, :, :])
            # viz.img(name='prediction', img_=pred[0, :, :, :])

            torchvision.utils.save_image(img[0, :, :, :],
                                         "test_data/results4/image_" +
                                         str(epoch) + ".jpg",
                                         nrow=1,
                                         padding=2,
                                         normalize=True,
                                         range=None,
                                         scale_each=False,
                                         pad_value=0)
            torchvision.utils.save_image(mask[0, :, :, :],
                                         "test_data/results4/mask_" +
                                         str(epoch) + ".jpg",
                                         nrow=1,
                                         padding=2,
                                         normalize=True,
                                         range=None,
                                         scale_each=False,
                                         pad_value=0)
            torchvision.utils.save_image(pred[0, :, :, :],
                                         "test_data/results4/pred_" +
                                         str(epoch) + ".jpg",
                                         nrow=1,
                                         padding=2,
                                         normalize=True,
                                         range=None,
                                         scale_each=False,
                                         pad_value=0)

            # x = torch.tensor([[1,2,3],[4,5,6]], dtype = torch.uint8)
            # x = show_image[0,:,:,:]
            # print(x.shape)
            # pil_im = transforms.ToPILImage(mode = 'RGB')(x)
            # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_' + str(epoch) +  '.jpg')

            # x = mask[0,:,:,:]
            # print(x.shape)
            # pil_im = transforms.ToPILImage(mode = 'L')(x)
            # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_' + str(epoch) +  '.jpg')

            # x = pred[0,:,:,:]
            # print(x.shape)
            # pil_im = transforms.ToPILImage(mode = 'HSV')(x.detach().cpu().numpy())
            # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_' + str(epoch) +  '.jpg')
            # (x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_" + str(epoch) + ".png")
            # cv2.imwrite('imagename.jpg', x.detach().numpy().astype('uint8')).transpose(2,1,0)
            # x = mask[0,:,:,:]
            # # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_" + str(epoch) + ".png")
            # x = pred[0,:,:,:]
            # print(x.shape)
            # cv2.imwrite('imagename2.jpg', x.detach().numpy().astype('uint8'))

            # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_" + str(epoch) + ".png")
            # print("saving images")

            # train_epoch_loss = train_epoch_loss/len(data_loader_iter)
            # print(mylog, '********')
            # print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
            # print(mylog, 'train_loss:', train_epoch_loss)
            # print(mylog, 'SHAPE:', Constants.Image_size)
            # print('********')
            # print('epoch:', epoch, '    time:', int(time() - tic))
            # print('train_loss:', train_epoch_loss)
            # print('SHAPE:', Constants.Image_size)

            # if train_epoch_loss >= train_epoch_best_loss:
            #     no_optim += 1
            # else:
            #     no_optim = 0
            #     train_epoch_best_loss = train_epoch_loss
            #     solver.save('./weights/' + NAME + '.th')
            # if no_optim > Constants.NUM_EARLY_STOP:
            #     print(mylog, 'early stop at %d epoch' % epoch)
            #     print('early stop at %d epoch' % epoch)
            #     break
            # if no_optim > Constants.NUM_UPDATE_LR:
            #     if solver.old_lr < 5e-7:
            #         break
            #     solver.load('./weights/' + NAME + '.th')
            #     solver.update_lr(2.0, factor=True, mylog=mylog)
            # mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()
Exemplo n.º 4
0
def CE_Net_Train():
    NAME = 'CE-Net' + Constants.ROOT.split('/')[-1]

    # run the Visdom
    viz = Visualizer(env=NAME)

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    dataset = ImageFolder(root_path=Constants.ROOT, datasets='DRIVE')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        index = 0

        for img, mask in data_loader_iter:
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # show the original images, predication and ground truth on the visdom.
        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', Constants.Image_size)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', Constants.Image_size)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save('./weights/' + NAME + '.th')
        if no_optim > Constants.NUM_EARLY_STOP:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()
Exemplo n.º 5
0
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                    style_score.item(), content_score.item()))
                print()

            return style_score + content_score

        optimizer.step(closure)

    # a last correction...
    input_img.data.clamp_(0, 1)

    return input_img


if __name__ == '__main__':
    output = run_style_transfer(cnn, cnn_normalization_mean,
                                cnn_normalization_std, content_img, style_img,
                                input_img)
    # plt.figure()

    # output = Tensor2Image(output)

    # imshow(output, title='Output Image')
    Vis.img(name="output", img_=output)

    # sphinx_gallery_thumbnail_number = 4
    # plt.ioff()
    # plt.show()