Пример #1
0
def main():
    net = DPNet().cuda()
    # net = nn.DataParallel(net, device_ids=[0])

    print 'load snapshot \'%s\' for testing' % args['snapshot']
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth')))
    net.eval()

    with torch.no_grad():

        results = {}

        for name, root in to_test.iteritems():

            precision_record, recall_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()
            time_record = AvgMeter()

            img_list = [
                os.path.splitext(f)[0] for f in os.listdir(root)
                if f.endswith('.jpg')
            ]

            for idx, img_name in enumerate(img_list):
                img_name = img_list[idx]
                print 'predicting for %s: %d / %d' % (name, idx + 1,
                                                      len(img_list))
                check_mkdir(
                    os.path.join(
                        ckpt_path, exp_name,
                        '(%s) %s_%s' % (exp_name, name, args['snapshot'])))

                start = time.time()
                img = Image.open(os.path.join(root, img_name +
                                              '.jpg')).convert('RGB')
                img_var = Variable(img_transform(img).unsqueeze(0),
                                   volatile=True).cuda()
                prediction = net(img_var)
                W, H = img.size
                prediction = F.upsample_bilinear(prediction, size=(H, W))
                prediction = np.array(to_pil(prediction.data.squeeze(0).cpu()))

                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                end = time.time()
                time_record.update(end - start)

                gt = np.array(
                    Image.open(os.path.join(root,
                                            img_name + '.jpg')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(
                    prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    Image.fromarray(prediction).save(
                        os.path.join(
                            ckpt_path, exp_name,
                            '(%s) %s_%s' % (exp_name, name, args['snapshot']),
                            img_name + '.jpg'))

            max_fmeasure, mean_fmeasure = cal_fmeasure_both(
                [precord.avg for precord in precision_record],
                [rrecord.avg for rrecord in recall_record])

            results[name] = {
                'max_fmeasure': max_fmeasure,
                'mae': mae_record.avg,
                'mean_fmeasure': mean_fmeasure
            }

        print 'test results:'
        print results
        print 'Runing time %.6f \n' % time_record.avg

        with open('dpnet_result', 'a') as f:
            f.write('\n%s \n %s: \n' % (exp_name, exp_predict))
            f.write('Runing time %.6f \n' % time_record.avg)
            for name, value in results.iteritems():
                f.write(
                    '%s: max_fmeasure: %.10f, mae: %.10f, mean_fmeasure: %.10f\n'
                    % (name, value['max_fmeasure'], value['mae'],
                       value['mean_fmeasure']))
Пример #2
0
def train(net, optimizer):
    curr_iter = 1

    for epoch in range(args['last_epoch'] + 1,
                       args['last_epoch'] + 1 + args['epoch_num']):
        loss_4_record, loss_3_record, loss_2_record, loss_1_record, \
        loss_c_record, loss_b_record, loss_o_record, loss_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), \
                                                                   AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

        train_iterator = tqdm(train_loader, total=len(train_loader))
        for data in train_iterator:
            if args['poly_train']:
                base_lr = args['lr'] * (
                    1 - float(curr_iter) /
                    (args['epoch_num'] * len(train_loader)))**args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            inputs, labels, edges = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda(device_ids[0])
            labels = Variable(labels).cuda(device_ids[0])
            edges = Variable(edges).cuda(device_ids[0])

            optimizer.zero_grad()

            predict_4, predict_3, predict_2, predict_1, predict_c, predict_b, predict_o = net(
                inputs)

            loss_4 = L.lovasz_hinge(predict_4, labels)
            loss_3 = L.lovasz_hinge(predict_3, labels)
            loss_2 = L.lovasz_hinge(predict_2, labels)
            loss_1 = L.lovasz_hinge(predict_1, labels)
            loss_c = L.lovasz_hinge(predict_c, labels)
            loss_b = bce(predict_b, edges)
            loss_o = 2 * L.lovasz_hinge(predict_o, labels)

            loss = loss_4 + loss_3 + loss_2 + loss_1 + loss_c + loss_b + loss_o

            loss.backward()

            optimizer.step()

            loss_record.update(loss.data, batch_size)
            loss_4_record.update(loss_4.data, batch_size)
            loss_3_record.update(loss_3.data, batch_size)
            loss_2_record.update(loss_2.data, batch_size)
            loss_1_record.update(loss_1.data, batch_size)
            loss_c_record.update(loss_c.data, batch_size)
            loss_b_record.update(loss_b.data, batch_size)
            loss_o_record.update(loss_o.data, batch_size)

            if curr_iter % 50 == 0:
                writer.add_scalar('loss', loss, curr_iter)
                writer.add_scalar('loss_4', loss_4, curr_iter)
                writer.add_scalar('loss_3', loss_3, curr_iter)
                writer.add_scalar('loss_2', loss_2, curr_iter)
                writer.add_scalar('loss_1', loss_1, curr_iter)
                writer.add_scalar('loss_c', loss_c, curr_iter)
                writer.add_scalar('loss_b', loss_b, curr_iter)
                writer.add_scalar('loss_o', loss_o, curr_iter)

            log = '[%3d], [%5d], [%.6f], [%.5f], [L4: %.5f], [L3: %.5f], ' \
                  '[L2: %.5f], [L1: %.5f], [Lc: %.5f], [Lb: %.5f], [Lo: %.5f]' % \
                  (epoch, curr_iter, base_lr, loss_record.avg, loss_4_record.avg, loss_3_record.avg, loss_2_record.avg,
                   loss_1_record.avg, loss_c_record.avg, loss_b_record.avg, loss_o_record.avg)
            train_iterator.set_description(log)
            open(log_path, 'a').write(log + '\n')

            curr_iter += 1

        if epoch in args['save_point']:
            net.cpu()
            torch.save(net.module.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            net.cuda(device_ids[0])

        if epoch >= args['epoch_num']:
            net.cpu()
            torch.save(net.module.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            print("Optimization Have Done!")
            return
Пример #3
0
def main():
    net = R3Net(motion='', se_layer=False, attention=True, dilation=True, basic_model='resnet50')

    print ('load snapshot \'%s\' for testing' % args['snapshot'])
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'), map_location='cuda:2'))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        for name, root in to_test.items():

            precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
            img_list = [i_id.strip() for i_id in open(imgs_path)]
            # img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]
            for idx, img_name in enumerate(img_list):
                print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))

                if name == 'VOS':
                    img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB')
                else:
                    img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                shape = img.size
                img = img.resize(args['input_size'])
                img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()
                start = time.time()
                prediction = net(img_var)
                end = time.time()
                print ('running time:', (end - start))
                precision = to_pil(prediction.data.squeeze(0).cpu())
                precision = precision.resize(shape)
                prediction = np.array(precision)
                prediction = prediction.astype('float')
                prediction = MaxMinNormalization(prediction, prediction.max(), prediction.min()) * 255.0
                prediction = prediction.astype('uint8')
                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                gt = np.array(Image.open(os.path.join(gt_root, img_name + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    folder, sub_name = os.path.split(img_name)
                    save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png'))

            fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                    [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print ('test results:')
    print (results)
Пример #4
0
def train():
    g = Generator(scale_factor=train_args['scale_factor']).cuda().train()
    g = nn.DataParallel(g, device_ids=[0, 1])
    if len(train_args['g_snapshot']) > 0:
        print('load generator snapshot ' + train_args['g_snapshot'])
        g.load_state_dict(
            torch.load(
                os.path.join(train_args['ckpt_path'],
                             train_args['g_snapshot'])))

    mse_criterion = nn.MSELoss().cuda()
    tv_criterion = TotalVariationLoss().cuda()
    g_mse_loss_record, g_tv_loss_record, g_loss_record, psnr_record = AvgMeter(
    ), AvgMeter(), AvgMeter(), AvgMeter()

    iter_nums = len(train_loader)

    if g_pretrain_args['pretrain']:
        g_optimizer = optim.Adam(g.parameters(), lr=g_pretrain_args['lr'])
        scheduler = optim.lr_scheduler.MultiStepLR(
            g_optimizer, milestones=[10, 20, 30, 40, 50], gamma=0.5)
        for epoch in range(g_pretrain_args['epoch_num']):
            scheduler.step()
            start = time.time()

            for i, data in enumerate(train_loader):
                hr_imgs, _ = data
                batch_size = hr_imgs.size(0)
                lr_imgs = Variable(
                    torch.stack([train_lr_transform(img) for img in hr_imgs],
                                0)).cuda()
                hr_imgs = Variable(hr_imgs).cuda()

                g.zero_grad()
                gen_hr_imgs = g(lr_imgs)

                g_mse_loss = mse_criterion(gen_hr_imgs, hr_imgs)
                # g_tv_loss = tv_criterion(gen_hr_imgs)
                g_tv_loss = 0
                g_loss = g_mse_loss + 2e-8 * g_tv_loss
                g_loss.backward()
                g_optimizer.step()

                g_mse_loss_record.update(g_mse_loss.item(), batch_size)
                # g_tv_loss_record.update(g_tv_loss.item(), batch_size)
                g_loss_record.update(g_loss.item(), batch_size)
                psnr_record.update(10 * np.log10(1 / g_mse_loss.item()),
                                   batch_size)

                print(
                    '[pretrain]: [epoch %d], [iter %d / %d], [loss %.5f], [psnr %.5f]'
                    % (epoch + 1, i + 1, iter_nums, g_loss_record.avg,
                       psnr_record.avg))

                writer.add_scalar('pretrain_g_loss', g_loss_record.avg,
                                  epoch * iter_nums + i + 1)
                writer.add_scalar('pretrain_psnr', psnr_record.avg,
                                  epoch * iter_nums + i + 1)

            torch.save(
                g.state_dict(),
                os.path.join(
                    train_args['ckpt_path'],
                    'pretrain_g_epoch_%d_loss_%.5f_psnr_%.5f.pth' %
                    (epoch + 1, g_loss_record.avg, psnr_record.avg)))

            end = time.time()

            print(
                '[time for last epoch: %.5f] [pretrain]: [epoch %d], [iter %d / %d], [loss %.5f], [psnr %.5f]'
                % (end - start, epoch + 1, i + 1, iter_nums, g_loss_record.avg,
                   psnr_record.avg))

            g_mse_loss_record.reset()
            psnr_record.reset()

            validate(g, epoch)

    d = Discriminator().cuda().train()
    d = nn.DataParallel(d, device_ids=[0, 1])
    if len(train_args['d_snapshot']) > 0:
        print('load discriminator snapshot ' + train_args['d_snapshot'])
        d.load_state_dict(
            torch.load(
                os.path.join(train_args['ckpt_path'],
                             train_args['d_snapshot'])))

    g_optimizer = optim.Adam(g.parameters(), lr=train_args['g_lr'])
    d_optimizer = optim.Adam(d.parameters(), lr=train_args['d_lr'])
    g_scheduler = optim.lr_scheduler.MultiStepLR(g_optimizer,
                                                 milestones=[10, 20, 30, 40],
                                                 gamma=0.5)
    d_scheduler = optim.lr_scheduler.MultiStepLR(g_optimizer,
                                                 milestones=[10, 20, 30, 40],
                                                 gamma=0.5)
    perceptual_criterion, tv_criterion = PerceptualLoss().cuda(
    ), TotalVariationLoss().cuda()

    g_mse_loss_record, g_perceptual_loss_record, g_tv_loss_record = AvgMeter(
    ), AvgMeter(), AvgMeter()
    psnr_record, g_ad_loss_record, g_loss_record, d_loss_record = AvgMeter(
    ), AvgMeter(), AvgMeter(), AvgMeter()

    for epoch in range(train_args['start_epoch'] - 1, train_args['epoch_num']):
        g_scheduler.step()
        d_scheduler.step()
        start = time.time()

        for i, data in enumerate(train_loader):
            hr_imgs, _ = data
            batch_size = hr_imgs.size(0)
            lr_imgs = Variable(
                torch.stack([train_lr_transform(img) for img in hr_imgs],
                            0)).cuda()
            hr_imgs = Variable(hr_imgs).cuda()
            gen_hr_imgs = g(lr_imgs)

            # update d
            d.zero_grad()

            # gen_hr_imgs.detach() because we don't want to update the gradients for g when d is being updated
            # d_ad_loss = - torch.log10(1 - d(gen_hr_imgs.detach())).mean() - torch.log10(d(hr_imgs)).mean()
            d_ad_loss = d(gen_hr_imgs.detach()).mean() - d(hr_imgs).mean()
            d_ad_loss.backward()
            d_optimizer.step()

            d_loss_record.update(d_ad_loss.item(), batch_size)

            for p in d.parameters():
                p.data.clamp_(-train_args['c'], train_args['c'])

            # update g
            g.zero_grad()
            g_mse_loss = mse_criterion(gen_hr_imgs, hr_imgs)
            g_perceptual_loss = perceptual_criterion(gen_hr_imgs, hr_imgs)
            g_tv_loss = tv_criterion(gen_hr_imgs)
            # g_ad_loss = -torch.log10(d(gen_hr_imgs)).mean()
            g_ad_loss = -d(gen_hr_imgs).mean()
            g_loss = g_mse_loss + 0.006 * g_perceptual_loss + 0.001 * g_ad_loss + 2e-8 * g_tv_loss
            g_loss.backward()
            g_optimizer.step()

            g_mse_loss_record.update(g_mse_loss.item(), batch_size)
            g_perceptual_loss_record.update(g_perceptual_loss.item(),
                                            batch_size)
            g_tv_loss_record.update(g_tv_loss.item(), batch_size)
            psnr_record.update(10 * np.log10(1 / g_mse_loss.item()),
                               batch_size)
            g_ad_loss_record.update(g_ad_loss.item(), batch_size)
            g_loss_record.update(g_loss.item(), batch_size)

            print ('[train]: [epoch %d], [iter %d / %d], [d_ad_loss %.5f], [g_ad_loss %.5f], [psnr %.5f], ' \
                  '[g_mse_loss %.5f], [g_perceptual_loss %.5f], [g_tv_loss %.5f] [g_loss %.5f]' % \
                  (epoch + 1, i + 1, iter_nums, d_loss_record.avg, g_ad_loss_record.avg, psnr_record.avg,
                   g_mse_loss_record.avg, g_perceptual_loss_record.avg, g_tv_loss_record.avg, g_loss_record.avg))

            writer.add_scalar('d_loss', d_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_mse_loss', g_mse_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_perceptual_loss',
                              g_perceptual_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_tv_loss', g_tv_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('psnr', psnr_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_ad_loss', g_ad_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_loss', g_loss_record.avg,
                              epoch * iter_nums + i + 1)

        end = time.time()

        print ('[time for last epoch: %.5f][train]: [epoch %d], [iter %d / %d], [d_ad_loss %.5f], [g_ad_loss %.5f], [psnr %.5f], ' \
              '[g_mse_loss %.5f], [g_perceptual_loss %.5f], [g_tv_loss %.5f] [g_loss %.5f]' % \
              (end - start, epoch + 1, i + 1, iter_nums, d_loss_record.avg, g_ad_loss_record.avg, psnr_record.avg,
               g_mse_loss_record.avg, g_perceptual_loss_record.avg, g_tv_loss_record.avg, g_loss_record.avg))

        d_loss_record.reset()
        g_mse_loss_record.reset()
        g_perceptual_loss_record.reset()
        g_tv_loss_record.reset()
        psnr_record.reset()
        g_ad_loss_record.reset()
        g_loss_record.reset()

        validate(g, epoch, d)
Пример #5
0
def train_online(net, seq_name='breakdance'):
    online_args = {
        'iter_num': 100,
        'train_batch_size': 5,
        'lr': 1e-8,
        'lr_decay': 0.95,
        'weight_decay': 5e-4,
        'momentum': 0.95,
    }

    joint_transform = joint_transforms.Compose([
        joint_transforms.ImageResize(473),
        # joint_transforms.RandomCrop(473),
        # joint_transforms.RandomHorizontallyFlip(),
        # joint_transforms.RandomRotate(10)
    ])
    target_transform = transforms.ToTensor()
    train_set = VideoFirstImageFolder(to_test['davis'], gt_root, seq_name,
                                      online_args['train_batch_size'],
                                      joint_transform, img_transform,
                                      target_transform)
    online_train_loader = DataLoader(
        train_set,
        batch_size=online_args['train_batch_size'],
        num_workers=1,
        shuffle=False)

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * online_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        online_args['lr'],
        'weight_decay':
        online_args['weight_decay']
    }],
                          momentum=online_args['momentum'])

    criterion = nn.BCEWithLogitsLoss().cuda()
    net.train().cuda()
    fix_parameters(net.named_parameters())
    for curr_iter in range(0, online_args['iter_num']):
        total_loss_record, loss0_record, loss1_record = AvgMeter(), AvgMeter(
        ), AvgMeter()
        loss2_record, loss3_record, loss4_record = AvgMeter(), AvgMeter(
        ), AvgMeter()

        for i, data in enumerate(online_train_loader):
            optimizer.param_groups[0]['lr'] = 2 * online_args['lr'] * (
                1 - float(curr_iter) /
                online_args['iter_num'])**online_args['lr_decay']
            optimizer.param_groups[1]['lr'] = online_args['lr'] * (
                1 - float(curr_iter) /
                online_args['iter_num'])**online_args['lr_decay']
            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()
            outputs0, outputs1, outputs2, outputs3, outputs4 = net(inputs)
            loss0 = criterion(outputs0, labels)
            loss1 = criterion(outputs1, labels.narrow(0, 1, 4))
            loss2 = criterion(outputs2, labels.narrow(0, 2, 3))
            loss3 = criterion(outputs3, labels.narrow(0, 3, 2))
            loss4 = criterion(outputs4, labels.narrow(0, 4, 1))

            total_loss = loss0 + loss1 + loss2 + loss3 + loss4
            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.data, batch_size)
            loss0_record.update(loss0.data, batch_size)
            loss1_record.update(loss1.data, batch_size)
            loss2_record.update(loss2.data, batch_size)
            loss3_record.update(loss3.data, batch_size)
            loss4_record.update(loss4.data, batch_size)

            log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
                  '[loss4 %.5f], [lr %.13f]' % \
                  (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
                   loss3_record.avg, loss4_record.avg, optimizer.param_groups[1]['lr'])
            print(log)

    return net
Пример #6
0
def main():
    net = R3Net(motion='')

    print ('load snapshot \'%s\' for testing' % args['snapshot'])
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'), map_location='cuda:0'))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        for name, root in to_test.items():

            precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
            img_list = [i_id.strip() for i_id in open(imgs_path)]
            # img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]
            for idx, img_names in enumerate(img_list):
                print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
                img_seq = img_names.split(',')
                img_var = []
                for img_name in img_seq:
                    img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                    shape = img.size
                    img = img.resize(args['input_size'])
                    img_var.append(Variable(img_transform(img).unsqueeze(0), volatile=True).cuda())

                img_var = torch.cat(img_var, dim=0)
                prediction = net(img_var)
                precision = to_pil(prediction.data[-1].cpu())
                precision = precision.resize(shape)
                prediction = np.array(precision)

                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                gt = np.array(Image.open(os.path.join(gt_root, img_seq[-1] + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    folder, sub_name = os.path.split(img_name)
                    save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png'))

            fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                    [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print ('test results:')
    print (results)
Пример #7
0
def train(net, optimizer):
    global total_epoch
    curr_iter = 1
    start_time = time.time()

    for epoch in range(args['last_epoch'] + 1, args['last_epoch'] + 1 + args['epoch_num']):
        loss_record, loss_4_record, loss_3_record, loss_2_record, loss_1_record, = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

        train_iterator = tqdm(train_loader, total=len(train_loader))
        for data in train_iterator:
            if args['poly_train']:
                base_lr = args['lr'] * (1 - float(curr_iter) / float(total_epoch)) ** args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            if args['poly_warmup']:
                if curr_iter < args['warmup_epoch']:
                    base_lr = 1 / args['warmup_epoch'] * (1 + curr_iter)
                else:
                    curr_iter = curr_iter - args['warmup_epoch'] + 1
                    total_epoch = total_epoch - args['warmup_epoch'] + 1
                    base_lr = args['lr'] * (1 - float(curr_iter) / float(total_epoch)) ** args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            if args['cosine_warmup']:
                if curr_iter < args['warmup_epoch']:
                    base_lr = 1 / args['warmup_epoch'] * (1 + curr_iter)
                else:
                    curr_iter = curr_iter - args['warmup_epoch'] + 1
                    total_epoch = total_epoch - args['warmup_epoch'] + 1
                    base_lr = args['lr'] * (1 + np.cos(np.pi * float(curr_iter) / float(total_epoch))) / 2
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            if args["f3_sche"]:
                base_lr = args['lr'] * (1 - abs((curr_iter + 1) / (total_epoch + 1) * 2 - 1))
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda(device_ids[0])
            labels = Variable(labels).cuda(device_ids[0])

            optimizer.zero_grad()

            predict_4, predict_3, predict_2, predict_1 = net(inputs)

            loss_4 = bce_iou_edge_loss(predict_4, labels)
            loss_3 = bce_iou_edge_loss(predict_3, labels)
            loss_2 = bce_iou_edge_loss(predict_2, labels)
            loss_1 = bce_iou_edge_loss(predict_1, labels)

            loss = args['w2'][0] * loss_4 + args['w2'][1] * loss_3 + args['w2'][2] * loss_2 + args['w2'][3] * loss_1

            loss.backward()

            optimizer.step()

            loss_record.update(loss.data, batch_size)
            loss_4_record.update(loss_4.data, batch_size)
            loss_3_record.update(loss_3.data, batch_size)
            loss_2_record.update(loss_2.data, batch_size)
            loss_1_record.update(loss_1.data, batch_size)

            if curr_iter % 50 == 0:
                writer.add_scalar('loss', loss, curr_iter)
                writer.add_scalar('loss_4', loss_4, curr_iter)
                writer.add_scalar('loss_3', loss_3, curr_iter)
                writer.add_scalar('loss_2', loss_2, curr_iter)
                writer.add_scalar('loss_1', loss_1, curr_iter)

            log = '[%3d], [%6d], [%.6f], [%.5f], [%.5f], [%.5f], [%.5f], [%.5f]' % \
                  (epoch, curr_iter, base_lr, loss_record.avg, loss_4_record.avg, loss_3_record.avg,
                   loss_2_record.avg, loss_1_record.avg)
            train_iterator.set_description(log)
            open(log_path, 'a').write(log + '\n')

            curr_iter += 1

        if epoch in args['save_point']:
            net.cpu()
            torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            net.cuda(device_ids[0])

        if epoch >= args['epoch_num']:
            net.cpu()
            torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            print("Total Training Time: {}".format(str(datetime.timedelta(seconds=int(time.time() - start_time)))))
            print("Optimization Have Done!")
            return
Пример #8
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()
        # loss3_record = AvgMeter()

        for i, data in enumerate(train_loader):

            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, labels = data
            if args['train_loader'] == 'video_sequence':
                inputs = inputs.squeeze(0)
                labels = labels.squeeze(0)
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()
            outputs0, outputs1, outputs2, _, _ = net(inputs)
            loss0 = criterion(outputs0, labels)
            loss1 = criterion(outputs1, labels)
            loss2 = criterion(outputs2, labels)
            # loss3 = criterion(outputs3, labels)
            # loss4 = criterion(outputs4, labels)

            if args['distillation']:
                loss02 = criterion(outputs0, F.sigmoid(outputs2))
                loss12 = criterion(outputs1, F.sigmoid(outputs2))
                total_loss = loss0 + loss1 + loss2 + 0.5 * loss02 + 0.5 * loss12
            else:
                total_loss = loss0 + loss1 + loss2

            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.data, batch_size)
            loss0_record.update(loss0.data, batch_size)
            loss1_record.update(loss1.data, batch_size)
            loss2_record.update(loss2.data, batch_size)
            # loss3_record.update(loss3.data, batch_size)
            # loss4_record.update(loss4.data, batch_size)

            curr_iter += 1

            log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f] ' \
                  '[lr %.13f]' % \
                  (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
                   optimizer.param_groups[1]['lr'])
            print(log)
            open(log_path, 'a').write(log + '\n')

            if curr_iter % args['iter_save'] == 0:
                print('taking snapshot ...')
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 '%d_optim.pth' % curr_iter))

            if curr_iter == args['iter_num']:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 '%d_optim.pth' % curr_iter))
                return
