示例#1
0
def load_data():
    img_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    dataset_source = datasets.MNIST(
        root=IMG_DIR_SRC,
        train=True,
        transform=img_transform,
        download=True
    )
    dataloader_source = torch.utils.data.DataLoader(
        dataset=dataset_source,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        num_workers=8)
    train_list = IMG_DIR_TAR + '/mnist_m_train_labels.txt'
    dataset_target = GetLoader(
        data_root=IMG_DIR_TAR + '/mnist_m_train',
        data_list=train_list,
        transform=img_transform
    )
    dataloader_target = torch.utils.data.DataLoader(
        dataset=dataset_target,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        num_workers=8)
    return dataloader_source, dataloader_target
示例#2
0
def load_test_data(dataset_name):
    img_transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    if dataset_name == 'mnist_m':
        test_list = '../dataset/mnist_m/mnist_m_test_labels.txt'
        dataset = GetLoader(
            data_root='../dataset/mnist_m/mnist_m_test',
            data_list=test_list,
            transform=img_transform
        )
    else:
        dataset = datasets.MNIST(
            root=IMG_DIR_SRC,
            train=False,
            transform=img_transform,
            download=True
        )
    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8
    )
    return dataloader
示例#3
0
def load_data():
    if DATASET_NAME == 'cifar':
        img_transform_cifar = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        dataset = datasets.CIFAR10(root='CIFAR',
                                   train=True,
                                   transform=img_transform_cifar,
                                   target_transform=None,
                                   download=True)

    elif DATASET_NAME == 'gtsrb':
        img_transform_gtrsb = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.3403, 0.3121, 0.3214),
                                 (0.2724, 0.2608, 0.2669))
        ])
        dataset = gtsrb_dataset.GTSRB(root_dir='./',
                                      train=True,
                                      transform=img_transform_gtrsb)

    elif DATASET_NAME == 'mnist':
        img_transform_mnist = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5)
        ])
        dataset = datasets.MNIST(root='./',
                                 train=True,
                                 transform=img_transform_mnist,
                                 download=True)

    elif DATASET_NAME == 'mnistm':
        img_transform_mnist = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5)
        ])
        train_list = './mnist_m/mnist_m_train_labels.txt'
        dataset = GetLoader(data_root='./mnist_m/mnist_m_train',
                            data_list=train_list,
                            transform=img_transform_mnist)

    else:
        print('Data not found.')
        exit()
    return dataset
示例#4
0
])

dataset_source = datasets.MNIST(root='dataset',
                                train=True,
                                transform=img_transform_source,
                                download=True)

dataloader_source = torch.utils.data.DataLoader(dataset=dataset_source,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                num_workers=8)

train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt')

dataset_target = GetLoader(data_root=os.path.join(target_image_root,
                                                  'mnist_m_train'),
                           data_list=train_list,
                           transform=img_transform_target)

dataloader_target = torch.utils.data.DataLoader(dataset=dataset_target,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                num_workers=8)

# load model

my_net = CNNModel()

# setup optimizer

