コード例 #1
0
def init_net_D(args, state_dict=None):
    net_D = FCDiscriminator(cfg.DATASET.NUM_CLASSES)

    if args.distributed:
        net_D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net_D)

    if cfg.MODEL.DOMAIN_BN:
        net_D = DomainBN.convert_domain_batchnorm(net_D, num_domains=2)

    if state_dict is not None:
        try:
            net_D.load_state_dict(state_dict)
        except:
            net_D = DomainBN.convert_domain_batchnorm(net_D, num_domains=2)
            net_D.load_state_dict(state_dict)

    if cfg.TRAIN.FREEZE_BN:
        net_D.apply(freeze_BN)

    if torch.cuda.is_available():
        net_D.cuda()

    if args.distributed:
        net_D = DistributedDataParallel(net_D, device_ids=[args.gpu])
    else:
        net_D = torch.nn.DataParallel(net_D)

    return net_D
def main():

    # make dir
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    run = 0
    while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)):
        run += 1
    os.mkdir("%s/run-%d" % (args.snapshot_dir, run))
    os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run))
    args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run)
    args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run)

    # xuan xue you hua
    cudnn.enabled = True
    cudnn.benchmark = True

    # create the model
    model = build_model()
    model.to(device)
    model.train()
    model.apply(weights_init)
    model.load_state_dict(torch.load(args.restore_from))
    # model.base.load_pretrained_model(torch.load(args.pretrained_model))

    # create domintor
    model_D1 = FCDiscriminator(num_classes=1).to(device)
    model_D1.train()
    model_D1.apply(weights_init)
    # model_D1.load_state_dict(torch.load(args.D_restore_from))

    # create optimizer
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)  # 整个模型的优化器
    optimizer.zero_grad()
    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    # uneccessery
    bce_loss = torch.nn.BCEWithLogitsLoss()

    # start time
    with open(args.file_dir, 'a') as f:
        f.write('strat time: ' + str(datetime.now()) + '\n\n')

        f.write('learning rate: ' + str(args.learning_rate) + '\n')
        f.write('learning rate D: ' + str(args.learning_rate_D) + '\n')
        f.write('wight decay: ' + str(args.weight_decay) + '\n')
        f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n')

        f.write('eptch size: ' + str(args.epotch_size) + '\n')
        f.write('batch size: ' + str(args.batch_size) + '\n')
        f.write('iter size: ' + str(args.iter_size) + '\n')
        f.write('num steps: ' + str(args.num_steps) + '\n\n')

    # labels for adversarial training 两种域的记号
    salLabel = 0
    edgeLabel = 1

    picloader = get_loader(args)
    iter_num = len(picloader.dataset) // args.batch_size
    aveGrad = 0

    for i_epotch in range(args.epotch_size):
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        model.zero_grad()

        for i_iter, data_batch in enumerate(picloader):

            sal_image, sal_label, edge_image = data_batch[
                'sal_image'], data_batch['sal_label'], data_batch['edge_image']
            if (sal_image.size(2) != sal_label.size(2)) or (
                    sal_image.size(3) != sal_label.size(3)):
                print('IMAGE ERROR, PASSING```')
                with open(args.file_dir, 'a') as f:
                    f.write('IMAGE ERROR, PASSING```\n')
                continue
            sal_image, sal_label, edge_image = Variable(sal_image), Variable(
                sal_label), Variable(edge_image)
            sal_image, sal_label, edge_image = sal_image.to(
                device), sal_label.to(device), edge_image.to(device)

            sal_pred = model(sal_image)
            edge_pred = model(edge_image)

            # train G(with G)
            for param in model_D1.parameters():
                param.requires_grad = False

            sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred,
                                                               sal_label,
                                                               reduction='sum')
            sal_loss = sal_loss_fuse / (args.iter_size * args.batch_size)
            loss_seg_value1 += sal_loss.data

            sal_loss.backward()

            sD_out = model_D1(edge_pred)
            # 这里用的是bceloss 训练G的时候,target判别为sourse_label时损失函数低
            loss_adv_target1 = bce_loss(
                sD_out,
                torch.FloatTensor(sD_out.data.size()).fill_(salLabel).to(
                    device))  # 后面一个相当于全部是正确答案的和前一个size相同的tensor
            sd_loss = loss_adv_target1 / (args.iter_size * args.batch_size)
            loss_adv_target_value1 += sd_loss.data  # 记录专用

            sd_loss = sd_loss * args.lambda_adv_target2
            sd_loss.backward()

            # train D
            for param in model_D1.parameters():
                param.requires_grad = True

            sal_pred = sal_pred.detach()
            edge_pred = edge_pred.detach()

            ss_out = model_D1(sal_pred)
            ss_loss = bce_loss(
                ss_out,
                torch.FloatTensor(
                    ss_out.data.size()).fill_(salLabel).to(device))
            ss_Loss = ss_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += ss_Loss.data

            ss_Loss.backward()

            se_out = model_D1(edge_pred)
            se_loss = bce_loss(
                se_out,
                torch.FloatTensor(
                    se_out.data.size()).fill_(edgeLabel).to(device))
            se_Loss = se_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += se_Loss.data

            se_Loss.backward()

            aveGrad += 1
            if aveGrad % args.iter_size == 0:
                optimizer.step()
                optimizer.zero_grad()
                optimizer_D1.step()
                optimizer_D1.zero_grad()
                aveGrad = 0

            if i_iter % (args.show_every // args.batch_size) == 0:
                print(
                    'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}'
                    .format(i_iter, iter_num, loss_seg_value1,
                            loss_adv_target_value1, loss_D_value1, i_epotch,
                            args.epotch_size))
                with open(args.file_dir, 'a') as f:
                    f.write(
                        'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}\n'
                        .format(i_iter, iter_num, loss_seg_value1,
                                loss_adv_target_value1, loss_D_value1,
                                i_epotch, args.epotch_size))

                loss_seg_value1, loss_adv_target_value1, loss_D_value1 = 0, 0, 0

            if i_iter == iter_num - 1 or i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                with open(args.file_dir, 'a') as f:
                    f.write('taking snapshot ...\n')
                torch.save(
                    model.state_dict(),
                    osp.join(
                        args.snapshot_dir,
                        'sal_' + str(i_epotch) + '_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(
                        args.snapshot_dir, 'sal_' + str(i_epotch) + '_' +
                        str(i_iter) + '_D1.pth'))

        if i_epotch == 7:
            args.learning_rate = args.learning_rate * 0.1
            args.learning_rate_D = args.learning_rate_D * 0.1
            optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          model.parameters()),
                                   lr=args.learning_rate,
                                   weight_decay=args.weight_decay)
            optimizer_D1 = optim.Adam(model_D1.parameters(),
                                      lr=args.learning_rate_D,
                                      betas=(0.9, 0.99))

    # end
    with open(args.file_dir, 'a') as f:
        f.write('end time: ' + str(datetime.now()) + '\n')
