Esempio n. 1
0
def validate(net, curr_iter, optimizer):
    print('validating...')
    net.eval()

    loss_record1 = AvgMeter()
    iter_num1 = len(test1_loader)

    with torch.no_grad():
        for i, data in enumerate(test1_loader):
            inputs, gts = data
            inputs = Variable(inputs).cuda()
            gts = Variable(gts).cuda()

            res = net(inputs)

            loss = criterion(res, gts)
            loss_record1.update(loss.data, inputs.size(0))

            print('processed test1 %d / %d' % (i + 1, iter_num1))

    # snapshot_name = 'iter_%d_loss1_%.5f_lr_%.6f' % (curr_iter + 1, loss_record1.avg,
    #                                                            optimizer.param_groups[1]['lr'])

    log_val = '[validate]: [iter %d], [loss1 %.5f]' % (curr_iter + 1,
                                                       loss_record1.avg)
    print(log_val)
    open(log_path_val, 'a').write(log_val + '\n')

    # torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '.pth'))
    # torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, snapshot_name + '_optim.pth'))

    net.train()
Esempio n. 2
0
def main():
    net = R3Net().cuda()

    print 'load snapshot \'%s\' for testing' % args['snapshot']
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth')))
    net.eval()

    results = {}

    with torch.no_grad():

        for name, root in to_test.iteritems():

            precision_record, recall_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(
                    os.path.join(
                        ckpt_path, exp_name,
                        '(%s) %s_%s' % (exp_name, name, args['snapshot'])))

            img_list = os.listdir(root)
            print img_list
            for idx, img_name in enumerate(img_list):
                print 'predicting for %s: %d / %d' % (name, idx + 1,
                                                      len(img_list))
                #print img_name
                img_path = os.path.join(root, img_name)
                img = Image.open(img_path).convert('RGB')
                img_var = Variable(img_transform(img).unsqueeze(0),
                                   volatile=True).cuda()
                prediction = net(img_var)
                prediction = np.array(to_pil(prediction.data.squeeze(0).cpu()))

                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                #gt = np.array(Image.open(os.path.join(root+"/masks", img_name )).convert('L'))
                #precision, recall, mae = cal_precision_recall_mae(prediction, gt)
                #for pidx, pdata in enumerate(zip(precision, recall)):
                #    p, r = pdata
                #    precision_record[pidx].update(p)
                #    recall_record[pidx].update(r)
                #mae_record.update(mae)

                if args['save_results']:
                    Image.fromarray(prediction).save(
                        os.path.join(
                            ckpt_path, exp_name, '%s_%s_%s' %
                            (exp_name, name, args['snapshot'] + "kaist_test"),
                            img_name))
Esempio n. 3
0
File: train.py Progetto: xw-hu/R3Net
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
        loss3_record, loss4_record, loss5_record, loss6_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                                ) ** args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                            ) ** args['lr_decay']

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()
            outputs0, outputs1, outputs2, outputs3, outputs4, outputs5, outputs6 = net(inputs)
            loss0 = criterion(outputs0, labels)
            loss1 = criterion(outputs1, labels)
            loss2 = criterion(outputs2, labels)
            loss3 = criterion(outputs3, labels)
            loss4 = criterion(outputs4, labels)
            loss5 = criterion(outputs5, labels)
            loss6 = criterion(outputs6, labels)

            total_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.data[0], batch_size)
            loss0_record.update(loss0.data[0], batch_size)
            loss1_record.update(loss1.data[0], batch_size)
            loss2_record.update(loss2.data[0], batch_size)
            loss3_record.update(loss3.data[0], batch_size)
            loss4_record.update(loss4.data[0], batch_size)
            loss5_record.update(loss5.data[0], batch_size)
            loss6_record.update(loss6.data[0], batch_size)

            curr_iter += 1

            log = '[iter %d], [total loss %.5f], [loss0 %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
                  '[loss4 %.5f], [loss5 %.5f], [loss6 %.5f], [lr %.13f]' % \
                  (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
                   loss3_record.avg, loss4_record.avg, loss5_record.avg, loss6_record.avg,
                   optimizer.param_groups[1]['lr'])
            print
            log
            open(log_path, 'a').write(log + '\n')

            if curr_iter == args['iter_num']:
                torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(optimizer.state_dict(),
                           os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
                return
Esempio n. 4
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                                ) ** args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                            ) ** args['lr_decay']

            inputs, depth,labels,img_name = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            depth = Variable(depth).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()
            depthMap, salMap, sideout5, sideout4, sideout3, sideout2 = net(inputs, depth)

            loss0 = criterion(depthMap, labels)
            loss1 = criterion(salMap, labels)
            loss5 = criterion(sideout5, labels)
            loss4 = criterion(sideout4, labels)
            loss3 = criterion(sideout3, labels)
            loss2 = criterion(sideout2, labels)

            total_loss = loss0 + loss1 + loss5 + loss4 + loss3 + loss2
            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.item(), batch_size)
            loss0_record.update(loss0.item(), batch_size)
            loss1_record.update(loss1.item(), batch_size)

            curr_iter += 1

            log = '[iter %d], [total loss %.5f], [loss0 %.5f],[loss1 %.5f],[lr %.13f]' % \
                  (curr_iter, total_loss_record.avg, loss0_record.avg,  loss1_record.avg,
                   optimizer.param_groups[1]['lr'])
            print(log)
            open(log_path, 'a').write(log + '\n')

            if curr_iter == args['iter_num']:
                torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d_res18_noCross_Gdept.pth' % curr_iter))
                return
Esempio n. 5
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        total_loss_record, loss1_record, loss2_record, loss3_record, loss4_record, loss5_record, loss6_record, loss7_record, loss8_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                                ) ** args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
                                                            ) ** args['lr_decay']
            # data\binarizing\Variable
            inputs, labels = data
            labels[labels > 0.5] = 1
            labels[labels != 1] = 0
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()
            optimizer.zero_grad()
            output_fpn, output_final = net(inputs)
            ##########loss#############

            loss1 = criterion(output_fpn, labels)
            loss2 = criterion(output_final, labels)

            total_loss = loss1+loss2
            total_loss.backward()
            optimizer.step()
            total_loss_record.update(total_loss.item(), batch_size)
            loss1_record.update(loss1.item(), batch_size)
            loss2_record.update(loss2.item(), batch_size)

            #############log###############
            curr_iter += 1

            log = '[iter %d], [total loss %.5f],[loss1 %.5f],[loss1 %.5f],[lr %.13f] ' % \
                  (curr_iter, total_loss_record.avg, loss1_record.avg, loss2_record.avg, optimizer.param_groups[1]['lr'])
            print(log)
            open(log_path, 'a').write(log + '\n')

            if curr_iter == args['iter_num']:
                torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(optimizer.state_dict(),
                           os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
                return
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss_record, bce_loss_record, dice_loss_record = AvgMeter(
    ), AvgMeter(), AvgMeter()
    for batch_idx, data in enumerate(train_loader):
        if epoch == args['lr_step']:
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] / args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] / args['lr_decay']

        inputs, labels, counter = data
        batch_size = inputs.size(0)
        inputs = Variable(inputs).cuda()
        labels = Variable(labels).cuda()
        counter = Variable(counter).cuda()
        optimizer.zero_grad()
        outputs, outputs_counter = net(inputs)
        # outputs = net(inputs)

        # BCE loss and dice loss can be used
        criterion_bce = nn.BCELoss()
        criterion_dice = Dice_loss()
        # if not isinstance(fnd_out, list):
        loss_bce = criterion_bce(outputs, labels) + criterion_bce(
            outputs_counter, counter)
        loss_dice = criterion_dice(outputs, labels) + criterion_dice(
            outputs_counter, counter)
        # loss_bce = criterion_bce(outputs, labels)
        # loss_dice = criterion_dice(outputs, labels)
        # else:
        #     loss_bce = criterion_bce(outputs, labels)
        #     loss_dice = criterion_dice(outputs, labels)
        #     for fnd_mask, fpd_mask in zip(fnd_out, fpd_out):
        #         loss_bce += criterion_bce(fnd_mask, fnd) + criterion_bce(fpd_mask, fpd)

        # else:
        #     loss_bce_each = [None] * len(outputs)
        #     loss_dice_each = [None] * len(outputs)
        #     for idx in range(len(outputs)):
        #         loss_bce_each[idx] = criterion_bce(outputs[idx], labels)
        #         loss_dice_each[idx] = criterion_dice(outputs[idx], labels)
        #     loss_bce = sum(loss_bce_each)
        #     loss_dice = sum(loss_dice_each)
        coeff = loss_dice.item() / loss_bce.item(
        ) if loss_dice.item() / loss_bce.item() < 1 else 1
        # coeff = 1
        loss = coeff * loss_bce + loss_dice
        # loss = loss_bce + loss_dice
        loss.backward()
        optimizer.step()
        train_loss_record.update(loss.item(), batch_size)
        bce_loss_record.update(loss_bce.item(), batch_size)
        dice_loss_record.update(loss_dice.item(), batch_size)
        log = 'iter: %d | [bce loss: %.5f], [dice loss: %.5f],[Total loss: %.5f], [lr: %.8f]' % \
              (epoch, bce_loss_record.avg, dice_loss_record.avg, train_loss_record.avg, optimizer.param_groups[1]['lr'])
        progress_bar(batch_idx, len(train_loader), log)
