Ejemplo n.º 1
0
    def build_model(self):
        self.net = build_model(self.config.arch)
        # 是否将网络搬运至cuda
        if self.config.cuda:
            self.net = self.net.cuda()
        # self.net.train()
        # 设置eval状态
        self.net.eval()  # use_global_stats = True
        # 网络权重初始化
        self.net.apply(weights_init)
        # 载入预训练模型或自行训练模型
        if self.config.load == '':
            self.net.base.load_pretrained_model(
                torch.load(self.config.pretrained_model))
        else:
            self.net.load_state_dict(torch.load(self.config.load))

        # 学习率
        self.lr = self.config.lr
        # 权值衰减
        self.wd = self.config.wd

        # 设置优化器
        self.optimizer = Adam(filter(lambda p: p.requires_grad,
                                     self.net.parameters()),
                              lr=self.lr,
                              weight_decay=self.wd)
        # 打印网络结构
        self.print_network(self.net, 'PoolNet Structure')
Ejemplo n.º 2
0
    def build_model(self):
        net = build_model(self.arch)
        if torch.cuda.is_available():
            net = net.cuda()

        # BN层中有个参数use_global_stats,它表示是否使用caffe内部的均值和方差。
        # 训练模型的时候,将BN层use_global_stats设置为false;测试的时候设置为true,不然训练的时候会报nan或者模型不收敛。
        net.eval()  # use_global_stats = True

        net.apply(self.weights_init)
        if self.pretrained_model:
            net.base.load_pretrained_model(torch.load(self.pretrained_model))
            pass
        self._print_network(net, 'PoolNet Structure')
        return net
Ejemplo n.º 3
0
def main():

    # create the model
    model = build_model()
    model.to(device)
    model.load_state_dict(torch.load(args.restore_from))

    # create domintor
    model_D1 = FCDiscriminator(num_classes=1)
    model_D1.to(device)
    model_D1.load_state_dict(torch.load(args.D_restore_from))

    up = torch.nn.Upsample(scale_factor=32, mode='bilinear')
    sig = torch.nn.Sigmoid()

    # labels for adversarial training 两种域的记号
    salLabel = 0
    edgeLabel = 1

    picloader = get_loader(args)
    correct = 0
    tot = 0
    for i_iter, data_batch in enumerate(picloader):
        tot += 2

        sal_image, edge_image = data_batch['sal_image'], data_batch[
            'edge_image']
        sal_image, edge_image = Variable(sal_image), Variable(edge_image)
        sal_image, edge_image = sal_image.to(device), edge_image.to(device)

        sal_pred = model(sal_image)
        edge_pred = model(edge_image)

        # test D
        # for param in model_D1.parameters():
        #     param.requires_grad = True

        ss_out = model_D1(sal_pred)
        se_out = model_D1(edge_pred)
        if pan(ss_out) == salLabel:
            correct += 1
        if pan(se_out) == edgeLabel:
            correct += 1

        if i_iter % 100 == 0:
            print('processing %d: %f' % (i_iter, correct / tot))

    print(correct / tot)
Ejemplo n.º 4
0
    def build_model(self):
        self.net = build_model(self.config.arch)
        if self.config.cuda:
            self.net = self.net.cuda()
        # self.net.train()
        self.net.eval()  # use_global_stats = True
        self.net.apply(weights_init)
        if self.config.load == '':
            self.net.base.load_pretrained_model(torch.load(self.config.pretrained_model))
        else:
            self.net.load_state_dict(torch.load(self.config.load))

        self.lr = self.config.lr
        self.wd = self.config.wd

        self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd)
        self.print_network(self.net, 'PoolNet Structure')
