コード例 #1
0
ファイル: SSSRNet8_main.py プロジェクト: summer1719/SSSRNet
def train(training_data_loader, optimizer, model, model_pretrained, criterion,
          epoch, gpuid):
    avgloss = 0
    lr = adjust_learning_rate(optimizer, epoch - 1)

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    print("epoch =", epoch, "lr =", optimizer.param_groups[0]["lr"])
    model.train()
    input_size = (80, 80)

    for iteration, batch in enumerate(training_data_loader, 1):

        input, target, label, hr_label = Variable(batch[0]), \
                        Variable(batch[1], requires_grad=False), batch[2],  batch[3]

        #=======label transform h*w to 21*h*w=======#
        #label_pro = np.expand_dims(Label_patch, axis=0)
        Label_patch = label.numpy()
        label_pro = np.repeat(Label_patch, 20, axis=1)
        for i in range(1, 21):
            tmp = label_pro[:, i - 1:i]
            if i == 0:
                tmp[tmp == 255] = 0
            tmp[tmp != i] = -1
            tmp[tmp == i] = 1
            tmp[tmp == -1] = 0
        label = Variable(torch.from_numpy(label_pro[:, :, :, :]).float())

        Label_patch = hr_label.numpy()
        label_pro = np.repeat(Label_patch, 20, axis=1)
        for i in range(1, 21):
            tmp = label_pro[:, i - 1:i]
            if i == 0:
                tmp[tmp == 255] = 0
            tmp[tmp != i] = -1
            tmp[tmp == i] = 1
            tmp[tmp == -1] = 0
        hr_label = Variable(torch.from_numpy(label_pro[:, :, :, :]).float())

        #input_ss_r = input[:,0:1,...] * 255.0 -  IMG_MEAN[2]
        #input_ss_g = input[:,1:2,...] * 255.0 -  IMG_MEAN[1]
        #input_ss_b = input[:,2:3,...] * 255.0 -  IMG_MEAN[0]
        #input_ss = torch.cat((input_ss_b, input_ss_g, input_ss_r), 1)

        input = input.cuda(gpuid)
        target = target.cuda(gpuid)
        label = label.cuda(gpuid)
        hr_label = hr_label.cuda(gpuid)

        #=========image mask generation=========#

        output, output_fg = model(input, hr_label)
        #output = model_pretrained(input)

        ##########show results###############
        is_show = False
        if is_show == True:
            label_show = label_pro[0].transpose((1, 2, 0))
            label_show = np.asarray(np.argmax(label_show, axis=2),
                                    dtype=np.int)
            image = input.cpu().data[0].numpy().transpose((1, 2, 0))
            image_out = output.cpu().data[0].numpy().transpose((1, 2, 0))
            label_heatmap = label.cpu().data[0].view(20, 1,
                                                     input.data[0].size(1),
                                                     input.data[0].size(2))
            label_heatmap = torchvision.utils.make_grid(label_heatmap)
            label_heatmap = label_heatmap.numpy().transpose((1, 2, 0))
            images_cls = input_cls.cpu().data[0].view(20, 3,
                                                      input.data[0].size(1),
                                                      input.data[0].size(2))
            images_cls = torchvision.utils.make_grid(images_cls)
            images_cls = images_cls.numpy().transpose((1, 2, 0))

            show_seg(image, label_show, image_out, label_heatmap, images_cls)
        #####################################

        optimizer.zero_grad()
        loss = criterion(output, target) / opt.batchSize
        avgloss += loss.data[0]
        for i in range(20):
            mask = hr_label[:, i:i + 1, :, :].repeat(1, 3, 1, 1)
            classwise_target = torch.mul(mask, target)
            classwise_output = output_fg[:, 3 * i:3 * (i + 1), :, :]
            classwise_loss = criterion(classwise_output,
                                       classwise_target) / opt.batchSize
            if classwise_loss.data[0] == 0:
                continue
            else:
                loss += classwise_loss * 10

        loss.backward()
        total_norm = 0
        if (loss.data[0] < 10000):
            total_norm = torch.nn.utils.clip_grad_norm(model.parameters(),
                                                       opt.clip)
        optimizer.step()

        if iteration % 50 == 0:

            print("===> Epoch[{}]({}/{}): Total_norm:{:.6f} Loss: {:.10f}".
                  format(epoch, iteration, len(training_data_loader),
                         total_norm, loss.data[0]))
    print("===> Epoch {} Complete: Avg. SR Loss: {:.6f}".format(
        epoch, avgloss / len(training_data_loader)))
    return (avgloss / len(training_data_loader))
