コード例 #1
0
ファイル: test.py プロジェクト: yhn280385395/DSN
def test(epoch, name):

    ###################
    # params          #
    ###################
    cuda = True
    cudnn.benchmark = True
    batch_size = 64
    image_size = 28

    ###################
    # load data       #
    ###################

    img_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    model_root = 'model'
    if name == 'mnist':
        mode = 'source'
        image_root = os.path.join('dataset', 'mnist')
        dataset = datasets.MNIST(root=image_root,
                                 train=False,
                                 transform=img_transform)

        dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=8)

    elif name == 'mnist_m':
        mode = 'target'
        image_root = os.path.join('dataset', 'mnist_m', 'mnist_m_test')
        test_list = os.path.join('dataset', 'mnist_m',
                                 'mnist_m_test_labels.txt')

        dataset = GetLoader(data_root=image_root,
                            data_list=test_list,
                            transform=img_transform)

        dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=8)

    else:
        print 'error dataset name'

    ####################
    # load model       #
    ####################

    my_net = DSN()
    checkpoint = torch.load(
        os.path.join(model_root,
                     'dsn_mnist_mnistm_epoch_' + str(epoch) + '.pth'))
    my_net.load_state_dict(checkpoint)
    my_net.eval()

    if cuda:
        my_net = my_net.cuda()

    ####################
    # transform image  #
    ####################

    def tr_image(img):

        img_new = (img + 1) / 2

        return img_new

    len_dataloader = len(dataloader)
    data_iter = iter(dataloader)

    i = 0
    n_total = 0
    n_correct = 0

    while i < len_dataloader:

        data_input = data_iter.next()
        img, label = data_input

        batch_size = len(label)

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        class_label = torch.LongTensor(batch_size)

        if cuda:
            img = img.cuda()
            label = label.cuda()
            input_img = input_img.cuda()
            class_label = class_label.cuda()

        input_img.resize_as_(input_img).copy_(img)
        class_label.resize_as_(label).copy_(label)
        inputv_img = Variable(input_img)
        classv_label = Variable(class_label)

        result = my_net(input_data=inputv_img,
                        mode='source',
                        rec_scheme='share')
        pred = result[3].data.max(1, keepdim=True)[1]

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='all')
        rec_img_all = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='share')
        rec_img_share = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='private')
        rec_img_private = tr_image(result[-1].data)

        if i == len_dataloader - 2:
            vutils.save_image(rec_img_all, name + '_rec_image_all.png', nrow=8)
            vutils.save_image(rec_img_share,
                              name + '_rec_image_share.png',
                              nrow=8)
            vutils.save_image(rec_img_private,
                              name + '_rec_image_private.png',
                              nrow=8)

        n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1

    accu = n_correct * 1.0 / n_total

    print 'epoch: %d, accuracy of the %s dataset: %f' % (epoch, name, accu)
コード例 #2
0
ファイル: train.py プロジェクト: HUSTluoqingqing/gitlearn
dataset_target = GetLoader(data_root=os.path.join(target_image_root,
                                                  'mnist_m_train'),
                           data_list=train_list,
                           transform=img_transform)

dataloader_target = torch.utils.data.DataLoader(dataset=dataset_target,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                num_workers=0)

#####################
#  load model       #
#####################

my_net = DSN()

#####################
# setup optimizer   #
#####################


def exp_lr_scheduler(optimizer,
                     step,
                     init_lr=lr,
                     lr_decay_step=lr_decay_step,
                     step_decay_weight=step_decay_weight):

    # Decay learning rate by a factor of step_decay_weight every lr_decay_step
    current_lr = init_lr * (step_decay_weight**(step / lr_decay_step))