Ejemplo n.º 5
0
def main():

    gtdir = args.snapshot_dir + 'gt/1/'
    preddir = args.snapshot_dir + 'pred/tsy1/1/'

    # make dir
    if not os.path.exists(gtdir):
        os.makedirs(gtdir)
    if not os.path.exists(preddir):
        os.makedirs(preddir)

    # xuan xue you hua
    cudnn.enabled = True
    cudnn.benchmark = True

    # create the model
    model = build_model()
    model.to(device)
    model.train()
    model.apply(weights_init)
    model.load_state_dict(torch.load(args.restore_from))

    picloader = get_loader(args, mode='test')

    for i_iter, data_batch in enumerate(picloader):

        if i_iter % 50 == 0:
            print(i_iter)
        sal_image, sal_label = data_batch['sal_image'], data_batch['sal_label']

        with torch.no_grad():
            sal_image = Variable(sal_image).to(device)
            preds = model(sal_image)
            pred = np.squeeze(torch.sigmoid(preds).cpu().data.numpy())
            label = np.squeeze(sal_label.cpu().data.numpy())
            multi_fuse = 255 * pred
            label = 255 * label
            cv2.imwrite(os.path.join(preddir,
                                     str(i_iter) + '.jpg'), multi_fuse)
            cv2.imwrite(os.path.join(gtdir, str(i_iter) + '.png'), label)
Ejemplo n.º 6
0
    def test(arch, model_path, test_loader, result_fold):
        Tools.print('Loading trained model from {}'.format(model_path))
        net = build_model(arch).cuda()
        net.load_state_dict(torch.load(model_path))
        net.eval()

        time_s = time.time()
        img_num = len(test_loader)
        for i, data_batch in enumerate(test_loader):
            if i % 100 == 0:
                Tools.print("test {} {}".format(i, img_num))
            images, name, im_size = data_batch['image'], data_batch['name'][
                0], np.asarray(data_batch['size'])
            with torch.no_grad():
                images = torch.Tensor(images).cuda()
                pred = net(images)
                pred = np.squeeze(torch.sigmoid(pred).cpu().data.numpy()) * 255
                cv2.imwrite(os.path.join(result_fold, name[:-4] + '.png'),
                            pred)
        time_e = time.time()
        Tools.print('Speed: %f FPS' % (img_num / (time_e - time_s)))
        Tools.print('Test Done!')
        pass
