Esempio n. 1
0
    def __init__(self, opts, device_ids, load_model=False):
        self.opts = opts
        self.epoch_step = 0
        self.model_num = 0
        self.network = Unnet.UNet(opts)
        # self.network = Resnet.resnet152()
        # self.network = framelet.Framelets()
        # self.network = googlenet.GoogLeNet()
        # self.network = densenet.densenet161()

        if torch.cuda.device_count() > 1 and opts.max_gpus > 1:
            if len(device_ids) <= opts.max_gpus:
                self.network = torch.nn.DataParallel(
                    self.network)  #, device_ids=device_ids[0]
            else:
                self.network = torch.nn.DataParallel(
                    self.network, device_ids=device_ids[0:opts.max_gpus - 1])
        self.network.cuda()

        # Create two sets of loss functions
        self.loss_func_l1 = torch.nn.L1Loss()
        self.loss_func_MSE = torch.nn.MSELoss()
        self.MedGanloss = MedGanloss()
        self.mssim_loss = MSSSIM(window_size=9, size_average=True)
        self.loss_func_poss = torch.nn.PoissonNLLLoss()
        self.loss_func_KLDiv = torch.nn.KLDivLoss()
        self.loss_func_Smoothl1 = torch.nn.SmoothL1Loss()
        self.loss_func_part = torch.nn.L1Loss()
        self.test_loss = torch.nn.MSELoss(reduction='none')
        self.averagepool = torch.nn.AvgPool2d(3, stride=2)
        self.optim_count = 0
        #TODO2: Change the load model dict
        if self.opts.load_model == True or load_model:
            print("Restoring model")
            try:
                if load_model:
                    self.network.load_state_dict(
                        torch.load(
                            '/home/liang/Desktop/output/model/model_dict_0'))

                else:
                    self.network.load_state_dict(
                        torch.load(
                            os.path.join(self.opts.output_path, 'model',
                                         'model_dict')))
            except:
                # original saved file with DataParallel
                state_dict = torch.load(
                    os.path.join(self.opts.output_path, 'model', 'model_dict'))
                # create new OrderedDict that does not contain `module.`
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    # name = k[7:] # remove `module.`
                    name = 'module.' + k
                    new_state_dict[name] = v
                # load params
                self.network.load_state_dict(new_state_dict)
Esempio n. 2
0
 def __init__(self, shape1, shape2):
     # self.network = Net(shape1, shape2)
     self.network = SETLayer(wl * wl, wl * wl)
     # self.network.cuda()
     self.loss_func = torch.nn.L1Loss(reduction='mean')
     self.mssim_loss = MSSSIM(window_size=11, size_average=True)
     self.loss_func_MSE = torch.nn.MSELoss()
     self.model_num = 0
Esempio n. 3
0
 def __init__(self, shape1, shape2):
     self.network = unet.UNet()  #(shape1, shape2)
     self.network.cuda()
     self.loss_func = torch.nn.L1Loss(reduction='mean')
     self.mssim_loss = MSSSIM(window_size=3, size_average=True)
     self.crossengropy = torch.nn.CrossEntropyLoss()
     self.loss_func_MSE = torch.nn.MSELoss()
     self.model_num = 0
Esempio n. 4
0
 def __init__(self, shape1, shape2):
     # self.network = Net(shape1, shape2)
     # self.network.cuda()
     self.loss_func = torch.nn.L1Loss(reduction='sum')
     self.mssim_loss = MSSSIM(window_size=11, size_average=True)
     self.loss_func_MSE = torch.nn.MSELoss()
     self.model_num = 0
     self.weight = torch.randn(
         shape2[0] * shape2[1], shape1[0] * shape1[1],
         device='cuda').to_sparse().requires_grad_(True)
     self.learning_rate = 1e-3
     self.count = 0