Esempio n. 7
0
def main():
    net = R3Net().cuda()

    print
    'load snapshot \'%s\' for testing' % args['snapshot']
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
    net.eval()

    results = {}

    for name, root in to_test.iteritems():

        precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
        mae_record = AvgMeter()

        img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]
        for idx, img_name in enumerate(img_list):
            print
            'predicting for %s: %d / %d' % (name, idx + 1, len(img_list))
            check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))

            img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
            img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()
            prediction = net(img_var)
            prediction = np.array(to_pil(prediction.data.squeeze(0).cpu()))

            if args['crf_refine']:
                prediction = crf_refine(np.array(img), prediction)

            gt = np.array(Image.open(os.path.join(root, img_name + '.png')).convert('L'))
            precision, recall, mae = cal_precision_recall_mae(prediction, gt)
            for pidx, pdata in enumerate(zip(precision, recall)):
                p, r = pdata
                precision_record[pidx].update(p)
                recall_record[pidx].update(r)
            mae_record.update(mae)

            if args['save_results']:
                Image.fromarray(prediction).save(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (
                    exp_name, name, args['snapshot']), img_name + '.png'))

        fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                [rrecord.avg for rrecord in recall_record])

        results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print
    'test results:'
    print
    results
Esempio n. 8
0
def main():
    net = Ensemble(device_id, pretrained=False)

    print ('load snapshot \'%s\' for testing' % args['snapshot'])
    # net.load_state_dict(torch.load('pretrained/R2Net.pth', map_location='cuda:2'))
    # net = load_part_of_model2(net, 'pretrained/R2Net.pth', device_id=2)
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'),
                                   map_location='cuda:' + str(device_id)))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        for name, root in to_test.items():

            precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
            img_list = [i_id.strip() for i_id in open(imgs_path)]
            for idx, img_name in enumerate(img_list):
                print('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
                print(img_name)

                if name == 'VOS' or name == 'DAVSOD':
                    img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB')
                else:
                    img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                shape = img.size
                img = img.resize(args['input_size'])
                img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()
                start = time.time()
                outputs_a, outputs_c = net(img_var)
                a_out1u, a_out2u, a_out2r, a_out3r, a_out4r, a_out5r = outputs_a  # F3Net
                # b_outputs0, b_outputs1 = outputs_b  # CPD
                c_outputs0, c_outputs1, c_outputs2, c_outputs3, c_outputs4 = outputs_c  # RAS
                prediction = torch.sigmoid(c_outputs0)
                end = time.time()
                print('running time:', (end - start))
                # e = Erosion2d(1, 1, 5, soft_max=False).cuda()
                # prediction2 = e(prediction)
                #
                # precision2 = to_pil(prediction2.data.squeeze(0).cpu())
                # precision2 = prediction2.data.squeeze(0).cpu().numpy()
                # precision2 = precision2.resize(shape)
                # prediction2 = np.array(precision2)
                # prediction2 = prediction2.astype('float')

                precision = to_pil(prediction.data.squeeze(0).cpu())
                precision = precision.resize(shape)
                prediction = np.array(precision)
                prediction = prediction.astype('float')

                # plt.style.use('classic')
                # plt.subplot(1, 2, 1)
                # plt.imshow(prediction)
                # plt.subplot(1, 2, 2)
                # plt.imshow(precision2[0])
                # plt.show()

                prediction = MaxMinNormalization(prediction, prediction.max(), prediction.min()) * 255.0
                prediction = prediction.astype('uint8')
                # if args['crf_refine']:
                #     prediction = crf_refine(np.array(img), prediction)

                gt = np.array(Image.open(os.path.join(gt_root, img_name + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    folder, sub_name = os.path.split(img_name)
                    save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png'))

            fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                    [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print ('test results:')
    print (results)
    log_path = os.path.join('result_all.txt')
    open(log_path, 'a').write(exp_name + ' ' + args['snapshot'] + '\n')
    open(log_path, 'a').write(str(results) + '\n\n')
Esempio n. 9
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        total_loss_record = AvgMeter()
        lossL2H0_record, lossL2H1_record, lossL2H2_record, lossL2H3_record, lossL2H4_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
        lossH2L0_record, lossH2L1_record, lossH2L2_record, lossH2L3_record, lossH2L4_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()
            outputsL2H0, outputsL2H1, outputsL2H2, outputsL2H3, outputsL2H4, outputsH2L0, outputsH2L1, outputsH2L2, outputsH2L3, outputsH2L4, outputsFusion = net(
                inputs)
            lossL2H0 = criterion(outputsL2H0, labels)
            lossL2H1 = criterion(outputsL2H1, labels)
            lossL2H2 = criterion(outputsL2H2, labels)
            lossL2H3 = criterion(outputsL2H3, labels)
            lossL2H4 = criterion(outputsL2H4, labels)

            lossH2L0 = criterion(outputsH2L0, labels)
            lossH2L1 = criterion(outputsH2L1, labels)
            lossH2L2 = criterion(outputsH2L2, labels)
            lossH2L3 = criterion(outputsH2L3, labels)
            lossH2L4 = criterion(outputsH2L4, labels)

            lossFusion = criterion(outputsFusion, labels)

            total_loss = lossFusion + lossL2H0 + lossL2H1 + lossL2H2 + lossL2H3 + lossL2H4 + lossH2L0 + lossH2L1 + lossH2L2 + lossH2L3 + lossH2L4
            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.data[0], batch_size)
            lossL2H0_record.update(lossL2H0.data[0], batch_size)
            lossL2H1_record.update(lossL2H1.data[0], batch_size)
            lossL2H2_record.update(lossL2H2.data[0], batch_size)
            lossL2H3_record.update(lossL2H3.data[0], batch_size)
            lossL2H4_record.update(lossL2H4.data[0], batch_size)

            lossH2L0_record.update(lossH2L0.data[0], batch_size)
            lossH2L1_record.update(lossH2L1.data[0], batch_size)
            lossH2L2_record.update(lossH2L2.data[0], batch_size)
            lossH2L3_record.update(lossH2L3.data[0], batch_size)
            lossH2L4_record.update(lossH2L4.data[0], batch_size)

            curr_iter += 1

            log = '[iter %d], [total loss %.5f], [lossL2H0 %.5f], [lossL2H1 %.5f], [lossL2H2 %.5f], [lossL2H3 %.5f], [lossL2H4 %.5f]' \
                  '[lossH2L0 %.5f], [lossH2L1 %.5f], [lossH2L2 %.5f], [lossH2L3 %.5f], [lossH2L4 %.5f], [lr %.13f]' % \
                   (curr_iter, total_loss_record.avg, lossL2H0_record.avg, lossL2H1_record.avg, lossL2H2_record.avg,
                   lossL2H3_record.avg, lossL2H4_record.avg, lossH2L0_record.avg, lossH2L1_record.avg, lossH2L2_record.avg,
                   lossH2L3_record.avg, lossH2L4_record.avg,
                   optimizer.param_groups[1]['lr'])
            logWrite = '%d %.5f %.5f %.5f %.5f %.5f %.5f]' \
                         '%.5f %.5f %.5f %.5f %.5f %.13f]' % \
                          (curr_iter, total_loss_record.avg, lossL2H0_record.avg, lossL2H1_record.avg, lossL2H2_record.avg,
                          lossL2H3_record.avg, lossL2H4_record.avg, lossH2L0_record.avg, lossH2L1_record.avg, lossH2L2_record.avg,
                          lossH2L3_record.avg, lossH2L4_record.avg,
                          optimizer.param_groups[1]['lr'])
            print log
            open(log_path, 'a').write(logWrite + '\n')

            if curr_iter == args['iter_num']:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, 'CA_%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 'CA_%d_optim.pth' % curr_iter))
                return
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss_record, bce_loss_record, dice_loss_record = AvgMeter(), AvgMeter(), AvgMeter()
    for batch_idx, data in enumerate(train_loader):
        if epoch == args['lr_step']:
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] / args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] / args['lr_decay']
        inputs_volume, labels_volume = data
        inputs_volume = norm_filter(inputs_volume) #normlization via means and std
        sub_batch_len = math.ceil(inputs_volume.shape[1]/args['train_batch_size'])
        print_flag=0
        for sub_batch_idx in range(sub_batch_len):
            #split volume to batchs
            start = sub_batch_idx*args['train_batch_size']
            end = (sub_batch_idx+1)*args['train_batch_size'] if (sub_batch_idx+1)*args['train_batch_size']<inputs_volume.shape[1] else inputs_volume.shape[1]
            inputs = inputs_volume[:, start:end, :, :].permute(1, 0, 2, 3)
            inputs = inputs.expand(torch.Size((inputs.shape[0], 3, inputs.shape[2], inputs.shape[3])))
            labels = labels_volume[:, start:end, :, :].permute(1, 0, 2, 3)
            if torch.max(labels).item() == 0:
                print_flag += 1
                print(print_flag)
                continue
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()
            optimizer.zero_grad()
            outputs = net(inputs)
            # outputs = net(inputs)

            # BCE loss and dice loss can be used
            criterion_bce = nn.BCELoss()
            criterion_dice = Dice_loss()
            # if not isinstance(fnd_out, list):
            loss_bce = criterion_bce(outputs, labels)
            loss_dice = criterion_dice(outputs, labels)
            # loss_bce = criterion_bce(outputs, labels)
            # loss_dice = criterion_dice(outputs, labels)
            # else:
            #     loss_bce = criterion_bce(outputs, labels)
            #     loss_dice = criterion_dice(outputs, labels)
            #     for fnd_mask, fpd_mask in zip(fnd_out, fpd_out):
            #         loss_bce += criterion_bce(fnd_mask, fnd) + criterion_bce(fpd_mask, fpd)


            # else:
            #     loss_bce_each = [None] * len(outputs)
            #     loss_dice_each = [None] * len(outputs)
            #     for idx in range(len(outputs)):
            #         loss_bce_each[idx] = criterion_bce(outputs[idx], labels)
            #         loss_dice_each[idx] = criterion_dice(outputs[idx], labels)
            #     loss_bce = sum(loss_bce_each)
            #     loss_dice = sum(loss_dice_each)
            # coeff = loss_dice.item()/(loss_bce.item()+1e-5) if loss_dice.item()/(loss_bce.item()+1e-5) < 1 else 1
            coeff = 1
            loss = coeff*loss_bce + loss_dice
            # loss = loss_bce + loss_dice
            loss.backward()
            optimizer.step()
            train_loss_record.update(loss.item(), batch_size)
            bce_loss_record.update(loss_bce.item(), batch_size)
            dice_loss_record.update(loss_dice.item(), batch_size)
            log = 'iter: %d | [bce loss: %.5f], [dice loss: %.5f],[Total loss: %.5f], [lr: %.8f]' % \
                  (epoch, bce_loss_record.avg, dice_loss_record.avg, train_loss_record.avg, optimizer.param_groups[1]['lr'])
            progress_bar(batch_idx, len(train_loader), log)
Esempio n. 11
0
def train_online(net, seq_name='breakdance'):
    online_args = {
        'iter_num': 100,
        'train_batch_size': 1,
        'lr': 1e-10,
        'lr_decay': 0.95,
        'weight_decay': 5e-4,
        'momentum': 0.95,
    }

    joint_transform = joint_transforms.Compose([
        joint_transforms.ImageResize(380),
        # joint_transforms.RandomCrop(473),
        # joint_transforms.RandomHorizontallyFlip(),
        # joint_transforms.RandomRotate(10)
    ])
    target_transform = transforms.ToTensor()
    # train_set = VideoFSImageFolder(to_test['davis'], seq_name, use_first=True, joint_transform=joint_transform, transform=img_transform)
    train_set = VideoFirstImageFolder(to_test['davis'],
                                      gt_root,
                                      seq_name,
                                      joint_transform=joint_transform,
                                      transform=img_transform,
                                      target_transform=target_transform)
    online_train_loader = DataLoader(
        train_set,
        batch_size=online_args['train_batch_size'],
        num_workers=1,
        shuffle=False)

    # criterion = nn.MSELoss().cuda()
    criterion = nn.BCEWithLogitsLoss().cuda()
    erosion = Erosion2d(1, 1, 5, soft_max=False).cuda()
    net.train()
    net.cuda()
    # fix_parameters(net.named_parameters())

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * online_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        online_args['lr'],
        'weight_decay':
        online_args['weight_decay']
    }],
                          momentum=online_args['momentum'])

    for curr_iter in range(0, online_args['iter_num']):
        total_loss_record, loss0_record, loss1_record = AvgMeter(), AvgMeter(
        ), AvgMeter()
        loss2_record = AvgMeter()

        for i, data in enumerate(online_train_loader):
            optimizer.param_groups[0]['lr'] = 2 * online_args['lr'] * (
                1 - float(curr_iter) /
                online_args['iter_num'])**online_args['lr_decay']
            optimizer.param_groups[1]['lr'] = online_args['lr'] * (
                1 - float(curr_iter) /
                online_args['iter_num'])**online_args['lr_decay']
            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()
            if args['model'] == 'BASNet':
                total_loss, loss0, loss1, loss2 = train_BASNet(
                    net, inputs, criterion, erosion, labels)
            elif args['model'] == 'R3Net':
                total_loss, loss0, loss1, loss2 = train_R3Net(
                    net, inputs, criterion, erosion, labels)
            elif args['model'] == 'DSSNet':
                total_loss, loss0, loss1, loss2 = train_DSSNet(
                    net, inputs, criterion, erosion, labels)
            elif args['model'] == 'CPD':
                total_loss, loss0, loss1, loss2 = train_CPD(
                    net, inputs, criterion, erosion, labels)
            elif args['model'] == 'RAS':
                total_loss, loss0, loss1, loss2 = train_RAS(
                    net, inputs, criterion, erosion, labels)
            elif args['model'] == 'PoolNet':
                total_loss, loss0, loss1, loss2 = train_PoolNet(
                    net, inputs, criterion, erosion, labels)
            elif args['model'] == 'F3Net':
                total_loss, loss0, loss1, loss2 = train_F3Net(
                    net, inputs, criterion, erosion, labels)
            elif args['model'] == 'R2Net':
                total_loss, loss0, loss1, loss2 = train_R2Net(
                    net, inputs, criterion, erosion, labels)
            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.data, batch_size)
            loss0_record.update(loss0.data, batch_size)
            loss1_record.update(loss1.data, batch_size)
            loss2_record.update(loss2.data, batch_size)
            # loss3_record.update(loss3.data, batch_size)
            # loss4_record.update(loss4.data, batch_size)

            log = '[iter %d], [total loss %.5f], [loss0 %.8f], [loss1 %.8f], [loss2 %.8f], [lr %.13f]' % \
                  (curr_iter, total_loss_record.avg, loss0_record.avg, loss1_record.avg, loss2_record.avg,
                   optimizer.param_groups[1]['lr'])
            print(log)

    print('taking snapshot ...')
    torch.save(
        net.state_dict(),
        os.path.join(ckpt_path, exp_name,
                     str(args['snapshot']) + '_' + seq_name + '_online.pth'))
    # torch.save(optimizer.state_dict(),
    #            os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))

    return net