Пример #9
0
def train(net, optimizer):
    curr_iter = 1

    for epoch in range(args['last_epoch'] + 1,
                       args['last_epoch'] + 1 + args['epoch_num']):
        loss_f4_record, loss_f3_record, loss_f2_record, loss_f1_record, \
        loss_b4_record, loss_b3_record, loss_b2_record, loss_b1_record, \
        loss_e_record, loss_fb_record, loss_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), \
                                                     AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), \
                                                     AvgMeter(), AvgMeter(), AvgMeter()

        train_iterator = tqdm(train_loader, total=len(train_loader))
        for data in train_iterator:
            if args['poly_train']:
                base_lr = args['lr'] * (
                    1 - float(curr_iter) /
                    (args['epoch_num'] * len(train_loader)))**args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            inputs, labels, edges = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda(device_ids[0])
            labels = Variable(labels).cuda(device_ids[0])
            edges = Variable(edges).cuda(device_ids[0])

            optimizer.zero_grad()

            predict_f4, predict_f3, predict_f2, predict_f1, \
            predict_b4, predict_b3, predict_b2, predict_b1, predict_e, predict_fb = net(inputs)

            loss_f4 = wl(predict_f4, labels)
            loss_f3 = wl(predict_f3, labels)
            loss_f2 = wl(predict_f2, labels)
            loss_f1 = wl(predict_f1, labels)

            # loss_b4 = wl(1 - torch.sigmoid(predict_b4), labels)
            # loss_b3 = wl(1 - torch.sigmoid(predict_b3), labels)
            # loss_b2 = wl(1 - torch.sigmoid(predict_b2), labels)
            # loss_b1 = wl(1 - torch.sigmoid(predict_b1), labels)

            loss_b4 = wl(1 - predict_b4, labels)
            loss_b3 = wl(1 - predict_b3, labels)
            loss_b2 = wl(1 - predict_b2, labels)
            loss_b1 = wl(1 - predict_b1, labels)

            loss_e = el(predict_e, edges)

            loss_fb = wl(predict_fb, labels)

            loss = loss_f4 + loss_f3 + loss_f2 + loss_f1 + \
                   loss_b4 + loss_b3 + loss_b2 + loss_b1 + loss_e + 8 * loss_fb

            loss.backward()

            optimizer.step()

            loss_record.update(loss.data, batch_size)
            loss_f4_record.update(loss_f4.data, batch_size)
            loss_f3_record.update(loss_f3.data, batch_size)
            loss_f2_record.update(loss_f2.data, batch_size)
            loss_f1_record.update(loss_f1.data, batch_size)
            loss_b4_record.update(loss_b4.data, batch_size)
            loss_b3_record.update(loss_b3.data, batch_size)
            loss_b2_record.update(loss_b2.data, batch_size)
            loss_b1_record.update(loss_b1.data, batch_size)
            loss_e_record.update(loss_e.data, batch_size)
            loss_fb_record.update(loss_fb.data, batch_size)

            if curr_iter % 50 == 0:
                writer.add_scalar('Total loss', loss, curr_iter)
                writer.add_scalar('f4 loss', loss_f4, curr_iter)
                writer.add_scalar('f3 loss', loss_f3, curr_iter)
                writer.add_scalar('f2 loss', loss_f2, curr_iter)
                writer.add_scalar('f1 loss', loss_f1, curr_iter)
                writer.add_scalar('b4 loss', loss_b4, curr_iter)
                writer.add_scalar('b3 loss', loss_b3, curr_iter)
                writer.add_scalar('b2 loss', loss_b2, curr_iter)
                writer.add_scalar('b1 loss', loss_b1, curr_iter)
                writer.add_scalar('e loss', loss_e, curr_iter)
                writer.add_scalar('fb loss', loss_fb, curr_iter)

            log = '[%3d], [f4 %.5f], [f3 %.5f], [f2 %.5f], [f1 %.5f] ' \
                  '[b4 %.5f], [b3 %.5f], [b2 %.5f], [b1 %.5f], [e %.5f], [fb %.5f], [lr %.6f]' % \
                  (epoch,
                   loss_f4_record.avg, loss_f3_record.avg, loss_f2_record.avg, loss_f1_record.avg,
                   loss_b4_record.avg, loss_b3_record.avg, loss_b2_record.avg, loss_b1_record.avg,
                   loss_e_record.avg, loss_fb_record.avg, base_lr)
            train_iterator.set_description(log)
            open(log_path, 'a').write(log + '\n')

            curr_iter += 1

        if epoch in args['save_point']:
            net.cpu()
            torch.save(net.module.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            net.cuda(device_ids[0])

        if epoch >= args['epoch_num']:
            net.cpu()
            torch.save(net.module.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            print("Optimization Have Done!")
            return
Пример #10
0
def main():
    net = R3Net().cuda()

    print('load snapshot \'%s\' for testing' % args['snapshot'])
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth')))
    net.eval()

    results = {}

    with torch.no_grad():

        for name, root in to_test.iteritems():

            precision_record, recall_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(
                    os.path.join(
                        ckpt_path, exp_name,
                        '(%s) %s_%s' % (exp_name, name, args['snapshot'])))

            img_list = [
                os.path.splitext(f)[0] for f in os.listdir(root)
                if f.endswith('.jpg')
            ]
            for idx, img_name in enumerate(img_list):
                print('predicting for %s: %d / %d' %
                      (name, idx + 1, len(img_list)))

                img = Image.open(os.path.join(root, img_name +
                                              '.jpg')).convert('RGB')
                img_var = Variable(img_transform(img).unsqueeze(0),
                                   volatile=True).cuda()
                prediction = net(img_var)
                prediction = np.array(to_pil(prediction.data.squeeze(0).cpu()))

                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                gt = np.array(
                    Image.open(os.path.join(root,
                                            img_name + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(
                    prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    Image.fromarray(prediction).save(
                        os.path.join(
                            ckpt_path, exp_name,
                            '(%s) %s_%s' % (exp_name, name, args['snapshot']),
                            img_name + '.png'))

            fmeasure = cal_fmeasure(
                [precord.avg for precord in precision_record],
                [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print('test results:')
    print(results)
Пример #11
0
def main():
    net = SDCNet(num_classes=5).cuda()

    print('load snapshot \'%s\' for testing, mode:\'%s\'' %
          (args['snapshot'], args['test_mode']))
    print(exp_name)
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth')))
    net.eval()

    results = {}

    with torch.no_grad():

        for name, root in to_test.items():
            print('load snapshot \'%s\' for testing %s' %
                  (args['snapshot'], name))

            test_data = pd.read_csv(root)
            test_set = TestFolder_joint(test_data, joint_transform,
                                        img_transform, target_transform)
            test_loader = DataLoader(test_set,
                                     batch_size=1,
                                     num_workers=0,
                                     shuffle=False)

            precision0_record, recall0_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            precision1_record, recall1_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            precision2_record, recall2_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            precision3_record, recall3_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            precision4_record, recall4_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            precision5_record, recall5_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            precision6_record, recall6_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]

            mae0_record = AvgMeter()
            mae1_record = AvgMeter()
            mae2_record = AvgMeter()
            mae3_record = AvgMeter()
            mae4_record = AvgMeter()
            mae5_record = AvgMeter()
            mae6_record = AvgMeter()

            n0, n1, n2, n3, n4, n5 = 0, 0, 0, 0, 0, 0

            if args['save_results']:
                check_mkdir(
                    os.path.join(ckpt_path, exp_name,
                                 '%s_%s' % (name, args['snapshot'])))

            for i, (inputs, gt, labels,
                    img_path) in enumerate(tqdm(test_loader)):

                shape = gt.size()[2:]

                img_var = Variable(inputs).cuda()
                img = np.array(to_pil(img_var.data.squeeze(0).cpu()))

                gt = np.array(to_pil(gt.data.squeeze(0).cpu()))
                sizec = labels.numpy()
                pred2021 = net(img_var, sizec)

                pred2021 = F.interpolate(pred2021,
                                         size=shape,
                                         mode='bilinear',
                                         align_corners=True)
                pred2021 = np.array(to_pil(pred2021.data.squeeze(0).cpu()))

                if labels == 0:
                    precision1, recall1, mae1 = cal_precision_recall_mae(
                        pred2021, gt)
                    for pidx, pdata in enumerate(zip(precision1, recall1)):
                        p, r = pdata
                        precision1_record[pidx].update(p)
                        #print('Presicion:', p, 'Recall:', r)
                        recall1_record[pidx].update(r)
                    mae1_record.update(mae1)
                    n1 += 1

                elif labels == 1:
                    precision2, recall2, mae2 = cal_precision_recall_mae(
                        pred2021, gt)
                    for pidx, pdata in enumerate(zip(precision2, recall2)):
                        p, r = pdata
                        precision2_record[pidx].update(p)
                        #print('Presicion:', p, 'Recall:', r)
                        recall2_record[pidx].update(r)
                    mae2_record.update(mae2)
                    n2 += 1

                elif labels == 2:
                    precision3, recall3, mae3 = cal_precision_recall_mae(
                        pred2021, gt)
                    for pidx, pdata in enumerate(zip(precision3, recall3)):
                        p, r = pdata
                        precision3_record[pidx].update(p)
                        #print('Presicion:', p, 'Recall:', r)
                        recall3_record[pidx].update(r)
                    mae3_record.update(mae3)
                    n3 += 1

                elif labels == 3:
                    precision4, recall4, mae4 = cal_precision_recall_mae(
                        pred2021, gt)
                    for pidx, pdata in enumerate(zip(precision4, recall4)):
                        p, r = pdata
                        precision4_record[pidx].update(p)
                        #print('Presicion:', p, 'Recall:', r)
                        recall4_record[pidx].update(r)
                    mae4_record.update(mae4)
                    n4 += 1

                elif labels == 4:
                    precision5, recall5, mae5 = cal_precision_recall_mae(
                        pred2021, gt)
                    for pidx, pdata in enumerate(zip(precision5, recall5)):
                        p, r = pdata
                        precision5_record[pidx].update(p)
                        #print('Presicion:', p, 'Recall:', r)
                        recall5_record[pidx].update(r)
                    mae5_record.update(mae5)
                    n5 += 1

                precision6, recall6, mae6 = cal_precision_recall_mae(
                    pred2021, gt)
                for pidx, pdata in enumerate(zip(precision6, recall6)):
                    p, r = pdata
                    precision6_record[pidx].update(p)
                    recall6_record[pidx].update(r)
                mae6_record.update(mae6)

                img_name = os.path.split(str(img_path))[1]
                img_name = os.path.splitext(img_name)[0]
                n0 += 1

                if args['save_results']:
                    Image.fromarray(pred2021).save(
                        os.path.join(ckpt_path, exp_name,
                                     '%s_%s' % (name, args['snapshot']),
                                     img_name + '_2021.png'))

            fmeasure1 = cal_fmeasure(
                [precord.avg for precord in precision1_record],
                [rrecord.avg for rrecord in recall1_record])
            fmeasure2 = cal_fmeasure(
                [precord.avg for precord in precision2_record],
                [rrecord.avg for rrecord in recall2_record])
            fmeasure3 = cal_fmeasure(
                [precord.avg for precord in precision3_record],
                [rrecord.avg for rrecord in recall3_record])
            fmeasure4 = cal_fmeasure(
                [precord.avg for precord in precision4_record],
                [rrecord.avg for rrecord in recall4_record])
            fmeasure5 = cal_fmeasure(
                [precord.avg for precord in precision5_record],
                [rrecord.avg for rrecord in recall5_record])
            fmeasure6 = cal_fmeasure(
                [precord.avg for precord in precision6_record],
                [rrecord.avg for rrecord in recall6_record])
            results[name] = {
                'fmeasure1': fmeasure1,
                'mae1': mae1_record.avg,
                'fmeasure2': fmeasure2,
                'mae2': mae2_record.avg,
                'fmeasure3': fmeasure3,
                'mae3': mae3_record.avg,
                'fmeasure4': fmeasure4,
                'mae4': mae4_record.avg,
                'fmeasure5': fmeasure5,
                'mae5': mae5_record.avg,
                'fmeasure6': fmeasure6,
                'mae6': mae6_record.avg
            }

            print('test results:')
            print('[fmeasure1 %.3f], [mae1 %.4f], [class1 %.0f]\n'\
                  '[fmeasure2 %.3f], [mae2 %.4f], [class2 %.0f]\n'\
                  '[fmeasure3 %.3f], [mae3 %.4f], [class3 %.0f]\n'\
                  '[fmeasure4 %.3f], [mae4 %.4f], [class4 %.0f]\n'\
                  '[fmeasure5 %.3f], [mae5 %.4f], [class5 %.0f]\n'\
                  '[fmeasure6 %.3f], [mae6 %.4f], [all %.0f]\n'%\
                  (fmeasure1, mae1_record.avg, n1, fmeasure2, mae2_record.avg, n2, fmeasure3, mae3_record.avg, n3, fmeasure4, mae4_record.avg, n4, fmeasure5, mae5_record.avg, n5, fmeasure6, mae6_record.avg, n0))
Пример #12
0
def train(exp_name):

    net = AADFNet().cuda().train()
    net = nn.DataParallel(net, device_ids=[0, 1])

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             args['snapshot'] + '_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    log_path = os.path.join(ckpt_path, exp_name,
                            str(datetime.datetime.now()) + '.txt')
    open(log_path, 'w').write(str(args) + '\n\n')
    print 'start to train'

    curr_iter = args['last_iter']
    while True:
        total_loss_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(
        ), AvgMeter()
        loss3_record, loss4_record = AvgMeter(), AvgMeter()
        loss2_2_record, loss3_2_record, loss4_2_record = AvgMeter(), AvgMeter(
        ), AvgMeter()

        loss44_record, loss43_record, loss42_record, loss41_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()
        loss34_record, loss33_record, loss32_record, loss31_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()
        loss24_record, loss23_record, loss22_record, loss21_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()
        loss14_record, loss13_record, loss12_record, loss11_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()

            outputs4_2, outputs3_2, outputs2_2, outputs1, outputs2, outputs3, outputs4, \
                    predict41, predict42, predict43, predict44, \
                    predict31, predict32, predict33, predict34, \
                    predict21, predict22, predict23, predict24, \
                    predict11, predict12, predict13, predict14 = net(inputs)

            loss1 = criterion(outputs1, labels)
            loss2 = criterion(outputs2, labels)
            loss3 = criterion(outputs3, labels)
            loss4 = criterion(outputs4, labels)

            loss2_2 = criterion(outputs2_2, labels)
            loss3_2 = criterion(outputs3_2, labels)
            loss4_2 = criterion(outputs4_2, labels)

            loss44 = criterion(predict44, labels)
            loss43 = criterion(predict43, labels)
            loss42 = criterion(predict42, labels)
            loss41 = criterion(predict41, labels)

            loss34 = criterion(predict34, labels)
            loss33 = criterion(predict33, labels)
            loss32 = criterion(predict32, labels)
            loss31 = criterion(predict31, labels)

            loss24 = criterion(predict24, labels)
            loss23 = criterion(predict23, labels)
            loss22 = criterion(predict22, labels)
            loss21 = criterion(predict21, labels)

            loss14 = criterion(predict14, labels)
            loss13 = criterion(predict13, labels)
            loss12 = criterion(predict12, labels)
            loss11 = criterion(predict11, labels)

            total_loss = loss1 + loss2 + loss3 + loss4 + loss2_2 + loss3_2 + loss4_2 \
                         + (loss44 + loss43 + loss42 + loss41)/10 \
                         + (loss34 + loss33 + loss32 + loss31)/10 \
                         + (loss24 + loss23 + loss22 + loss21)/10 \
                         + (loss14 + loss13 + loss12 + loss11)/10

            total_loss = loss1 + loss2 + loss3 + loss4

            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.item(), batch_size)
            loss1_record.update(loss1.item(), batch_size)
            loss2_record.update(loss2.item(), batch_size)
            loss3_record.update(loss3.item(), batch_size)
            loss4_record.update(loss4.item(), batch_size)

            loss2_2_record.update(loss2_2.item(), batch_size)
            loss3_2_record.update(loss3_2.item(), batch_size)
            loss4_2_record.update(loss4_2.item(), batch_size)

            loss44_record.update(loss44.item(), batch_size)
            loss43_record.update(loss43.item(), batch_size)
            loss42_record.update(loss42.item(), batch_size)
            loss41_record.update(loss41.item(), batch_size)

            loss34_record.update(loss34.item(), batch_size)
            loss33_record.update(loss33.item(), batch_size)
            loss32_record.update(loss32.item(), batch_size)
            loss31_record.update(loss31.item(), batch_size)

            loss24_record.update(loss24.item(), batch_size)
            loss23_record.update(loss23.item(), batch_size)
            loss22_record.update(loss22.item(), batch_size)
            loss21_record.update(loss21.item(), batch_size)

            loss14_record.update(loss14.item(), batch_size)
            loss13_record.update(loss13.item(), batch_size)
            loss12_record.update(loss12.item(), batch_size)
            loss11_record.update(loss11.item(), batch_size)

            curr_iter += 1

            log = '[iter %d], [total loss %.5f], ' \
                  '[loss4_2 %.5f], [loss3_2 %.5f], [loss2_2 %.5f], [loss1 %.5f], ' \
                  '[loss2 %.5f], [loss3 %.5f], [loss4 %.5f], ' \
                  '[loss44 %.5f], [loss43 %.5f], [loss42 %.5f], [loss41 %.5f], ' \
                  '[loss34 %.5f], [loss33 %.5f], [loss32 %.5f], [loss31 %.5f], ' \
                  '[loss24 %.5f], [loss23 %.5f], [loss22 %.5f], [loss21 %.5f], ' \
                  '[loss14 %.5f], [loss13 %.5f], [loss12 %.5f], [loss11 %.5f], ' \
                  '[lr %.13f]' % \
                  (curr_iter, total_loss_record.avg,
                   loss4_2_record.avg, loss3_2_record.avg,
                   loss2_2_record.avg, loss1_record.avg, loss2_record.avg,
                   loss3_record.avg, loss4_record.avg,
                   loss44_record.avg, loss43_record.avg, loss42_record.avg, loss41_record.avg,
                   loss34_record.avg, loss33_record.avg, loss32_record.avg, loss31_record.avg,
                   loss24_record.avg, loss23_record.avg, loss22_record.avg, loss21_record.avg,
                   loss14_record.avg, loss13_record.avg, loss12_record.avg, loss11_record.avg,
                   optimizer.param_groups[1]['lr'])

            print log
            open(log_path, 'a').write(log + '\n')

            if curr_iter == args['iter_num']:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 '%d_optim.pth' % curr_iter))
                return
Пример #13
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        train_loss_record = AvgMeter()
        train_net_loss_record = AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, gts, dps = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            gts = Variable(gts).cuda()
            dps = Variable(dps).cuda()

            optimizer.zero_grad()

            result = net(inputs)

            loss_net = criterion(result, gts)

            loss = loss_net

            loss.backward()

            optimizer.step()

            # for n, p in net.named_parameters():
            #     if n[-5:] == 'alpha':
            #         print(p.grad.data)
            #         print(p.data)

            train_loss_record.update(loss.data, batch_size)
            train_net_loss_record.update(loss_net.data, batch_size)

            curr_iter += 1

            log = '[iter %d], [train loss %.5f], [lr %.13f], [loss_net %.5f]' % \
                  (curr_iter, train_loss_record.avg, optimizer.param_groups[1]['lr'],
                   train_net_loss_record.avg)
            print(log)
            open(log_path, 'a').write(log + '\n')

            if (curr_iter + 1) % args['val_freq'] == 0:
                validate(net, curr_iter, optimizer)

            if (curr_iter + 1) % args['snapshot_epochs'] == 0:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 ('%d.pth' % (curr_iter + 1))))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 ('%d_optim.pth' % (curr_iter + 1))))

            if curr_iter > args['iter_num']:
                return
