Example #1
0
def test(epoch, name):

    ###################
    # params          #
    ###################
    cuda = True
    cudnn.benchmark = True
    batch_size = 64
    image_size = 28

    ###################
    # load data       #
    ###################

    img_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    model_root = 'model'
    if name == 'mnist':
        mode = 'source'
        image_root = os.path.join('dataset', 'mnist')
        dataset = datasets.MNIST(root=image_root,
                                 train=False,
                                 transform=img_transform)

        dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=8)

    elif name == 'mnist_m':
        mode = 'target'
        image_root = os.path.join('dataset', 'mnist_m', 'mnist_m_test')
        test_list = os.path.join('dataset', 'mnist_m',
                                 'mnist_m_test_labels.txt')

        dataset = GetLoader(data_root=image_root,
                            data_list=test_list,
                            transform=img_transform)

        dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=8)

    else:
        print 'error dataset name'

    ####################
    # load model       #
    ####################

    my_net = DSN()
    checkpoint = torch.load(
        os.path.join(model_root,
                     'dsn_mnist_mnistm_epoch_' + str(epoch) + '.pth'))
    my_net.load_state_dict(checkpoint)
    my_net.eval()

    if cuda:
        my_net = my_net.cuda()

    ####################
    # transform image  #
    ####################

    def tr_image(img):

        img_new = (img + 1) / 2

        return img_new

    len_dataloader = len(dataloader)
    data_iter = iter(dataloader)

    i = 0
    n_total = 0
    n_correct = 0

    while i < len_dataloader:

        data_input = data_iter.next()
        img, label = data_input

        batch_size = len(label)

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        class_label = torch.LongTensor(batch_size)

        if cuda:
            img = img.cuda()
            label = label.cuda()
            input_img = input_img.cuda()
            class_label = class_label.cuda()

        input_img.resize_as_(input_img).copy_(img)
        class_label.resize_as_(label).copy_(label)
        inputv_img = Variable(input_img)
        classv_label = Variable(class_label)

        result = my_net(input_data=inputv_img,
                        mode='source',
                        rec_scheme='share')
        pred = result[3].data.max(1, keepdim=True)[1]

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='all')
        rec_img_all = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='share')
        rec_img_share = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='private')
        rec_img_private = tr_image(result[-1].data)

        if i == len_dataloader - 2:
            vutils.save_image(rec_img_all, name + '_rec_image_all.png', nrow=8)
            vutils.save_image(rec_img_share,
                              name + '_rec_image_share.png',
                              nrow=8)
            vutils.save_image(rec_img_private,
                              name + '_rec_image_private.png',
                              nrow=8)

        n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1

    accu = n_correct * 1.0 / n_total

    print 'epoch: %d, accuracy of the %s dataset: %f' % (epoch, name, accu)
Example #2
0
def test(epoch, name):
    cuda = False
    cudnn.benchmark = True
    batch_size = 16
    image_size = 28
    p = str(8)
    model_root = 'dataset1' + p + 'dataset2'

    ################
    #   load data  #
    ################
    # 图形变换
    source_img_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),  # 归一化,进行图像的灰度处理
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # 单通道变为三通道
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    img_transform_source = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))  #?
    ])

    img_tranform_target = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    if name == 'dataset1':
        mode = 'source'
        image_root = r'D:\study\graduation_project\grdaution_project\instru_identify\dataset\dataset1\dataset1_test'
        # image_root.replace("\\",'/')
        test_list = r'D:\study\graduation_project\grdaution_project\instru_identify\dataset\dataset1\dataset1_test_labels.txt'

        dataset = GetLoader(
            data_root=image_root,
            data_list=test_list,
            transform=img_transform_source,
        )
        dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=0)
        # print('success')
    elif name == 'dataset2':
        mode = 'target'
        image_root = os.path.join('D:\\', 'study', 'graduation_project',
                                  'grdaution_project', 'instru_identify',
                                  'dataset', 'dataset2', 'dataset2_test')
        test_list = os.path.join('D:\\', 'study', 'graduation_project',
                                 'grdaution_project', 'instru_identify',
                                 'dataset', 'dataset2',
                                 'dataset2_test_labels.txt')
        dataset = GetLoader(
            data_root=image_root,
            data_list=test_list,
            transform=img_tranform_target,
        )

        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0  #?
        )
    else:
        print('error dataset name')

    ###############
    # load model  #
    ###############
    # print('image_root:', image_root)
    # print('test_list:',test_list)
    my_net = DSN()
    checkpoint = torch.load(
        os.path.join(model_root, 'dsn_epoch_' + str(epoch) + '.pth'))
    my_net.load_state_dict(checkpoint)
    my_net.eval()  #?

    if cuda:
        my_net = my_net  #.cuda()

    ###################
    # transform image #
    ###################

    # 这个函数对图片做了什么操作?
    def tr_image(img):
        img_new = (img + 1) / 2
        return img_new

    # print(dataloader)
    len_dataloader = len(dataloader)
    # print('len_dataloader:',len_dataloader)
    data_iter = iter(dataloader)  # 获取迭代器
    # print('data_iter:',data_iter)

    i = 0
    n_total = 0
    n_correct = 0

    total_accu = 0
    while i < len_dataloader - 1:
        #print(i)
        data_input = data_iter.next()
        #print('data_input:', data_input)
        img, label = data_input
        # print('label:', label)
        batch_size = len(label)  # batch_size为一个batch中图片的数量

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        class_label = torch.LongTensor(batch_size)

        if cuda:
            img = img  #.cuda()
            label = label  #.cuda()
            input_img = input_img  #.cuda()
            class_label = class_label  #.cuda()

        input_img.resize_as_(input_img).copy_(img)
        class_label.resize_as_(class_label).copy_(label)
        inputv_img = Variable(input_img)  #?
        classv_label = Variable(class_label)

        # 输入网络

        result = my_net(input_data=inputv_img,
                        mode='source',
                        rec_scheme='share')
        pred = result[3].data.max(1, keepdim=True)[1]
        # print('pred:',pred)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='all')
        rec_img_all = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='share')
        rec_img_share = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='private')
        rec_img_private = tr_image(result[-1].data)

        if i == len_dataloader - 2:
            image_save_path = os.path.join(model_root, 'images')
            if not os.path.exists(image_save_path):
                os.mkdir(image_save_path)
            vutils.save_image(rec_img_all,
                              image_save_path + '/' + name +
                              '_rec_image_all.png',
                              nrow=8)
            vutils.save_image(rec_img_share,
                              image_save_path + '/' + name +
                              'rec_image_share.png',
                              nrow=8)
            vutils.save_image(rec_img_private,
                              image_save_path + '/' + name +
                              'rec_image_private.png',
                              nrow=8)

        n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1
    accu = n_correct * 1.0 / n_total

    # print('n_correct:', n_correct)
    # print('n_total:', n_total)
    # print('epoch: %d,accuracy of the %s dataset: %f' % (epoch, name, accu))
    return accu