コード例 #2
0
def train(training_data_loader, optimizer, deeplab, model, mid, criterion, epoch, gpuid):
    avgloss = 0
    lr = adjust_learning_rate(optimizer, epoch - 1)

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    print("epoch =", epoch, "lr =", optimizer.param_groups[0]["lr"])
    model.train()
    input_size = (80, 80)


    for iteration, batch in enumerate(training_data_loader, 1):

        input, target, label = Variable(batch[0]), \
                        Variable(batch[1], requires_grad=False), batch[2]


        input_ss_r = input[:, 0:1, ...] * 255.0 - IMG_MEAN[2]
        input_ss_g = input[:, 1:2, ...] * 255.0 - IMG_MEAN[1]
        input_ss_b = input[:, 2:3, ...] * 255.0 - IMG_MEAN[0]
        input_ss = torch.cat((input_ss_b, input_ss_g, input_ss_r), 1)

        input = input.cuda(gpuid)
        target = target.cuda(gpuid)
        input_ss = input_ss.cuda(gpuid)

        seg_out = deeplab(input_ss)
        size = (input_ss.size()[2], input_ss.size()[3])
        seg_out = mid(seg_out, size)
        label= seg_out.detach()

        #=======label transform h*w to 21*h*w=======#
        label_argmax = False
        if label_argmax == True:
            Label_patch = label.cpu().data[0:1].numpy()
            Label_patch = np.expand_dims(np.argmax(Label_patch, axis=1), axis=0)
            label_pro = np.repeat(Label_patch , 21, axis=1)
            for i in range(21):
                tmp = label_pro[:,i:i+1]
                if i == 0:
                    tmp[tmp==255] = 0
                tmp[tmp != i] = -1
                tmp[tmp == i] = 1
                tmp[tmp == -1] = 0
            label = Variable(torch.from_numpy(label_pro[:, :, :, :]).float())
            transform_test = False
            if transform_test == True:
                Label_patch_test = Label_patch.copy()
                Label_patch_test [Label_patch_test == 255] = 0
                if (np.argmax(label_pro, axis=1).reshape((label.size())) - Label_patch_test).any() != 0:
                    print(">>>>>>Transform Error!")
        #=======label transform h*w to 21*h*w=======#
        label = label.cuda(gpuid)

        #=========image mask generation=========#
        for i in range(21):
            mask = label[:,i:i+1,:,:].repeat(1,3,1,1)
            mask_selected = torch.mul(mask, input)
            if i == 0:
                input_cls = mask_selected
            else:
                input_cls = torch.cat((input_cls, mask_selected), dim=1)
        input_cls = input_cls.cuda(gpuid)
        Blur_SR = model(input_cls)
        ##########show results###############
        is_show = False
        if is_show == True:
            label_show = label.cpu().data[0].numpy().transpose((1, 2, 0))
            label_show = np.asarray(np.argmax(label_show, axis=2), dtype=np.int)

            image_out = input_ss.cpu().data[0].numpy()
            image_out = image_out.transpose((1, 2, 0))
            image_out += IMG_MEAN
            image_out = image_out[:, :, ::-1]  # BRG2RGB
            image = np.asarray(image_out, np.uint8)
            #image = input.cpu().data[0].numpy().transpose((1, 2, 0))
            image_out = Blur_SR.cpu().data[0].numpy().transpose((1, 2, 0))

            label_heatmap = label.cpu().data[0].view(21, 1, input.data[0].size(1), input.data[0].size(2))
            label_heatmap = torchvision.utils.make_grid(label_heatmap)
            label_heatmap = label_heatmap.numpy().transpose((1, 2, 0))
            images_cls = input_cls.cpu().data[0].view(21, 3, input.data[0].size(1), input.data[0].size(2))
            images_cls = torchvision.utils.make_grid(images_cls)
            images_cls = images_cls.numpy().transpose((1, 2, 0))

            show_seg(image, label_show, image_out, label_heatmap, images_cls)
        #####################################


        loss = criterion(Blur_SR, target) / opt.batchSize
        avgloss += loss.data[0]

        if opt.vgg_loss:
            content_input = netContent(Blur_SR)
            content_target = netContent(target)
            content_target = content_target.detach()
            content_loss = criterion(content_input, content_target)

        optimizer.zero_grad()

        if opt.vgg_loss:
            netContent.zero_grad()
            content_loss.backward(retain_variables=True)

        loss.backward()
        total_norm = 0
        if(loss.data[0] < 10000):
            total_norm = torch.nn.utils.clip_grad_norm(model.parameters(), opt.clip)
        optimizer.step()

        if iteration % 50 == 0:
            if opt.vgg_loss:
                print("===> Epoch[{}]({}/{}): Total_norm:{:.6f} Loss: {:.10f} Content_loss {:.10f}".format(epoch, iteration,
                                                                                         len(training_data_loader),
                                                                                         total_norm, loss.data[0],
                                                                                         content_loss.data[0]))
            else:
                print("===> Epoch[{}]({}/{}): Total_norm:{:.6f} Loss: {:.10f}".format(epoch, iteration, len(training_data_loader),
                                                                    total_norm,loss.data[0]))
    print("===> Epoch {} Complete: Avg. SR Loss: {:.6f}".format(epoch, avgloss / len(training_data_loader)))
    return (avgloss / len(training_data_loader))