def main():

    # make dir
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    run = 0
    while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)):
        run += 1
    os.mkdir("%s/run-%d" % (args.snapshot_dir, run))
    os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run))
    args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run)
    args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run)

    # xuan xue you hua
    cudnn.enabled = True
    cudnn.benchmark = True

    # create the model
    model = build_model()
    model.to(device)
    model.train()
    model.apply(weights_init)
    model.load_state_dict(torch.load(args.restore_from))
    # model.base.load_pretrained_model(torch.load(args.pretrained_model))

    # create domintor
    model_D1 = FCDiscriminator(num_classes=1).to(device)
    model_D1.train()
    model_D1.apply(weights_init)
    # model_D1.load_state_dict(torch.load(args.D_restore_from))

    # create optimizer
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)  # 整个模型的优化器
    optimizer.zero_grad()
    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    # uneccessery
    bce_loss = torch.nn.BCEWithLogitsLoss()

    # start time
    with open(args.file_dir, 'a') as f:
        f.write('strat time: ' + str(datetime.now()) + '\n\n')

        f.write('learning rate: ' + str(args.learning_rate) + '\n')
        f.write('learning rate D: ' + str(args.learning_rate_D) + '\n')
        f.write('wight decay: ' + str(args.weight_decay) + '\n')
        f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n')

        f.write('eptch size: ' + str(args.epotch_size) + '\n')
        f.write('batch size: ' + str(args.batch_size) + '\n')
        f.write('iter size: ' + str(args.iter_size) + '\n')
        f.write('num steps: ' + str(args.num_steps) + '\n\n')

    # labels for adversarial training 两种域的记号
    salLabel = 0
    edgeLabel = 1

    picloader = get_loader(args)
    iter_num = len(picloader.dataset) // args.batch_size
    aveGrad = 0

    for i_epotch in range(args.epotch_size):
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        model.zero_grad()

        for i_iter, data_batch in enumerate(picloader):

            sal_image, sal_label, edge_image = data_batch[
                'sal_image'], data_batch['sal_label'], data_batch['edge_image']
            if (sal_image.size(2) != sal_label.size(2)) or (
                    sal_image.size(3) != sal_label.size(3)):
                print('IMAGE ERROR, PASSING```')
                with open(args.file_dir, 'a') as f:
                    f.write('IMAGE ERROR, PASSING```\n')
                continue
            sal_image, sal_label, edge_image = Variable(sal_image), Variable(
                sal_label), Variable(edge_image)
            sal_image, sal_label, edge_image = sal_image.to(
                device), sal_label.to(device), edge_image.to(device)

            sal_pred = model(sal_image)
            edge_pred = model(edge_image)

            # train G(with G)
            for param in model_D1.parameters():
                param.requires_grad = False

            sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred,
                                                               sal_label,
                                                               reduction='sum')
            sal_loss = sal_loss_fuse / (args.iter_size * args.batch_size)
            loss_seg_value1 += sal_loss.data

            sal_loss.backward()

            sD_out = model_D1(edge_pred)
            # 这里用的是bceloss 训练G的时候,target判别为sourse_label时损失函数低
            loss_adv_target1 = bce_loss(
                sD_out,
                torch.FloatTensor(sD_out.data.size()).fill_(salLabel).to(
                    device))  # 后面一个相当于全部是正确答案的和前一个size相同的tensor
            sd_loss = loss_adv_target1 / (args.iter_size * args.batch_size)
            loss_adv_target_value1 += sd_loss.data  # 记录专用

            sd_loss = sd_loss * args.lambda_adv_target2
            sd_loss.backward()

            # train D
            for param in model_D1.parameters():
                param.requires_grad = True

            sal_pred = sal_pred.detach()
            edge_pred = edge_pred.detach()

            ss_out = model_D1(sal_pred)
            ss_loss = bce_loss(
                ss_out,
                torch.FloatTensor(
                    ss_out.data.size()).fill_(salLabel).to(device))
            ss_Loss = ss_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += ss_Loss.data

            ss_Loss.backward()

            se_out = model_D1(edge_pred)
            se_loss = bce_loss(
                se_out,
                torch.FloatTensor(
                    se_out.data.size()).fill_(edgeLabel).to(device))
            se_Loss = se_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += se_Loss.data

            se_Loss.backward()

            aveGrad += 1
            if aveGrad % args.iter_size == 0:
                optimizer.step()
                optimizer.zero_grad()
                optimizer_D1.step()
                optimizer_D1.zero_grad()
                aveGrad = 0

            if i_iter % (args.show_every // args.batch_size) == 0:
                print(
                    'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}'
                    .format(i_iter, iter_num, loss_seg_value1,
                            loss_adv_target_value1, loss_D_value1, i_epotch,
                            args.epotch_size))
                with open(args.file_dir, 'a') as f:
                    f.write(
                        'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}\n'
                        .format(i_iter, iter_num, loss_seg_value1,
                                loss_adv_target_value1, loss_D_value1,
                                i_epotch, args.epotch_size))

                loss_seg_value1, loss_adv_target_value1, loss_D_value1 = 0, 0, 0

            if i_iter == iter_num - 1 or i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                with open(args.file_dir, 'a') as f:
                    f.write('taking snapshot ...\n')
                torch.save(
                    model.state_dict(),
                    osp.join(
                        args.snapshot_dir,
                        'sal_' + str(i_epotch) + '_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(
                        args.snapshot_dir, 'sal_' + str(i_epotch) + '_' +
                        str(i_iter) + '_D1.pth'))

        if i_epotch == 7:
            args.learning_rate = args.learning_rate * 0.1
            args.learning_rate_D = args.learning_rate_D * 0.1
            optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          model.parameters()),
                                   lr=args.learning_rate,
                                   weight_decay=args.weight_decay)
            optimizer_D1 = optim.Adam(model_D1.parameters(),
                                      lr=args.learning_rate_D,
                                      betas=(0.9, 0.99))

    # end
    with open(args.file_dir, 'a') as f:
        f.write('end time: ' + str(datetime.now()) + '\n')