Esempio n. 5
0
def main():
    with torch.cuda.device(1):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = AS_Net(in_channels=args.in_channels).to(device)

        batch_time = AverageMeter()
        train_ssim_meter = AverageMeter()
        train_psnr_meter = AverageMeter()
        test_ssim_meter = AverageMeter()
        test_psnr_meter = AverageMeter()

        vis = Visualizer(env=args.vis_env)

        train_dataset = multichanneldata.ReconDataset0526(args.dataset_pathr, train=True)
        test_dataset = multichanneldata.ReconDataset0526(args.dataset_pathr, train=False)

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch, shuffle=False)

        smooth_L1 = nn.SmoothL1Loss()
        msssim = MSSSIM(channel=1)

        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

        if args.loadcp:
            checkpoint = torch.load(args.save_path + 'latest_' + args.file_name)
            start_epoch = checkpoint['epoch']
            print('%s%d' % ('training from epoch:', start_epoch))
            model = checkpoint['model']
            optimizer = checkpoint['optimizer']
            args.learning_rate = checkpoint['curr_lr']

        cudnn.benchmark = True
        total_step = len(train_loader)

        best_metric = {'test_epoch': 0, 'test_ssim': 0, 'test_psnr': 0}
        log.info('train image num: {}'.format(train_dataset.__len__()))
        log.info('val image num: {}'.format(test_dataset.__len__()))

        end = time.time()
        for epoch in range(args.start_epoch, args.num_epochs):
            for batch_idx, (rawdata, reimage, bfimg) in enumerate(tqdm(train_loader)):
                rawdata = rawdata.to(device)
                reimage = reimage.to(device)
                bfimg = bfimg.to(device)

                fake_img, bf_feature, side = model(rawdata, bfimg)
                loss_pe = smooth_L1(fake_img, reimage)
                bf_loss = smooth_L1(bf_feature, reimage)
                loss = 5 * loss_pe + bf_loss
                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                ssim = compare_ssim(np.array(reimage[0, 0, :, :].cpu().detach()),
                                    np.array(fake_img[0, 0, :, :].cpu().detach()))
                train_ssim_meter.update(ssim)
                psnr = compare_psnr(np.array(reimage[0, 0, :, :].cpu().detach()),
                                    np.array(fake_img[0, 0, :, :].cpu().detach()),
                                    data_range=1)
                train_psnr_meter.update(psnr)

                # visualization and evaluation
                if (batch_idx + 1) % 5 == 0:
                    reimage = reimage.detach()
                    bfimg = bfimg.detach()
                    bf_feature = bf_feature.detach()
                    side = side.detach()
                    fake_img = fake_img.detach()
                    vis.img(name='ground truth', img_=255 * reimage[0])
                    vis.img(name='DAS image', img_=255 * bfimg[0])
                    vis.img(name='textural map', img_=255 * bf_feature[0])
                    vis.img(name='side_output', img_=255 * side[0])
                    vis.img(name='output', img_=255 * fake_img[0])

                batch_time.update(time.time() - end)
                end = time.time()

            log.info(
                'Epoch [{}], Start [{}], Step [{}/{}], Loss: {:.4f}, Time [{batch_time.val:.3f}({batch_time.avg:.3f})]'
                    .format(epoch + 1, args.start_epoch, batch_idx + 1, total_step, loss.item(),
                            batch_time=batch_time))

            vis.plot_multi_win(
                dict(
                    bfloss=bf_loss.item(),
                    loss_mse=loss_pe.item(),
                    total_loss=loss.item(),
                ))

            vis.plot_multi_win(dict(train_ssim=train_ssim_meter.avg, train_psnr=train_psnr_meter.avg))
            log.info('tain_ssim: {}, train_psnr: {}'.format(train_ssim_meter.avg, train_psnr_meter.avg))

            # Validata
            if epoch % 5 == 0:
                with torch.no_grad():
                    for batch_idx, (rawdata, reimage, bfimg) in enumerate(tqdm(test_loader)):
                        rawdata = rawdata.to(device)
                        reimage = reimage.to(device)
                        bfimg = bfimg.to(device)
                        outputs, bf_feature, side_test = model(rawdata, bfimg)
                        test_ms_ssim = msssim(outputs, reimage)

                        ssim = compare_ssim(np.array(reimage.cpu().squeeze()), np.array(outputs.cpu().squeeze()))
                        test_ssim_meter.update(ssim)
                        psnr = compare_psnr(np.array(reimage.cpu().squeeze()), np.array(outputs.cpu().squeeze()),
                                            data_range=1)
                        test_psnr_meter.update(psnr)

                        if (batch_idx + 1) % 2 == 0:
                            reimage = reimage.detach()
                            bf_feature = bf_feature.detach()
                            outputs = outputs.detach()
                            side_test = side_test.detach()
                            bfimg = bfimg.detach()
                            vis.img(name='Test: ground truth', img_=255 * reimage[0])
                            vis.img(name='Test: DASimage', img_=255 * bfimg[0])
                            vis.img(name='Test: textural map', img_=255 * bf_feature[0])
                            vis.img(name='Test: output', img_=255 * outputs[0])
                            vis.img(name='Test: side_output', img_=255 * side_test[0])

                    vis.plot_multi_win(dict(
                        test_ssim=test_ssim_meter.avg,
                        test_psnr=test_psnr_meter.avg,
                        test_msssim=test_ms_ssim.item()
                    ))
                    log.info('test_ssim: {}, test_psnr: {}'.format(test_ssim_meter.avg, test_psnr_meter.avg))

            # Decay learning rate
            if (epoch + 1) % 50 == 0:
                args.learning_rate /= 5
                update_lr(optimizer, args.learning_rate)

            torch.save({'epoch': epoch,
                        'model': model,
                        'optimizer': optimizer,
                        'curr_lr': args.learning_rate,
                        },
                       args.save_path + 'latest_' + args.file_name
                       )

            if best_metric['test_ssim'] < test_ssim_meter.avg:
                torch.save({'epoch': epoch,
                            'model': model,
                            'optimizer': optimizer,
                            'curr_lr': args.learning_rate,
                            },
                           args.save_path + 'best_' + args.file_name
                           )
                best_metric['test_epoch'] = epoch
                best_metric['test_ssim'] = test_ssim_meter.avg
                best_metric['test_psnr'] = test_psnr_meter.avg
            log.info('best_epoch: {}, best_ssim: {}, best_psnr: {}'.format(best_metric['test_epoch'],
                                                                           best_metric['test_ssim'],
                                                                           best_metric['test_psnr']))
