Example #1
0
    def build_model(self):
        self.net = build_model(self.config.arch)
        # 是否将网络搬运至cuda
        if self.config.cuda:
            self.net = self.net.cuda()
        # self.net.train()
        # 设置eval状态
        self.net.eval()  # use_global_stats = True
        # 网络权重初始化
        self.net.apply(weights_init)
        # 载入预训练模型或自行训练模型
        if self.config.load == '':
            self.net.base.load_pretrained_model(
                torch.load(self.config.pretrained_model))
        else:
            self.net.load_state_dict(torch.load(self.config.load))

        # 学习率
        self.lr = self.config.lr
        # 权值衰减
        self.wd = self.config.wd

        # 设置优化器
        self.optimizer = Adam(filter(lambda p: p.requires_grad,
                                     self.net.parameters()),
                              lr=self.lr,
                              weight_decay=self.wd)
        # 打印网络结构
        self.print_network(self.net, 'PoolNet Structure')
Example #2
0
 def build_model(self):
     self.net = build_model()
     if self.cuda:
         self.net = self.net.cuda()
     # self.net.train()
     self.net.eval()  # use_global_stats = True
     self.net.apply(weights_init)
     #  self.net.base.load_pretrained_model(torch.load(self.config.pretrained_model))
     self.lr = 5e-5
     self.wd = 0.0005
     self.optimizer = Adam(filter(lambda p: p.requires_grad,
                                  self.net.parameters()),
                           lr=self.lr,
                           weight_decay=self.wd)
Example #3
0
    def build_model(self):
        self.net = build_model(self.config.arch)
        if self.config.cuda:
            self.net = self.net.cuda()
        # self.net.train()
        self.net.eval()  # use_global_stats = True
        self.net.apply(weights_init)
        if self.config.load == '':
            self.net.base.load_pretrained_model(torch.load(self.config.pretrained_model))
        else:
            self.net.load_state_dict(torch.load(self.config.load))

        self.lr = self.config.lr
        self.wd = self.config.wd

        self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd)
        self.print_network(self.net, 'PoolNet Structure')
Example #4
0
def main():

    gtdir = args.snapshot_dir + 'gt/1/'
    preddir = args.snapshot_dir + 'pred/wgan/1/'

    # make dir
    if not os.path.exists(gtdir):
        os.makedirs(gtdir)
    if not os.path.exists(preddir):
        os.makedirs(preddir)

    # xuan xue you hua
    cudnn.enabled = True
    cudnn.benchmark = True
    
    # create the model
    model = build_model()
    model.to(device)
    model.train()
    model.apply(weights_init)
    model.load_state_dict(torch.load(args.restore_from))    

    picloader = get_loader(args, mode='test')

    Start = time.time()
    for i_iter, data_batch in enumerate(picloader):
           
        if i_iter % 50 == 0:
            print(i_iter)
        sal_image, sal_label = data_batch['sal_image'], data_batch['sal_label']

        with torch.no_grad():
            sal_image = Variable(sal_image).to(device)
            preds = model(sal_image, mode=1)
            pred = np.squeeze(torch.sigmoid(preds).cpu().data.numpy())
            label = np.squeeze(sal_label.cpu().data.numpy())
            multi_fuse = 255 * pred
            label = 255 * label
            cv2.imwrite(os.path.join(preddir, str(i_iter) + '.jpg'), multi_fuse)
            # cv2.imwrite(os.path.join(gtdir, str(i_iter) + '.png'), label)
    End = time.time()
    print('FPS = ', (End-Start) / len(picloader))