Ejemplo n.º 8
0
def main():

    # make dir
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    run = 0
    while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)):
        run += 1
    os.mkdir("%s/run-%d" % (args.snapshot_dir, run))
    os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run))
    args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run)
    args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run)

    # xuan xue you hua
    cudnn.enabled = True
    cudnn.benchmark = True

    # create the model
    model = build_model()
    model.to(device)
    model.train()
    model.apply(weights_init)
    model.load_state_dict(torch.load(args.restore_from))
    # model.base.load_pretrained_model(torch.load(args.pretrained_model))

    # create optimizer
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)  # 整个模型的优化器
    optimizer.zero_grad()

    # start time
    with open(args.file_dir, 'a') as f:
        f.write('strat time: ' + str(datetime.now()) + '\n\n')

        f.write('learning rate: ' + str(args.learning_rate) + '\n')
        f.write('learning rate D: ' + str(args.learning_rate_D) + '\n')
        f.write('wight decay: ' + str(args.weight_decay) + '\n')
        f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n')

        f.write('eptch size: ' + str(args.epotch_size) + '\n')
        f.write('batch size: ' + str(args.batch_size) + '\n')
        f.write('iter size: ' + str(args.iter_size) + '\n')
        f.write('num steps: ' + str(args.num_steps) + '\n\n')

    picloader = get_loader(args)
    iter_num = len(picloader.dataset) // args.batch_size
    aveGrad = 0

    for i_epotch in range(args.epotch_size):
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        model.zero_grad()

        for i_iter, data_batch in enumerate(picloader):

            sal_image, sal_label = data_batch['sal_image'], data_batch[
                'sal_label']
            if (sal_image.size(2) != sal_label.size(2)) or (
                    sal_image.size(3) != sal_label.size(3)):
                print('IMAGE ERROR, PASSING```')
                with open(args.file_dir, 'a') as f:
                    f.write('IMAGE ERROR, PASSING```\n')
                continue
            sal_image, sal_label = Variable(sal_image), Variable(sal_label)
            sal_image, sal_label = sal_image.to(device), sal_label.to(device)

            sal_pred = model(sal_image)
            sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred,
                                                               sal_label,
                                                               reduction='sum')
            sal_loss = sal_loss_fuse / (args.iter_size * args.batch_size)
            loss_seg_value1 += sal_loss.data

            sal_loss.backward()

            aveGrad += 1
            if aveGrad % args.iter_size == 0:
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            if i_iter % (args.show_every // args.batch_size) == 0:
                print(
                    'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}'
                    .format(i_iter, iter_num, loss_seg_value1,
                            loss_adv_target_value1, loss_D_value1, i_epotch,
                            args.epotch_size))
                with open(args.file_dir, 'a') as f:
                    f.write(
                        'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}\n'
                        .format(i_iter, iter_num, loss_seg_value1,
                                loss_adv_target_value1, loss_D_value1,
                                i_epotch, args.epotch_size))

                loss_seg_value1, loss_adv_target_value1, loss_D_value1 = 0, 0, 0

            if i_iter == iter_num - 1 or i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                with open(args.file_dir, 'a') as f:
                    f.write('taking snapshot ...\n')
                torch.save(
                    model.state_dict(),
                    osp.join(
                        args.snapshot_dir,
                        'sal_' + str(i_epotch) + '_' + str(i_iter) + '.pth'))
                # torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(i_iter) + '_D1.pth'))

        if i_epotch == 7:
            args.learning_rate = args.learning_rate * 0.1
            optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          model.parameters()),
                                   lr=args.learning_rate,
                                   weight_decay=args.weight_decay)

    # end
    with open(args.file_dir, 'a') as f:
        f.write('end time: ' + str(datetime.now()) + '\n')