Esempio n. 6
0
class Sino_repair_net():
    def __init__(self, opts, device_ids, load_model=False):
        self.opts = opts
        self.epoch_step = 0
        self.model_num = 0
        self.network = Unnet.UNet(opts)
        # self.network = Resnet.resnet152()
        # self.network = framelet.Framelets()
        # self.network = googlenet.GoogLeNet()
        # self.network = densenet.densenet161()

        if torch.cuda.device_count() > 1 and opts.max_gpus > 1:
            if len(device_ids) <= opts.max_gpus:
                self.network = torch.nn.DataParallel(
                    self.network)  #, device_ids=device_ids[0]
            else:
                self.network = torch.nn.DataParallel(
                    self.network, device_ids=device_ids[0:opts.max_gpus - 1])
        self.network.cuda()

        # Create two sets of loss functions
        self.loss_func_l1 = torch.nn.L1Loss()
        self.loss_func_MSE = torch.nn.MSELoss()
        self.MedGanloss = MedGanloss()
        self.mssim_loss = MSSSIM(window_size=9, size_average=True)
        self.loss_func_poss = torch.nn.PoissonNLLLoss()
        self.loss_func_KLDiv = torch.nn.KLDivLoss()
        self.loss_func_Smoothl1 = torch.nn.SmoothL1Loss()
        self.loss_func_part = torch.nn.L1Loss()
        self.test_loss = torch.nn.MSELoss(reduction='none')
        self.averagepool = torch.nn.AvgPool2d(3, stride=2)
        self.optim_count = 0
        #TODO2: Change the load model dict
        if self.opts.load_model == True or load_model:
            print("Restoring model")
            try:
                if load_model:
                    self.network.load_state_dict(
                        torch.load(
                            '/home/liang/Desktop/output/model/model_dict_0'))

                else:
                    self.network.load_state_dict(
                        torch.load(
                            os.path.join(self.opts.output_path, 'model',
                                         'model_dict')))
            except:
                # original saved file with DataParallel
                state_dict = torch.load(
                    os.path.join(self.opts.output_path, 'model', 'model_dict'))
                # create new OrderedDict that does not contain `module.`
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    # name = k[7:] # remove `module.`
                    name = 'module.' + k
                    new_state_dict[name] = v
                # load params
                self.network.load_state_dict(new_state_dict)

    def set_optimizer(self, optimizer):
        self.optimizer = optimizer

    def train_batch(self, input_img, target_img, valid=None):

        if valid is None:
            output = self.network.forward(input_img)
            loss, loss2 = self.optimize(output, target_img)
            return output, loss, loss2
        else:

            final = input_img.clone()
            mask = torch.tensor([i for i, n in enumerate(valid)
                                 if n == 1]).cuda()
            if len(mask) > 0:
                traininput = torch.index_select(input_img, 0, mask)
                trainoutput = self.network.forward(traininput)
                loss, loss2 = self.optimize(
                    trainoutput, torch.index_select(target_img, 0, mask))
                final[mask] = trainoutput
            else:
                loss, loss2 = self.loss_func_l1(final, target_img), (
                    1 - self.mssim_loss.forward(final, target_img)) / 2

            return final, loss, loss2

    def test(self, x, y, valid=None):

        if valid is None:
            output = self.network.forward(x)
            loss = self.test_loss(output, y).detach()
            return output, loss
        else:
            final = x.clone()
            mask = torch.tensor([i for i, n in enumerate(valid)
                                 if n == 1]).cuda()
            if len(mask) > 0:
                traininput = torch.index_select(x, 0, mask)
                trainoutput = self.network.forward(traininput)
                final[mask] = trainoutput
            loss = self.test_loss(final, y).detach()
        return final, loss

    def optimize(self, output, target_img):
        #TODO: can add other loss terms if needed
        #TODO: need to step though this code to make sure it works correctly
        input1 = output  #torch.floor(output)  #/ (output.max() + 1e-8)
        input2 = target_img  #/ (target_img.max() + 1e-8)
        # # Including l1 loss
        # mask = ((input_img * -1.0) + 1.0) >= 0.8
        loss1 = self.loss_func_l1(input1, input2)
        l1 = loss1.detach()

        # Including a consistency loss
        loss2 = self.loss_func_MSE(input1, input2)
        l2 = loss2.detach()

        loss3 = 0.0001 * (1 - self.mssim_loss.forward(input1, input2))
        l3 = loss3.detach()
        #
        # loss3 = self.MedGanloss(output, target_img)
        # l3 = loss3.detach()
        # loss3 = self.loss_func_l1(self.averagepool(input1), self.averagepool(input2))
        # loss3 = abs(torch.floor(input1).mean()- input2.mean())

        # loss3 = self.loss_func_MSE(output, target_img)
        #
        # if self.OPT_count == 0:
        #     self.alpha = torch.tensor(0.5).cuda()
        #     self.lossl1 = []
        #     self.lossmssim = []
        #
        # self.lossl1.append(l1.item())
        # self.lossmssim.append(l2.item())
        #
        # if self.OPT_count >= 20:
        #     self.alpha = torch.FloatTensor(self.lossl1).mean().cuda()/(torch.FloatTensor(self.lossl1).mean() + torch.FloatTensor(self.lossmssim).mean()).cuda()
        #     self.OPT_count = 0
        #
        # # loss = self.alpha * loss1+(1-self.alpha) *loss2
        # vx = output - torch.mean(output)
        # vy = target_img - torch.mean(target_img)
        # loss_pearson_correlation = 1 - torch.sum(vx * vy) / (
        #             torch.rsqrt(torch.sum(vx ** 2)) * torch.rsqrt(torch.sum(vy ** 2)))  # use Pearson correlation

        loss = loss1 + loss2 + loss3
        # print(loss1, loss2, loss3)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.optim_count += 1

        return l1, l2

    def save_network(self):
        print("saving network parameters")
        folder_path = os.path.join(self.opts.output_path, 'model')
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        torch.save(
            self.network.state_dict(),
            os.path.join(folder_path, "model_dict_{}".format(self.model_num)))
        self.model_num += 1
        if self.model_num >= 5: self.model_num = 0