Esempio n. 12
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        total_loss_record, loss1_record, loss2_record, loss3_record, loss4_record, loss5_record, loss6_record, loss7_record, loss8_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            inputs, depth, labels = data
            labels[labels > 0.5] = 1
            labels[labels != 1] = 0
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            depth = Variable(depth).cuda()
            labels = Variable(labels).cuda()
            outputs, outputs_fg, outputs_bg, attention1, attention2, attention3, attention4, attention5 = net(
                inputs, depth)  #hed
            ##########loss#############
            optimizer.zero_grad()
            labels1 = functional.interpolate(labels, size=24, mode='bilinear')
            labels2 = functional.interpolate(labels, size=48, mode='bilinear')
            labels3 = functional.interpolate(labels, size=96, mode='bilinear')
            labels4 = functional.interpolate(labels, size=192, mode='bilinear')
            loss1 = criterion_BCE(attention1, labels1)
            loss2 = criterion_BCE(attention2, labels2)
            loss3 = criterion_BCE(attention3, labels3)
            loss4 = criterion_BCE(attention4, labels4)
            loss5 = criterion_BCE(attention5, labels)
            loss6 = criterion(outputs_fg, labels)
            loss7 = criterion(outputs_bg, (1 - labels))
            loss8 = criterion(outputs, labels)
            total_loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8
            total_loss.backward()
            optimizer.step()
            total_loss_record.update(total_loss.item(), batch_size)
            loss1_record.update(loss1.item(), batch_size)
            loss2_record.update(loss2.item(), batch_size)
            loss3_record.update(loss3.item(), batch_size)
            loss4_record.update(loss4.item(), batch_size)
            loss5_record.update(loss5.item(), batch_size)
            loss6_record.update(loss6.item(), batch_size)
            loss7_record.update(loss7.item(), batch_size)
            loss8_record.update(loss8.item(), batch_size)
            curr_iter += 1
            #############log###############
            if curr_iter % 2050 == 0:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 '%d_optim.pth' % curr_iter))
            log = '[iter %d], [total loss %.5f],[loss1 %.5f],,[loss2 %.5f],[loss3 %.5f],[loss4 %.5f],[loss5 %.5f],[loss6 %.5f],[loss7 %.5f],[loss8 %.5f],[lr %.13f] '  % \
                     (curr_iter, total_loss_record.avg, loss1_record.avg,loss2_record.avg,loss3_record.avg,loss4_record.avg,loss5_record.avg,loss6_record.avg,loss7_record.avg,loss8_record.avg,optimizer.param_groups[1]['lr'])
            print(log)
            open(log_path, 'a').write(log + '\n')
            if curr_iter == args['iter_num']:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 '%d_optim.pth' % curr_iter))
                return
Esempio n. 13
0
import os
import numpy as np
from PIL import Image
from misc import check_mkdir, crf_refine, AvgMeter, cal_precision_recall_mae, cal_fmeasure

# root = '/home/qub/data/saliency/davis/davis_test2'
root_inference = '/home/ty/code/tf_saliency_attention/total_result/result_rnn_2018-08-04 11:08:00/'
root = '/home/ty/data/davis/480p/'
name = 'davis'
# gt_root = '/home/qub/data/saliency/davis/GT'
gt_root = '/home/ty/data/davis/GT/'
# gt_root = '/home/qub/data/saliency/VOS/GT'

precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
mae_record = AvgMeter()
results = {}