コード例 #3
0
def run(net_str):
    # execute only if run as the entry point into the program
    # 定义源域和当前目标域
    net_str = os.path.join(
        'D:\study\graduation_project\grdaution_project\instru_identify\dataset18dataset2',
        net_str)
    source_image_root = os.path.join('D:\\', 'study', 'graduation_project',
                                     'grdaution_project', 'instru_identify',
                                     'dataset', 'dataset1')
    target_image_root = os.path.join('D:\\', 'study', 'graduation_project',
                                     'grdaution_project', 'instru_identify',
                                     'dataset', 'dataset2')

    target = 'dataset2'

    # 选取历史数据的比例
    p = str(8)
    # 模型保存路径
    model_root = 'dataset1' + p + 'dataset2'
    if not os.path.exists(model_root):
        os.mkdir(model_root)
    if not os.path.exists(model_root):
        os.makedirs(model_root)

    # 训练日志保存
    log_path = os.path.join(model_root, 'train.txt')
    sys.stdout = Logger(log_path)

    # print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))

    # 训练参数定义
    cuda = False
    cudnn.benchmark = True
    lr = 1e-2
    batch_size = 16
    image_size = 28
    n_epoch = 1
    step_decay_weight = 0.95
    lr_decay_step = 20000
    active_domain_loss_step = 10000
    weight_decay = 1e-6
    alpha_weight = 0.01
    beta_weight = 0.075
    gamma_weight = 0.25
    momentum = 0.9

    manual_seed = random.randint(1, 10000)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)

    #######################
    # load data           #
    #######################

    img_transform_source = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
    ])

    img_transform_target = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # 源域数据加载
    source_list = os.path.join(source_image_root, 'dataset1_train_labels.txt')
    dataset_source = GetLoader(
        data_root=os.path.join(source_image_root, 'dataset1_train'),
        data_list=source_list,
        transform=img_transform_target,
    )

    dataloader_source = torch.utils.data.DataLoader(
        dataset=dataset_source,
        batch_size=batch_size,
        shuffle=True,  # 随机数种子
        num_workers=0  # 进程数
    )

    # 目标域数据加载
    target_list = os.path.join(target_image_root, 'dataset2_train_labels.txt')
    dataset_target = GetLoader(
        data_root=os.path.join(target_image_root, 'dataset2_train'),
        data_list=target_list,
        transform=img_transform_target,
    )

    dataloader_target = torch.utils.data.DataLoader(
        dataset=dataset_target,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # 单进程加载
    )

    #####################
    #  load model       #
    #####################

    my_net = DSN()
    my_net.load_state_dict(torch.load(net_str))

    #####################
    # setup optimizer   #
    #####################

    def exp_lr_scheduler(optimizer,
                         step,
                         init_lr=lr,
                         lr_decay_step=lr_decay_step,
                         step_decay_weight=step_decay_weight):

        # Decay learning rate by a factor of step_decay_weight every lr_decay_step
        current_lr = init_lr * (step_decay_weight**(step / lr_decay_step))

        if step % lr_decay_step == 0:
            print('learning rate is set to %f' % current_lr)

        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr

        return optimizer

    optimizer = optim.SGD(my_net.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay)

    # 损失函数定义
    loss_classfication = torch.nn.CrossEntropyLoss()
    loss_recon1 = MSE()
    loss_recon2 = SIMSE()
    loss_diff = DiffLoss_tfTrans()
    loss_similarity = torch.nn.CrossEntropyLoss()

    if cuda:
        my_net = my_net.cuda()
        loss_classification = loss_classification.cuda()
        loss_recon1 = loss_recon1.cuda()
        loss_recon2 = loss_recon2.cuda()
        loss_diff = loss_diff.cuda()
        loss_similarity = loss_similarity.cuda()

    for p in my_net.parameters():
        p.requires_grad = True

    #############################
    # training network          #
    #############################

    # 获取最短数据长度
    len_dataloader = min(len(dataloader_source), len(dataloader_target))
    # 设置epoch
    dann_epoch = np.floor(active_domain_loss_step / len_dataloader * 1.0)

    current_step = 0
    # 开始训练
    accu_total1 = 0  # 统计dataset1中的总准确率和
    accu_total2 = 0  # 统计dataset2中的总准确率和
    time_total1 = 0  # 统计dataset1训练的总时间
    time_total2 = 0  # 统计dataset2训练的总时间
    for epoch in range(n_epoch):

        # 1.加载数据
        data_source_iter = iter(dataloader_source)
        data_target_iter = iter(dataloader_target)

        i = 0

        # 防止数据超过最短数据长度,否则可能由于缺失某些数据出现报错
        while i < len_dataloader:

            ########################
            # target data training #
            ########################

            # 加载target
            data_target = data_target_iter.next()
            t_img, t_label = data_target

            # 1.梯度清零
            my_net.zero_grad()
            loss = 0
            batch_size = len(t_label)

            # 2.初始化一些变量
            input_img = torch.FloatTensor(batch_size, 3, image_size,
                                          image_size)
            class_label = torch.LongTensor(batch_size)
            domain_label = torch.ones(batch_size)
            domain_label = domain_label.long()

            # 判断gpu是否可用,如果可用,就将数据传入cuda中
            if cuda:
                t_img = t_img.cuda()
                t_label = t_label.cuda()
                input_img = input_img.cuda()
                class_label = class_label.cuda()
                domain_label = domain_label.cuda()

            # 将一部分数据resize,并拷贝到上面设置的变量
            input_img.resize_as_(t_img).copy_(t_img)
            class_label.resize_as_(t_label).copy_(t_label)
            target_inputv_img = Variable(input_img)
            target_classv_label = Variable(class_label)
            target_domainv_label = Variable(domain_label)

            # 论文中涉及到的公式
            if current_step > active_domain_loss_step:
                p = float(i + (epoch - dann_epoch) * len_dataloader /
                          (n_epoch - dann_epoch) / len_dataloader)
                p = 2. / (1. + np.exp(-10 * p)) - 1

                # active domain loss
                # 这一步就是将输入输入到模型中,然后得到模型的结果
                result = my_net(input_data=target_inputv_img,
                                mode='target',
                                rec_scheme='all',
                                p=p)
                target_private_coda, target_share_coda, target_domain_label, target_rec_code = result  # 通过python拆包得到的几个变量
                target_dann = gamma_weight * loss_similarity(
                    target_domain_label, target_domainv_label)  # 4.计算损失值
                loss += target_dann  # 计算累计损失值
            else:
                if cuda:
                    target_dann = Variable(torch.zeros(1).float().cuda())  # ?
                else:
                    target_dann = Variable(torch.zeros(1).float())
                # 将输入传到模型中,然后得到模型结果
                result = my_net(input_data=target_inputv_img,
                                mode='target',
                                rec_scheme='all')
                target_private_coda, target_share_coda, _, target_rec_code = result  # 通过python的拆包得到几个变量

                # 以下几步用于计算损失值
                target_diff = beta_weight * loss_diff(
                    target_private_coda, target_share_coda, weight=0.05)
                loss += target_diff
                target_mse = alpha_weight * loss_recon1(
                    target_rec_code, target_inputv_img)
                loss += target_mse
                target_simse = alpha_weight * loss_recon2(
                    target_rec_code, target_inputv_img)
                loss += target_mse

                # 5.计算梯度
                loss.backward()
                # 6.利用梯度优化权重和偏置等网络参数
                # optimizer = exp_lr_scheduler(optimizer=optimizer,step = current_step)
                optimizer.step()

                #######################
                # source data training#
                #######################

                data_source = data_source_iter.next()
                s_img, s_label = data_source

                my_net.zero_grad()
                batch_size = len(s_label)

                input_img = torch.FloatTensor(batch_size, 3, image_size,
                                              image_size)
                class_label = torch.LongTensor(batch_size)
                domain_label = torch.zeros(batch_size)
                damain_label = domain_label.long()

                loss = 0

                if cuda:
                    s_img = s_img.cuda()
                    s_label = s_label.cuda()
                    input_img = input_img.cuda()
                    class_label = class_label.cuda()
                    domain_label = domain_label.cuda()

                input_img.resize_as_(input_img).copy_(s_img)
                class_label.resize_as_(s_label).copy_(s_label)
                source_inputv_img = Variable(input_img)
                source_classv_label = Variable(class_label)
                source_domainv_label = Variable(domain_label)

                if current_step > active_domain_loss_step:

                    # active domain loss

                    # 输入模型进行训练
                    result = my_net(input_data=source_inputv_img,
                                    mode='source',
                                    rec_scheme='all',
                                    p=p)
                    source_private_code, source_share_code, source_domain_label, source_classv_label, source_rec_code = result
                    source_dann = gamma_weight * loss_similarity(
                        source_domain_label, source_classv_label)
                    loss += source_dann
                else:
                    if cuda:
                        source_dann = Variable(torch.zeros(1).float().cuda())
                    else:
                        if cuda:
                            source_dann = Variable(
                                torch.zeros(1).float().cuda())
                        else:
                            source_dann = Variable(torch.zeros(1).float())
                        result = my_net(input_data=source_inputv_img,
                                        mode='source',
                                        rec_scheme='all')
                        source_private_code, source_share_code, _, source_class_label, source_rec_code = result

                    source_classification = loss_classfication(
                        source_class_label, source_classv_label)
                    loss += source_classification

                    source_diff = beta_weight * loss_diff(
                        source_private_code, source_share_code, weight=0.05)
                    loss += source_diff
                    source_mse = alpha_weight * loss_recon1(
                        source_rec_code, source_inputv_img)
                    loss += source_mse
                    source_simse = gamma_weight * loss_recon2(
                        source_rec_code, source_inputv_img)
                    loss += source_simse

                    loss.backward()
                    # optimizer = exp_lr_scheduler(optimizer=optimizer,step=current_step)
                    optimizer.step()

                    ##############
                    # 测试保存    #
                    ##############
                    i += 1
                    current_step += 1
                    # print('source_classification: %f, source_dann: %f, source_diff: %f, '\
                    # 'source_mse: %f, source_simse: %f, target_dann: %f, target_diff: %f, '\
                    # 'target_mse: %f, target_simse: %f' \
                    # % (source_classification.data.cpu().numpy(), source_dann.data.cpu().numpy(),
                    #   source_diff.data.cpu().numpy(),
                    #   source_mse.data.cpu().numpy(), source_simse.data.cpu().numpy(), target_dann.data.cpu().numpy(),
                    #   target_diff.data.cpu().numpy(), target_mse.data.cpu().numpy(), target_simse.data.cpu().numpy()))
                    # 训练数据集1并计算累积时间,和累积准确率
                    start1 = time.time()
                    accu1 = test(epoch=epoch, name='dataset1')
                    end1 = time.time()
                    curr1 = end1 - start1
                    time_total1 += curr1
                    accu_total1 += accu1
                    # 训练数据集2并计算累积时间,和累积准确率
                    start2 = time.time()
                    accu2 = test(epoch=epoch, name='dataset2')
                    end2 = time.time()
                    curr2 = end2 - start2
                    time_total2 += curr2
                    accu_total2 += accu2
                    # print(time.strftime('%Y-%m-%d %H:%M:%S'), time.localtime(time.time()))

                # 获取平均准确率做为训练性能的评价指标
    model_index = epoch
    # 获取模型保存路径
    model_path = 'D:\study\graduation_project\grdaution_project\instru_identify\dataset18dataset2' + '\dsn_epoch_' + str(
        model_index) + '.pth'
    while os.path.exists(model_path):
        model_index = model_index + 1
        model_path = 'D:\study\graduation_project\grdaution_project\instru_identify\dataset18dataset2' + '\dsn_epoch_' + str(
            model_index) + '.pth'
    torch.save(my_net.state_dict(), model_path)  # 保存模型
    average_accu1 = accu_total1 / (len_dataloader * n_epoch)
    average_accu2 = accu_total2 / (len_dataloader * n_epoch)
    # result = [float(average_accu1),float(average_accu2)]
    # 所有数据均保留三位小数进行存储
    print(round(float(average_accu1), 3))
    print(round(float(average_accu2), 3))
    print(round(float(time_total1), 3))
    print(round(float(time_total2), 3))
    # print('result:',result)
    return result