Esempio n. 7
0
def train(netG1, netG2, netD1, netD2, vgg, optG1, optG2, optD1, optD2,
          dataloader, bs, l, savename):

    real_label = torch.ones(bs).long()
    fake_label = torch.zeros(bs).long()

    #     crit_data = nn.L1Loss()
    #     criterion = nn.BCEWithLogitsLoss()

    # Segmentation losses
    crit_discriminator = nn.CrossEntropyLoss(reduction='sum')  #NLLLoss()
    crit_generator = nn.CrossEntropyLoss(reduction='mean')  #reduction='sum')

    # loss weights
    # loss weights
    tv_w = 2000
    content_w = 10
    color_w = 0.5
    tv_msssim = 500
    texture_w = 1
    NUM_PATCHES = 50 // bs
    # DPED losses

    color_criterion = nn.MSELoss(reduction='sum')  #nn.MSELoss(reduction='sum')
    texture_criterion = nn.BCELoss(reduction='mean')
    content_criterion = nn.MSELoss(reduction='sum')
    msssim_criterion = MSSSIM()

    tv_criterion = TVLoss()
    blur = GaussianBlur()
    gray = GrayLayer()

    if torch.cuda.is_available():
        real_label = real_label.cuda()
        fake_label = fake_label.cuda()
        tv_criterion = tv_criterion.cuda()
        blur = blur.cuda()
        gray = gray.cuda()

    for epoch in range(40):
        begin = time.time()
        for i, (x, y, r) in enumerate(dataloader):
            if i % 10 == 0:
                print(epoch, i, end='\r')
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
                r = r.cuda()
            # Image Segmentation
            # generator 1
            # Generate classes mask image and calculate loss with ground truth image mask
            optG1.zero_grad()
            optG2.zero_grad()
            optD1.zero_grad()
            optD2.zero_grad()

            y_G1 = netG1(x)

            loss_G1_data = crit_generator(y_G1, y)

            _, y_G1 = torch.max(y_G1, 1)
            y_G1 = (y_G1.float() / 2.0 - 0.5) * 2
            x_y_G1 = torch.cat((x, y_G1.float().unsqueeze(1)), 1)
            x_y = torch.cat((x, y.float().unsqueeze(1)), 1)

            loss_G1_adv = crit_discriminator(netD1(x_y_G1), real_label.long())

            loss_G1 = loss_G1_data + loss_G1_adv

            # discriminator 1
            y_D1_fake = netD1(x_y_G1.detach())
            y_D1_real = netD1(x_y)
            loss_D1_fake = crit_discriminator(y_D1_fake, fake_label.long())
            loss_D1_real = crit_discriminator(y_D1_real, real_label.long())
            loss_D1 = l[1] * (loss_D1_real + loss_D1_fake)

            # Image Cropping
            x_y_G1_crop, x_y_crop, r_crop = x_y_G1.detach(), x_y, r
            texture_loss = 0
            content_loss = 0
            color_loss = 0
            tv_loss = 0
            loss_D2_fake = 0
            loss_D2_real = 0
            msssim_loss = 0

            for _ in range(NUM_PATCHES):
                x_y_G1, x_y, r = CropImage(x_y_G1_crop, x_y_crop, r_crop)

                # Image Enhancement
                # DPED
                # train generator 2

                y_G2 = netG2(x_y_G1)
                # texture loss
                y_G2_gray = gray(y_G2)

                x_y_G1_G2 = torch.cat((x_y_G1, y_G2), 1)
                y_G2_pred = netD2(x_y_G1_G2)

                texture_loss += -texture_criterion(y_G2_pred,
                                                   fake_label.float()) * bs

                # content loss

                vgg_y_G2 = vgg(y_G2)
                vgg_r = vgg(r).detach()
                _, c1, h1, w1 = y_G2.size()
                chw1 = c1 * h1 * w1
                content_loss += 1.0 / (2 * bs * chw1) * content_criterion(
                    vgg_y_G2, vgg_r)

                # color loss

                y_G2_blur = blur(y_G2)
                r_blur = blur(r).detach()
                color_loss += color_criterion(y_G2_blur, r_blur) / (2 * bs)
                #                 color_loss += color_criterion(y_G2, r) / (2 * bs)
                #                 msssim_loss += msssim_criterion(r, y_G2)

                # total variation loss

                tv_loss += tv_criterion(y_G2)

                # discriminator 2

                x_y_r = torch.cat((x_y, r), 1)
                r_pred = netD2(x_y_r.detach())
                y_G2_pred = netD2(x_y_G1_G2.detach())

                loss_D2_fake += texture_criterion(y_G2_pred,
                                                  fake_label.float()) * bs
                loss_D2_real += texture_criterion(r_pred,
                                                  real_label.float()) * bs

            texture_loss /= NUM_PATCHES
            content_loss /= NUM_PATCHES
            color_loss /= NUM_PATCHES
            tv_loss /= NUM_PATCHES
            loss_D2_fake /= NUM_PATCHES
            loss_D2_real /= NUM_PATCHES
            #             msssim_loss /= NUM_PATCHES
            #             msssim_loss = 1 - msssim_loss

            # Total losses
            loss_G2 = texture_w * texture_loss + content_w * content_loss + color_w * color_loss + tv_w * tv_loss  #+ tv_msssim * msssim_loss
            loss_D2 = texture_w * (loss_D2_real + loss_D2_fake)

            loss_D1.backward()
            loss_D2.backward()
            optD1.step()
            optD2.step()
            loss_G1.backward()
            loss_G2.backward()
            optG1.step()
            optG2.step()

            if i % 100 == 0:
                losses = {
                    'content': content_loss.item(),
                    'color': color_loss.item(),
                    #                     'msssim': msssim_loss.item(),
                    'tv': tv_loss.item(),
                    'gen_texture_loss': texture_loss.item(),
                    'disc_fake_loss': loss_D2_fake.item() / bs,
                    'disc_real_loss': loss_D2_real.item() / bs,
                }

                torch.save(netG1.state_dict(),
                           './models/netG1' + savename + '.pth')
                torch.save(netG2.state_dict(),
                           './models/netG2' + savename + '.pth')
                torch.save(netD1.state_dict(),
                           './models/netD1' + savename + '.pth')
                torch.save(netD2.state_dict(),
                           './models/netD2' + savename + '.pth')

                print('epoch {:}'.format(epoch),
                      'iter {:}'.format(i),
                      'iter time {:.2f}'.format(time.time() - begin),
                      'loss_D1: {:.4f}'.format(loss_D1.item()),
                      'loss_D2: {:.4f}'.format(loss_D2.item()),
                      'loss_G1: {:.4f}'.format(loss_G1.item()),
                      'loss_G2: {:.4f}'.format(loss_G2.item()),
                      file=open(
                          '/home/jupyter/STGAN/results_stgan' + savename +
                          '.txt', 'a+'))
                for k, v in losses.items():
                    print(k,
                          '{:.3f}'.format(v),
                          end=', ',
                          file=open(
                              '/home/jupyter/STGAN/results_stgan' + savename +
                              '.txt', 'a+'))
                print('',
                      file=open(
                          '/home/jupyter/STGAN/results_stgan' + savename +
                          '.txt', 'a+'))

        torch.save(netG1.state_dict(), './models/netG1' + savename + '.pth')
        torch.save(netG2.state_dict(), './models/netG2' + savename + '.pth')
        torch.save(netD1.state_dict(), './models/netD1' + savename + '.pth')
        torch.save(netD2.state_dict(), './models/netD2' + savename + '.pth')