optimizer = optim.Adam(my_net.parameters(), lr=lr)
示例#5
0
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
    ])

    img_transform_target = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    if dataset_name == 'mnist_m':
        test_list = os.path.join(image_root, 'mnist_m_test_labels.txt')

        dataset = GetLoader(data_root=os.path.join(image_root, 'mnist_m_test'),
                            data_list=test_list,
                            transform=img_transform_target)
    else:
        dataset = datasets.MNIST(
            root='dataset',
            train=False,
            transform=img_transform_source,
        )

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=8)
    """ test """

    my_net = torch.load(
示例#6
0
文件: test.py 项目: yhn280385395/DSN
def test(epoch, name):

    ###################
    # params          #
    ###################
    cuda = True
    cudnn.benchmark = True
    batch_size = 64
    image_size = 28

    ###################
    # load data       #
    ###################

    img_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    model_root = 'model'
    if name == 'mnist':
        mode = 'source'
        image_root = os.path.join('dataset', 'mnist')
        dataset = datasets.MNIST(root=image_root,
                                 train=False,
                                 transform=img_transform)

        dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=8)

    elif name == 'mnist_m':
        mode = 'target'
        image_root = os.path.join('dataset', 'mnist_m', 'mnist_m_test')
        test_list = os.path.join('dataset', 'mnist_m',
                                 'mnist_m_test_labels.txt')

        dataset = GetLoader(data_root=image_root,
                            data_list=test_list,
                            transform=img_transform)

        dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=8)

    else:
        print 'error dataset name'

    ####################
    # load model       #
    ####################

    my_net = DSN()
    checkpoint = torch.load(
        os.path.join(model_root,
                     'dsn_mnist_mnistm_epoch_' + str(epoch) + '.pth'))
    my_net.load_state_dict(checkpoint)
    my_net.eval()

    if cuda:
        my_net = my_net.cuda()

    ####################
    # transform image  #
    ####################

    def tr_image(img):

        img_new = (img + 1) / 2

        return img_new

    len_dataloader = len(dataloader)
    data_iter = iter(dataloader)

    i = 0
    n_total = 0
    n_correct = 0

    while i < len_dataloader:

        data_input = data_iter.next()
        img, label = data_input

        batch_size = len(label)

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        class_label = torch.LongTensor(batch_size)

        if cuda:
            img = img.cuda()
            label = label.cuda()
            input_img = input_img.cuda()
            class_label = class_label.cuda()

        input_img.resize_as_(input_img).copy_(img)
        class_label.resize_as_(label).copy_(label)
        inputv_img = Variable(input_img)
        classv_label = Variable(class_label)

        result = my_net(input_data=inputv_img,
                        mode='source',
                        rec_scheme='share')
        pred = result[3].data.max(1, keepdim=True)[1]

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='all')
        rec_img_all = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='share')
        rec_img_share = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='private')
        rec_img_private = tr_image(result[-1].data)

        if i == len_dataloader - 2:
            vutils.save_image(rec_img_all, name + '_rec_image_all.png', nrow=8)
            vutils.save_image(rec_img_share,
                              name + '_rec_image_share.png',
                              nrow=8)
            vutils.save_image(rec_img_private,
                              name + '_rec_image_private.png',
                              nrow=8)

        n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1

    accu = n_correct * 1.0 / n_total

    print 'epoch: %d, accuracy of the %s dataset: %f' % (epoch, name, accu)