コード例 #4
0
def test(epoch, name):
    cuda = False
    cudnn.benchmark = True
    batch_size = 16
    image_size = 28
    p = str(8)
    model_root = 'dataset1' + p + 'dataset2'

    ################
    #   load data  #
    ################
    # 图形变换
    source_img_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),  # 归一化,进行图像的灰度处理
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # 单通道变为三通道
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    img_transform_source = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))  #?
    ])

    img_tranform_target = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    if name == 'dataset1':
        mode = 'source'
        image_root = r'D:\study\graduation_project\grdaution_project\instru_identify\dataset\dataset1\dataset1_test'
        # image_root.replace("\\",'/')
        test_list = r'D:\study\graduation_project\grdaution_project\instru_identify\dataset\dataset1\dataset1_test_labels.txt'

        dataset = GetLoader(
            data_root=image_root,
            data_list=test_list,
            transform=img_transform_source,
        )
        dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=0)
        # print('success')
    elif name == 'dataset2':
        mode = 'target'
        image_root = os.path.join('D:\\', 'study', 'graduation_project',
                                  'grdaution_project', 'instru_identify',
                                  'dataset', 'dataset2', 'dataset2_test')
        test_list = os.path.join('D:\\', 'study', 'graduation_project',
                                 'grdaution_project', 'instru_identify',
                                 'dataset', 'dataset2',
                                 'dataset2_test_labels.txt')
        dataset = GetLoader(
            data_root=image_root,
            data_list=test_list,
            transform=img_tranform_target,
        )

        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0  #?
        )
    else:
        print('error dataset name')

    ###############
    # load model  #
    ###############
    # print('image_root:', image_root)
    # print('test_list:',test_list)
    my_net = DSN()
    checkpoint = torch.load(
        os.path.join(model_root, 'dsn_epoch_' + str(epoch) + '.pth'))
    my_net.load_state_dict(checkpoint)
    my_net.eval()  #?

    if cuda:
        my_net = my_net  #.cuda()

    ###################
    # transform image #
    ###################

    # 这个函数对图片做了什么操作?
    def tr_image(img):
        img_new = (img + 1) / 2
        return img_new

    # print(dataloader)
    len_dataloader = len(dataloader)
    # print('len_dataloader:',len_dataloader)
    data_iter = iter(dataloader)  # 获取迭代器
    # print('data_iter:',data_iter)

    i = 0
    n_total = 0
    n_correct = 0

    total_accu = 0
    while i < len_dataloader - 1:
        #print(i)
        data_input = data_iter.next()
        #print('data_input:', data_input)
        img, label = data_input
        # print('label:', label)
        batch_size = len(label)  # batch_size为一个batch中图片的数量

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        class_label = torch.LongTensor(batch_size)

        if cuda:
            img = img  #.cuda()
            label = label  #.cuda()
            input_img = input_img  #.cuda()
            class_label = class_label  #.cuda()

        input_img.resize_as_(input_img).copy_(img)
        class_label.resize_as_(class_label).copy_(label)
        inputv_img = Variable(input_img)  #?
        classv_label = Variable(class_label)

        # 输入网络

        result = my_net(input_data=inputv_img,
                        mode='source',
                        rec_scheme='share')
        pred = result[3].data.max(1, keepdim=True)[1]
        # print('pred:',pred)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='all')
        rec_img_all = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='share')
        rec_img_share = tr_image(result[-1].data)

        result = my_net(input_data=inputv_img, mode=mode, rec_scheme='private')
        rec_img_private = tr_image(result[-1].data)

        if i == len_dataloader - 2:
            image_save_path = os.path.join(model_root, 'images')
            if not os.path.exists(image_save_path):
                os.mkdir(image_save_path)
            vutils.save_image(rec_img_all,
                              image_save_path + '/' + name +
                              '_rec_image_all.png',
                              nrow=8)
            vutils.save_image(rec_img_share,
                              image_save_path + '/' + name +
                              'rec_image_share.png',
                              nrow=8)
            vutils.save_image(rec_img_private,
                              image_save_path + '/' + name +
                              'rec_image_private.png',
                              nrow=8)

        n_correct += pred.eq(classv_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1
    accu = n_correct * 1.0 / n_total

    # print('n_correct:', n_correct)
    # print('n_total:', n_total)
    # print('epoch: %d,accuracy of the %s dataset: %f' % (epoch, name, accu))
    return accu