Esempio n. 8
0
def mssdim(output, target):
    return (1. - MSSSIM(output, target)) / 2.
Esempio n. 9
0
    D = model.Discriminator(6)
    G = model.VGG_VAE(5)

    D.apply(weights_init)
    G.apply(weights_init)

    D.cuda()
    G.cuda()
    print(D)
    print(G)
    D_criterion = torch.nn.BCEWithLogitsLoss().cuda()
    D_optimizer = torch.optim.SGD(D.parameters(), lr=1e-3)

    G_criterion = torch.nn.BCEWithLogitsLoss().cuda()
    G_l1 = torch.nn.L1Loss().cuda()
    G_msssim = MSSSIM().cuda()
    G_ssim = SSIM().cuda()
    G_optimizer = torch.optim.Adam(G.parameters(), lr=1e-3)

    pathlib.Path(sample_output).mkdir(parents=True, exist_ok=True)
    pathlib.Path(os.path.join(sample_output, "images")).mkdir(parents=True, exist_ok=True)
    d_loss = 0
    g_loss = 0
    
    d_to_g_threshold = 0.5
    g_to_d_threshold = 0.3

    train_d = True
    train_g = True

    conditional_training = False
Esempio n. 10
0
def l1loss(x, y):
    return torch.nn.functional.l1_loss(x, y, reduction='mean')