Example #5
0
def main():

    # make dir
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    run = 0
    while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)):
        run += 1
    os.mkdir("%s/run-%d" % (args.snapshot_dir, run))
    os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run))
    args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run)
    args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run)

    # xuan xue you hua
    cudnn.enabled = True
    cudnn.benchmark = True

    # create the model
    model = build_model()
    model.to(device)
    model.train()
    model.apply(weights_init)
    model.load_state_dict(torch.load(args.restore_from))
    # model.base.load_pretrained_model(torch.load(args.pretrained_model))

    # create domintor
    model_D1 = FCDiscriminator(num_classes=1).to(device)
    model_D2 = FCDiscriminator(num_classes=1).to(device)
    model_D1.train()
    model_D2.train()
    model_D1.apply(weights_init)
    model_D2.apply(weights_init)
    # model_D1.load_state_dict(torch.load(args.D_restore_from))
    # model_D2.load_state_dict(torch.load(args.D_restore_from))

    # create optimizer
    optimizer = optim.RMSprop(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.learning_rate)  # 整个模型的优化器
    optimizer.zero_grad()
    optimizer_D1 = optim.RMSprop(model_D1.parameters(),
                                 lr=args.learning_rate_D)
    optimizer_D1.zero_grad()
    optimizer_D2 = optim.RMSprop(model_D2.parameters(),
                                 lr=args.learning_rate_D)
    optimizer_D2.zero_grad()

    # start time
    with open(args.file_dir, 'a') as f:
        f.write('strat time: ' + str(datetime.now()) + '\n\n')

        f.write('learning rate: ' + str(args.learning_rate) + '\n')
        f.write('learning rate D: ' + str(args.learning_rate_D) + '\n')
        f.write('wight decay: ' + str(args.weight_decay) + '\n')
        f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n')

        f.write('eptch size: ' + str(args.epotch_size) + '\n')
        f.write('batch size: ' + str(args.batch_size) + '\n')
        f.write('iter size: ' + str(args.iter_size) + '\n')
        f.write('num steps: ' + str(args.num_steps) + '\n\n')

    # labels for adversarial training 两种域的记号
    salLabel = 0
    edgeLabel = 1

    picloader = get_loader(args)
    iter_num = len(picloader.dataset) // args.batch_size
    aveGrad = 0

    for i_epotch in range(args.epotch_size):
        loss_seg_value1 = 0
        loss_seg_value2 = 0
        loss_adv_target_value1 = 0
        loss_adv_target_value2 = 0
        loss_D_value1 = 0
        loss_D_value2 = 0
        model.zero_grad()

        for i_iter, data_batch in enumerate(picloader):

            sal_image, sal_label, edge_image, edge_label = data_batch[
                'sal_image'], data_batch['sal_label'], data_batch[
                    'edge_image'], data_batch['edge_label']
            if (sal_image.size(2) != sal_label.size(2)) or (
                    sal_image.size(3) != sal_label.size(3)
                    or edge_image.size(2) != edge_label.size(2)) or (
                        edge_image.size(3) != edge_label.size(3)):
                print('IMAGE ERROR, PASSING```')
                with open(args.file_dir, 'a') as f:
                    f.write('IMAGE ERROR, PASSING```\n')
                continue

            sal_image, sal_label, edge_image, edge_label = Variable(
                sal_image), Variable(sal_label), Variable(
                    edge_image), Variable(edge_label)
            sal_image, sal_label, edge_image, edge_label = sal_image.to(
                device), sal_label.to(device), edge_image.to(
                    device), edge_label.to(device)

            s_sal_pred = model(sal_image, mode=1)
            s_edge_pred = model(edge_image, mode=1)
            e_sal_pred = model(sal_image, mode=0)
            e_edge_pred = model(edge_image, mode=0)

            # train G(with G)
            for param in model_D1.parameters():
                param.requires_grad = False
            for param in model_D2.parameters():
                param.requires_grad = False

            # sal
            sal_loss_fuse = F.binary_cross_entropy_with_logits(s_sal_pred,
                                                               sal_label,
                                                               reduction='sum')
            sal_loss = sal_loss_fuse / (args.iter_size * args.batch_size)
            loss_seg_value1 += sal_loss.data

            sal_loss.backward()

            loss_adv_target1 = torch.mean(
                model_D1(s_edge_pred))  # 后面一个相当于全部是正确答案的和前一个size相同的tensor
            sd_loss = loss_adv_target1 / (args.iter_size * args.batch_size)
            loss_adv_target_value1 += sd_loss.data  # 记录专用

            sd_loss = sd_loss * args.lambda_adv_target2
            sd_loss.backward()

            # edge
            edge_loss_fuse = bce2d(e_edge_pred[0], edge_label, reduction='sum')
            edge_loss_part = []
            for ix in e_edge_pred[1]:
                edge_loss_part.append(bce2d(ix, edge_label, reduction='sum'))
            edge_loss = (edge_loss_fuse + sum(edge_loss_part)) / (
                args.iter_size * args.batch_size)
            loss_seg_value2 += edge_loss.data

            edge_loss.backward()

            loss_adv_target2 = -torch.mean(model_D2(
                e_sal_pred[0]))  # 后面一个相当于全部是正确答案的和前一个size相同的tensor
            for ix in e_sal_pred[1]:
                loss_adv_target2 += -torch.mean(model_D2(ix))
            ed_loss = loss_adv_target2 / (args.iter_size * args.batch_size) / (
                len(e_sal_pred[1]) + 1)
            loss_adv_target_value2 += ed_loss.data

            ed_loss = ed_loss * args.lambda_adv_target2
            ed_loss.backward()

            # train D
            for param in model_D1.parameters():
                param.requires_grad = True
            for param in model_D2.parameters():
                param.requires_grad = True

            s_sal_pred = s_sal_pred.detach()
            s_edge_pred = s_edge_pred.detach()
            e_sal_pred = [
                e_sal_pred[0].detach(), [x.detach() for x in e_sal_pred[1]]
            ]
            e_edge_pred = [
                e_edge_pred[0].detach(), [x.detach() for x in e_edge_pred[1]]
            ]

            # sal
            ss_loss = torch.mean(model_D1(s_sal_pred))
            ss_Loss = ss_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += ss_Loss.data

            ss_Loss.backward()

            se_loss = -torch.mean(model_D1(s_edge_pred))
            se_Loss = se_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += se_Loss.data

            se_Loss.backward()

            # edge
            es_loss = torch.mean(model_D2(e_sal_pred[0]))
            for ix in e_sal_pred[1]:
                es_loss += torch.mean(model_D2(ix))
            es_Loss = es_loss / (args.iter_size *
                                 args.batch_size) / (len(e_sal_pred[1]) + 1)
            loss_D_value2 += es_Loss.data

            es_Loss.backward()

            ee_loss = -torch.mean(model_D2(e_edge_pred[0]))
            for ix in e_edge_pred[1]:
                ee_loss += -torch.mean(model_D2(ix))
            ee_Loss = ee_loss / (args.iter_size *
                                 args.batch_size) / (len(e_edge_pred[1]) + 1)
            loss_D_value2 += ee_Loss.data

            ee_Loss.backward()

            aveGrad += 1
            if aveGrad % args.iter_size == 0:
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            optimizer_D1.step()
            for p in model_D1.parameters():
                p.data.clamp_(-args.clip_value, args.clip_value)
            optimizer_D1.zero_grad()
            optimizer_D2.step()
            for p in model_D2.parameters():
                p.data.clamp_(-args.clip_value, args.clip_value)
            optimizer_D2.zero_grad()

            if i_iter % (args.show_every // args.batch_size) == 0:
                print(
                    'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_seg2 = {7:.3f}, loss_adv1 = {3:.3f}, loss_adv2 = {8:.3f}, loss_D1 = {4:.3f}, loss_D2 = {9:.3f}'
                    .format(i_iter, iter_num, loss_seg_value1,
                            loss_adv_target_value1, loss_D_value1, i_epotch,
                            args.epotch_size, loss_seg_value2,
                            loss_adv_target_value2, loss_D_value2))
                with open(args.file_dir, 'a') as f:
                    f.write(
                        'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_seg2 = {7:.3f}, loss_adv1 = {3:.3f}, loss_adv2 = {8:.3f}, loss_D1 = {4:.3f}, loss_D2 = {9:.3f}\n'
                        .format(i_iter, iter_num, loss_seg_value1,
                                loss_adv_target_value1, loss_D_value1,
                                i_epotch, args.epotch_size, loss_seg_value2,
                                loss_adv_target_value2, loss_D_value2))

                loss_seg_value1, loss_adv_target_value1, loss_D_value1, loss_seg_value2, loss_adv_target_value2, loss_D_value2 = 0, 0, 0, 0, 0, 0

            if i_iter == iter_num - 1 or i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                with open(args.file_dir, 'a') as f:
                    f.write('taking snapshot ...\n')
                torch.save(model.state_dict(),
                           osp.join(args.snapshot_dir, 'sal_.pth'))
                torch.save(model_D1.state_dict(),
                           osp.join(args.snapshot_dir, 'sal_D1.pth'))
                torch.save(model_D2.state_dict(),
                           osp.join(args.snapshot_dir, 'sal_D2.pth'))

        if i_epotch == 7:
            args.learning_rate = args.learning_rate * 0.1
            args.learning_rate_D = args.learning_rate_D * 0.1
            optimizer = optim.RMSprop(filter(lambda p: p.requires_grad,
                                             model.parameters()),
                                      lr=args.learning_rate,
                                      weight_decay=args.weight_decay)
            optimizer_D1 = optim.RMSprop(model_D1.parameters(),
                                         lr=args.learning_rate_D)
            optimizer_D2 = optim.RMSprop(model_D2.parameters(),
                                         lr=args.learning_rate_D)

    # end
    with open(args.file_dir, 'a') as f:
        f.write('end time: ' + str(datetime.now()) + '\n')