# save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))
folders = os.listdir(root_inference)
folders.sort()
for folder in folders:
    imgs = os.listdir(os.path.join(root_inference, folder))
    imgs.sort()

    for img in imgs:
        print(os.path.join(folder, img))
        image = Image.open(os.path.join(root, folder, img[:-4] + '.jpg')).convert('RGB')
        gt = np.array(Image.open(os.path.join(gt_root, folder, img)).convert('L'))
        pred = np.array(Image.open(os.path.join(root_inference, folder, img)).convert('L'))

        precision, recall, mae = cal_precision_recall_mae(pred, gt)
Esempio n. 14
0
def train():
    g = Generator(scale_factor=train_args['scale_factor']).cuda().train()
    g = nn.DataParallel(g, device_ids=[0])
    if len(train_args['g_snapshot']) > 0:
        print 'load generator snapshot ' + train_args['g_snapshot']
        g.load_state_dict(
            torch.load(
                os.path.join(train_args['ckpt_path'],
                             train_args['g_snapshot'])))

    mse_criterion = nn.MSELoss().cuda()
    g_mse_loss_record, psnr_record = AvgMeter(), AvgMeter()

    iter_nums = len(train_loader)

    if g_pretrain_args['pretrain']:
        g_optimizer = optim.Adam(g.parameters(), lr=g_pretrain_args['lr'])
        for epoch in range(g_pretrain_args['epoch_num']):
            for i, data in enumerate(train_loader):
                hr_imgs, _ = data
                batch_size = hr_imgs.size(0)
                lr_imgs = Variable(
                    torch.stack([train_lr_transform(img) for img in hr_imgs],
                                0)).cuda()
                hr_imgs = Variable(hr_imgs).cuda()

                g.zero_grad()
                gen_hr_imgs = g(lr_imgs)
                mse_loss = mse_criterion(gen_hr_imgs, hr_imgs)
                mse_loss.backward()
                g_optimizer.step()

                g_mse_loss_record.update(mse_loss.data[0], batch_size)
                psnr_record.update(10 * math.log10(1 / mse_loss.data[0]),
                                   batch_size)

                print '[pretrain]: [epoch %d], [iter %d / %d], [loss %.5f], [psnr %.5f]' % (
                    epoch + 1, i + 1, iter_nums, g_mse_loss_record.avg,
                    psnr_record.avg)

                writer.add_scalar('pretrain_g_mse_loss', g_mse_loss_record.avg,
                                  epoch * iter_nums + i + 1)
                writer.add_scalar('pretrain_psnr', psnr_record.avg,
                                  epoch * iter_nums + i + 1)

            torch.save(
                g.state_dict(),
                os.path.join(
                    train_args['ckpt_path'],
                    'pretrain_g_epoch_%d_loss_%.5f_psnr_%.5f.pth' %
                    (epoch + 1, g_mse_loss_record.avg, psnr_record.avg)))

            g_mse_loss_record.reset()
            psnr_record.reset()

            validate(g, epoch)

    d = Discriminator().cuda().train()
    d = nn.DataParallel(d, device_ids=[0])
    if len(train_args['d_snapshot']) > 0:
        print 'load discriminator snapshot ' + train_args['d_snapshot']
        d.load_state_dict(
            torch.load(
                os.path.join(train_args['ckpt_path'],
                             train_args['d_snapshot'])))

    g_optimizer = optim.RMSprop(g.parameters(), lr=train_args['g_lr'])
    d_optimizer = optim.RMSprop(d.parameters(), lr=train_args['d_lr'])

    perceptual_criterion, tv_criterion = PerceptualLoss().cuda(
    ), TotalVariationLoss().cuda()

    g_mse_loss_record, g_perceptual_loss_record, g_tv_loss_record = AvgMeter(
    ), AvgMeter(), AvgMeter()
    psnr_record, g_ad_loss_record, g_loss_record, d_loss_record = AvgMeter(
    ), AvgMeter(), AvgMeter(), AvgMeter()

    for epoch in range(train_args['start_epoch'] - 1, train_args['epoch_num']):
        for i, data in enumerate(train_loader):
            hr_imgs, _ = data
            batch_size = hr_imgs.size(0)
            lr_imgs = Variable(
                torch.stack([train_lr_transform(img) for img in hr_imgs],
                            0)).cuda()
            hr_imgs = Variable(hr_imgs).cuda()
            gen_hr_imgs = g(lr_imgs)

            # update d
            d.zero_grad()
            d_ad_loss = d(gen_hr_imgs.detach()).mean() - d(hr_imgs).mean()
            d_ad_loss.backward()
            d_optimizer.step()

            d_loss_record.update(d_ad_loss.data[0], batch_size)

            for p in d.parameters():
                p.data.clamp_(-train_args['c'], train_args['c'])

            # update g
            g.zero_grad()
            g_mse_loss = mse_criterion(gen_hr_imgs, hr_imgs)
            g_perceptual_loss = perceptual_criterion(gen_hr_imgs, hr_imgs)
            g_tv_loss = tv_criterion(gen_hr_imgs)
            g_ad_loss = -d(gen_hr_imgs).mean()
            g_loss = g_mse_loss + 0.006 * g_perceptual_loss + 2e-8 * g_tv_loss + 0.001 * g_ad_loss
            g_loss.backward()
            g_optimizer.step()

            g_mse_loss_record.update(g_mse_loss.data[0], batch_size)
            g_perceptual_loss_record.update(g_perceptual_loss.data[0],
                                            batch_size)
            g_tv_loss_record.update(g_tv_loss.data[0], batch_size)
            psnr_record.update(10 * math.log10(1 / g_mse_loss.data[0]),
                               batch_size)
            g_ad_loss_record.update(g_ad_loss.data[0], batch_size)
            g_loss_record.update(g_loss.data[0], batch_size)

            print '[train]: [epoch %d], [iter %d / %d], [d_ad_loss %.5f], [g_ad_loss %.5f], [psnr %.5f], ' \
                  '[g_mse_loss %.5f], [g_perceptual_loss %.5f], [g_tv_loss %.5f] [g_loss %.5f]' % \
                  (epoch + 1, i + 1, iter_nums, d_loss_record.avg, g_ad_loss_record.avg, psnr_record.avg,
                   g_mse_loss_record.avg, g_perceptual_loss_record.avg, g_tv_loss_record.avg, g_loss_record.avg)

            writer.add_scalar('d_loss', d_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_mse_loss', g_mse_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_perceptual_loss',
                              g_perceptual_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_tv_loss', g_tv_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('psnr', psnr_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_ad_loss', g_ad_loss_record.avg,
                              epoch * iter_nums + i + 1)
            writer.add_scalar('g_loss', g_loss_record.avg,
                              epoch * iter_nums + i + 1)

        d_loss_record.reset()
        g_mse_loss_record.reset()
        g_perceptual_loss_record.reset()
        g_tv_loss_record.reset()
        psnr_record.reset()
        g_ad_loss_record.reset()
        g_loss_record.reset()

        validate(g, epoch, d)
Esempio n. 15
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        train_loss_record = AvgMeter()
        train_net_loss_record = AvgMeter()
        train_depth_loss_record = AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, gts, dps = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            gts = Variable(gts).cuda()
            dps = Variable(dps).cuda()

            optimizer.zero_grad()

            result, depth_pred = net(inputs)

            loss_net = criterion(result, gts)
            loss_depth = criterion_depth(depth_pred, dps)

            loss = loss_net + loss_depth

            loss.backward()

            optimizer.step()

            # for n, p in net.named_parameters():
            #     if n[-5:] == 'alpha':
            #         print(p.grad.data)
            #         print(p.data)

            train_loss_record.update(loss.data, batch_size)
            train_net_loss_record.update(loss_net.data, batch_size)
            train_depth_loss_record.update(loss_depth.data, batch_size)

            curr_iter += 1

            log = '[iter %d], [train loss %.5f], [lr %.13f], [loss_net %.5f], [loss_depth %.5f]' % \
                  (curr_iter, train_loss_record.avg, optimizer.param_groups[1]['lr'],
                   train_net_loss_record.avg, train_depth_loss_record.avg)
            print(log)
            open(log_path, 'a').write(log + '\n')

            if (curr_iter + 1) % args['val_freq'] == 0:
                validate(net, curr_iter, optimizer)

            if (curr_iter + 1) % args['snapshot_epochs'] == 0:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 ('%d.pth' % (curr_iter + 1))))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 ('%d_optim.pth' % (curr_iter + 1))))

            if curr_iter > args['iter_num']:
                return