示例#7
0
def test(dataset_name):
    assert dataset_name in ['MNIST', 'mnist_m']

    model_root = 'models'
    image_root = os.path.join('dataset', dataset_name)

    cuda = True
    cudnn.benchmark = True
    batch_size = 128
    image_size = 28
    alpha = 0
    """load data"""

    img_transform_source = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
    ])

    img_transform_target = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    if dataset_name == 'mnist_m':
        test_list = os.path.join(image_root, 'mnist_m_test_labels.txt')

        dataset = GetLoader(data_root=os.path.join(image_root, 'mnist_m_test'),
                            data_list=test_list,
                            transform=img_transform_target)
    else:
        dataset = datasets.MNIST(
            root='dataset',
            train=False,
            transform=img_transform_source,
        )

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=8)
    """ test """

    my_net = torch.load(
        os.path.join(model_root, 'mnist_mnistm_model_epoch_current.pth'))
    my_net = my_net.eval()

    if cuda:
        my_net = my_net.cuda()

    len_dataloader = len(dataloader)
    data_target_iter = iter(dataloader)

    i = 0
    n_total = 0
    n_correct = 0

    while i < len_dataloader:

        # test model using target data
        data_target = data_target_iter.next()
        t_img, t_label = data_target

        batch_size = len(t_label)

        if cuda:
            t_img = t_img.cuda()
            t_label = t_label.cuda()

        class_output, _ = my_net(input_data=t_img, alpha=alpha)
        pred = class_output.data.max(1, keepdim=True)[1]
        n_correct += pred.eq(t_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1

    accu = n_correct.data.numpy() * 1.0 / n_total

    return accu
示例#8
0
def run(net_str):
    # execute only if run as the entry point into the program
    # 定义源域和当前目标域
    net_str = os.path.join(
        'D:\study\graduation_project\grdaution_project\instru_identify\dataset18dataset2',
        net_str)
    source_image_root = os.path.join('D:\\', 'study', 'graduation_project',
                                     'grdaution_project', 'instru_identify',
                                     'dataset', 'dataset1')
    target_image_root = os.path.join('D:\\', 'study', 'graduation_project',
                                     'grdaution_project', 'instru_identify',
                                     'dataset', 'dataset2')

    target = 'dataset2'

    # 选取历史数据的比例
    p = str(8)
    # 模型保存路径
    model_root = 'dataset1' + p + 'dataset2'
    if not os.path.exists(model_root):
        os.mkdir(model_root)
    if not os.path.exists(model_root):
        os.makedirs(model_root)

    # 训练日志保存
    log_path = os.path.join(model_root, 'train.txt')
    sys.stdout = Logger(log_path)

    # print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))

    # 训练参数定义
    cuda = False
    cudnn.benchmark = True
    lr = 1e-2
    batch_size = 16
    image_size = 28
    n_epoch = 1
    step_decay_weight = 0.95
    lr_decay_step = 20000
    active_domain_loss_step = 10000
    weight_decay = 1e-6
    alpha_weight = 0.01
    beta_weight = 0.075
    gamma_weight = 0.25
    momentum = 0.9

    manual_seed = random.randint(1, 10000)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)

    #######################
    # load data           #
    #######################

    img_transform_source = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
    ])

    img_transform_target = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # 源域数据加载
    source_list = os.path.join(source_image_root, 'dataset1_train_labels.txt')
    dataset_source = GetLoader(
        data_root=os.path.join(source_image_root, 'dataset1_train'),
        data_list=source_list,
        transform=img_transform_target,
    )

    dataloader_source = torch.utils.data.DataLoader(
        dataset=dataset_source,
        batch_size=batch_size,
        shuffle=True,  # 随机数种子
        num_workers=0  # 进程数
    )

    # 目标域数据加载
    target_list = os.path.join(target_image_root, 'dataset2_train_labels.txt')
    dataset_target = GetLoader(
        data_root=os.path.join(target_image_root, 'dataset2_train'),
        data_list=target_list,
        transform=img_transform_target,
    )

    dataloader_target = torch.utils.data.DataLoader(
        dataset=dataset_target,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # 单进程加载
    )

    #####################
    #  load model       #
    #####################

    my_net = DSN()
    my_net.load_state_dict(torch.load(net_str))

    #####################
    # setup optimizer   #
    #####################

    def exp_lr_scheduler(optimizer,
                         step,
                         init_lr=lr,
                         lr_decay_step=lr_decay_step,
                         step_decay_weight=step_decay_weight):

        # Decay learning rate by a factor of step_decay_weight every lr_decay_step
        current_lr = init_lr * (step_decay_weight**(step / lr_decay_step))

        if step % lr_decay_step == 0:
            print('learning rate is set to %f' % current_lr)

        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr

        return optimizer

    optimizer = optim.SGD(my_net.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay)

    # 损失函数定义
    loss_classfication = torch.nn.CrossEntropyLoss()
    loss_recon1 = MSE()
    loss_recon2 = SIMSE()
    loss_diff = DiffLoss_tfTrans()
    loss_similarity = torch.nn.CrossEntropyLoss()

    if cuda:
        my_net = my_net.cuda()
        loss_classification = loss_classification.cuda()
        loss_recon1 = loss_recon1.cuda()
        loss_recon2 = loss_recon2.cuda()
        loss_diff = loss_diff.cuda()
        loss_similarity = loss_similarity.cuda()

    for p in my_net.parameters():
        p.requires_grad = True

    #############################
    # training network          #
    #############################

    # 获取最短数据长度
    len_dataloader = min(len(dataloader_source), len(dataloader_target))
    # 设置epoch
    dann_epoch = np.floor(active_domain_loss_step / len_dataloader * 1.0)

    current_step = 0
    # 开始训练
    accu_total1 = 0  # 统计dataset1中的总准确率和
    accu_total2 = 0  # 统计dataset2中的总准确率和
    time_total1 = 0  # 统计dataset1训练的总时间
    time_total2 = 0  # 统计dataset2训练的总时间
    for epoch in range(n_epoch):

        # 1.加载数据
        data_source_iter = iter(dataloader_source)
        data_target_iter = iter(dataloader_target)

        i = 0

        # 防止数据超过最短数据长度,否则可能由于缺失某些数据出现报错
        while i < len_dataloader:

            ########################
            # target data training #
            ########################

            # 加载target
            data_target = data_target_iter.next()
            t_img, t_label = data_target

            # 1.梯度清零
            my_net.zero_grad()
            loss = 0
            batch_size = len(t_label)

            # 2.初始化一些变量
            input_img = torch.FloatTensor(batch_size, 3, image_size,
                                          image_size)
            class_label = torch.LongTensor(batch_size)
            domain_label = torch.ones(batch_size)
            domain_label = domain_label.long()

            # 判断gpu是否可用,如果可用,就将数据传入cuda中
            if cuda:
                t_img = t_img.cuda()
                t_label = t_label.cuda()
                input_img = input_img.cuda()
                class_label = class_label.cuda()
                domain_label = domain_label.cuda()

            # 将一部分数据resize,并拷贝到上面设置的变量
            input_img.resize_as_(t_img).copy_(t_img)
            class_label.resize_as_(t_label).copy_(t_label)
            target_inputv_img = Variable(input_img)
            target_classv_label = Variable(class_label)
            target_domainv_label = Variable(domain_label)

            # 论文中涉及到的公式
            if current_step > active_domain_loss_step:
                p = float(i + (epoch - dann_epoch) * len_dataloader /
                          (n_epoch - dann_epoch) / len_dataloader)
                p = 2. / (1. + np.exp(-10 * p)) - 1

                # active domain loss
                # 这一步就是将输入输入到模型中,然后得到模型的结果
                result = my_net(input_data=target_inputv_img,
                                mode='target',
                                rec_scheme='all',
                                p=p)
                target_private_coda, target_share_coda, target_domain_label, target_rec_code = result  # 通过python拆包得到的几个变量
                target_dann = gamma_weight * loss_similarity(
                    target_domain_label, target_domainv_label)  # 4.计算损失值
                loss += target_dann  # 计算累计损失值
            else:
                if cuda:
                    target_dann = Variable(torch.zeros(1).float().cuda())  # ?
                else:
                    target_dann = Variable(torch.zeros(1).float())
                # 将输入传到模型中,然后得到模型结果
                result = my_net(input_data=target_inputv_img,
                                mode='target',
                                rec_scheme='all')
                target_private_coda, target_share_coda, _, target_rec_code = result  # 通过python的拆包得到几个变量

                # 以下几步用于计算损失值
                target_diff = beta_weight * loss_diff(
                    target_private_coda, target_share_coda, weight=0.05)
                loss += target_diff
                target_mse = alpha_weight * loss_recon1(
                    target_rec_code, target_inputv_img)
                loss += target_mse
                target_simse = alpha_weight * loss_recon2(
                    target_rec_code, target_inputv_img)
                loss += target_mse

                # 5.计算梯度
                loss.backward()
                # 6.利用梯度优化权重和偏置等网络参数
                # optimizer = exp_lr_scheduler(optimizer=optimizer,step = current_step)
                optimizer.step()

                #######################
                # source data training#
                #######################

                data_source = data_source_iter.next()
                s_img, s_label = data_source

                my_net.zero_grad()
                batch_size = len(s_label)

                input_img = torch.FloatTensor(batch_size, 3, image_size,
                                              image_size)
                class_label = torch.LongTensor(batch_size)
                domain_label = torch.zeros(batch_size)
                damain_label = domain_label.long()

                loss = 0

                if cuda:
                    s_img = s_img.cuda()
                    s_label = s_label.cuda()
                    input_img = input_img.cuda()
                    class_label = class_label.cuda()
                    domain_label = domain_label.cuda()

                input_img.resize_as_(input_img).copy_(s_img)
                class_label.resize_as_(s_label).copy_(s_label)
                source_inputv_img = Variable(input_img)
                source_classv_label = Variable(class_label)
                source_domainv_label = Variable(domain_label)

                if current_step > active_domain_loss_step:

                    # active domain loss

                    # 输入模型进行训练
                    result = my_net(input_data=source_inputv_img,
                                    mode='source',
                                    rec_scheme='all',
                                    p=p)
                    source_private_code, source_share_code, source_domain_label, source_classv_label, source_rec_code = result
                    source_dann = gamma_weight * loss_similarity(
                        source_domain_label, source_classv_label)
                    loss += source_dann
                else:
                    if cuda:
                        source_dann = Variable(torch.zeros(1).float().cuda())
                    else:
                        if cuda:
                            source_dann = Variable(
                                torch.zeros(1).float().cuda())
                        else:
                            source_dann = Variable(torch.zeros(1).float())
                        result = my_net(input_data=source_inputv_img,
                                        mode='source',
                                        rec_scheme='all')
                        source_private_code, source_share_code, _, source_class_label, source_rec_code = result

                    source_classification = loss_classfication(
                        source_class_label, source_classv_label)
                    loss += source_classification

                    source_diff = beta_weight * loss_diff(
                        source_private_code, source_share_code, weight=0.05)
                    loss += source_diff
                    source_mse = alpha_weight * loss_recon1(
                        source_rec_code, source_inputv_img)
                    loss += source_mse
                    source_simse = gamma_weight * loss_recon2(
                        source_rec_code, source_inputv_img)
                    loss += source_simse

                    loss.backward()
                    # optimizer = exp_lr_scheduler(optimizer=optimizer,step=current_step)
                    optimizer.step()

                    ##############
                    # 测试保存    #
                    ##############
                    i += 1
                    current_step += 1
                    # print('source_classification: %f, source_dann: %f, source_diff: %f, '\
                    # 'source_mse: %f, source_simse: %f, target_dann: %f, target_diff: %f, '\
                    # 'target_mse: %f, target_simse: %f' \
                    # % (source_classification.data.cpu().numpy(), source_dann.data.cpu().numpy(),
                    #   source_diff.data.cpu().numpy(),
                    #   source_mse.data.cpu().numpy(), source_simse.data.cpu().numpy(), target_dann.data.cpu().numpy(),
                    #   target_diff.data.cpu().numpy(), target_mse.data.cpu().numpy(), target_simse.data.cpu().numpy()))
                    # 训练数据集1并计算累积时间,和累积准确率
                    start1 = time.time()
                    accu1 = test(epoch=epoch, name='dataset1')
                    end1 = time.time()
                    curr1 = end1 - start1
                    time_total1 += curr1
                    accu_total1 += accu1
                    # 训练数据集2并计算累积时间,和累积准确率
                    start2 = time.time()
                    accu2 = test(epoch=epoch, name='dataset2')
                    end2 = time.time()
                    curr2 = end2 - start2
                    time_total2 += curr2
                    accu_total2 += accu2
                    # print(time.strftime('%Y-%m-%d %H:%M:%S'), time.localtime(time.time()))

                # 获取平均准确率做为训练性能的评价指标
    model_index = epoch
    # 获取模型保存路径
    model_path = 'D:\study\graduation_project\grdaution_project\instru_identify\dataset18dataset2' + '\dsn_epoch_' + str(
        model_index) + '.pth'
    while os.path.exists(model_path):
        model_index = model_index + 1
        model_path = 'D:\study\graduation_project\grdaution_project\instru_identify\dataset18dataset2' + '\dsn_epoch_' + str(
            model_index) + '.pth'
    torch.save(my_net.state_dict(), model_path)  # 保存模型
    average_accu1 = accu_total1 / (len_dataloader * n_epoch)
    average_accu2 = accu_total2 / (len_dataloader * n_epoch)
    # result = [float(average_accu1),float(average_accu2)]
    # 所有数据均保留三位小数进行存储
    print(round(float(average_accu1), 3))
    print(round(float(average_accu2), 3))
    print(round(float(time_total1), 3))
    print(round(float(time_total2), 3))
    # print('result:',result)
    return result