Пример #14
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    for e in range(args["epoch"]):
        total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()
        loss3_record, loss4_record, loss5_record, loss6_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()
        print "epoch", e
        for i, data in enumerate(train_loader):
            #optimizer.param_groups[0]['lr'] = 2 * args['lr'] *
            #                                                    ** args['lr_decay']
            #optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
            #                                                ) ** args['lr_decay']

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()
            print(inputs.size())
            optimizer.zero_grad()
            outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs6 = net(
                inputs)
            loss0 = criterion(outputs0, labels)
            loss1 = criterion(outputs1, labels)
            loss2 = criterion(outputs2, labels)
            loss3 = criterion(outputs3, labels)
            loss4 = criterion(outputs4, labels)
            loss5 = criterion(outputs5, labels)
            loss6 = criterion(outputs6, labels)

            total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.data[0], batch_size)
            loss0_record.update(loss0.data[0], batch_size)
            loss1_record.update(loss1.data[0], batch_size)
            loss2_record.update(loss2.data[0], batch_size)
            loss3_record.update(loss3.data[0], batch_size)
            loss4_record.update(loss4.data[0], batch_size)
            loss5_record.update(loss5.data[0], batch_size)
            loss6_record.update(loss6.data[0], batch_size)

            curr_iter += 1

            log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
                  '[loss4 %.5f], [loss5 %.5f], [loss6 %.5f], [lr %.13f]' % \
                  (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
                   loss3_record.avg, loss4_record.avg, loss5_record.avg, loss6_record.avg,
                   optimizer.param_groups[1]['lr'])
            print log
            open(log_path, 'a').write(log + '\n')

            if curr_iter % args['iter_num'] == 0:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 '%d_optim.pth' % curr_iter))