Esempio n. 16
0
def train(net, optimizer):
    start_time = time.time()
    curr_iter = args['last_iter']
    num_class = [0, 0, 0, 0, 0]
    while True:
        total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()

        batch_time = AvgMeter()
        end = time.time()
        print('-----begining the first stage, train_mode==0-----')
        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, gt, labels = data
            print(labels)
            # depends on the num of classes
            cweight = torch.tensor([0.5, 0.75, 1, 1.25, 1.5])
            #weight = torch.ones(size=gt.shape)
            weight = gt.clone().detach()
            sizec = labels.numpy()
            #ta = np.zeros(shape=gt.shape)
            '''
            np.zeros(shape=labels.shape)
            sc = gt.clone().detach()
            for i in range(len(sizec)):
                gta = np.array(to_pil(sc[i,:].data.squeeze(0).cpu()))#
                #print(gta.shape)
                labels[i] = cal_sc(gta)
                sizec[i] = labels[i]
            print(labels)
            '''
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            gt = Variable(gt).cuda()
            labels = Variable(labels).cuda()

            #print(sizec.shape)

            optimizer.zero_grad()
            p5, p4, p3, p2, p1, predict1, predict2, predict3, predict4, predict5, predict6, predict7, predict8, predict9, predict10, predict11 = net(
                inputs, sizec)  # mode=1

            criterion = nn.BCEWithLogitsLoss().cuda()
            criterion2 = nn.CrossEntropyLoss().cuda()

            gt2 = gt.long()
            gt2 = gt2.squeeze(1)

            l5 = criterion2(p5, gt2)
            l4 = criterion2(p4, gt2)
            l3 = criterion2(p3, gt2)
            l2 = criterion2(p2, gt2)
            l1 = criterion2(p1, gt2)

            loss0 = criterion(predict11, gt)
            loss10 = criterion(predict10, gt)
            loss9 = criterion(predict9, gt)
            loss8 = criterion(predict8, gt)
            loss7 = criterion(predict7, gt)
            loss6 = criterion(predict6, gt)
            loss5 = criterion(predict5, gt)
            loss4 = criterion(predict4, gt)
            loss3 = criterion(predict3, gt)
            loss2 = criterion(predict2, gt)
            loss1 = criterion(predict1, gt)

            total_loss = l1 + l2 + l3 + l4 + l5 + loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8 + loss9 + loss10

            total_loss.backward()
            optimizer.step()

            total_loss_record.update(total_loss.item(), batch_size)
            loss1_record.update(l5.item(), batch_size)
            loss0_record.update(loss0.item(), batch_size)

            curr_iter += 1.0
            batch_time.update(time.time() - end)
            end = time.time()

            log = '[iter %d], [R1/Mode0], [total loss %.5f]\n' \
                  '[l5 %.5f], [loss0 %.5f]\n' \
                  '[lr %.13f], [time %.4f]' % \
                  (curr_iter, total_loss_record.avg, loss1_record.avg, loss0_record.avg, optimizer.param_groups[1]['lr'],
                   batch_time.avg)
            print(log)
            print('Num of class:', num_class)
            open(log_path, 'a').write(log + '\n')

            if curr_iter == args['iter_num']:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 '%d_optim.pth' % curr_iter))
                total_time = time.time() - start_time
                print(total_time)
                return
Esempio n. 17
0
def inst_meter_dict(meter_list, meter_style='avg'):
    result = dict()
    for meter in meter_list:
        if meter_style == 'avg':
            result[meter] = AvgMeter()
    return result