示例#9
0
def test(epoch, name):
    cuda = False
    cudnn.benchmark = True
    batch_size = 16
    image_size = 28
    p = str(8)
    model_root = 'dataset1' + p + 'dataset2'

    ################
    #   load data  #
    ################
    # 图形变换
    source_img_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),  # 归一化,进行图像的灰度处理
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # 单通道变为三通道
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    img_transform_source = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))  #?
    ])

    img_tranform_target = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    if name == 'dataset1':
        mode = 'source'
        image_root = r'D:\study\graduation_project\grdaution_project\instru_identify\dataset\dataset1\dataset1_test'
        # image_root.replace("\\",'/')
        test_list = r'D:\study\graduation_project\grdaution_project\instru_identify\dataset\dataset1\dataset1_test_labels.txt'

        dataset = GetLoader(
            data_root=image_root,
            data_list=test_list,
            transform=img_transform_source,
        )
        dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=0)
        # print('success')
    elif name == 'dataset2':
        mode = 'target'
        image_root = os.path.join('D:\\', 'study', 'graduation_project',
                                  'grdaution_project', 'instru_identify',
                                  'dataset', 'dataset2', 'dataset2_test')
        test_list = os.path.join('D:\\', 'study', 'graduation_project',
                                 'grdaution_project', 'instru_identify',
                                 'dataset', 'dataset2',
                                 'dataset2_test_labels.txt')
        dataset = GetLoader(
            data_root=image_root,
            data_list=test_list,
            transform=img_tranform_target,
        )

        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0  #?
        )
    else:
        print('error dataset name')

    ###############
    # load model  #
    ###############
    # print('image_root:', image_root)
    # print('test_list:',test_list)
    my_net = DSN()
    checkpoint = torch.load(
        os.path.join(model_root, 'dsn_epoch_' + str(epoch) + '.pth'))
    my_net.load_state_dict(checkpoint)
    my_net.eval()  #?

    if cuda:
        my_net = my_net  #.cuda()

    ###################
    # transform image #
    ###################

    # 这个函数对图片做了什么操作?
    def tr_image(img):
        img_new = (img + 1) / 2
        return img_new

    # print(dataloader)
    len_dataloader = len(dataloader)
    # print('len_dataloader:',len_dataloader)
    data_iter = iter(dataloader)  # 获取迭代器
    # print('data_iter:',data_iter)

    i = 0
    n_total = 0
    n_correct = 0

    total_accu = 0
    while i < len_dataloader - 1:
        #print(i)
        data_input = data_iter.next()
        #print('data_input:', data_input)
        img, label = data_input
        # print('label:', label)
        batch_size = len(label)  # batch_size为一个batch中图片的数量

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        class_label = torch.LongTensor(batch_size)

        if cuda:
            img = img  #.cuda()
            label = label  #.cuda()
            input_img = input_img  #.cuda()
            class_label = class_label  #.cuda()

        input_img.resize_as_(input_img).copy_(img)
        class_label.resize_as_(class_label).copy_(label)
        inputv_img = Variable(input_img)  #?
        classv_label = Variable(class_label)

        # 输入网络

        result = my_net(input_data=inputv_img,
                        mode='source',
                        rec_scheme='share')
        pred = result[3].data.max(1, keepdim=True)[1]
        # print('pred:',pred)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='all')
        rec_img_all = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='share')
        rec_img_share = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='private')
        rec_img_private = tr_image(result[-1].data)

        if i == len_dataloader - 2:
            image_save_path = os.path.join(model_root, 'images')
            if not os.path.exists(image_save_path):
                os.mkdir(image_save_path)
            vutils.save_image(rec_img_all,
                              image_save_path + '/' + name +
                              '_rec_image_all.png',
                              nrow=8)
            vutils.save_image(rec_img_share,
                              image_save_path + '/' + name +
                              'rec_image_share.png',
                              nrow=8)
            vutils.save_image(rec_img_private,
                              image_save_path + '/' + name +
                              'rec_image_private.png',
                              nrow=8)

        n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1
    accu = n_correct * 1.0 / n_total

    # print('n_correct:', n_correct)
    # print('n_total:', n_total)
    # print('epoch: %d,accuracy of the %s dataset: %f' % (epoch, name, accu))
    return accu