Пример #15
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        total_loss_record, loss0_record, loss1_record = AvgMeter(), AvgMeter(
        ), AvgMeter()
        loss2_record, loss3_record, loss4_record, loss5_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()
        if args['isTriplet']:
            loss_triplet_record = AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, labels = data
            if args['train_loader'] == 'video_sequence':
                inputs = inputs.squeeze(0)
                labels = labels.squeeze(0)
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()
            if args['isTriplet']:
                outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs_triplet = net(
                    inputs)
            else:
                outputs0, outputs1, outputs2, outputs3, outputs4, outputs5 = net(
                    inputs)
            loss0 = criterion(outputs0, labels)
            loss1 = criterion(outputs1, labels.narrow(0, 1, 5))
            loss2 = criterion(outputs2, labels.narrow(0, 2, 4))
            loss3 = criterion(outputs3, labels.narrow(0, 3, 3))
            loss4 = criterion(outputs4, labels.narrow(0, 4, 2))
            loss5 = criterion(outputs5, labels.narrow(0, 5, 1))

            if args['L2']:
                loss0 = loss0 + 0.1 * criterion_l2(
                    torch.relu(outputs0) / torch.max(outputs0), labels)
                loss1 = loss1 + 0.1 * criterion_l2(
                    torch.relu(outputs1) / torch.max(outputs1),
                    labels.narrow(0, 1, 4))
                loss2 = loss2 + 0.1 * criterion_l2(
                    torch.relu(outputs2) / torch.max(outputs2),
                    labels.narrow(0, 2, 3))
                loss3 = loss3 + 0.1 * criterion_l2(
                    torch.relu(outputs3) / torch.max(outputs3),
                    labels.narrow(0, 3, 2))
                loss4 = loss4 + 0.1 * criterion_l2(
                    torch.relu(outputs4) / torch.max(outputs4),
                    labels.narrow(0, 4, 1))

            if args['dice']:
                loss0 = loss0 + 0.5 * criterion_dice(outputs0, labels)
                loss1 = loss1 + 0.5 * criterion_dice(outputs1,
                                                     labels.narrow(0, 1, 5))
                loss2 = loss2 + 0.5 * criterion_dice(outputs2,
                                                     labels.narrow(0, 2, 4))
                loss3 = loss3 + 0.5 * criterion_dice(outputs3,
                                                     labels.narrow(0, 3, 3))
                loss4 = loss4 + 0.5 * criterion_dice(outputs4,
                                                     labels.narrow(0, 4, 2))
                loss5 = loss4 + 0.5 * criterion_dice(outputs5,
                                                     labels.narrow(0, 5, 1))

            if args['isTriplet']:
                loss_triplet = criterion_triplet(outputs_triplet[0],
                                                 outputs_triplet[1],
                                                 outputs_triplet[2])
                total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + 0.2 * loss_triplet
                total_loss.backward()
                optimizer.step()

                total_loss_record.update(total_loss.data, batch_size)
                loss0_record.update(loss0.data, batch_size)
                loss1_record.update(loss1.data, batch_size)
                loss2_record.update(loss2.data, batch_size)
                loss3_record.update(loss3.data, batch_size)
                loss4_record.update(loss4.data, batch_size)
                loss_triplet_record.update(loss_triplet.data, batch_size)

            else:
                total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5
                total_loss.backward()
                optimizer.step()

                total_loss_record.update(total_loss.data, batch_size)
                loss0_record.update(loss0.data, batch_size)
                loss1_record.update(loss1.data, batch_size)
                loss2_record.update(loss2.data, batch_size)
                loss3_record.update(loss3.data, batch_size)
                loss4_record.update(loss4.data, batch_size)
                loss5_record.update(loss5.data, batch_size)

            curr_iter += 1
            if args['isTriplet']:
                log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
                      '[loss4 %.5f], [loss_triplet %.5f], [lr %.13f] ' % \
                      (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
                       loss3_record.avg, loss4_record.avg, loss_triplet_record.avg, optimizer.param_groups[1]['lr'])
            else:
                log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
                  '[loss4 %.5f], [loss5 %.5f], [lr %.13f]' % \
                  (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
                   loss3_record.avg, loss4_record.avg, loss5_record.avg, optimizer.param_groups[1]['lr'])
            print(log)
            open(log_path, 'a').write(log + '\n')

            if curr_iter % args['iter_save'] == 0:
                print('taking snapshot ...')
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 '%d_optim.pth' % curr_iter))

            if curr_iter == args['iter_num']:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 '%d_optim.pth' % curr_iter))
                return