Ejemplo n.º 9
0
def main():
    """Create the model and start the training."""
    # device放GPU还是CPU
    device = torch.device("cuda" if not args.cpu else "cpu")
    cudnn.enabled = True  #一种玄学优化

    # Create network
    # 重要一步 输入类别数
    model = build_model()  # 生成一个由resnet组成的语义分割模型
    # 读取pretrained模型
    model.to(device)
    model.train()
    model.apply(weights_init)
    # model.load_state_dict(torch.load(args.restore_from))
    model.base.load_pretrained_model(torch.load(args.pretrained_model))
    #设置model参数

    # 玄学优化
    cudnn.benchmark = True

    # init D 设置D 鉴别器
    '''
    model_D1 = FCDiscriminator(num_classes=1).to(device)

    model_D1.train()
    model_D1.to(device)
    model_D1.apply(weights_init)
    '''
    # 创建存放模型的文件夹
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    picloader = get_loader(args)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # 优化器
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)  # 整个模型的优化器
    optimizer.zero_grad()

    #optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) # D1的优化器
    # optimizer_D1.zero_grad()

    # 损失函数
    bce_loss = torch.nn.BCEWithLogitsLoss()  # sigmoid + BCE的完美组合
    '''
    # 两个改变size的上采样
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) # 变为source input的上采样
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # 变为 target input的上采样 
    '''
    # save folder
    run = 0
    while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)):
        run += 1
    os.mkdir("%s/run-%d" % (args.snapshot_dir, run))
    os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run))
    args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run)
    args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run)

    # labels for adversarial training 两种域的记号
    source_label = 0
    target_label = 1

    with open(args.file_dir, 'a') as f:
        f.write('strat time: ' + str(datetime.now()) + '\n\n')

        f.write('learning rate: ' + str(args.learning_rate) + '\n')
        f.write('learning rate D: ' + str(args.learning_rate_D) + '\n')
        f.write('wight decay: ' + str(args.weight_decay) + '\n')
        f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n')

        f.write('eptch size: ' + str(args.epotch_size) + '\n')
        f.write('batch size: ' + str(args.batch_size) + '\n')
        f.write('iter size: ' + str(args.iter_size) + '\n')
        f.write('num steps: ' + str(args.num_steps) + '\n\n')

    for i_epotch in range(args.epotch_size):
        # 损失值置零
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        model.zero_grad()
        # model_D1.zero_grad()

        loader_iter = enumerate(picloader)

        for i_iter in range(args.num_steps // args.batch_size //
                            args.iter_size):  # 迭代次数 大batch

            # 优化器梯度置零 + 调整学习率
            optimizer.zero_grad()
            # adjust_learning_rate(optimizer, i_iter)

            # optimizer_D1.zero_grad()
            # adjust_learning_rate_D(optimizer_D1, i_iter)

            for sub_i in range(args.iter_size):  # 迭代次数 小batch

                # get picture
                _, data_batch = loader_iter.__next__()  # 获取一组图片
                source_images, source_labels, target_images = data_batch[
                    'sal_image'], data_batch['sal_label'], data_batch[
                        'edge_image']  #, data_batch['edge_label']
                source_images, source_labels, target_images = Variable(
                    source_images), Variable(source_labels), Variable(
                        target_images)

                if (source_images.size(2) != source_labels.size(2)) or (
                        source_images.size(3) != source_labels.size(3)):
                    print('IMAGE ERROR, PASSING```')
                    with open(args.file_dir, 'a') as f:
                        f.write('IMAGE ERROR, PASSING```\n')
                    continue

                # 放入GPU
                source_images = source_images.to(device)
                source_labels = source_labels.to(device)
                target_images = target_images.to(device)
                pred1 = model(
                    source_images)  # 三层block和四层block之后classify之后的结果(相当于两种层的结果)
                # pred_target1 = model(target_images) # 放入模型

                # train G

                # don't accumulate grads in D 不需要D的梯度,因为这里是用D来辅助训练G
                # for param in model_D1.parameters():
                #     param.requires_grad = False

                # train with source
                # 计算损失函数
                loss_seg1 = F.binary_cross_entropy_with_logits(pred1,
                                                               source_labels,
                                                               reduction='sum')
                lossG = loss_seg1 / args.iter_size / args.batch_size
                loss_seg_value1 += lossG.item()  # 记录这次的iter的结果,显示相关和训练不相关

                lossG.backward()
                '''
                # D_out1 = model_D1(F.softmax(pred_target1)) # 放入鉴别器(不知道为什么要softmax)
                D_out1 = model_D1(pred_target1)
                # 这里用的是bceloss 训练G的时候,target判别为sourse_label时损失函数低
                loss_adv_target1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device)) # 后面一个相当于全部是正确答案的和前一个size相同的tensor
                lossD = loss_adv_target1 / args.iter_size
                loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size # 记录专用
                
                lossD = lossD * args.lambda_adv_target2
                lossD.backward()
                
                # train D
                
                # bring back requires_grad 恢复D的grad
                for param in model_D1.parameters():
                    param.requires_grad = True

                pred1 = pred1.detach()# train with source 脱离grad
                # D_out1 = model_D1(F.softmax(pred1))# sourse的判别结果
                D_out1 = model_D1(pred1)

                # 训练D时sourse判断成sourse损失函数低
                loss_Ds = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device))
                loss_Ds = loss_Ds / args.iter_size
                loss_D_value1 += loss_Ds.item()# 显示专用

                pred_target1 = pred_target1.detach()# train with target target数据训练 脱离
                # D_out1 = model_D1(F.softmax(pred_target1))# 得到判别结果
                D_out1 = model_D1(pred_target1)# 得到判别结果

                # taget判别为target时损失函数低
                loss_Dt = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(target_label).to(device))
                loss_Dt = loss_Dt / args.iter_size
                loss_D_value1 += loss_Dt.item()# 显示专用

                loss_Ds.backward()
                loss_Dt.backward()
                '''
            # 修改一次参数
            optimizer.step()
            # optimizer_D1.step()
            '''
            # 不管
            if args.tensorboard:
                scalar_info = {
                    'loss_seg1': loss_seg_value1,
                    'loss_seg2': loss_seg_value2,
                    'loss_adv_target1': loss_adv_target_value1,
                    'loss_adv_target2': loss_adv_target_value2,
                    'loss_D1': loss_D_value1,
                    'loss_D2': loss_D_value2,
                }

                if i_iter % 10 == 0:
                    for key, val in scalar_info.items():
                        writer.add_scalar(key, val, i_iter)
            '''

            # 显示
            if i_iter * args.batch_size % SHOW_EVERY == 0:
                print('exp = {}'.format(args.snapshot_dir))
                print(
                    'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}'
                    .format(
                        i_iter,
                        args.num_steps // args.batch_size // args.iter_size,
                        loss_seg_value1, loss_adv_target_value1, loss_D_value1,
                        i_epotch, args.epotch_size))
                with open(args.file_dir, 'a') as f:
                    f.write('exp = {}\n'.format(args.snapshot_dir))
                    f.write(
                        'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}\n'
                        .format(
                            i_iter, args.num_steps // args.batch_size //
                            args.iter_size, loss_seg_value1,
                            loss_adv_target_value1, loss_D_value1, i_epotch,
                            args.epotch_size))

                loss_seg_value1, loss_adv_target_value1, loss_D_value1 = 0, 0, 0
            # 提前终止
            if i_iter >= args.num_steps_stop - 1:
                print('save model ...')
                with open(args.file_dir, 'a') as f:
                    f.write('save model ...\n')
                torch.save(
                    model.state_dict(),
                    osp.join(
                        args.snapshot_dir, 'sal_' + str(i_epotch) + '_' +
                        str(args.num_steps_stop) + '.pth'))
                # torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(args.num_steps_stop) + '_D1.pth'))
                break

            if i_iter == args.num_steps // args.batch_size // args.iter_size - 1 or i_iter * args.batch_size * args.iter_size % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                with open(args.file_dir, 'a') as f:
                    f.write('taking snapshot ...\n')
                torch.save(
                    model.state_dict(),
                    osp.join(
                        args.snapshot_dir,
                        'sal_' + str(i_epotch) + '_' + str(i_iter) + '.pth'))
                # torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'sal_' + str(i_epotch) + '_' + str(i_iter) + '_D1.pth'))
    '''
    if args.tensorboard:
        writer.close()
    '''
    with open(args.file_dir, 'a') as f:
        f.write('end time: ' + str(datetime.now()) + '\n')