def l2loss(x, y):
    return torch.pow(x - y, 2).mean()


def psnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return (10 * torch.log10(x.shape[-2] * x.shape[-1] /
                             (x - y).pow(2).sum(dim=(2, 3)))).mean(dim=1)


ssim = SSIM(data_range=1.0)
msssim = MSSSIM(data_range=1.0)


def gaussian(x, sigma=1.0):
    return np.exp(-(x**2) / (2 * (sigma**2)))


def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, device=None):
    """Construct the convolution kernel for a gaussian blur
    See https://en.wikipedia.org/wiki/Gaussian_blur for a definition.
    Overall I first generate a NxNx2 matrix of indices, and then use those to
    calculate the gaussian function on each element. The two dimensional
    Gaussian function is then the product along axis=2.
    Also, in_channels == out_channels == n_channels
    """
    if size % 2 != 1:
Esempio n. 11
0
    def __init__(self, opts, device_ids, load_model=False):
        self.opts = opts
        self.epoch_step = 0
        self.model_num = 0
        # self.network = UNet(opts)
        # self.network = test2.Net(opts)
        # self.network = VGG.vgg11_bn()
        self.network = GoogleNet.GoogLeNet()
        # self.network = framelet.Framelets()
        # import VGG
        # self.network = VGG.vgg19_bn()
        # self.network = Nets.AlexNet(channelnumber=1, num_classes=64*520)

        #Logic to make training on a GPU cluster easier
        if torch.cuda.device_count() > 1 and opts.max_gpus > 1:
            if len(device_ids) <= opts.max_gpus:
                self.network = torch.nn.DataParallel(
                    self.network)  #, device_ids=device_ids[0]
            else:
                self.network = torch.nn.DataParallel(
                    self.network, device_ids=device_ids[0:opts.max_gpus - 1])
        self.network.cuda()

        # Create two sets of loss functions
        self.loss_func_l1 = torch.nn.L1Loss()
        self.loss_func_MSE = torch.nn.MSELoss()
        self.mssim_loss = MSSSIM(window_size=11, size_average=True)

        self.loss_func_poss = torch.nn.PoissonNLLLoss()
        self.loss_func_KLDiv = torch.nn.KLDivLoss()
        self.loss_func_Smoothl1 = torch.nn.SmoothL1Loss()
        self.loss_func_part = torch.nn.L1Loss()
        self.test_loss = torch.nn.MSELoss(reduction='none')

        self.OPT_count = 0
        #TODO2: Change the load model dict
        if self.opts.load_model == True or load_model:
            print("Restoring model")
            try:
                if load_model:
                    self.network.load_state_dict(
                        torch.load(
                            '/home/liang/Desktop/output/model/model_dict_0'))

                else:
                    self.network.load_state_dict(
                        torch.load(
                            os.path.join(self.opts.output_path, 'model',
                                         'model_dict')))
            except:
                # original saved file with DataParallel
                state_dict = torch.load(
                    os.path.join(self.opts.output_path, 'model', 'model_dict'))
                # create new OrderedDict that does not contain `module.`
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    # name = k[7:] # remove `module.`
                    name = 'module.' + k
                    new_state_dict[name] = v
                # load params
                self.network.load_state_dict(new_state_dict)