示例#10
0
    transforms.Resize((28,28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# dataset_source = datasets.MNIST(
#     root='dataset',
#     train=True,
#     transform=img_transform_source,
#     download=True
# )

source_list = os.path.join(source_image_root, 'image_label.txt')
dataset_source = GetLoader(
    data_root='/root/Data/source/',
    data_list=source_list,
    transform=img_transform_source
)

dataloader_source = torch.utils.data.DataLoader(
    dataset=dataset_source,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4)

train_list = os.path.join(target_image_root, 'image_label.txt')
dataset_target = GetLoader(
    data_root='/root/Data/target/',
    data_list=train_list,
    transform=img_transform_target,
)
def run(args):
    args.logdir = args.logdir + args.mode
    args.trained = args.trained + args.mode + '/best_model.pt'
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    logger = get_logger(os.path.join(args.logdir, 'main.log'))
    logger.info(args)

    # data
    # source_transform = transforms.Compose([
    #     # transforms.Grayscale(),
    #     transforms.ToTensor()]
    # )
    # target_transform = transforms.Compose([
    #     transforms.Resize(32),
    #     transforms.ToTensor(),
    #     transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    # ])
    # source_dataset_train = SVHN(
    #     './input', 'train', transform=source_transform, download=True)
    # target_dataset_train = MNIST(
    #     './input', train=True, transform=target_transform, download=True)
    # target_dataset_test = MNIST(
    #     './input', train=False, transform=target_transform, download=True)
    # source_train_loader = DataLoader(
    #     source_dataset_train, args.batch_size, shuffle=True,
    #     drop_last=True,
    #     num_workers=args.n_workers)
    # target_train_loader = DataLoader(
    #     target_dataset_train, args.batch_size, shuffle=True,
    #     drop_last=True,
    #     num_workers=args.n_workers)
    # target_test_loader = DataLoader(
    #     target_dataset_test, args.batch_size, shuffle=False,
    #     num_workers=args.n_workers)
    batch_size = 128
    if args.mode == 'm2mm':
        source_dataset_name = 'MNIST'
        target_dataset_name = 'mnist_m'
        source_image_root = os.path.join('dataset', source_dataset_name)
        target_image_root = os.path.join('dataset', target_dataset_name)
        image_size = 28
        img_transform_source = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
        ])

        img_transform_target = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        dataset_source = datasets.MNIST(root='dataset',
                                        train=True,
                                        transform=img_transform_source,
                                        download=True)

        train_list = os.path.join(target_image_root,
                                  'mnist_m_train_labels.txt')

        dataset_target_train = GetLoader(data_root=os.path.join(
            target_image_root, 'mnist_m_train'),
                                         data_list=train_list,
                                         transform=img_transform_target)

        test_list = os.path.join(target_image_root, 'mnist_m_test_labels.txt')

        dataset_target_test = GetLoader(data_root=os.path.join(
            target_image_root, 'mnist_m_test'),
                                        data_list=test_list,
                                        transform=img_transform_target)
    elif args.mode == 's2u':
        dataset_source = svhn.SVHN('./data/svhn/',
                                   split='train',
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(28),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))
                                   ]))

        dataset_target_train = usps.USPS('./data/usps/',
                                         train=True,
                                         download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.5, ),
                                                                  (0.5, ))
                                         ]))
        dataset_target_test = usps.USPS('./data/usps/',
                                        train=False,
                                        download=True,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5, ),
                                                                 (0.5, ))
                                        ]))
        source_dataset_name = 'svhn'
        target_dataset_name = 'usps'

    source_train_loader = torch.utils.data.DataLoader(dataset=dataset_source,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      num_workers=8)

    target_train_loader = torch.utils.data.DataLoader(
        dataset=dataset_target_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8)

    target_test_loader = torch.utils.data.DataLoader(
        dataset=dataset_target_test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8)
    # train source CNN
    source_cnn = CNN(in_channels=args.in_channels).to(args.device)
    if os.path.isfile(args.trained):
        print("load model")
        c = torch.load(args.trained)
        source_cnn.load_state_dict(c['model'])
        logger.info('Loaded `{}`'.format(args.trained))
    else:
        print("not load model")

    # train target CNN
    target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
    target_cnn.load_state_dict(source_cnn.state_dict())
    discriminator = Discriminator(args=args).to(args.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(target_cnn.encoder.parameters(), lr=args.lr)
    # optimizer = optim.Adam(
    #     target_cnn.encoder.parameters(),
    #     lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr)
    # d_optimizer = optim.Adam(
    #     discriminator.parameters(),
    #     lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
    train_target_cnn(source_cnn,
                     target_cnn,
                     discriminator,
                     criterion,
                     optimizer,
                     d_optimizer,
                     source_train_loader,
                     target_train_loader,
                     target_test_loader,
                     args=args)
示例#12
0
def test(dataset_name, epoch):
    assert dataset_name in ['source', 'target']

    model_root = 'models'
    image_root = os.path.join('/root/Data', dataset_name)

    cuda = True
    cudnn.benchmark = True
    batch_size = 128
    image_size = 28
    alpha = 0
    """load data"""

    img_transform_source = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
    ])

    img_transform_target = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # if dataset_name == 'mnist_m':
    #     test_list = os.path.join(image_root, 'mnist_m_test_labels.txt')
    #
    #     dataset = GetLoader(
    #         data_root=os.path.join(image_root, 'mnist_m_test'),
    #         data_list=test_list,
    #         transform=img_transform_target
    #     )
    # else:
    #     dataset = datasets.MNIST(
    #         root='dataset',
    #         train=False,
    #         transform=img_transform_source,
    #     )

    target_list = os.path.join(image_root, 'image_label.txt')

    dataset = GetLoader(data_root=image_root,
                        data_list=target_list,
                        transform=img_transform_target)

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=2)
    """ training """

    my_net = torch.load(
        os.path.join(model_root,
                     'mnist_mnistm_model_epoch_' + str(epoch) + '.pth'))
    my_net = my_net.eval()

    if cuda:
        my_net = my_net.cuda()

    len_dataloader = len(dataloader)
    data_target_iter = iter(dataloader)

    i = 0
    n_total = 0
    n_correct = 0

    num_class = 15

    acc_class = [0 for _ in range(num_class)]
    count_class = [0 for _ in range(num_class)]

    tsne_results = np.array([])
    tsne_labels = np.array([])

    while i < len_dataloader:

        # test model using target data
        data_target = data_target_iter.next()
        t_img, t_label = data_target

        batch_size = len(t_label)

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        class_label = torch.LongTensor(batch_size)

        if cuda:
            t_img = t_img.cuda()
            t_label = t_label.cuda()
            input_img = input_img.cuda()
            class_label = class_label.cuda()

        input_img.resize_as_(t_img).copy_(t_img)
        class_label.resize_as_(t_label).copy_(t_label)

        class_output, _ = my_net(input_data=input_img, alpha=alpha)
        pred = class_output.data.max(1, keepdim=True)[1]
        pred1 = class_output.data.max(1)[1]
        n_correct += pred.eq(class_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1
        index_temp = pred1.eq(t_label.data)

        for acc_index in range(batch_size):
            temp_label_index = t_label.data[acc_index]
            count_class[temp_label_index] += 1
            if index_temp[acc_index]:
                acc_class[temp_label_index] += 1

        if len(tsne_labels) == 0:
            tsne_results = class_output.cpu().data.numpy()
            tsne_labels = t_label.cpu().numpy()
        else:
            tsne_results = np.concatenate(
                (tsne_results, class_output.cpu().data.numpy()))
            tsne_labels = np.concatenate((tsne_labels, t_label.cpu().numpy()))

    plot_only = 1000
    tsne_model = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
    tsne_transformed = tsne_model.fit_transform(tsne_results[:plot_only, :])
    tsne_labels = tsne_labels[:plot_only]

    # colors = cm.rainbow(np.linspace(0, 1, num_class))
    for x, y, s in zip(tsne_transformed[:, 0], tsne_transformed[:, 1],
                       tsne_labels):
        c = cm.rainbow(int(255 * s / num_class))
        plt.scatter(x, y, c=c)
    plt.xticks([])
    plt.yticks([])
    plt.savefig('output1.png')

    for print_index in range(len(acc_class)):
        print('Class:{}, Accuracy:{:.2f}%'.format(
            print_index,
            100. * acc_class[print_index] / count_class[print_index]))

    accu = n_correct.data.numpy() * 1.0 / n_total

    print('epoch: %d, accuracy of the %s dataset: %f' %
          (epoch, dataset_name, accu))
    torch.save(
        accu, '/root/Data/dann_result/dann_ep_' + str(epoch) + '_' +
        dataset_name + '_' + str(accu) + '.pt')