コード例 #1
0
ファイル: main_train_BSDR.py プロジェクト: Jimbo51000/bsdr
def test_function(X, Y, network):
    """
    Evaluation of network on test and valid set
    Parameters
    ----------
    X : input images (B,3,h,w)
    Y : ground truth (B,1,h/8,w/8)
    network : BSDR object
    """
    X = torch.autograd.Variable(torch.from_numpy(X)).cuda()
    Y = torch.autograd.Variable(torch.from_numpy(Y)).cuda()

    network = network.cuda()
    network.eval()
    output = network(X)  # (B,1,h,w)
    loss = 0.0
    loss_criterion = nn.MSELoss(size_average=True)
    # bp()

    loss = loss_criterion(output, Y)

    count_error = torch.abs(
        torch.sum(Y.view(Y.size(0), -1), dim=1) -
        torch.sum(output.view(output.size(0), -1), dim=1))
    network.train()
    network = set_batch_norm_to_eval(network)
    return loss.item(), output.cpu().detach().numpy(), count_error.cpu(
    ).detach().numpy()
コード例 #2
0
ファイル: stage2_main.py プロジェクト: Xinxinatg/css-ccnn
def test_function(X, Y, network):
    X = torch.autograd.Variable(torch.from_numpy(X)).cuda()
    Y = torch.autograd.Variable(torch.from_numpy(Y)).cuda()

    network = network.cuda()
    network.eval()
    output = network(X)  # (B,1,h,w)

    loss = 0.0
    loss_criterion = get_loss_criterion()
    avg_pool = nn.AvgPool2d(kernel_size=args.kernel_size,
                            stride=args.kernel_size)

    output_reshape = avg_pool(output) * (args.kernel_size * args.kernel_size)

    loss = loss_criterion(output_reshape.view(-1, 1),
                          torch.cuda.FloatTensor(sampled_GT).view(-1,
                                                                  1)) * 0.01
    count_error = torch.abs(
        torch.sum(Y.view(Y.size(0), -1), dim=1) -
        torch.sum(output.view(output.size(0), -1), dim=1))

    network.train()
    network = set_batch_norm_to_eval(network)
    return loss.item(), output.cpu().detach().numpy(), count_error.cpu(
    ).detach().numpy()
コード例 #3
0
ファイル: test_model.py プロジェクト: val-iisc/css-ccnn
def test_function(X, Y, network):
    X = torch.autograd.Variable(torch.from_numpy(X)).cuda()
    Y = torch.autograd.Variable(torch.from_numpy(Y)).cuda()

    network = network.cuda()
    network.eval()
    output = network(X) # (B,1,h,w)
 
    count_error = torch.abs(torch.sum(Y.view(Y.size(0), -1), dim=1) - torch.sum(output.view(output.size(0), -1), dim=1))

    network.train()
    network = set_batch_norm_to_eval(network)
    return output.cpu().detach().numpy(), count_error.cpu().detach().numpy()
コード例 #4
0
def test_function(Xs, Ys, network, set_name=None):
    assert(set_name is not None)

    X = torch.autograd.Variable(torch.from_numpy(Xs)).cuda()
    Y = torch.autograd.Variable(torch.from_numpy(Ys)).float().cuda()

    network = network.cuda()
    network.eval()
    output = network(X)  # (B,1,h,w)

    loss = 0.0
    loss_criterion = get_loss_criterion()

    avg_pool = nn.AvgPool2d(kernel_size=args.kernel_size,
                            stride=args.kernel_size)

    output_reshape_ = avg_pool(output) * (args.kernel_size * args.kernel_size)
    
    if set_name == 'test_valid':
        pseudo_density_maps = create_pseudo_density(Xs)
        pseudo_density_maps = torch.from_numpy(pseudo_density_maps).cuda()

        pseudo_reshape_ = avg_pool(pseudo_density_maps) * (args.kernel_size * args.kernel_size)
        output_reshape = output_reshape_.view(-1, 1)
        pseudo_reshape = pseudo_reshape_.view(-1, 1)

        pseudo_median = pseudo_reshape.topk(int(args.percentile_thresh*len(pseudo_reshape)), dim=0)[0][-1:][0]
        Y_median = Y.topk(int(args.percentile_thresh*(len(Y))), dim=0)[0][-1:][0]
        a_output_indices = pseudo_reshape < pseudo_median
        a_Y_indices = Y < Y_median

        if a_output_indices.sum() > 2:
            loss_sparse = loss_criterion(output_reshape[a_output_indices].view(-1, 1), Y[a_Y_indices].view(-1, 1))
            loss_dense = loss_criterion(output_reshape[~a_output_indices].view(-1, 1), Y[~a_Y_indices].view(-1, 1))
            loss = (loss_sparse + loss_dense)* 0.01
        else:
            loss = loss_criterion(output_reshape, Y.view(-1, 1)) * 0.01
    else:
        output_reshape = output_reshape_.view(-1, 1)        
        loss = loss_criterion(output_reshape, Y.view(-1, 1)) * 0.01
        count_error = torch.abs(torch.sum(Y.view(Y.size(0), -1), dim=1) - torch.sum(output.view(output.size(0), -1), dim=1))

    network.train()
    network = set_batch_norm_to_eval(network)

    if set_name == "test_valid":
        return loss.item(), loss_sparse.item(), loss_dense.item(), output.cpu().detach().numpy()
    else:
        return loss.item(), output.cpu().detach().numpy(), count_error.cpu().detach().numpy()