Пример #16
0
def train(net, optimizer):
    global best_ber
    curr_iter = 1
    start_time = time.time()

    for epoch in range(args['last_epoch'] + 1,
                       args['last_epoch'] + 1 + args['epoch_num']):
        loss_4_record, loss_3_record, loss_2_record, loss_1_record, \
        loss_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

        train_iterator = tqdm(train_loader, total=len(train_loader))
        for data in train_iterator:
            if args['poly_train']:
                base_lr = args['lr'] * (1 - float(curr_iter) /
                                        float(total_epoch))**args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda(device_ids[0])
            labels = Variable(labels).cuda(device_ids[0])

            optimizer.zero_grad()

            predict_4, predict_3, predict_2, predict_1 = net(inputs)

            loss_4 = L.lovasz_hinge(predict_4, labels)
            loss_3 = L.lovasz_hinge(predict_3, labels)
            loss_2 = L.lovasz_hinge(predict_2, labels)
            loss_1 = L.lovasz_hinge(predict_1, labels)

            loss = loss_4 + loss_3 + loss_2 + loss_1

            loss.backward()

            optimizer.step()

            loss_record.update(loss.data, batch_size)
            loss_4_record.update(loss_4.data, batch_size)
            loss_3_record.update(loss_3.data, batch_size)
            loss_2_record.update(loss_2.data, batch_size)
            loss_1_record.update(loss_1.data, batch_size)

            if curr_iter % 50 == 0:
                writer.add_scalar('loss', loss, curr_iter)
                writer.add_scalar('loss_4', loss_4, curr_iter)
                writer.add_scalar('loss_3', loss_3, curr_iter)
                writer.add_scalar('loss_2', loss_2, curr_iter)
                writer.add_scalar('loss_1', loss_1, curr_iter)

            log = '[%3d], [%6d], [%.6f], [%.5f], [L4: %.5f], [L3: %.5f], [L2: %.5f], [L1: %.5f]' % \
                  (epoch, curr_iter, base_lr, loss_record.avg, loss_4_record.avg, loss_3_record.avg, loss_2_record.avg,
                   loss_1_record.avg)
            train_iterator.set_description(log)
            open(log_path, 'a').write(log + '\n')

            curr_iter += 1

        if epoch in args['save_point']:
            net.cpu()
            torch.save(net.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            net.cuda(device_ids[0])

        if epoch >= args['epoch_thres'] and epoch % 5 == 0:
            ber = test(net)
            print("mean ber of %d epoch is %.5f" % (epoch, ber))
            if ber < best_ber:
                net.cpu()
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 'epoch_%d_ber_%.2f.pth' % (epoch, ber)))
                print("The optimized epoch is %04d" % epoch)
            net = net.cuda(device_ids[0]).train()

        if epoch >= args['epoch_num']:
            net.cpu()
            torch.save(net.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            print("Total Training Time: {}".format(
                str(datetime.timedelta(seconds=int(time.time() -
                                                   start_time)))))
            print(exp_name)
            print("Optimization Have Done!")
            return
Пример #17
0
def main():
    # net = R3Net(motion='', se_layer=False, dilation=False, basic_model='resnet50')

    net = SNet(cfg=None)

    print('load snapshot \'%s\' for testing' % args['snapshot'])
    # net.load_state_dict(torch.load('pretrained/R2Net.pth', map_location='cuda:2'))
    # net = load_part_of_model2(net, 'pretrained/R2Net.pth', device_id=2)
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth'),
                   map_location='cuda:2'))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        for name, root in to_test.items():

            precision_record, recall_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(
                    os.path.join(
                        ckpt_path, exp_name,
                        '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
            img_list = [i_id.strip() for i_id in open(imgs_path)]
            video = ''
            pre_predict = None
            for idx, img_name in enumerate(img_list):
                print('predicting for %s: %d / %d' %
                      (name, idx + 1, len(img_list)))
                print(img_name)
                if video != img_name.split('/')[0]:
                    video = img_name.split('/')[0]
                    if name == 'VOS' or name == 'DAVSOD':
                        img = Image.open(os.path.join(root, img_name +
                                                      '.png')).convert('RGB')
                    else:
                        img = Image.open(os.path.join(root, img_name +
                                                      '.jpg')).convert('RGB')
                    shape = img.size
                    img = img.resize(args['input_size'])
                    img_var = Variable(img_transform(img).unsqueeze(0),
                                       volatile=True).cuda()
                    start = time.time()
                    if args['model'] == 'BASNet':
                        prediction, _, prediction2, _, _, _, _, _ = net(
                            img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'R3Net':
                        prediction = net(img_var)
                    elif args['model'] == 'DSSNet':
                        select = [1, 2, 3, 6]
                        prediction = net(img_var)
                        prediction = torch.mean(torch.cat(
                            [torch.sigmoid(prediction[i]) for i in select],
                            dim=1),
                                                dim=1,
                                                keepdim=True)
                    elif args['model'] == 'CPD':
                        prediction2, prediction = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'RAS':
                        prediction, _, _, _, _ = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'PoolNet':
                        prediction = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'F3Net':
                        prediction2, prediction, _, _, _, _ = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'R2Net':
                        _, _, _, _, _, prediction = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    end = time.time()
                    pre_predict = prediction
                    print('running time:', (end - start))
                else:
                    if name == 'VOS' or name == 'DAVSOD':
                        img = Image.open(os.path.join(root, img_name +
                                                      '.png')).convert('RGB')
                    else:
                        img = Image.open(os.path.join(root, img_name +
                                                      '.jpg')).convert('RGB')
                    shape = img.size
                    img = img.resize(args['input_size'])
                    img_var = Variable(img_transform(img).unsqueeze(0),
                                       volatile=True).cuda()
                    start = time.time()

                    _, prediction, _, _, _, _ = net(img_var)

                    end = time.time()
                    print('running time:', (end - start))
                    pre_predict = prediction
                # e = Erosion2d(1, 1, 5, soft_max=False).cuda()
                # prediction2 = e(prediction)
                #
                # precision2 = to_pil(prediction2.data.squeeze(0).cpu())
                # precision2 = prediction2.data.squeeze(0).cpu().numpy()
                # precision2 = precision2.resize(shape)
                # prediction2 = np.array(precision2)
                # prediction2 = prediction2.astype('float')

                precision = to_pil(prediction.data.squeeze(0).cpu())
                precision = precision.resize(shape)
                prediction = np.array(precision)
                prediction = prediction.astype('float')

                # plt.style.use('classic')
                # plt.subplot(1, 2, 1)
                # plt.imshow(prediction)
                # plt.subplot(1, 2, 2)
                # plt.imshow(precision2[0])
                # plt.show()

                prediction = MaxMinNormalization(prediction, prediction.max(),
                                                 prediction.min()) * 255.0
                prediction = prediction.astype('uint8')
                # if args['crf_refine']:
                #     prediction = crf_refine(np.array(img), prediction)

                gt = np.array(
                    Image.open(os.path.join(gt_root,
                                            img_name + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(
                    prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    folder, sub_name = os.path.split(img_name)
                    save_path = os.path.join(
                        ckpt_path, exp_name,
                        '(%s) %s_%s' % (exp_name, name, args['snapshot']),
                        folder)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    Image.fromarray(prediction).save(
                        os.path.join(save_path, sub_name + '.png'))

            fmeasure = cal_fmeasure(
                [precord.avg for precord in precision_record],
                [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print('test results:')
    print(results)
Пример #18
0
def train(net, optimizer):
    curr_iter = 1

    for epoch in range(args['last_epoch'] + 1, args['last_epoch'] + 1 + args['epoch_num']):
        loss_record, loss_f_record, loss_b_record, loss_o_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

        train_iterator = tqdm(train_loader, total=len(train_loader))
        for data in train_iterator:
            if args['poly_train']:
                base_lr = args['lr'] * (1 - float(curr_iter) / (args['epoch_num'] * len(train_loader))) ** args[
                    'lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda(device_ids[0])
            labels = Variable(labels).cuda(device_ids[0])

            optimizer.zero_grad()

            predict_f, predict_b, predict_o = net(inputs)

            loss_f = L.lovasz_hinge(predict_f, labels)
            loss_b = L.lovasz_hinge(predict_b, 1 - labels)
            loss_o = 2 * L.lovasz_hinge(predict_o, labels)

            loss = loss_f + loss_b + loss_o

            loss.backward()

            optimizer.step()

            loss_record.update(loss.data, batch_size)
            loss_f_record.update(loss_f.data, batch_size)
            loss_b_record.update(loss_b.data, batch_size)
            loss_o_record.update(loss_o.data, batch_size)

            if curr_iter % 50 == 0:
                writer.add_scalar('loss', loss, curr_iter)
                writer.add_scalar('loss_f', loss_f, curr_iter)
                writer.add_scalar('loss_b', loss_b, curr_iter)
                writer.add_scalar('loss_o', loss_o, curr_iter)

            log = '[Epoch: %2d], [Iter: %5d], [%.7f], [Sum: %.5f], [Lf: %.5f], [Lb: %.5f], [Lo: %.5f]' % \
                  (epoch, curr_iter, base_lr, loss_record.avg, loss_f_record.avg, loss_b_record.avg, loss_o_record.avg)
            train_iterator.set_description(log)
            open(log_path, 'a').write(log + '\n')

            curr_iter += 1

        if epoch in args['save_point']:
            net.cpu()
            torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            net.cuda(device_ids[0])

        if epoch >= args['epoch_num']:
            net.cpu()
            torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            print("Optimization Have Done!")
            return
Пример #19
0
def train(net, optimizer):
    net.print_network()
    curr_iter = args['last_iter']
    while True:
        train_loss_record, loss_fuse_record, loss1_h2l_record = AvgMeter(
        ), AvgMeter(), AvgMeter()
        loss2_h2l_record, loss3_h2l_record, loss4_h2l_record = AvgMeter(
        ), AvgMeter(), AvgMeter()
        loss1_l2h_record, loss2_l2h_record, loss3_l2h_record = AvgMeter(
        ), AvgMeter(), AvgMeter()
        loss4_l2h_record = AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()

            fuse_predict, predict1_h2l, predict2_h2l, predict3_h2l, predict4_h2l, \
            predict1_l2h, predict2_l2h, predict3_l2h, predict4_l2h = net(inputs)

            loss_fuse = bce_logit(fuse_predict, labels)
            loss1_h2l = bce_logit(predict1_h2l, labels)
            loss2_h2l = bce_logit(predict2_h2l, labels)
            loss3_h2l = bce_logit(predict3_h2l, labels)
            loss4_h2l = bce_logit(predict4_h2l, labels)
            loss1_l2h = bce_logit(predict1_l2h, labels)
            loss2_l2h = bce_logit(predict2_l2h, labels)
            loss3_l2h = bce_logit(predict3_l2h, labels)
            loss4_l2h = bce_logit(predict4_l2h, labels)

            loss = loss_fuse + loss1_h2l + loss2_h2l + loss3_h2l + loss4_h2l + loss1_l2h + \
                   loss2_l2h + loss3_l2h + loss4_l2h
            loss.backward()

            optimizer.step()

            train_loss_record.update(loss.data, batch_size)
            loss_fuse_record.update(loss_fuse.data, batch_size)
            loss1_h2l_record.update(loss1_h2l.data, batch_size)
            loss2_h2l_record.update(loss2_h2l.data, batch_size)
            loss3_h2l_record.update(loss3_h2l.data, batch_size)
            loss4_h2l_record.update(loss4_h2l.data, batch_size)
            loss1_l2h_record.update(loss1_l2h.data, batch_size)
            loss2_l2h_record.update(loss2_l2h.data, batch_size)
            loss3_l2h_record.update(loss3_l2h.data, batch_size)
            loss4_l2h_record.update(loss4_l2h.data, batch_size)

            curr_iter += 1

            log = '[iter %d], [train loss %.5f], [loss_fuse %.5f], [loss1_h2l %.5f], [loss2_h2l %.5f], ' \
                  '[loss3_h2l %.5f], [loss4_h2l %.5f], [loss1_l2h %.5f], [loss2_l2h %.5f], [loss3_l2h %.5f], ' \
                  '[loss4_l2h %.5f], [lr %.13f]' % \
                  (curr_iter, train_loss_record.avg, loss_fuse_record.avg, loss1_h2l_record.avg, loss2_h2l_record.avg,
                   loss3_h2l_record.avg, loss4_h2l_record.avg, loss1_l2h_record.avg, loss2_l2h_record.avg,
                   loss3_l2h_record.avg, loss4_l2h_record.avg, optimizer.param_groups[1]['lr'])
            print(log)
            open(log_path, 'a').write(log + '\n')

            if curr_iter > args['iter_num']:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                return
Пример #20
0
def main():
    net = R3Net_prior(motion='GRU', se_layer=False, st_fuse=False)

    print('load snapshot \'%s\' for testing' % args['snapshot'])
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth'),
                   map_location='cuda:0'))
    # net = train_online(net)
    results = {}

    for name, root in to_test.items():

        precision_record, recall_record, = [AvgMeter() for _ in range(256)], [
            AvgMeter() for _ in range(256)
        ]
        mae_record = AvgMeter()

        if args['save_results']:
            check_mkdir(
                os.path.join(ckpt_path, exp_name, '(%s) %s_%s' %
                             (exp_name, name, args['snapshot'])))

        folders = os.listdir(root)
        folders.sort()
        for folder in folders:
            net = train_online(net, seq_name=folder)
            with torch.no_grad():

                net.eval()
                net.cuda()
                imgs = os.listdir(os.path.join(root, folder))
                imgs.sort()
                for i in range(1, len(imgs) - args['batch_size'] + 1):
                    print(imgs[i])
                    img_var = []
                    img_names = []
                    for j in range(0, args['batch_size']):
                        img = Image.open(
                            os.path.join(root, folder,
                                         imgs[i + j])).convert('RGB')
                        img_names.append(imgs[i + j])
                        shape = img.size
                        img = img.resize(args['input_size'])
                        img_var.append(
                            Variable(img_transform(img).unsqueeze(0),
                                     volatile=True).cuda())

                    img_var = torch.cat(img_var, dim=0)
                    prediction = net(img_var)
                    precision = to_pil(prediction.data.squeeze(0).cpu())
                    precision = precision.resize(shape)
                    prediction = np.array(precision)

                    if args['crf_refine']:
                        prediction = crf_refine(np.array(img), prediction)
                    gt = np.array(
                        Image.open(
                            os.path.join(gt_root, folder, img_names[-1][:-4] +
                                         '.png')).convert('L'))
                    precision, recall, mae = cal_precision_recall_mae(
                        prediction, gt)
                    for pidx, pdata in enumerate(zip(precision, recall)):
                        p, r = pdata
                        precision_record[pidx].update(p)
                        recall_record[pidx].update(r)
                    mae_record.update(mae)

                    if args['save_results']:
                        # folder, sub_name = os.path.split(img_names[-1])
                        save_path = os.path.join(
                            ckpt_path, exp_name,
                            '(%s) %s_%s' % (exp_name, name, args['snapshot']),
                            folder)
                        if not os.path.exists(save_path):
                            os.makedirs(save_path)
                        Image.fromarray(prediction).save(
                            os.path.join(save_path,
                                         img_names[-1][:-4] + '.png'))

        fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                [rrecord.avg for rrecord in recall_record])

        results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print('test results:')
    print(results)
            os.path.join(root_inference, folder,
                         img[:-4] + '.png')).convert('L')
        gt = Image.open(os.path.join(gt_root, folder,
                                     img[:-4] + '.png')).convert('L')
        gt = gt.resize(pred.size)
        image = image.resize(pred.size)
        gt = np.array(gt)
        pred = np.array(pred)

        precision, recall, mae = cal_precision_recall_mae(pred, gt)

        for pidx, pdata in enumerate(zip(precision, recall)):
            p, r = pdata
            precision_record[pidx].update(p)
            recall_record[pidx].update(r)
        mae_record.update(mae)

fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                        [rrecord.avg for rrecord in recall_record])

results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

print('test results:')
print(results)

# {'davis': {'mae': 0.041576569176772944, 'fmeasure': 0.8341383096984007}}
# {'MSST_davis': {'fmeasure': 0.8175943834081874, 'mae': 0.04597473876855389}}
# {'Amulet_davis': {'mae': 0.08374974551689243, 'fmeasure': 0.7234079968968813}}
# {'CG_davis': {'fmeasure': 0.6278087775523111, 'mae': 0.09568971798828023}}
# {'CS_davis': {'fmeasure': 0.387371123540425, 'mae': 0.11592338609834756}}
# {'DCL_davis': {'fmeasure': 0.7555328232313439, 'mae': 0.1325773794024856}}
Пример #22
0
def validate(g, curr_epoch, d=None):
    g.eval()

    mse_criterion = nn.MSELoss()
    g_mse_loss_record, psnr_record = AvgMeter(), AvgMeter()

    for name, loader in val_loader.items():

        val_visual = []
        # note that the batch size is 1
        for i, data in enumerate(loader):
            hr_img, _ = data

            lr_img, hr_interpolated_img = val_lr_transform(hr_img.squeeze(0))

            lr_img = Variable(lr_img.unsqueeze(0), volatile=True).cuda()
            hr_interpolated_img = hr_interpolated_img
            hr_img = Variable(hr_img, volatile=True).cuda()

            gen_hr_img = g(lr_img)

            g_mse_loss = mse_criterion(gen_hr_img, hr_img)

            g_mse_loss_record.update(g_mse_loss.item())
            psnr_record.update(10 * np.log10(1 / g_mse_loss.item()))

            val_visual.extend([
                val_display_transform(hr_interpolated_img),
                val_display_transform(hr_img.cpu().data.squeeze(0)),
                val_display_transform(gen_hr_img.cpu().data.squeeze(0))
            ])

        val_visual = torch.stack(val_visual, 0)
        val_visual = vutils.make_grid(val_visual, nrow=3, padding=5)

        snapshot_name = 'epoch_%d_%s_g_mse_loss_%.5f_psnr_%.5f' % (
            curr_epoch + 1, name, g_mse_loss_record.avg, psnr_record.avg)

        if d is None:
            snapshot_name = 'pretrain_' + snapshot_name
            writer.add_scalar('pretrain_validate_%s_psnr' % name,
                              psnr_record.avg, curr_epoch + 1)
            writer.add_scalar('pretrain_validate_%s_g_mse_loss' % name,
                              g_mse_loss_record.avg, curr_epoch + 1)

            print(
                '[pretrain validate %s]: [epoch %d], [g_mse_loss %.5f], [psnr %.5f]'
                %
                (name, curr_epoch + 1, g_mse_loss_record.avg, psnr_record.avg))
        else:
            writer.add_scalar('validate_%s_psnr' % name, psnr_record.avg,
                              curr_epoch + 1)
            writer.add_scalar('validate_%s_g_mse_loss' % name,
                              g_mse_loss_record.avg, curr_epoch + 1)

            print(
                '[validate %s]: [epoch %d], [g_mse_loss %.5f], [psnr %.5f]' %
                (name, curr_epoch + 1, g_mse_loss_record.avg, psnr_record.avg))

            torch.save(
                d.state_dict(),
                os.path.join(train_args['ckpt_path'],
                             snapshot_name + '_d.pth'))

        torch.save(
            g.state_dict(),
            os.path.join(train_args['ckpt_path'], snapshot_name + '_g.pth'))

        writer.add_image(snapshot_name, val_visual)

        g_mse_loss_record.reset()
        psnr_record.reset()

    g.train()