コード例 #3
0
def main():

    # make dir
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    run = 0
    while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)):
        run += 1
    os.mkdir("%s/run-%d" % (args.snapshot_dir, run))
    os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run))
    args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run)
    args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run)

    # xuan xue you hua
    cudnn.enabled = True
    cudnn.benchmark = True

    # create the model
    model = build_model()
    model.to(device)
    model.train()
    model.apply(weights_init)
    model.load_state_dict(torch.load(args.restore_from))
    # model.base.load_pretrained_model(torch.load(args.pretrained_model))

    # create domintor
    model_D1 = FCDiscriminator(num_classes=1).to(device)
    model_D2 = FCDiscriminator(num_classes=1).to(device)
    model_D1.train()
    model_D2.train()
    model_D1.apply(weights_init)
    model_D2.apply(weights_init)
    # model_D1.load_state_dict(torch.load(args.D_restore_from))
    # model_D2.load_state_dict(torch.load(args.D_restore_from))

    # create optimizer
    optimizer = optim.RMSprop(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.learning_rate)  # 整个模型的优化器
    optimizer.zero_grad()
    optimizer_D1 = optim.RMSprop(model_D1.parameters(),
                                 lr=args.learning_rate_D)
    optimizer_D1.zero_grad()
    optimizer_D2 = optim.RMSprop(model_D2.parameters(),
                                 lr=args.learning_rate_D)
    optimizer_D2.zero_grad()

    # start time
    with open(args.file_dir, 'a') as f:
        f.write('strat time: ' + str(datetime.now()) + '\n\n')

        f.write('learning rate: ' + str(args.learning_rate) + '\n')
        f.write('learning rate D: ' + str(args.learning_rate_D) + '\n')
        f.write('wight decay: ' + str(args.weight_decay) + '\n')
        f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n')

        f.write('eptch size: ' + str(args.epotch_size) + '\n')
        f.write('batch size: ' + str(args.batch_size) + '\n')
        f.write('iter size: ' + str(args.iter_size) + '\n')
        f.write('num steps: ' + str(args.num_steps) + '\n\n')

    # labels for adversarial training 两种域的记号
    salLabel = 0
    edgeLabel = 1

    picloader = get_loader(args)
    iter_num = len(picloader.dataset) // args.batch_size
    aveGrad = 0

    for i_epotch in range(args.epotch_size):
        loss_seg_value1 = 0
        loss_seg_value2 = 0
        loss_adv_target_value1 = 0
        loss_adv_target_value2 = 0
        loss_D_value1 = 0
        loss_D_value2 = 0
        model.zero_grad()

        for i_iter, data_batch in enumerate(picloader):

            sal_image, sal_label, edge_image, edge_label = data_batch[
                'sal_image'], data_batch['sal_label'], data_batch[
                    'edge_image'], data_batch['edge_label']
            if (sal_image.size(2) != sal_label.size(2)) or (
                    sal_image.size(3) != sal_label.size(3)
                    or edge_image.size(2) != edge_label.size(2)) or (
                        edge_image.size(3) != edge_label.size(3)):
                print('IMAGE ERROR, PASSING```')
                with open(args.file_dir, 'a') as f:
                    f.write('IMAGE ERROR, PASSING```\n')
                continue

            sal_image, sal_label, edge_image, edge_label = Variable(
                sal_image), Variable(sal_label), Variable(
                    edge_image), Variable(edge_label)
            sal_image, sal_label, edge_image, edge_label = sal_image.to(
                device), sal_label.to(device), edge_image.to(
                    device), edge_label.to(device)

            s_sal_pred = model(sal_image, mode=1)
            s_edge_pred = model(edge_image, mode=1)
            e_sal_pred = model(sal_image, mode=0)
            e_edge_pred = model(edge_image, mode=0)

            # train G(with G)
            for param in model_D1.parameters():
                param.requires_grad = False
            for param in model_D2.parameters():
                param.requires_grad = False

            # sal
            sal_loss_fuse = F.binary_cross_entropy_with_logits(s_sal_pred,
                                                               sal_label,
                                                               reduction='sum')
            sal_loss = sal_loss_fuse / (args.iter_size * args.batch_size)
            loss_seg_value1 += sal_loss.data

            sal_loss.backward()

            loss_adv_target1 = torch.mean(
                model_D1(s_edge_pred))  # 后面一个相当于全部是正确答案的和前一个size相同的tensor
            sd_loss = loss_adv_target1 / (args.iter_size * args.batch_size)
            loss_adv_target_value1 += sd_loss.data  # 记录专用

            sd_loss = sd_loss * args.lambda_adv_target2
            sd_loss.backward()

            # edge
            edge_loss_fuse = bce2d(e_edge_pred[0], edge_label, reduction='sum')
            edge_loss_part = []
            for ix in e_edge_pred[1]:
                edge_loss_part.append(bce2d(ix, edge_label, reduction='sum'))
            edge_loss = (edge_loss_fuse + sum(edge_loss_part)) / (
                args.iter_size * args.batch_size)
            loss_seg_value2 += edge_loss.data

            edge_loss.backward()

            loss_adv_target2 = -torch.mean(model_D2(
                e_sal_pred[0]))  # 后面一个相当于全部是正确答案的和前一个size相同的tensor
            for ix in e_sal_pred[1]:
                loss_adv_target2 += -torch.mean(model_D2(ix))
            ed_loss = loss_adv_target2 / (args.iter_size * args.batch_size) / (
                len(e_sal_pred[1]) + 1)
            loss_adv_target_value2 += ed_loss.data

            ed_loss = ed_loss * args.lambda_adv_target2
            ed_loss.backward()

            # train D
            for param in model_D1.parameters():
                param.requires_grad = True
            for param in model_D2.parameters():
                param.requires_grad = True

            s_sal_pred = s_sal_pred.detach()
            s_edge_pred = s_edge_pred.detach()
            e_sal_pred = [
                e_sal_pred[0].detach(), [x.detach() for x in e_sal_pred[1]]
            ]
            e_edge_pred = [
                e_edge_pred[0].detach(), [x.detach() for x in e_edge_pred[1]]
            ]

            # sal
            ss_loss = torch.mean(model_D1(s_sal_pred))
            ss_Loss = ss_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += ss_Loss.data

            ss_Loss.backward()

            se_loss = -torch.mean(model_D1(s_edge_pred))
            se_Loss = se_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += se_Loss.data

            se_Loss.backward()

            # edge
            es_loss = torch.mean(model_D2(e_sal_pred[0]))
            for ix in e_sal_pred[1]:
                es_loss += torch.mean(model_D2(ix))
            es_Loss = es_loss / (args.iter_size *
                                 args.batch_size) / (len(e_sal_pred[1]) + 1)
            loss_D_value2 += es_Loss.data

            es_Loss.backward()

            ee_loss = -torch.mean(model_D2(e_edge_pred[0]))
            for ix in e_edge_pred[1]:
                ee_loss += -torch.mean(model_D2(ix))
            ee_Loss = ee_loss / (args.iter_size *
                                 args.batch_size) / (len(e_edge_pred[1]) + 1)
            loss_D_value2 += ee_Loss.data

            ee_Loss.backward()

            aveGrad += 1
            if aveGrad % args.iter_size == 0:
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            optimizer_D1.step()
            for p in model_D1.parameters():
                p.data.clamp_(-args.clip_value, args.clip_value)
            optimizer_D1.zero_grad()
            optimizer_D2.step()
            for p in model_D2.parameters():
                p.data.clamp_(-args.clip_value, args.clip_value)
            optimizer_D2.zero_grad()

            if i_iter % (args.show_every // args.batch_size) == 0:
                print(
                    'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_seg2 = {7:.3f}, loss_adv1 = {3:.3f}, loss_adv2 = {8:.3f}, loss_D1 = {4:.3f}, loss_D2 = {9:.3f}'
                    .format(i_iter, iter_num, loss_seg_value1,
                            loss_adv_target_value1, loss_D_value1, i_epotch,
                            args.epotch_size, loss_seg_value2,
                            loss_adv_target_value2, loss_D_value2))
                with open(args.file_dir, 'a') as f:
                    f.write(
                        'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_seg2 = {7:.3f}, loss_adv1 = {3:.3f}, loss_adv2 = {8:.3f}, loss_D1 = {4:.3f}, loss_D2 = {9:.3f}\n'
                        .format(i_iter, iter_num, loss_seg_value1,
                                loss_adv_target_value1, loss_D_value1,
                                i_epotch, args.epotch_size, loss_seg_value2,
                                loss_adv_target_value2, loss_D_value2))

                loss_seg_value1, loss_adv_target_value1, loss_D_value1, loss_seg_value2, loss_adv_target_value2, loss_D_value2 = 0, 0, 0, 0, 0, 0

            if i_iter == iter_num - 1 or i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                with open(args.file_dir, 'a') as f:
                    f.write('taking snapshot ...\n')
                torch.save(model.state_dict(),
                           osp.join(args.snapshot_dir, 'sal_.pth'))
                torch.save(model_D1.state_dict(),
                           osp.join(args.snapshot_dir, 'sal_D1.pth'))
                torch.save(model_D2.state_dict(),
                           osp.join(args.snapshot_dir, 'sal_D2.pth'))

        if i_epotch == 7:
            args.learning_rate = args.learning_rate * 0.1
            args.learning_rate_D = args.learning_rate_D * 0.1
            optimizer = optim.RMSprop(filter(lambda p: p.requires_grad,
                                             model.parameters()),
                                      lr=args.learning_rate,
                                      weight_decay=args.weight_decay)
            optimizer_D1 = optim.RMSprop(model_D1.parameters(),
                                         lr=args.learning_rate_D)
            optimizer_D2 = optim.RMSprop(model_D2.parameters(),
                                         lr=args.learning_rate_D)

    # end
    with open(args.file_dir, 'a') as f:
        f.write('end time: ' + str(datetime.now()) + '\n')