コード例 #3
0
def test(testloader, model, deeplab, mid, criterion, gpuid, SR_dir):
    avg_psnr = 0
    interp = torch.nn.Upsample(size=(505, 505), mode='nearest')

    data_list = []
    for iteration, batch in enumerate(testloader, 1):
        input_ss, input, target, label_gt, size, name, label_hr = Variable(batch[0], volatile=True), Variable(batch[1], volatile=True), \
                                                     Variable(batch[2], volatile=True), batch[3], batch[4], batch[5], batch[6]
        input_ss = input_ss.cuda(gpuid)
        seg = deeplab(input_ss)
        size = (input_ss.size()[2], input_ss.size()[3])
        label_deeplab = mid(seg, size)

        #=======label transform h*w to 21*h*w=======#
        enhanced_input = False
        use_deeplab = True
        if use_deeplab == True:
            label_argmax = False
            if label_argmax == True:
                Label_patch = label_deeplab.cpu().data[0:1].numpy()
                Label_patch = np.expand_dims(np.argmax(Label_patch, axis=1),
                                             axis=0)
                label_pro = np.repeat(Label_patch, 20, axis=1)
                for i in range(1, 21):
                    tmp = label_pro[:, i - 1:i]
                    if i == 0:
                        tmp[tmp == 255] = 0
                    tmp[tmp != i] = -1
                    tmp[tmp == i] = 1
                    tmp[tmp == -1] = 0
                label = Variable(
                    torch.from_numpy(label_pro[:, :, :, :]).float())
                transform_test = False
                if transform_test == True:
                    Label_patch_test = Label_patch.copy()
                    Label_patch_test[Label_patch_test == 255] = 0
                    if (np.argmax(label_pro, axis=1).reshape(
                        (label_deeplab.size())) - Label_patch_test).any() != 0:
                        print(">>>>>>Transform Error!")
            else:
                label = label_deeplab[:, 1:21]
        else:
            label = label_gt
            Label_patch = label.numpy()
            label_pro = np.repeat(Label_patch, 20, axis=1)
            for i in range(1, 21):
                tmp = label_pro[:, i - 1:i]
                if i == 0:
                    tmp[tmp == 255] = 0
                tmp[tmp != i] = -1
                tmp[tmp == i] = 1
                tmp[tmp == -1] = 0
            label = Variable(torch.from_numpy(label_pro[:, :, :, :]).float())
        #=======label transform h*w to 21*h*w=======#
        input = input.cuda(gpuid)
        target = target.cuda(gpuid)
        label = label.cuda(gpuid)

        #=========image mask generation=========#
        for i in range(20):
            mask = label[:, i:i + 1, :, :].repeat(1, 3, 1, 1)
            mask_selected = torch.mul(mask, input)
            if enhanced_input == True:
                mask_selected = torch.add(0.7 * mask_selected, 0.3 * input)
            if i == 0:
                input_cls = mask_selected
            else:
                input_cls = torch.cat((input_cls, mask_selected), dim=1)
        input_cls = input_cls.cuda(gpuid)
        Blur_SR = model(input_cls, input)
        #output = model_pretrained(input)

        im_h = Blur_SR.cpu().data[0].numpy().astype(np.float32)
        im_h[im_h < 0] = 0
        im_h[im_h > 1.] = 1.
        SR = Variable((torch.from_numpy(im_h)).unsqueeze(0)).cuda(gpuid)

        result = transforms.ToPILImage()(SR.cpu().data[0])
        path = join(SR_dir, '{0:04d}.jpg'.format(iteration))
        #result.save(path)

        ##########Per-class evaluation###############

        print("%s: %s.png" % (iteration, name[0]))
        classwise_evaluate(SR, target, label_hr, NUM_CLASSES, classes)

        ##########show results###############
        is_show = False
        if is_show == True:
            label_show = label.cpu().data[0].numpy().transpose((1, 2, 0))
            label_show = np.asarray(np.argmax(label_show, axis=2),
                                    dtype=np.int)

            #image_out = input.cpu().data[0].numpy()
            #image_out = image_out.transpose((1, 2, 0))
            #image_out += IMG_MEAN
            #image_out = image_out[:, :, ::-1]  # BRG2RGB
            #image_out = np.asarray(image_out, np.uint8)
            image = input.cpu().data[0].numpy().transpose((1, 2, 0))
            image_out = SR.cpu().data[0].numpy().transpose((1, 2, 0))

            label_heatmap = label.cpu().data[0].view(21, 1,
                                                     input.data[0].size(1),
                                                     input.data[0].size(2))
            label_heatmap = torchvision.utils.make_grid(label_heatmap)
            label_heatmap = label_heatmap.numpy().transpose((1, 2, 0))
            images_cls = input_cls.cpu().data[0].view(21, 3,
                                                      input.data[0].size(1),
                                                      input.data[0].size(2))
            images_cls = torchvision.utils.make_grid(images_cls)
            images_cls = images_cls.numpy().transpose((1, 2, 0))

            show_seg(image, label_show, image_out, label_heatmap, images_cls)
        #####################################
        #size = (target.size()[2], target.size()[3])
        #gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int)
        #seg_out = torch.nn.Upsample(size, mode='bilinear')(seg)
        #seg_out = seg_out.cpu().data[0].numpy()
        #seg_out = seg_out.transpose(1, 2, 0)
        #seg_out = np.asarray(np.argmax(seg_out, axis=2), dtype=np.int)
        #data_list.append([gt.flatten(), seg_out.flatten()])
    #get_iou(data_list, NUM_CLASSES )
    for key in classes.keys():
        classes[key][0] /= classes[key][1]
    print("===> Avg. SR Per-Class PSNR: \n ")
    print(classes)

    for key in classes.keys():
        print('%s' % key)
        print('{:.4f} dB'.format(classes[key][0]))
        print('{:.1f} pixels'.format(classes[key][2]))