Esempio n. 18
0
def main():
    # net = R3Net(motion='', se_layer=False, dilation=False, basic_model='resnet50')
    net = Distill(basic_model='resnet50', seq=True)

    print ('load snapshot \'%s\' for testing' % args['snapshot'])
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'), map_location='cuda:2'))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        for name, root in to_test.items():

            precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
            img_list = [i_id.strip() for i_id in open(imgs_path)]
            video = ''
            for idx, img_name in enumerate(img_list):
                print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
                print(img_name)

                # img_var = []
                if args['seq']:
                    if video != img_name.split('/')[0]:
                        video = img_name.split('/')[0]
                        if name == 'VOS' or name == 'DAVSOD':
                            img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB')
                        else:
                            img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                        shape = img.size
                        img = img.resize(args['input_size'])
                        img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()
                        start = time.time()
                        _, _, prediction = net(img_var, img_var, flag='seq')
                        end = time.time()
                        print('running time:', (end - start))
                    else:
                        if name == 'VOS' or name == 'DAVSOD':
                            img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB')
                            pre = Image.open(os.path.join(root, img_list[idx - 1] + '.png')).convert('RGB')
                        else:
                            img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                            pre = Image.open(os.path.join(root, img_list[idx - 1] + '.jpg')).convert('RGB')
                        shape = img.size
                        img = img.resize(args['input_size'])
                        pre = pre.resize(args['input_size'])
                        img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()
                        pre_var = Variable(img_transform(pre).unsqueeze(0), volatile=True).cuda()
                        start = time.time()
                        _, _, prediction = net(pre_var, img_var, flag='seq')
                        end = time.time()
                        print('running time:', (end - start))
                else:
                    start = time.time()
                    if name == 'VOS' or name == 'DAVSOD':
                        img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB')
                    else:
                        img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                    shape = img.size
                    img = img.resize(args['input_size'])
                    img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()

                    _, _, prediction, _ = net(img_var, img_var)
                    end = time.time()
                    print('running time:', (end - start))

                precision = to_pil(prediction.data.squeeze(0).cpu())
                precision = precision.resize(shape)
                prediction = np.array(precision)
                prediction = prediction.astype('float')
                prediction = MaxMinNormalization(prediction, prediction.max(), prediction.min()) * 255.0
                prediction = prediction.astype('uint8')
                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                gt = np.array(Image.open(os.path.join(gt_root, img_name + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    folder, sub_name = os.path.split(img_name)
                    save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png'))

            fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                    [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print ('test results:')
    print (results)
Esempio n. 19
0
def train(net, optimizer):
    global total_epoch
    curr_iter = 1
    start_time = time.time()

    for epoch in range(args['last_epoch'] + 1, args['last_epoch'] + 1 + args['epoch_num']):
        loss_record, loss_4_record, loss_3_record, loss_2_record, loss_1_record,  = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

        train_iterator = tqdm(train_loader, total=len(train_loader))
        for data in train_iterator:
            if args['poly_train']:
                base_lr = args['lr'] * (1 - float(curr_iter) / float(total_epoch)) ** args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            if args['poly_warmup']:
                if curr_iter < args['warmup_epoch']:
                    base_lr = 1 / args['warmup_epoch'] * (1+curr_iter)
                else:
                    curr_iter = curr_iter - args['warmup_epoch'] + 1
                    total_epoch = total_epoch - args['warmup_epoch'] + 1
                    base_lr = args['lr'] * (1 - float(curr_iter) / float(total_epoch)) ** args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr


            if args['cosine_warmup']:
                if curr_iter < args['warmup_epoch']:
                    base_lr = 1 / args['warmup_epoch'] * (1+curr_iter)
                else:
                    curr_iter = curr_iter - args['warmup_epoch'] + 1
                    total_epoch = total_epoch - args['warmup_epoch'] + 1
                    base_lr = args['lr'] * (1 + np.cos(np.pi *float(curr_iter) / float(total_epoch))) / 2
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr


            if args["f3_sche"]:
                base_lr = args['lr'] *(1 - abs((curr_iter + 1) / (total_epoch + 1) * 2 - 1))
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda(device_ids[0])
            labels = Variable(labels).cuda(device_ids[0])

            optimizer.zero_grad()

            predict_4, predict_3, predict_2, predict_1 = net(inputs)

            loss_4 = bce_iou_edge_loss(predict_4, labels)
            loss_3 = bce_iou_edge_loss(predict_3, labels)
            loss_2 = bce_iou_edge_loss(predict_2, labels)
            loss_1 = bce_iou_edge_loss(predict_1, labels)

            loss = args['w2'][0] * loss_4 + args['w2'][1] * loss_3 + args['w2'][2] * loss_2 + args['w2'][3] * loss_1

            loss.backward()

            optimizer.step()

            loss_record.update(loss.data, batch_size)
            loss_4_record.update(loss_4.data, batch_size)
            loss_3_record.update(loss_3.data, batch_size)
            loss_2_record.update(loss_2.data, batch_size)
            loss_1_record.update(loss_1.data, batch_size)

            if curr_iter % 50 == 0:
                writer.add_scalar('loss', loss, curr_iter)
                writer.add_scalar('loss_4', loss_4, curr_iter)
                writer.add_scalar('loss_3', loss_3, curr_iter)
                writer.add_scalar('loss_2', loss_2, curr_iter)
                writer.add_scalar('loss_1', loss_1, curr_iter)

            log = '[%3d], [%6d], [%.6f], [%.5f], [%.5f], [%.5f], [%.5f], [%.5f]' % \
                  (epoch, curr_iter, base_lr, loss_record.avg, loss_4_record.avg, loss_3_record.avg,
                   loss_2_record.avg, loss_1_record.avg)
            train_iterator.set_description(log)
            open(log_path, 'a').write(log + '\n')

            curr_iter += 1

        if epoch in args['save_point']:
            net.cpu()
            torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            net.cuda(device_ids[0])

        if epoch >= args['epoch_num']:
            net.cpu()
            torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            print("Total Training Time: {}".format(str(datetime.timedelta(seconds=int(time.time() - start_time)))))
            print("Optimization Have Done!")
            return
Esempio n. 20
0
def train(net, optimizer):
    curr_iter = 1

    for epoch in range(args['last_epoch'] + 1,
                       args['last_epoch'] + 1 + args['epoch_num']):
        loss_4_record, loss_3_record, loss_2_record, loss_1_record, \
        loss_f_record, loss_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

        train_iterator = tqdm(train_loader, total=len(train_loader))
        for data in train_iterator:
            if args['poly_train']:
                base_lr = args['lr'] * (
                    1 - float(curr_iter) /
                    (args['epoch_num'] * len(train_loader)))**args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda(device_ids[0])
            labels = Variable(labels).cuda(device_ids[0])

            optimizer.zero_grad()

            predict_4, predict_3, predict_2, predict_1, predict_f = net(inputs)

            loss_4 = L.lovasz_hinge(predict_4, labels)
            loss_3 = L.lovasz_hinge(predict_3, labels)
            loss_2 = L.lovasz_hinge(predict_2, labels)
            loss_1 = L.lovasz_hinge(predict_1, labels)
            loss_f = L.lovasz_hinge(predict_f, labels)

            loss = loss_4 + loss_3 + loss_2 + loss_1 + loss_f

            loss.backward()

            optimizer.step()

            loss_record.update(loss.data, batch_size)
            loss_4_record.update(loss_4.data, batch_size)
            loss_3_record.update(loss_3.data, batch_size)
            loss_2_record.update(loss_2.data, batch_size)
            loss_1_record.update(loss_1.data, batch_size)
            loss_f_record.update(loss_f.data, batch_size)

            if curr_iter % 50 == 0:
                writer.add_scalar('loss', loss, curr_iter)
                writer.add_scalar('loss_4', loss_4, curr_iter)
                writer.add_scalar('loss_3', loss_3, curr_iter)
                writer.add_scalar('loss_2', loss_2, curr_iter)
                writer.add_scalar('loss_1', loss_1, curr_iter)
                writer.add_scalar('loss_f', loss_f, curr_iter)

            log = '[%3d], [%6d], [%.6f], [%.5f], [L4: %.5f], [L3: %.5f], [L2: %.5f], [L1: %.5f], [Lf: %.5f]' % \
                  (epoch, curr_iter, base_lr, loss_record.avg, loss_4_record.avg, loss_3_record.avg, loss_2_record.avg,
                   loss_1_record.avg, loss_f_record.avg)
            train_iterator.set_description(log)
            open(log_path, 'a').write(log + '\n')

            curr_iter += 1

        if epoch in args['save_point']:
            net.cpu()
            torch.save(net.module.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            net.cuda(device_ids[0])

        if epoch >= args['epoch_num']:
            net.cpu()
            torch.save(net.module.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            print("Optimization Have Done!")
            return
Esempio n. 21
0
def validate(g, curr_epoch, d=None):
    g.eval()

    mse_criterion = nn.MSELoss()
    g_mse_loss_record, psnr_record = AvgMeter(), AvgMeter()

    for name, loader in val_loader.iteritems():

        val_visual = []
        # note that the batch size is 1
        for i, data in enumerate(loader):
            hr_img, _ = data

            lr_img, hr_restore_img = val_lr_transform(hr_img.squeeze(0))

            lr_img = Variable(lr_img.unsqueeze(0), volatile=True).cuda()
            hr_restore_img = hr_restore_img
            hr_img = Variable(hr_img, volatile=True).cuda()

            gen_hr_img = g(lr_img)

            g_mse_loss = mse_criterion(gen_hr_img, hr_img)

            g_mse_loss_record.update(g_mse_loss.data[0])
            psnr_record.update(10 * math.log10(1 / g_mse_loss.data[0]))

            val_visual.extend([
                val_display_transform(hr_restore_img),
                val_display_transform(hr_img.cpu().data.squeeze(0)),
                val_display_transform(gen_hr_img.cpu().data.squeeze(0))
            ])

        val_visual = torch.stack(val_visual, 0)
        val_visual = vutils.make_grid(val_visual, nrow=3, padding=5)

        snapshot_name = 'epoch_%d_%s_g_mse_loss_%.5f_psnr_%.5f' % (
            curr_epoch + 1, name, g_mse_loss_record.avg, psnr_record.avg)

        if d is None:
            snapshot_name = 'pretrain_' + snapshot_name
            writer.add_scalar('pretrain_validate_%s_psnr' % name,
                              psnr_record.avg, curr_epoch + 1)
            writer.add_scalar('pretrain_validate_%s_g_mse_loss' % name,
                              g_mse_loss_record.avg, curr_epoch + 1)

            print '[pretrain validate %s]: [epoch %d], [g_mse_loss %.5f], [psnr %.5f]' % (
                name, curr_epoch + 1, g_mse_loss_record.avg, psnr_record.avg)
        else:
            writer.add_scalar('validate_%s_psnr' % name, psnr_record.avg,
                              curr_epoch + 1)
            writer.add_scalar('validate_%s_g_mse_loss' % name,
                              g_mse_loss_record.avg, curr_epoch + 1)

            print '[validate %s]: [epoch %d], [g_mse_loss %.5f], [psnr %.5f]' % (
                name, curr_epoch + 1, g_mse_loss_record.avg, psnr_record.avg)

            torch.save(
                d.state_dict(),
                os.path.join(train_args['ckpt_path'],
                             snapshot_name + '_d.pth'))

        torch.save(
            g.state_dict(),
            os.path.join(train_args['ckpt_path'], snapshot_name + '_g.pth'))

        writer.add_image(snapshot_name, val_visual)

        g_mse_loss_record.reset()
        psnr_record.reset()

    g.train()
Esempio n. 22
0
def train(net, optimizer):

    curr_iter = args['last_iter']
    train_loss_record = AvgMeter()
    train_net_loss_record = AvgMeter()
    train_uncertainty_loss_record = AvgMeter()

    while True:
        for i, data in enumerate(train_loader):

            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, gts = data

            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            gts = Variable(gts).cuda()

            optimizer.zero_grad()

            result = net(inputs)

            nhw = inputs.size(0) * inputs.size(2) * inputs.size(3)

            loss_net = criterion(result, gts)
            loss = loss_net

            loss.backward()

            optimizer.step()

            #print(uncertainty)

            train_loss_record.update(loss.data, batch_size)
            train_net_loss_record.update(loss_net.data, batch_size)
            #train_uncertainty_loss_record.update(loss_uncertainty_regular.data, batch_size)

            curr_iter += 1

            # log = '[iter %d], [train loss %.5f], [lr %.8f], [loss_net %.5f], [w1 %.8f], [loss_uncertainty %.5f]' % \
            #       (curr_iter, train_loss_record.avg, optimizer.param_groups[1]['lr'],
            #        train_net_loss_record.avg, w1, train_uncertainty_loss_record.avg)

            log = '[iter %d], [train loss %.5f], [lr %.8f], [loss_net %.5f]' % \
                  (curr_iter, train_loss_record.avg, optimizer.param_groups[1]['lr'],
                   train_net_loss_record.avg)
            print(log)
            open(log_path, 'a').write(log + '\n')

            if (curr_iter + 1) % args['val_freq'] == 0:
                validate(net, curr_iter, optimizer)

            if (curr_iter + 1) % args['snapshot_epochs'] == 0:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 ('%d.pth' % (curr_iter + 1))))
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(ckpt_path, exp_name,
                                 ('%d_optim.pth' % (curr_iter + 1))))

            if curr_iter > args['iter_num']:
                return
Esempio n. 23
0
def train(net, optimizer):
    curr_iter = 1

    for epoch in range(args['last_epoch'] + 1,
                       args['last_epoch'] + 1 + args['epoch_num']):
        loss_record, loss_b_record, loss_c_record, loss_o_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()

        train_iterator = tqdm(train_loader, total=len(train_loader))
        for data in train_iterator:
            if args['poly_train']:
                base_lr = args['lr'] * (
                    1 - float(curr_iter) /
                    (args['epoch_num'] * len(train_loader)))**args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            inputs, labels, edges = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda(device_ids[0])
            labels = Variable(labels).cuda(device_ids[0])
            edges = Variable(edges).cuda(device_ids[0])

            optimizer.zero_grad()

            predict_c, predict_b, predict_o = net(inputs)

            loss_b = bce(predict_b, edges)
            loss_c = L.lovasz_hinge(predict_c, labels)
            loss_o = L.lovasz_hinge(predict_o, labels)

            loss = loss_b + loss_c + loss_o

            loss.backward()

            optimizer.step()

            loss_record.update(loss.data, batch_size)
            loss_b_record.update(loss_b.data, batch_size)
            loss_c_record.update(loss_c.data, batch_size)
            loss_o_record.update(loss_o.data, batch_size)

            if curr_iter % 50 == 0:
                writer.add_scalar('loss', loss, curr_iter)
                writer.add_scalar('loss_b', loss_b, curr_iter)
                writer.add_scalar('loss_c', loss_c, curr_iter)
                writer.add_scalar('loss_o', loss_o, curr_iter)

            log = '[Epoch: %2d], [Iter: %5d], [%.7f], [Sum: %.5f], [Lb: %.5f], [Lc: %.5f], [Lo: %.5f]' % \
                  (epoch, curr_iter, base_lr, loss_record.avg, loss_b_record.avg, loss_c_record.avg, loss_o_record.avg)
            train_iterator.set_description(log)
            open(log_path, 'a').write(log + '\n')

            curr_iter += 1

        if epoch in args['save_point']:
            net.cpu()
            torch.save(net.module.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            net.cuda(device_ids[0])

        if epoch >= args['epoch_num']:
            net.cpu()
            torch.save(net.module.state_dict(),
                       os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            print("Optimization Have Done!")
            return
Esempio n. 24
0
    train_set = VideoImageFolder(video_train_path, imgs_file, joint_transform, img_transform, target_transform)
else:
    train_set = VideoImage2Folder(video_train_path, imgs_file, video_seq_path + '/DAFB2', video_seq_gt_path + '/DAFB2',
                                  joint_transform, None, input_size, img_transform, target_transform)

train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=4, shuffle=True)

criterion = nn.BCEWithLogitsLoss().cuda()
if args['L2']:
    criterion_l2 = nn.MSELoss().cuda()
    # criterion_pair = CriterionPairWise(scale=0.5).cuda()
if args['KL']:
    criterion_kl = CriterionKL3().cuda()
log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')

total_loss_record, loss0_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

def fix_parameters(parameters):
    for name, parameter in parameters:
        if name.find('motion') >= 0 \
                or name.find('GRU') >= 0 or name.find('predict') >= 0:
            print(name, 'is not fixed')

        else:
            print(name, 'is fixed')
            parameter.requires_grad = False


def main():
    net = Distill(basic_model=args['basic_model'], seq=True).cuda().train()
Esempio n. 25
0
def main():
    net = ConcatNet(backbone='resnet50', embedding=128, batch_mode='old')

    print('load snapshot \'%s\' for testing' % args['snapshot'])
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth'),
                   map_location='cuda:0'))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        precision_record, recall_record, = [AvgMeter() for _ in range(256)], [
            AvgMeter() for _ in range(256)
        ]
        mae_record = AvgMeter()

        if args['save_results']:
            check_mkdir(
                os.path.join(
                    ckpt_path, exp_name,
                    '(%s) %s_%s' % (exp_name, 'DAVIS', args['snapshot'])))

        index = 0
        old_seq = ''
        # img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]
        for idx, batch in enumerate(testloader):

            print('%d processd ' % (idx))
            target = batch['target']
            seq_name = batch['seq_name']
            args['seq_name'] = seq_name[0]
            print('sequence name:', seq_name[0])
            if old_seq == args['seq_name']:
                index = index + 1
            else:
                index = 0
            output_sum = 0
            for i in range(0, args['sample_range']):
                search = batch['search_' + str(i)]
                output = net(
                    Variable(target, volatile=True).cuda(),
                    Variable(search, volatile=True).cuda())
                output = F.upsample(output,
                                    size=target.size()[2:],
                                    mode='bilinear',
                                    align_corners=True)
                output = F.sigmoid(output)
                output_sum = output_sum + output[0].data.cpu().numpy()

            output_final = output_sum / args['sample_range']
            output_final = cv2.resize(output_final[0], (854, 480))
            output_final = (output_final * 255).astype(np.uint8)

            gt = np.array(
                Image.open(
                    os.path.join(args['gt_dir'], args['seq_name'],
                                 str(index).zfill(5) + '.png')).convert('L'))
            precision, recall, mae = cal_precision_recall_mae(output_final, gt)
            for pidx, pdata in enumerate(zip(precision, recall)):
                p, r = pdata
                precision_record[pidx].update(p)
                recall_record[pidx].update(r)
            mae_record.update(mae)

            if args['save_results']:
                old_seq = args['seq_name']
                save_path = os.path.join(
                    ckpt_path, exp_name,
                    '(%s) %s_%s' % (exp_name, 'DAVIS', args['snapshot']),
                    args['seq_name'])
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                Image.fromarray(output_final).save(
                    os.path.join(save_path,
                                 str(index).zfill(5) + '.png'))

        fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                [rrecord.avg for rrecord in recall_record])

        results['DAVIS'] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print('test results:')
    print(results)
Esempio n. 26
0
def train(net, optimizer):
    curr_iter = args['last_iter']
    while True:
        train_loss_record, loss_fuse_record, loss1_h2l_record = AvgMeter(
        ), AvgMeter(), AvgMeter()
        loss2_h2l_record, loss3_h2l_record, loss4_h2l_record = AvgMeter(
        ), AvgMeter(), AvgMeter()
        loss1_l2h_record, loss2_l2h_record, loss3_l2h_record = AvgMeter(
        ), AvgMeter(), AvgMeter()
        loss4_l2h_record = AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()

            fuse_predict, predict1_h2l, predict2_h2l, predict3_h2l, predict4_h2l, \
            predict1_l2h, predict2_l2h, predict3_l2h, predict4_l2h = net(inputs)

            loss_fuse = bce_logit(fuse_predict, labels)
            loss1_h2l = bce_logit(predict1_h2l, labels)
            loss2_h2l = bce_logit(predict2_h2l, labels)
            loss3_h2l = bce_logit(predict3_h2l, labels)
            loss4_h2l = bce_logit(predict4_h2l, labels)
            loss1_l2h = bce_logit(predict1_l2h, labels)
            loss2_l2h = bce_logit(predict2_l2h, labels)
            loss3_l2h = bce_logit(predict3_l2h, labels)
            loss4_l2h = bce_logit(predict4_l2h, labels)

            loss = loss_fuse + loss1_h2l + loss2_h2l + loss3_h2l + loss4_h2l + loss1_l2h + \
                   loss2_l2h + loss3_l2h + loss4_l2h
            loss.backward()

            optimizer.step()

            train_loss_record.update(loss.data, batch_size)
            loss_fuse_record.update(loss_fuse.data, batch_size)
            loss1_h2l_record.update(loss1_h2l.data, batch_size)
            loss2_h2l_record.update(loss2_h2l.data, batch_size)
            loss3_h2l_record.update(loss3_h2l.data, batch_size)
            loss4_h2l_record.update(loss4_h2l.data, batch_size)
            loss1_l2h_record.update(loss1_l2h.data, batch_size)
            loss2_l2h_record.update(loss2_l2h.data, batch_size)
            loss3_l2h_record.update(loss3_l2h.data, batch_size)
            loss4_l2h_record.update(loss4_l2h.data, batch_size)

            curr_iter += 1

            log = '[iter %d], [train loss %.5f], [loss_fuse %.5f], [loss1_h2l %.5f], [loss2_h2l %.5f], ' \
                  '[loss3_h2l %.5f], [loss4_h2l %.5f], [loss1_l2h %.5f], [loss2_l2h %.5f], [loss3_l2h %.5f], ' \
                  '[loss4_l2h %.5f], [lr %.13f]' % \
                  (curr_iter, train_loss_record.avg, loss_fuse_record.avg, loss1_h2l_record.avg, loss2_h2l_record.avg,
                   loss3_h2l_record.avg, loss4_h2l_record.avg, loss1_l2h_record.avg, loss2_l2h_record.avg,
                   loss3_l2h_record.avg, loss4_l2h_record.avg, optimizer.param_groups[1]['lr'])
            print log
            open(log_path, 'a').write(log + '\n')

            if curr_iter > args['iter_num']:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                return
Esempio n. 27
0
def train(exp_name):
    log_path = os.path.join(ckpt_path, exp_name,
                            str(datetime.datetime.now()) + '.txt')
    net = DPNet().cuda().train()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')
    print 'start to train'

    curr_iter = args['last_iter']
    while True:
        total_loss_record = AvgMeter()
        loss1_record, loss2_record, loss3_record, loss4_record = AvgMeter(
        ), AvgMeter(), AvgMeter(), AvgMeter()
        loss_DPM1_record = AvgMeter()

        for i, data in enumerate(train_loader):
            optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']
            optimizer.param_groups[1]['lr'] = args['lr'] * (
                1 - float(curr_iter) / args['iter_num'])**args['lr_decay']

            inputs, labels = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()

            optimizer.zero_grad()

            predict1, predict2, predict3, predict4, predict_DPM1 = net(inputs)

            loss1 = criterionBCE(predict1, labels)
            loss2 = criterionBCE(predict2, labels)
            loss3 = criterionBCE(predict3, labels)
            loss4 = criterionBCE(predict4, labels)
            loss_DPM1 = criterionBCE(predict_DPM1, labels)

            total_loss = loss1 + loss2 + loss3 + loss4 + loss_DPM1
            total_loss.backward()

            optimizer.step()

            total_loss_record.update(total_loss.item(), batch_size)
            loss1_record.update(loss1.item(), batch_size)
            loss2_record.update(loss2.item(), batch_size)
            loss3_record.update(loss3.item(), batch_size)
            loss4_record.update(loss4.item(), batch_size)
            loss_DPM1_record.update(loss_DPM1.item(), batch_size)

            curr_iter += 1

            log = '[iter %d], [total loss %.5f], [loss1 %.5f], [loss2 %.5f], [loss3 %.5f], ' \
                  '[loss4 %.5f], [loss_DPM1 %.5f], [lr %.13f]' \
                  % (curr_iter, total_loss_record.avg, loss1_record.avg, loss2_record.avg, loss3_record.avg,
                     loss4_record.avg, loss_DPM1_record.avg, optimizer.param_groups[1]['lr'])

            print log
            open(log_path, 'a').write(log + '\n')

            if curr_iter == args['iter_num']:
                torch.save(
                    net.state_dict(),
                    os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
                return
Esempio n. 28
0
def main():
    if args['model'] == 'BASNet':
        net = BASNet(3, 1)
    elif args['model'] == 'R3Net':
        net = R3Net()
    elif args['model'] == 'DSSNet':
        net = build_model()
    elif args['model'] == 'CPD':
        net = CPD_ResNet()
    elif args['model'] == 'RAS':
        net = RAS()
    elif args['model'] == 'PiCANet':
        net = Unet()
    elif args['model'] == 'PoolNet':
        net = build_model_poolnet(base_model_cfg='resnet')
    elif args['model'] == 'R2Net':
        net = build_model_r2net(base_model_cfg='resnet')
    elif args['model'] == 'F3Net':
        net = F3Net(cfg=None)

    print('load snapshot \'%s\' for testing' % args['snapshot'])
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth'),
                   map_location='cuda:0'))
    # net = train_online(net)
    results = {}

    for name, root in to_test.items():

        precision_record, recall_record, = [AvgMeter() for _ in range(256)], [
            AvgMeter() for _ in range(256)
        ]
        mae_record = AvgMeter()

        if args['save_results']:
            check_mkdir(
                os.path.join(
                    ckpt_path, exp_name, '(%s) %s_%s' %
                    (exp_name, name, args['snapshot'] + '_online')))

        folders = os.listdir(root)
        folders.sort()
        folders = ['bmx-trees']
        for folder in folders:
            if args['online_train']:
                net = train_online(net, seq_name=folder)
                net.load_state_dict(
                    torch.load(os.path.join(
                        ckpt_path, exp_name,
                        str(args['snapshot']) + '_' + folder + '_online.pth'),
                               map_location='cuda:0'))
            with torch.no_grad():

                net.eval()
                net.cuda()
                imgs = os.listdir(os.path.join(root, folder))
                imgs.sort()
                for i in range(args['start'], len(imgs)):
                    print(imgs[i])
                    start = time.time()

                    img = Image.open(os.path.join(root, folder,
                                                  imgs[i])).convert('RGB')
                    shape = img.size
                    img = img.resize(args['input_size'])
                    img_var = Variable(img_transform(img).unsqueeze(0),
                                       volatile=True).cuda()

                    if args['model'] == 'BASNet':
                        prediction, _, _, _, _, _, _, _ = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'R3Net':
                        prediction = net(img_var)
                    elif args['model'] == 'DSSNet':
                        select = [1, 2, 3, 6]
                        prediction = net(img_var)
                        prediction = torch.mean(torch.cat(
                            [torch.sigmoid(prediction[i]) for i in select],
                            dim=1),
                                                dim=1,
                                                keepdim=True)
                    elif args['model'] == 'CPD':
                        _, prediction = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'RAS':
                        prediction, _, _, _, _ = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'PoolNet':
                        prediction = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'F3Net':
                        _, prediction, _, _, _, _ = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    elif args['model'] == 'R2Net':
                        _, _, _, _, _, prediction = net(img_var)
                        prediction = torch.sigmoid(prediction)
                    end = time.time()
                    print('running time:', (end - start))

                    if args['crf_refine']:
                        prediction = crf_refine(np.array(img), prediction)

                    precision = to_pil(prediction.data.squeeze(0).cpu())
                    precision = precision.resize(shape)
                    prediction = np.array(precision)
                    prediction = prediction.astype('float')
                    prediction = MaxMinNormalization(
                        prediction, prediction.max(), prediction.min()) * 255.0
                    prediction = prediction.astype('uint8')

                    gt = np.array(
                        Image.open(
                            os.path.join(gt_root, folder,
                                         imgs[i][:-4] + '.png')).convert('L'))
                    precision, recall, mae = cal_precision_recall_mae(
                        prediction, gt)
                    for pidx, pdata in enumerate(zip(precision, recall)):
                        p, r = pdata
                        precision_record[pidx].update(p)
                        recall_record[pidx].update(r)
                    mae_record.update(mae)

                    if args['save_results']:
                        # folder, sub_name = os.path.split(imgs[i])
                        save_path = os.path.join(
                            ckpt_path, exp_name, '(%s) %s_%s' %
                            (exp_name, name, args['snapshot'] + '_online'),
                            folder)
                        if not os.path.exists(save_path):
                            os.makedirs(save_path)
                        Image.fromarray(prediction).save(
                            os.path.join(save_path, imgs[i]))

        fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                [rrecord.avg for rrecord in recall_record])

        results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print('test results:')
    print(results)