コード例 #4
0
ファイル: test_SSSR5.py プロジェクト: summer1719/SSSRNet
def test(testloader, model, criterion, gpuid, SR_dir):
    avg_psnr = 0
    interp = torch.nn.Upsample(size=(505, 505), mode='nearest')

    data_list = []
    for iteration, batch in enumerate(testloader, 1):
        input_ss, input, target, label, size, name = Variable(batch[0], volatile=True), Variable(batch[1], volatile=True), \
                                                     Variable(batch[2], volatile=True), batch[3], batch[4], batch[5]

        #size = (1,1, input.size()[2], input.size()[3])
        #label = torch.nn.Upsample(size=(input.size()[2], input.size()[3]) , mode='nearest')(label)
        #=======label transform h*w to 21*h*w=======#
        #label_pro = np.expand_dims(Label_patch, axis=0)
        Label_patch = label.numpy()
        label_pro = np.repeat(Label_patch, 20, axis=1)
        for i in range(1, 21):
            tmp = label_pro[:, i - 1:i]
            if i == 0:
                tmp[tmp == 255] = 0
            tmp[tmp != i] = -1
            tmp[tmp == i] = 1
            tmp[tmp == -1] = 0
        #Label_patch_test = Label_patch.copy()
        #Label_patch_test [Label_patch_test == 255] = 0
        #if (np.argmax(label_pro, axis=1).reshape((label.size())) - Label_patch_test).any() != 0:
        #print(">>>>>>Transform Error!")
        #=======label transform h*w to 21*h*w=======#
        label = Variable(torch.from_numpy(label_pro[:, :, :, :]).float())

        input = input.cuda(gpuid)
        target = target.cuda(gpuid)
        label = label.cuda(gpuid)

        #=========image mask generation=========#
        for i in range(20):
            mask = label[:, i:i + 1, :, :].repeat(1, 3, 1, 1)
            mask_selected = torch.mul(mask, input)
            if i == 0:
                input_cls = mask_selected
            else:
                input_cls = torch.cat((input_cls, mask_selected), dim=1)
        input_cls = input_cls.cuda(gpuid)
        Blur_SR = model(input_cls, input)
        #output = model_pretrained(input)

        im_h = Blur_SR.cpu().data[0].numpy().astype(np.float32)
        im_h[im_h < 0] = 0
        im_h[im_h > 1.] = 1.
        SR = Variable((torch.from_numpy(im_h)).unsqueeze(0)).cuda(gpuid)

        result = transforms.ToPILImage()(SR.cpu().data[0])
        path = join(SR_dir, '{0:04d}.jpg'.format(iteration))
        #result.save(path)
        mse = criterion(SR, target)
        psnr = 10 * log10(1 / mse.data[0])
        avg_psnr += psnr
        print("%s: %s.png" % (iteration, name[0]))
        print('===>psnr: {:.4f} dB'.format(psnr))

        ##########show results###############
        is_show = False
        if is_show == True:
            label_show = label_pro[0].transpose((1, 2, 0))
            label_show = np.asarray(np.argmax(label_show, axis=2),
                                    dtype=np.int)
            image = input.cpu().data[0].numpy().transpose((1, 2, 0))
            image_out = SR.cpu().data[0].numpy().transpose((1, 2, 0))
            label_heatmap = label.cpu().data[0].view(21, 1,
                                                     input.data[0].size(1),
                                                     input.data[0].size(2))
            label_heatmap = torchvision.utils.make_grid(label_heatmap)
            label_heatmap = label_heatmap.numpy().transpose((1, 2, 0))
            images_cls = input_cls.cpu().data[0].view(21, 3,
                                                      input.data[0].size(1),
                                                      input.data[0].size(2))
            images_cls = torchvision.utils.make_grid(images_cls)
            images_cls = images_cls.numpy().transpose((1, 2, 0))

            show_seg(image, label_show, image_out, label_heatmap, images_cls)
        #####################################
        #size = (target.size()[2], target.size()[3])
        #gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int)
        #seg_out = torch.nn.Upsample(size, mode='bilinear')(seg)
        #seg_out = seg_out.cpu().data[0].numpy()
        #seg_out = seg_out.transpose(1, 2, 0)
        #seg_out = np.asarray(np.argmax(seg_out, axis=2), dtype=np.int)
        #data_list.append([gt.flatten(), seg_out.flatten()])
    #get_iou(data_list, NUM_CLASSES )
    print("===> Avg. SR PSNR: {:.4f} dB".format(avg_psnr / iteration))