Esempio n. 29
0
def main():
    net = R3Net_prior(motion='GRU',
                      se_layer=False,
                      attention=False,
                      pre_attention=True,
                      basic_model='resnet50',
                      sta=False,
                      naive_fuse=False)

    print('load snapshot \'%s\' for testing' % args['snapshot'])
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth'),
                   map_location='cuda:1'))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        for name, root in to_test.items():

            precision_record, recall_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(
                    os.path.join(
                        ckpt_path, exp_name,
                        '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
            img_list = [i_id.strip() for i_id in open(imgs_path)]
            # img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]
            for idx, img_names in enumerate(img_list):
                print('predicting for %s: %d / %d' %
                      (name, idx + 1, len(img_list)))
                img_seq = img_names.split(',')

                for ratio in args['scale_ratio']:
                    prediction_scale = []
                    img_var = []
                    for img_name in img_seq:
                        if name == 'VOS' or name == 'DAVSOD':
                            img = Image.open(
                                os.path.join(root,
                                             img_name + '.png')).convert('RGB')
                        else:
                            img = Image.open(
                                os.path.join(root,
                                             img_name + '.jpg')).convert('RGB')
                        shape = img.size

                        new_dims = (int(img.size[0] * ratio),
                                    int(img.size[1] * ratio))
                        img = img.resize(new_dims, Image.BILINEAR)
                        img = img.resize(args['input_size'])
                        img_var.append(
                            Variable(img_transform(img).unsqueeze(0),
                                     volatile=True).cuda())

                    img_var = torch.cat(img_var, dim=0)
                    start = time.time()
                    prediction = net(img_var)
                    end = time.time()
                    print('running time:', (end - start))
                    prediction_scale.append(prediction)
                prediction = torch.mean(torch.stack(prediction_scale, dim=0),
                                        0)
                prediction = to_pil(prediction.data.squeeze(0).cpu())
                prediction = prediction.resize(shape)
                prediction = np.array(prediction)
                prediction = prediction.astype('float')
                prediction = MaxMinNormalization(prediction, prediction.max(),
                                                 prediction.min()) * 255.0
                prediction = prediction.astype('uint8')

                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                gt = np.array(
                    Image.open(os.path.join(gt_root, img_seq[-1] +
                                            '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(
                    prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    folder, sub_name = os.path.split(img_name)
                    save_path = os.path.join(
                        ckpt_path, exp_name,
                        '(%s) %s_%s' % (exp_name, name, args['snapshot']),
                        folder)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    Image.fromarray(prediction).save(
                        os.path.join(save_path, sub_name + '.png'))

            fmeasure = cal_fmeasure(
                [precord.avg for precord in precision_record],
                [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print('test results:')
    print(results)
Esempio n. 30
0
def main():
    net = AADFNet().cuda()
    net = nn.DataParallel(net, device_ids=[0])

    print exp_name + 'crf: ' + str(args['crf_refine'])
    print 'load snapshot \'%s\' for testing' % args['snapshot']
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, exp_name,
                                args['snapshot'] + '.pth')))
    net.eval()

    with torch.no_grad():
        results = {}

        for name, root in to_test.iteritems():

            precision_record, recall_record, = [
                AvgMeter() for _ in range(256)
            ], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()
            time_record = AvgMeter()

            img_list = [
                os.path.splitext(f)[0] for f in os.listdir(root)
                if f.endswith('.jpg')
            ]

            for idx, img_name in enumerate(img_list):
                img_name = img_list[idx]
                print 'predicting for %s: %d / %d' % (name, idx + 1,
                                                      len(img_list))
                check_mkdir(
                    os.path.join(
                        ckpt_path, exp_name,
                        '(%s) %s_%s' % (exp_name, name, args['snapshot'])))

                start = time.time()
                img = Image.open(os.path.join(root, img_name +
                                              '.jpg')).convert('RGB')
                img_var = Variable(img_transform(img).unsqueeze(0),
                                   volatile=True).cuda()
                prediction = net(img_var)
                W, H = img.size
                prediction = F.upsample_bilinear(prediction, size=(H, W))
                prediction = np.array(to_pil(prediction.data.squeeze(0).cpu()))

                if args['crf_refine']:
                    prediction = crf_refine(np.array(img), prediction)

                end = time.time()

                gt = np.array(
                    Image.open(os.path.join(root,
                                            img_name + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(
                    prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)

                mae_record.update(mae)
                time_record.update(end - start)

                if args['save_results']:
                    Image.fromarray(prediction).save(
                        os.path.join(
                            ckpt_path, exp_name,
                            '(%s) %s_%s' % (exp_name, name, args['snapshot']),
                            img_name + '.png'))

            max_fmeasure, mean_fmeasure = cal_fmeasure_both(
                [precord.avg for precord in precision_record],
                [rrecord.avg for rrecord in recall_record])
            results[name] = {
                'max_fmeasure': max_fmeasure,
                'mae': mae_record.avg,
                'mean_fmeasure': mean_fmeasure
            }

        print 'test results:'
        print results

        with open('Result', 'a') as f:
            if args['crf_refine']:
                f.write('with CRF')

            f.write('Runing time %.6f \n' % time_record.avg)
            f.write('\n%s\n  iter:%s\n' % (exp_name, args['snapshot']))
            for name, value in results.iteritems():
                f.write(
                    '%s: mean_fmeasure: %.10f, mae: %.10f, max_fmeasure: %.10f\n'
                    % (name, value['mean_fmeasure'], value['mae'],
                       value['max_fmeasure']))