def forward(self, image, target):
     image1 = F.conv2d(image, self.op1, padding=0)
     target1 = F.conv2d(target, self.op1, padding=0)
     image2 = F.conv2d(image, self.op2, padding=0)
     target2 = F.conv2d(target, self.op2, padding=0)
     criterionL1 = nn.L1Loss()
     return criterionL1(image1,target1) + criterionL1(image2, target2)
Example #2
0
 def __init__(self, in_channels, out_channels, kernel_size, alpha_shape, stride=1,
              padding=0, dilation=1, prior='loguni', bias=True):
     super(ConvVDO, self).__init__()
     self.in_channels = in_channels
     self.out_channels = out_channels
     self.kernel_size = (kernel_size, kernel_size)
     self.stride = stride
     self.padding = padding
     self.dilation = dilation
     self.alpha_shape = alpha_shape
     self.groups = 1
     self.weight = Parameter(torch.Tensor(
         out_channels, in_channels, *self.kernel_size))
     if bias:
         self.bias = Parameter(torch.Tensor(1, out_channels, 1, 1))
     else:
         self.register_parameter('bias', None)
     self.op_bias = lambda input, kernel: F.conv2d(input, kernel, self.bias, self.stride, self.padding, self.dilation, self.groups)
     self.op_nobias = lambda input, kernel: F.conv2d(input, kernel, None, self.stride, self.padding, self.dilation, self.groups)
     self.log_alpha = Parameter(torch.Tensor(*alpha_shape))
     self.reset_parameters()
     self.zero_mean = False
     self.permute_sigma = False
     self.prior = prior
     if prior == 'loguni':
         self.kl_fun = metrics.kl_loguni
     else:
         self.kl_fun = metrics.kl_ard
 def block(x, params, base, mode, stride):
     o1 = F.relu(utils.batch_norm(x, params, base + '.bn0', mode), inplace=True)
     y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
     o2 = F.relu(utils.batch_norm(y, params, base + '.bn1', mode), inplace=True)
     z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
     if base + '.convdim' in params:
         return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
     else:
         return z + x
Example #4
0
File: losses.py Project: Daiver/jff
def mk_diff_img(image, channels_first=True):
    assert channels_first
    assert len(image.shape) == 2

    image = image.unsqueeze(0).unsqueeze(0)

    x_kernel = torch.Tensor([
        [-1, 0, 1],
        [-1, 0, 1],
        [-1, 0, 1]])
    x_kernel = x_kernel.view((1, 1, 3, 3)).cuda()

    y_kernel = torch.Tensor([
        [-1, -1, -1],
        [0, 0, 0],
        [1, 1, 1]])
    y_kernel = y_kernel.view((1, 1, 3, 3)).cuda()

    # padded_image = F.pad(image, [1, 1, 1, 1], value=image.abs().max())
    padded_image = F.pad(image, [1, 1, 1, 1], value=0)

    diff_img_x = F.conv2d(padded_image, y_kernel)
    diff_img_y = F.conv2d(padded_image, x_kernel)

    image = image.squeeze(0).squeeze(0)
    diff_img_x.squeeze_(0)
    diff_img_x.squeeze_(0)
    diff_img_y.squeeze_(0)
    diff_img_y.squeeze_(0)

    class LocalFunction(autograd.Function):
        def __init__(self):
            super().__init__()

        @staticmethod
        def forward(ctx, points_positions):
            assert len(points_positions.shape) == 2
            points_positions_detached = points_positions.detach().round().long()
            points_positions_detached[:, 0].clamp_(0, image.shape[0] - 1)
            points_positions_detached[:, 1].clamp_(0, image.shape[1] - 1)
            ctx.save_for_backward(points_positions_detached)
            return image[points_positions_detached[:, 0], points_positions_detached[:, 1]]

        @staticmethod
        def backward(ctx, grad_outputs):
            points_positions_detached, = ctx.saved_tensors

            d_x = diff_img_x[points_positions_detached[:, 0], points_positions_detached[:, 1]]
            d_y = diff_img_y[points_positions_detached[:, 0], points_positions_detached[:, 1]]
            res = torch.zeros(points_positions_detached.shape).cuda()

            res[:, 0] = grad_outputs * d_x
            res[:, 1] = grad_outputs * d_y
            return res

    return LocalFunction()
    def blur_frame(self,frame):


        aaa = 0

        if aaa ==0:

            if torch.max(frame) > 1.:
                print ('DDDDDDDD')
                print (torch.max(frame).data.cpu().numpy())
                fasdf

            K = 21 #11 #21
            padding = 10# 5
            filter_weights = torch.ones(1,1,K,K).cuda()

            filter_weights = filter_weights / K**2
            frame_c0 = frame[:,0].unsqueeze(1)
            # print (frame_c0.shape)
            frame_c0 = F.conv2d(input=frame_c0, weight=filter_weights, bias=None, padding=padding, stride=1, dilation=1)
            # print (frame_c0.size())
            # print ('Output: [B,outC,outH,outW]')
            # print ()

            # print (torch.max(frame_c0).data.cpu().numpy())

            frame_c1 = frame[:,1].unsqueeze(1)
            frame_c1 = F.conv2d(input=frame_c1, weight=filter_weights, bias=None, padding=padding, stride=1, dilation=1)

            # print (torch.max(frame_c1).data.cpu().numpy())


            frame_c2 = frame[:,2].unsqueeze(1)
            frame_c2 = F.conv2d(input=frame_c2, weight=filter_weights, bias=None, padding=padding, stride=1, dilation=1)

            # print (torch.max(frame_c2).data.cpu().numpy())
            # fdsfa

            blurred_image = [frame_c0, frame_c1, frame_c2]
            blurred_image = torch.stack(blurred_image, dim=1)

            # print (blurred_image.shape)

            blurred_image = blurred_image.squeeze(dim=2)  #[B,3,480,640]

            # blurred_image = blurred_image / torch.max(blurred_image)
            blurred_image = torch.clamp(blurred_image, max=1.0)

            # print (torch.max(blurred_image).data.cpu().numpy())
            # fas

        else:
            blurred_image = torch.zeros(frame.size()[0],3,480,640).cuda()

        return blurred_image
Example #6
0
def f(params, inputs, mode):
    o = inputs.view(inputs.size(0), 1, 28, 28)
    o = F.conv2d(o, params['conv0.weight'], params['conv0.bias'], stride=2)
    o = F.relu(o)
    o = F.conv2d(o, params['conv1.weight'], params['conv1.bias'], stride=2)
    o = F.relu(o)
    o = o.view(o.size(0), -1)
    o = F.linear(o, params['linear2.weight'], params['linear2.bias'])
    o = F.relu(o)
    o = F.linear(o, params['linear3.weight'], params['linear3.bias'])
    return o
Example #7
0
 def forward(self, input):
     w_mat = self.weight.view(self.weight.size(0), -1)
     sigma, _u = max_singular_value(w_mat, self.u)
     self.u = _u
     self.weight.data = self.weight.data / sigma
     return F.conv2d(input, self.weight, self.bias, self.stride,
                     self.padding, self.dilation, self.groups)
    def forward(self, input):
        #return F.conv2d(input, self.weight, self.bias, self.stride,
        #                self.padding, self.dilation, self.groups)
        
        #conv_weight = torch.mm(self.compressed_weight, torch.tanh(self.transform_mat))
        #conv_weight = torch.mm(self.compressed_weight, (self.transform_mat))
        
        #compressed_weight = torch.mm(self.input_weight, torch.tanh(self.transform_mat))
        #conv_weight = torch.mm(compressed_weight, torch.tanh(self.transform_back_mat))
       
        #conv_weight = conv_weight.view(self.in_channels // self.groups, self.kernel_size[0], \
        #        self.kernel_size[1], self.out_channels);
        #conv_weight = conv_weight.permute(3, 0, 1, 2)
        #conv_weight = conv_weight.contiguous()
        
        #fit_loss = torch.norm(conv_weight-self.ref_conv_weight,2)

        #pdb.set_trace()
        #conv_weight = Variable(torch.Tensor(self.out_channels, self.in_channels // self.groups, \
        #              self.kernel_size[0], self.kernel_size[1]))
        #conv_weight.cuda()
        conv_weight = self.extract_filters()
        #conv_weight[0,:,:,:] = self.filtermap[0:self.in_channels, \
        #                                      0:0+self.kernel_size[0], 0:0+self.kernel_size[1]]
        out = F.conv2d(input, conv_weight, self.bias, self.stride,
                         self.padding, self.dilation, self.groups)
   
        return out        
def test_backward_computes_backward_pass():
    weight = torch.randn(4, 8, 3, 3).cuda()
    input = torch.randn(4, 8, 4, 4).cuda()

    input_var = Variable(input, requires_grad=True)
    weight_var = Parameter(weight)
    out_var = F.conv2d(
        input=input_var,
        weight=weight_var,
        bias=None,
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
    )
    out_var.backward(gradient=input_var.data.clone().fill_(1))
    out = out_var.data
    input_grad = input_var.grad.data
    weight_grad = weight_var.grad.data

    func = _EfficientConv2d(
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
    )
    out_efficient = func.forward(weight, None, input)
    weight_grad_efficient, _, input_grad_efficient = func.backward(
            weight, None, input, input.clone().fill_(1))

    assert(almost_equal(out, out_efficient))
    assert(almost_equal(input_grad, input_grad_efficient))
    assert(almost_equal(weight_grad, weight_grad_efficient))
Example #10
0
 def __init__(self, in_channels, out_channels, kernel_size, stride=1,
              padding=0, dilation=1, bias=True):
     super(ConvVarianceUnif, self).__init__()
     self.in_channels = in_channels
     self.out_channels = out_channels
     self.kernel_size = (kernel_size, kernel_size)
     self.stride = stride
     self.padding = padding
     self.dilation = dilation
     self.groups = 1
     self.W = Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
     if bias:
         self.bias = Parameter(torch.Tensor(1, out_channels, 1, 1))
     else:
         self.register_parameter('bias', None)
     self.op_bias = lambda input, kernel: F.conv2d(input, kernel, self.bias, self.stride, self.padding, self.dilation, self.groups)
     self.op_nobias = lambda input, kernel: F.conv2d(input, kernel, None, self.stride, self.padding, self.dilation, self.groups)
     self.reset_parameters()
Example #11
0
 def forward(self, x):
     self._check_drop()
     x = self.norm(x)
     x = self.relu(x)
     if self.dropout_rate > 0:
         x = self.drop(x)
     ### Masked output
     weight = self.conv.weight * self.mask
     return F.conv2d(x, weight, None, self.conv.stride,
                     self.conv.padding, self.conv.dilation, 1)
 def f(input, params, mode):
     x = F.conv2d(input, params['conv0'], padding=1)
     g0 = group(x, params, 'group0', mode, 1)
     g1 = group(g0, params, 'group1', mode, 2)
     g2 = group(g1, params, 'group2', mode, 2)
     o = F.relu(utils.batch_norm(g2, params, 'bn', mode))
     o = F.avg_pool2d(o, 8, 1, 0)
     o = o.view(o.size(0), -1)
     o = F.linear(o, params['fc.weight'], params['fc.bias'])
     return o
Example #13
0
def SSIM(img1, img2):
	(_, channel, _, _) = img1.size()
	window_size = 11
	window = create_window(window_size, channel)
	mu1 = F.conv2d(img1, window, padding = window_size/2, groups = channel)
	mu2 = F.conv2d(img2, window, padding = window_size/2, groups = channel)

	mu1_sq = mu1.pow(2)
	mu2_sq = mu2.pow(2)
	mu1_mu2 = mu1*mu2

	sigma1_sq = F.conv2d(img1*img1, window, padding = window_size/2, groups = channel) - mu1_sq
	sigma2_sq = F.conv2d(img2*img2, window, padding = window_size/2, groups = channel) - mu2_sq
	sigma12 = F.conv2d(img1*img2, window, padding = window_size/2, groups = channel) - mu1_mu2

	C1 = 0.01**2
	C2 = 0.03**2

	ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
	return ssim_map.mean()
    def forward(self, x):
        if self.deterministic:
            assert self.training == False, "Flag deterministic is True. This should not be used in training."
            return F.conv2d(x, self.post_weight_mu, self.bias_mu)
        batch_size = x.size()[0]
        # apply local reparametrisation trick see [1] Eq. (6)
        # to the parametrisation given in [3] Eq. (6)
        mu_activations = F.conv2d(x, self.weight_mu, self.bias_mu, self.stride,
                                  self.padding, self.dilation, self.groups)

        var_activations = F.conv2d(x.pow(2), self.weight_logvar.exp(), self.bias_logvar.exp(), self.stride,
                                   self.padding, self.dilation, self.groups)
        # compute z
        # note that we reparametrise according to [2] Eq. (11) (not [1])
        z = reparametrize(self.z_mu.repeat(batch_size, 1), self.z_logvar.repeat(batch_size, 1),
                          sampling=self.training, cuda=self.cuda)
        z = z[:, :, None, None]

        return reparametrize(mu_activations * z, (var_activations * z.pow(2)).log(), sampling=self.training,
                             cuda=self.cuda)
Example #15
0
def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)
 def f(input, params, pooling_classif=True):
     o = F.conv2d(input, params['conv0.weight'], params['conv0.bias'], 2, 3)
     o = F.relu(o)
     o = F.max_pool2d(o, 3, 2, 1)
     o_g0 = group(o, params, 'group0', 1, blocks[0])
     o_g1 = group(o_g0, params, 'group1', 2, blocks[1])
     o_g2 = group(o_g1, params, 'group2', 2, blocks[2])
     o_g3 = group(o_g2, params, 'group3', 2, blocks[3])
     if pooling_classif:
         o = F.avg_pool2d(o_g3, 7, 1, 0)
         o = o.view(o.size(0), -1)
         o = F.linear(o, params['fc.weight'], params['fc.bias'])
     return o
Example #17
0
    def test_torch_F_conv2d_on_remote_var(self):
        hook = TorchHook(verbose=False)
        me = hook.local_worker
        remote = VirtualWorker(id=2,hook=hook)
        me.add_worker(remote)

        x = Var(torch.FloatTensor([[[[1, -1, 2], [-1, 0, 1], [1, 0, -2]]]]))
        x.send(remote)
        weight = torch.nn.Parameter(torch.FloatTensor([[[[1, -1], [-1, 1]]]]))
        bias = torch.nn.Parameter(torch.FloatTensor([0]))
        weight.send(remote)
        bias.send(remote)
        conv = F.conv2d(x, weight, bias, stride=(1,1))
        conv.get()
        expected_conv = Var(torch.FloatTensor([[[[3, -2], [-2, -3]]]]))
        assert torch.equal(conv, expected_conv)
def conv2d_same_padding(input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):

    input_rows = input.size(2)
    filter_rows = weight.size(2)
    effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1
    out_rows = (input_rows + stride[0] - 1) // stride[0]
    padding_needed = max(0, (out_rows - 1) * stride[0] + effective_filter_size_rows -
                  input_rows)
    padding_rows = max(0, (out_rows - 1) * stride[0] +
                        (filter_rows - 1) * dilation[0] + 1 - input_rows)
    rows_odd = (padding_rows % 2 != 0)
    padding_cols = max(0, (out_rows - 1) * stride[0] +
                        (filter_rows - 1) * dilation[0] + 1 - input_rows)
    cols_odd = (padding_rows % 2 != 0)

    if rows_odd or cols_odd:
        input = pad(input, [0, int(cols_odd), 0, int(rows_odd)])

    return F.conv2d(input, weight, bias, stride,
                  padding=(padding_rows // 2, padding_cols // 2),
                  dilation=dilation, groups=groups)
Example #19
0
	def stage1(self, x0):

		# input {X0}
		# return {X2, X3, X4, .... Xl}

		output = None
		for i in xrange(self.layers):
			if i == 0:
				data = x0
				weight = self.W0[i]
			else:
				data = torch.cat([data, output], dim=1)
			
				weight = torch.cat([self.W0[i]]+[self.W[self.coordinate2idx(j, i)] for j in xrange(i)], dim=1) 

			bias = self.b[i]

			conv = F.conv2d(data, weight, bias, stride=1, padding=self.kernel/2) 
			output = self.activates[i](conv)

		return torch.cat([data[:, (self.nin+self.filters):, :, :], output], dim=1)
def test_forward_computes_forward_pass():
    weight = torch.randn(4, 8, 3, 3).cuda()
    input = torch.randn(4, 8, 4, 4).cuda()

    out = F.conv2d(
        input=Variable(input),
        weight=Parameter(weight),
        bias=None,
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
    ).data

    func = _EfficientConv2d(
        stride=1,
        padding=1,
        dilation=1,
        groups=1,
    )
    out_efficient = func.forward(weight, None, input)

    assert(almost_equal(out, out_efficient))
Example #21
0
	def stage2(self, x):

		# input {X2, X3, ... , Xl}
		# output {X1', X2',..., Xl'}
		output = None
		from_layers = range(1, self.layers) # from layer index

		for i in xrange(self.layers):
			if i == 0:
				data = x
			else:
				data = torch.cat([data[:, self.filters:, :, :], output], dim=1)
			
			weight = torch.cat([self.W[self.coordinate2idx(j, i)] for j in from_layers], dim=1)
			bias = self.b[self.layers+i]
			from_layers = from_layers[1:] + [self.recurrent_index(from_layers[-1]+1)]

			conv = F.conv2d(data, weight, bias, stride=1, padding=self.kernel/2) 
			output = self.activates[self.layers+i](conv)

		s2 = torch.cat([data, output], dim=1) 

		return s2
 def conv2d(input, params, base, stride=1, pad=0):
     return F.conv2d(input, params[base + '.weight'],
                     params[base + '.bias'], stride, pad)
Example #23
0

import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.transforms import ToTensor
rgb = ToTensor()(lena)
rgb = rgb.view(1, rgb.size(0), rgb.size(1), rgb.size(2))
rgb = Variable(rgb)
rgb2ycbcr = Variable(torch.FloatTensor([[0.299, 0.587, 0.114], [-0.169, -0.331, 0.5], [0.5, -0.419, -0.081]]).resize_(3,3,1,1))
print rgb2ycbcr


print "---- rgb -----"
print rgb
ycbcr = F.conv2d(rgb, weight=rgb2ycbcr)

print "first pixel:", rgb.data[0,0,0,0]*255
print lena.getpixel((0,0))

print "---- ycbcr -----"
print ycbcr


print "first pixel:", ycbcr.data[0,0,0,0]*255, (ycbcr.data[0,1,0,0]+0.5)*255
print lena_ycbcr.getpixel((0,0))

ycbcr2rgb = Variable(torch.FloatTensor([[1, -0.00001, 1.402], [1, -0.34413, -0.71414], [1, 1.772, 0.00004]]).resize_(3,3,1,1))
rgb = F.conv2d(ycbcr, weight=ycbcr2rgb)
print "---- rgb -----"
print rgb
Example #24
0
 def forward(self, x):
     x = F.conv2d(x, self.weight2d, self.bias2d, self.stride2d,
                  self.padding2d, self.dilation2d, self.groups)
     return x
Example #25
0
 def forward(self, input):
     return fn.conv2d(input, self.weight, self.bias, self.stride,
                      self.padding, self.dilation, self.groups)
Example #26
0
    def _compute_weight_kxk(self,
                            update=True,
                            n_iterations=None,
                            atol=None,
                            rtol=None):
        n_iterations = self.n_iterations if n_iterations is None else n_iterations
        atol = self.atol if atol is None else atol
        rtol = self.rtol if rtol is None else atol

        if n_iterations is None and (atol is None or rtol is None):
            raise ValueError('Need one of n_iteration or (atol, rtol).')

        if n_iterations is None:
            n_iterations = 20000

        u = self.u
        v = self.v
        weight = self.weight
        c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(
            self.spatial_dims[1].item())
        if update:
            with torch.no_grad():
                itrs_used = 0
                for _ in range(n_iterations):
                    old_u = u.clone()
                    old_v = v.clone()
                    v_s = F.conv_transpose2d(u.view(self.out_shape),
                                             weight,
                                             stride=self.stride,
                                             padding=self.padding,
                                             output_padding=0)
                    v = F.normalize(v_s.view(-1), dim=0, out=v)
                    u_s = F.conv2d(v.view(1, c, h, w),
                                   weight,
                                   stride=self.stride,
                                   padding=self.padding,
                                   bias=None)
                    u = F.normalize(u_s.view(-1), dim=0, out=u)
                    itrs_used = itrs_used + 1
                    if atol is not None and rtol is not None:
                        err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
                        err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
                        tol_u = atol + rtol * torch.max(u)
                        tol_v = atol + rtol * torch.max(v)
                        if err_u < tol_u and err_v < tol_v:
                            break
                if itrs_used > 0:
                    u = u.clone()
                    v = v.clone()

        weight_v = F.conv2d(v.view(1, c, h, w),
                            weight,
                            stride=self.stride,
                            padding=self.padding,
                            bias=None)
        weight_v = weight_v.view(-1)
        sigma = torch.dot(u.view(-1), weight_v)
        with torch.no_grad():
            self.scale.copy_(sigma)
        # soft normalization: only when sigma larger than coeff
        factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
        weight = weight / factor
        return weight
def compute_reblurred_image_and_subsampled_kernel_loss(
        sharp_image,
        kernels,
        masks,
        gt_kernels,
        gt_masks,
        masks_weights,
        kernels_loss_type,
        n_grid_points=128,
        GPU=0,
        manage_saturated_pixels=True):
    n_kernels = kernels.size(1)
    K = kernels.size(2)
    N = sharp_image.size(0)
    C = sharp_image.size(1)
    H = sharp_image.size(2) - K + 1
    W = sharp_image.size(3) - K + 1
    reblurred_images = torch.empty(N, n_kernels, C, H, W).cuda(GPU)

    Kgt = gt_kernels.shape[1]
    Wk = gt_kernels.shape[2]  # kernel side

    ind_x = np.random.permutation(W)[:n_grid_points]
    ind_y = np.random.permutation(H)[:n_grid_points]
    xx, yy = np.meshgrid(ind_x, ind_y)

    kernels_loss = torch.Tensor([0.]).cuda(GPU)
    for n in range(N):
        gt_masks_nn = gt_masks[n, :, xx, yy].view(
            Kgt,
            n_grid_points * n_grid_points)  # *(1/(masks_sums[n][nonzero]))
        gt_kernels_nn = gt_kernels[n].view(Kgt, Wk * Wk)
        gt_kernels_per_pixel = torch.mm(gt_kernels_nn.t(), gt_masks_nn)

        predicted_kernels_per_pixel = torch.mm(
            kernels[n].contiguous().view(n_kernels, Wk * Wk).t(),
            masks[n, :, xx,
                  yy].contiguous().view(n_kernels,
                                        n_grid_points * n_grid_points))

        if kernels_loss_type == 'L2':
            per_pixel_kernel_diff = (predicted_kernels_per_pixel -
                                     gt_kernels_per_pixel)**2
        elif kernels_loss_type == 'L1':
            per_pixel_kernel_diff = (predicted_kernels_per_pixel -
                                     gt_kernels_per_pixel).abs()

        kernels_loss += (per_pixel_kernel_diff.sum(0) *
                         masks_weights[n, xx, yy].view(
                             n_grid_points * n_grid_points)).sum() / N

        for c in range(C):
            conv_output = F.conv2d(sharp_image[n:n + 1, c:c + 1, :, :],
                                   kernels[n][:, np.newaxis, :, :])
            reblurred_images[n:n + 1, :,
                             c, :, :] = conv_output * masks[n:n + 1]

    reblurred_images = torch.sum(reblurred_images, (1))

    if manage_saturated_pixels:
        output_reblurred = apply_saturation_function(reblurred_images)

    return output_reblurred, kernels_loss
Example #28
0
 def forward(self, x):
     kernel = self.get_kernel()
     return F.conv2d(x, kernel)
Example #29
0
    def forward(self,
                input,
                excitation,
                inhibition,
                label=None,
                activ=F.softplus,
                testmode=False,
                reset_hidden=False):
        "Run the dales law circuit." ""

        if inhibition is None:  #  or reset_hidden:
            inhibition = activ(self.inh_init(label))
        if excitation is None:  #  or reset_hidden:
            excitation = activ(self.exc_init(label))

        if self.use_attention:
            input_state_cur = torch.cat([input, excitation], dim=1)
            att_gate = self.a_wu_gate(
                input_state_cur)  # Attention Spotlight -- MOST RECENT WORKING
            att_gate = torch.sigmoid(att_gate)

        # Gate E/I with attention immediately
        if self.use_attention:
            gated_input = input  # * att_gate  # In activ range
            gated_excitation = att_gate * excitation
            gated_inhibition = inhibition  # att_gate * inhibition
        else:
            gated_input = input

        # Compute inhibition
        inh_intx = activ(
            F.conv2d(self.bn[0](gated_excitation),
                     self.w_inh,
                     padding=self.h_padding))  # in activ range
        inhibition_hat = activ(input - inh_intx *
                               (self.alpha * gated_inhibition + self.mu))

        # Integrate inhibition
        inh_gate = torch.sigmoid(
            self.i_w_gate(torch.cat([gated_input, gated_inhibition], dim=1)))
        inhibition = (
            1 - inh_gate
        ) * inhibition + inh_gate * inhibition_hat  # In activ range

        # Pass to excitatory neurons
        exc_gate = torch.sigmoid(
            self.e_w_gate(
                torch.cat([gated_excitation, inhibition * att_gate],
                          dim=1)))  # used to be gated_inhibition
        exc_intx = activ(
            F.conv2d(self.bn[1](inhibition),
                     self.w_exc,
                     padding=self.h_padding))  # In activ range
        excitation_hat = activ(
            exc_intx *
            (self.kappa * inhibition +
             self.gamma))  # Skip connection OR add OR add by self-sim
        excitation = (1 - exc_gate) * excitation + exc_gate * excitation_hat
        if testmode:
            return excitation, inhibition, att_gate
        else:
            return excitation, inhibition
Example #30
0
 def conv2d_zeros_pad(self, x: Tensor, weight: Tensor, bias: Tensor):
     out = conv2d(x, weight, bias, self.stride, self.padding, self.dilation,
                  self.groups)
     return out
Example #31
0
    def forward(self, xset):
        # X_h, X_l = x
        yset = []
        ysets = []
        for j in range(self.outbranch):
            ysets.append([])

        if isinstance(xset, torch.Tensor):
            xset = [
                xset,
            ]

        for i in range(self.inbranch):
            if xset[i] is None:
                continue
            if self.stride == 2:
                x = F.avg_pool2d(xset[i], (2, 2), stride=2)
            else:
                x = xset[i]
            begin_x = int(
                round(self.in_channels * self.alpha_in[i] / self.groups))
            end_x = int(
                round(self.in_channels * self.alpha_in[i + 1] / self.groups))
            if begin_x == end_x:
                continue
            for j in range(self.outbranch):
                begin_y = int(round(self.out_channels * self.alpha_out[j]))
                end_y = int(round(self.out_channels * self.alpha_out[j + 1]))
                if begin_y == end_y:
                    continue
                scale_factor = 2**(i - j)
                if self.bias is not None:
                    this_bias = self.bias[begin_y:end_y]
                else:
                    this_bias = None

                this_weight = self.weight[begin_y:end_y, begin_x:end_x, :, :]

                if scale_factor > 1:
                    y = F.conv2d(x, this_weight, this_bias, 1, self.padding,
                                 self.dilation, self.groups)
                    y = F.interpolate(y,
                                      scale_factor=scale_factor,
                                      mode=up_kwargs['mode'])
                elif scale_factor < 1:
                    x_resize = F.max_pool2d(x,
                                            int(round(1.0 / scale_factor)),
                                            stride=int(
                                                round(1.0 / scale_factor)))
                    y = F.conv2d(x_resize, this_weight, this_bias, 1,
                                 self.padding, self.dilation, self.groups)
                else:
                    y = F.conv2d(x, this_weight, this_bias, 1, self.padding,
                                 self.dilation, self.groups)
                ysets[j].append(y)

        for j in range(self.outbranch):
            if len(ysets[j]) != 0:
                yset.append(sum(ysets[j]))
            else:
                yset.append(None)
        del ysets
        return yset
Example #32
0
    def forward(self, x, weights=None, distilled_params=None, condition=None):
        """Compute the output :math:`y` of this network given the input
        :math:`x`.

        Args:
            (....): See docstring of method
                :meth:`mnets.mnet_interface.MainNetInterface.forward`. We
                provide some more specific information below.
            weights (list or dict): See argument ``weights`` of method
                :meth:`mnets.mlp.MLP.forward`.
            condition (int, optional): If provided, then this argument will be
                passed as argument ``ckpt_id`` to the method
                :meth:`utils.context_mod_layer.ContextModLayer.forward`.

        Returns:
            (torch.Tensor): The output of the network.
        """
        if ((not self._use_context_mod and self._no_weights) or \
                (self._no_weights or self._context_mod_no_weights)) and \
                weights is None:
            raise Exception('Network was generated without weights. ' +
                            'Hence, "weights" option may not be None.')

        ############################################
        ### Extract which weights should be used ###
        ############################################
        # FIXME Code copied from MLP its `forward` method.
        # I.e., are we using internally maintained weights or externally given
        # ones or are we even mixing between these groups.
        n_cm = self._num_context_mod_shapes()

        if weights is None:
            weights = self.weights

            if self._use_context_mod:
                cm_weights = weights[:n_cm]
                int_weights = weights[n_cm:]
            else:
                cm_weights = None
                int_weights = weights
        else:
            int_weights = None
            cm_weights = None

            if isinstance(weights, dict):
                assert('internal_weights' in weights.keys() or \
                       'mod_weights' in weights.keys())
                if 'internal_weights' in weights.keys():
                    int_weights = weights['internal_weights']
                if 'mod_weights' in weights.keys():
                    cm_weights = weights['mod_weights']
            else:
                if self._use_context_mod and \
                        len(weights) == n_cm:
                    cm_weights = weights
                else:
                    assert len(weights) == len(self.param_shapes)
                    if self._use_context_mod:
                        cm_weights = weights[:n_cm]
                        int_weights = weights[n_cm:]
                    else:
                        int_weights = weights

            if self._use_context_mod and cm_weights is None:
                if self._context_mod_no_weights:
                    raise Exception(
                        'Network was generated without weights ' +
                        'for context-mod layers. Hence, they must be passed ' +
                        'via the "weights" option.')
                cm_weights = self.weights[:n_cm]
            if int_weights is None:
                if self._no_weights:
                    raise Exception(
                        'Network was generated without internal ' +
                        'weights. Hence, they must be passed via the ' +
                        '"weights" option.')
                if self._context_mod_no_weights:
                    int_weights = self.weights
                else:
                    int_weights = self.weights[n_cm:]

            # Note, context-mod weights might have different shapes, as they
            # may be parametrized on a per-sample basis.
            if self._use_context_mod:
                assert len(cm_weights) == n_cm
            int_shapes = self.param_shapes[n_cm:]
            assert len(int_weights) == len(int_shapes)
            for i, s in enumerate(int_shapes):
                assert np.all(np.equal(s, list(int_weights[i].shape)))

        cm_ind = 0
        # Split context-mod weights per context-mod layer.
        if cm_weights is not None:
            cm_weights_layer = []
            cm_start = 0
            for cm_layer in self.context_mod_layers:
                cm_end = cm_start + len(cm_layer.param_shapes)
                cm_weights_layer.append(cm_weights[cm_start:cm_end])
                cm_start = cm_end

        #######################
        ### Parse condition ###
        #######################

        cmod_cond = None

        if condition is not None:
            assert isinstance(condition, int)
            cmod_cond = condition

            # FIXME We always require context-mod weight above, but
            # we can't pass both (a condition and weights) to the
            # context-mod layers.
            # An unelegant solution would be, to just set all
            # context-mod weights to None.
            raise NotImplementedError('CM-conditions not implemented!')
            cm_weights_layer = [None] * len(cm_weights_layer)

        ###########################
        ### Forward Computation ###
        ###########################
        ### Helper function to handle context-mod and non-linearities.
        def modulate_layer(h):
            """Compute context-modulation and non-linearity.

            The order if the following:

            context-mod (if pre-activation) -> non-linearity ->
            context-mod (if post-activation)

            This method increments the index ``cm_ind``.

            Args:
                h: Input activity.

            Returns:
                Output of layer.
            """
            nonlocal cm_ind

            # Context-dependent modulation (pre-activation).
            if self._use_context_mod and \
                    not self._context_mod_post_activation:
                h = self._context_mod_layers[cm_ind].forward(
                    h, weights=cm_weights_layer[cm_ind], ckpt_id=cmod_cond)
                cm_ind += 1

            # Non-linearity
            h = F.relu(h)

            # Context-dependent modulation (post-activation).
            if self._use_context_mod and self._context_mod_post_activation:
                h = self._context_mod_layers[cm_ind].forward(
                    h, weights=cm_weights_layer[cm_ind], ckpt_id=cmod_cond)
                cm_ind += 1

            return h

        x = x.view(-1, *self._in_shape)
        x = x.permute(0, 3, 1, 2)
        h = x

        # Context-dependent modulation of inputs directly.
        if self._use_context_mod and self._context_mod_inputs:
            h = self._context_mod_layers[cm_ind].forward(
                h, weights=cm_weights_layer[cm_ind], ckpt_id=cmod_cond)
            cm_ind += 1

        h = F.conv2d(h, int_weights[0], bias=int_weights[1])
        if self._dropout_rate != -1:
            h = self._drop_conv1(h)
        h = F.max_pool2d(h, 2)
        h = modulate_layer(h)

        h = F.conv2d(h, int_weights[2], bias=int_weights[3])
        if self._dropout_rate != -1:
            h = self._drop_conv2(h)
        h = F.max_pool2d(h, 2)
        h = modulate_layer(h)

        h = h.reshape(-1, int_weights[4].size()[1])

        h = F.linear(h, int_weights[4], bias=int_weights[5])
        h = modulate_layer(h)
        # FIXME Before we applied context-modulation after dropout, since
        # dropout was before the non-linearity and not after.
        if self._dropout_rate != -1:
            h = self._drop_fc1(h)

        h = F.linear(h, int_weights[6], bias=int_weights[7])

        # Context-dependent modulation in output layer.
        if self._use_context_mod and not self._no_last_layer_context_mod:
            h = self._context_mod_layers[cm_ind].forward(
                h, weights=cm_weights_layer[cm_ind], ckpt_id=cmod_cond)

        return h
    def forward(self,
                X,
                sample=False,
                act_drop=False,
                first_layer=False,
                first_sample=False,
                given_alpha=0,
                given_beta=0):
        #         print(self.training)

        drop_rate_alpha = 0
        drop_rate_beta = 0

        if not self.training and not sample:  # When training return MLE of w for quick validation
            # output = torch.mm(X, self.W_mu) + self.b_mu.expand(X.size()[0], self.n_out)
            output = F.conv2d(X,
                              self.W_mu,
                              bias=self.b_mu,
                              padding=self.padding)
            return output, 0, 0

        elif first_sample:
            output = F.conv2d(X,
                              self.W_mu,
                              bias=self.b_mu,
                              padding=self.padding)
            self.input_first = X
            # 小さすぎるものはのぞいておく
            self.output_first = torch.where(
                torch.abs(output) > 1e-6, output,
                torch.zeros(output.size()).to(device='cuda'))
            return output, 0, 0

        else:

            # Tensor.new()  Constructs a new tensor of the same data type as self tensor.
            # the same random sample is used for every element in the minibatch
            eps_W = Variable(self.W_mu.data.new(self.W_mu.size()).normal_())
            eps_b = Variable(self.b_mu.data.new(self.b_mu.size()).normal_())

            # sample parameters
            std_w = 1e-6 + F.softplus(self.W_p, beta=1, threshold=20)
            std_b = 1e-6 + F.softplus(self.b_p, beta=1, threshold=20)

            W = self.W_mu + 1 * std_w * eps_W
            b = self.b_mu + 1 * std_b * eps_b

            if not (act_drop):
                output = F.conv2d(X.to(device='cuda'),
                                  W,
                                  bias=b,
                                  padding=self.padding)
            else:
                if (first_layer):
                    # first layerではalphaの値が変わるので
                    beta = 0

                    X_new = torch.where(
                        torch.abs(X) < beta,
                        torch.zeros(X.size()).to(device='cuda'), X)
                    output2 = F.conv2d(X_new,
                                       1 * std_w * eps_W,
                                       padding=self.padding)
                    output = self.output_first + output2

                else:
                    #平均を使うsampling手法
                    alpha = given_alpha
                    beta = given_beta
                    self.x_for_save = X
                    cond_num = torch.where(
                        torch.abs(X) < 1e-6,
                        torch.ones(X.size()).to(device='cuda'),
                        torch.zeros(X.size()).to(device='cuda'))
                    drop_rate = torch.sum(cond_num) / (
                        X.shape[0] * X.shape[1] * X.shape[2] * X.shape[3])
                    # print(drop_rate)
                    X_diff = X - self.input_first
                    self.diff = torch.abs(X_diff)
                    # X_diff = torch.where(torch.logical_and(torch.abs(X_diff)<beta, X<alpha),torch.zeros(X_diff.size()).to(device='cuda'), X_diff)
                    X_diff = torch.where(
                        torch.abs(X_diff) < alpha,
                        torch.zeros(X_diff.size()).to(device='cuda'), X_diff)
                    output1 = F.conv2d(X_diff, self.W_mu, padding=self.padding)

                    X_new = torch.where(
                        X < beta,
                        torch.zeros(X.size()).to(device='cuda'), X)
                    output2 = F.conv2d(X_new,
                                       1 * std_w * eps_W,
                                       padding=self.padding)
                    output = self.output_first + output1 + output2

                    self.x1_for_save = X_diff
                    self.x2_for_save = X_new

                    # # print how many samples are skipped
                    drop_rate_alpha = torch.sum(X < alpha) / (
                        X.shape[0] * X.shape[1] * X.shape[2] * X.shape[3])
                    drop_rate_beta = torch.sum(torch.abs(X_diff) < beta) / (
                        X.shape[0] * X.shape[1] * X.shape[2] * X.shape[3])
                    # print(drop_rate_alpha,drop_rate_beta)

            return output, drop_rate_alpha, drop_rate_beta
Example #34
0
    def forward(self, x, vars=None, bn_training=True):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """
        if vars is None:
            vars = self.vars
        idx = 0
        bn_idx = 0

        for name, param in self.config:
            if name is 'conv2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x,
                                       w,
                                       b,
                                       stride=param[4],
                                       padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
                # print('forward:', idx, x.norm().item())
            elif name is 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]
                x = F.batch_norm(x,
                                 running_mean,
                                 running_var,
                                 weight=w,
                                 bias=b,
                                 training=bn_training)
                idx += 2
                bn_idx += 2
            elif name is 'flatten':
                # print(x.shape)
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])
            else:
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)
        return x
def compute_reblurred_image_and_kernel_loss(sharp_image,
                                            kernels,
                                            masks,
                                            gt_kernels,
                                            gt_masks,
                                            masks_weights,
                                            kernels_loss_type,
                                            GPU=0,
                                            manage_saturated_pixels=True,
                                            pairwise_matrix=None):

    n_kernels = kernels.size(1)
    K = kernels.size(2)
    N = sharp_image.size(0)
    C = sharp_image.size(1)
    H = sharp_image.size(2) - K + 1
    W = sharp_image.size(3) - K + 1
    reblurred_images = torch.empty(N, n_kernels, C, H, W).cuda(GPU)

    Kgt = gt_kernels.shape[1]
    Wk = gt_kernels.shape[2]  # kernel side

    kernels_loss = torch.Tensor([0.]).cuda(GPU)
    mse_loss = torch.nn.MSELoss(reduction='none')
    l1_loss = torch.nn.L1Loss(reduction='none')
    huber_loss = torch.nn.SmoothL1Loss(reduction='none')
    kl_loss = torch.nn.KLDivLoss(reduction='none')
    for n in range(N):
        gt_masks_nn = gt_masks[n].view(Kgt,
                                       H * W)  # *(1/(masks_sums[n][nonzero]))
        gt_kernels_nn = gt_kernels[n].view(Kgt, Wk * Wk)
        gt_kernels_per_pixel = torch.mm(gt_kernels_nn.t(), gt_masks_nn)
        masks_weights_nn = masks_weights[n].view(H * W)

        predicted_kernels_per_pixel = torch.mm(
            kernels[n].contiguous().view(n_kernels, Wk * Wk).t(),
            masks[n].contiguous().view(n_kernels, H * W))

        if kernels_loss_type == 'L2':
            #per_pixel_kernel_diff = (predicted_kernels_per_pixel - gt_kernels_per_pixel)**2
            per_pixel_kernel_diff = mse_loss(predicted_kernels_per_pixel,
                                             gt_kernels_per_pixel)
            kernels_loss += (per_pixel_kernel_diff.sum(0) *
                             masks_weights[n].view(H * W)).sum() / N
        elif kernels_loss_type == 'L1':
            #per_pixel_kernel_diff = (predicted_kernels_per_pixel - gt_kernels_per_pixel).abs()
            per_pixel_kernel_diff = l1_loss(predicted_kernels_per_pixel,
                                            gt_kernels_per_pixel)
            kernels_loss += (per_pixel_kernel_diff.sum(0) *
                             masks_weights[n].view(H * W)).sum() / N
        elif kernels_loss_type == 'Huber':
            per_pixel_kernel_diff = huber_loss(predicted_kernels_per_pixel,
                                               gt_kernels_per_pixel)
            kernels_loss += (per_pixel_kernel_diff.sum(0) *
                             masks_weights[n].view(H * W)).sum() / N
        elif kernels_loss_type == 'KL':
            per_pixel_kernel_diff = kl_loss(
                torch.log(predicted_kernels_per_pixel), gt_kernels_per_pixel)
            kernels_loss += (per_pixel_kernel_diff.sum(0) *
                             masks_weights[n].view(H * W)).sum() / N
        elif kernels_loss_type == 'IND':
            # for i in range(H*W): # para cada pixel de la image
            #     a = gt_kernels_per_pixel[:,i:i+1]  #k_gt
            #     b = predicted_kernels_per_pixel[:,i:i+1]  #k_p
            #     kernels_loss += torch.squeeze(torch.mm(a.t(), torch.mm(pairwise_matrix, b)))/(H*W*N)
            start_IND = time.time()
            divs = 8
            len = (H * W) // divs
            for i in range(divs):  # para cada fila de la imagen
                a = gt_kernels_per_pixel[:, i * len:len * (i + 1)]  #k_gt
                b = predicted_kernels_per_pixel[:, i * len:len * (i + 1)]  #k_p
                #diff=torch.abs(a-b)
                w = masks_weights_nn[i * len:len * (i + 1)]
                distances = torch.diagonal(
                    torch.mm(a.t(), torch.mm(pairwise_matrix, b)))
                #print(distances.shape, distances.min(), distances.max())
                kernels_loss += torch.mean(w * distances) / divs
            stop_IND = time.time()
            print('IND finished in %f seconds' % (stop_IND - start_IND))
            # a = gt_kernels_per_pixel  #k_gt
            # b = predicted_kernels_per_pixel  #k_p
            # kernels_loss += torch.mean(torch.mm(a.t(), torch.mm(pairwise_matrix, b)))

        for c in range(C):
            conv_output = F.conv2d(sharp_image[n:n + 1, c:c + 1, :, :],
                                   kernels[n][:, np.newaxis, :, :])
            reblurred_images[n:n + 1, :,
                             c, :, :] = conv_output * masks[n:n + 1]

    reblurred_images = torch.sum(reblurred_images, (1))

    if manage_saturated_pixels:
        output_reblurred = apply_saturation_function(reblurred_images)

    return output_reblurred, kernels_loss
Example #36
0
 def forward(self, input):
     weight = self.compute_weight()
     return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1,
                     1)
    def forward(ctx, input, kernel, kernel_flip):
        ctx.save_for_backward(kernel, kernel_flip)

        output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])

        return output
    def forward(self, f, b, mask=None):
        """ Contextual attention layer implementation.
        Contextual attention is first introduced in publication:
            Generative Image Inpainting with Contextual Attention, Yu et al.
        Args:
            f: Input feature to match (foreground).
            b: Input feature for match (background).
            mask: Input mask for b, indicating patches not available.
            ksize: Kernel size for contextual attention.
            stride: Stride for extracting patches from b.
            rate: Dilation for matching.
            softmax_scale: Scaled softmax for attention.
        Returns:
            torch.tensor: output
        """
        # get shapes
        raw_int_fs = list(f.size())  # b*c*h*w
        raw_int_bs = list(b.size())  # b*c*h*w

        # extract patches from background with stride and rate
        kernel = 2 * self.rate
        # raw_w is extracted for reconstruction
        raw_w = extract_image_patches(b,
                                      ksizes=[kernel, kernel],
                                      strides=[self.rate,
                                               self.rate])  # b*hw*c*k*k
        raw_w_groups = torch.split(raw_w, 1, dim=0)

        # downscaling foreground option: downscaling both foreground and
        # background for matching and use original background for reconstruction.
        f = F.interpolate(f, scale_factor=1 / self.rate, mode='nearest')
        b = F.interpolate(b, scale_factor=1 / self.rate, mode='nearest')
        int_fs = list(f.size())  # b*c*h*w
        int_bs = list(b.size())
        f_groups = torch.split(
            f, 1, dim=0)  # split tensors along the batch dimension

        w = extract_image_patches(b,
                                  ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride,
                                           self.stride])  # b*hw*c*k*k
        w_groups = torch.split(w, 1, dim=0)

        # process mask
        if mask is None:
            mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]])
            if self.use_cuda:
                mask = mask.cuda()
        else:
            mask = F.interpolate(mask,
                                 scale_factor=1. / (4. * self.rate),
                                 mode='nearest')
        m_groups = extract_image_patches(mask,
                                         ksizes=[self.ksize, self.ksize],
                                         strides=[self.stride,
                                                  self.stride])  # b*hw*c*k*k

        # m = m[0]  # hw*c*k*k
        # m = reduce_mean(m, axis=[1, 2, 3])  # hw*1*1*1
        # m = m.permute(1, 0, 2, 3).contiguous()  # 1*hw*1*1
        # mm = (m==0).to(torch.float32)   # 1*hw*1*1

        y = []
        offsets = []
        k = self.fuse_k
        scale = self.softmax_scale * 255  # to fit the PyTorch tensor image value range
        fuse_weight = torch.eye(k).view(1, 1, k, k)  # 1*1*k*k
        if self.use_cuda:
            fuse_weight = fuse_weight.cuda()

        for xi, wi, raw_wi, mi in zip(f_groups, w_groups, raw_w_groups,
                                      m_groups):
            '''
            O => output channel as a conv filter
            I => input channel as a conv filter
            xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
            wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
            raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
            '''
            # conv for compare
            escape_NaN = torch.FloatTensor([1e-4])
            if self.use_cuda:
                escape_NaN = escape_NaN.cuda()
            wi = wi[0]  # hw*c*k*k
            wi_normed = wi / torch.max(
                torch.sqrt(reduce_sum(torch.pow(wi, 2), axis=[1, 2, 3])),
                escape_NaN)
            xi_normed = same_padding(xi, [self.ksize, self.ksize],
                                     [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(xi_normed, wi_normed, stride=1)  # 1*hw*H*W

            # conv implementation for fuse scores to encourage large patches
            if self.fuse:
                # make all of depth to spatial resolution
                yi = yi.view(1, 1, int_bs[2] * int_bs[3], int_fs[2] *
                             int_fs[3])  # (B=1, I=1, H=32*32, W=32*32)
                yi = same_padding(yi, [k, k], [1, 1])
                yi = F.conv2d(yi, fuse_weight,
                              stride=1)  # (B=1, C=1, H=32*32, W=32*32)

                yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2],
                                          int_fs[3])  # (B=1, 32, 32, 32, 32)
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(1, 1, int_bs[2] * int_bs[3],
                                          int_fs[2] * int_fs[3])
                yi = same_padding(yi, [k, k], [1, 1])
                yi = F.conv2d(yi, fuse_weight, stride=1)
                yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3],
                                          int_fs[2])
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(
                    1, int_bs[2] * int_bs[3], int_fs[2],
                    int_fs[3])  # (B=1, C=32*32, H=32, W=32)

            # mi: hw*c*k*k
            mi = reduce_mean(mi, axis=[1, 2, 3])  # hw*1*1*1
            mi = mi.permute(1, 0, 2, 3).contiguous()  # 1*hw*1*1
            mm = (mi == 0).to(torch.float32)  # 1*hw*1*1

            # softmax to match
            yi = yi * mm
            yi = F.softmax(yi * scale, dim=1)
            yi = yi * mm  # 1*hw*H*W

            offset = torch.argmax(yi, dim=1, keepdim=True)  # 1*1*H*W
            if int_bs != int_fs:
                # Normalize the offset value to match foreground dimension
                times = float(int_fs[2] * int_fs[3]) / float(
                    int_bs[2] * int_bs[3])
                offset = ((offset + 1).float() * times - 1).to(torch.int64)
            offset = torch.cat([offset // int_fs[3], offset % int_fs[3]],
                               dim=1)  # 1*2*H*W

            # deconv for patch pasting
            wi_center = raw_wi[0]
            yi = F.conv_transpose2d(yi, wi_center, stride=self.rate,
                                    padding=1) / 4.  # (B=1, C=128, H=64, W=64)
            y.append(yi)
            offsets.append(offset)

        y = torch.cat(y, dim=0)  # back to the mini-batch
        y.contiguous().view(raw_int_fs)

        offsets = torch.cat(offsets, dim=0)
        offsets = offsets.view(int_fs[0], 2, *int_fs[2:])

        # case1: visualize optical flow: minus current position
        h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand(
            int_fs[0], -1, -1, int_fs[3])
        w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand(
            int_fs[0], -1, int_fs[2], -1)
        ref_coordinate = torch.cat([h_add, w_add], dim=1)  # b*2*H*W
        if self.use_cuda:
            ref_coordinate = ref_coordinate.cuda()

        offsets = offsets - ref_coordinate
        # flow = pt_flow_to_image(offsets)

        flow = torch.from_numpy(
            flow_to_image(offsets.permute(0, 2, 3,
                                          1).cpu().data.numpy())) / 255.
        flow = flow.permute(0, 3, 1, 2)
        if self.use_cuda:
            flow = flow.cuda()
        # case2: visualize which pixels are attended
        # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy()))

        if self.rate != 1:
            flow = F.interpolate(flow,
                                 scale_factor=self.rate * 4,
                                 mode='nearest')

        return y, flow
 def forward(self, input):
   weight = self.weight.expand(input.size(1), 1, 3, 3).contiguous()
   return F.conv2d(input, weight, groups=input.size(1), padding=1)
 def forward(self, input_):
     if self.input_shape is None:
         self.input_shape = input_.size()
     output = F.conv2d(input_, self.weight, self.bias, self.stride,
                       self.padding, self.dilation, self.groups)
     return output
Example #41
0
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    return F.conv2d(input, weight.cuda(), bias.cuda(), stride, padding, dilation, groups)
Example #42
0
    def forward(self, input):
        # forward function is called when you pass data (input) into the already instantiated class
        # Args: input - 4D spike wave tensor that was input to the Excitatory neurons.
        #               Its dimensions are (time,in_channels,height,width).
        #               Height and width are nothing but Receptive Field's height and width.
        #
        # Returns: out - output potential tensor corresponding to this layer's excitatory neuron potentials after convolving the
        #                synaptic weights with the input spike wave tensor (step-no-leak response).
        #                It should be a 4D tensor with dimensions (time, out_channels, rows, cols).

        # Since we don't use weight sharing, you have to work around the usual striding convolution essentially by manually
        # taking kernel_size patches from input and convolving it with the same size kernel. So, typically, you have to manually
        # stride across the input to create multiple columns.

        ### *** WRITE YOUR CONVOLUTION FUNCTION HERE *** ###

        # Need to convolve with separate weights for each receptive field

        # Get size of time dimension from input:
        time_duration = input.size()[0]

        # Create output tensor
        output = torch.zeros(
            (time_duration, self.out_channels, self.rows, self.cols),
            dtype=torch.int)

        # Process patches of input in chunks of receptive field size, convolving and saving the result
        # into the output tensor.
        for neural_row in range(self.rows):  # iterating along height in RFs
            for neural_col in range(self.cols):  # iterating along width in RFs

                # Get weights for this receptive field:
                weights_patch = torch.squeeze(self.weight[
                    neural_row,
                    neural_col, :, :, :, :]).float()  # 1 x 1 x 16 x 2 x 5 x 5

                # Get the start and end dims of input receptive field (need to take stride into account)
                st_row_in = neural_row * self.stride
                st_col_in = neural_col * self.stride
                nd_row_in = st_row_in + self.kernel_size[0]  # + kernel height
                nd_col_in = st_col_in + self.kernel_size[0]  # + kernel width

                # Given the row and col of the RF we want, we take all time steps and in channels, but
                # only a slice of the input field (height and width)
                input_slice = (input[:, :, st_row_in:nd_row_in,
                                     st_col_in:nd_col_in]
                               ).float()  # 8 x 2 x 5 x 5, for example.

                # Convolve; the output row and column correspond exactly to the output field row and col indices:
                output_slice = fn.conv2d(input_slice,
                                         weights_patch,
                                         bias=self.bias,
                                         stride=self.stride,
                                         padding=self.padding,
                                         dilation=self.dilation,
                                         groups=self.groups)
                sq_output_slice = torch.squeeze(
                    output_slice).int()  # Squeeze size 1 dimensions

                # Assign to output
                output[:, :, neural_row,
                       neural_col] = sq_output_slice  # Output slice should have size 8x16x1x1, as an example and assuming 1 receptive field.

        return output
Example #43
0
    def step(
        self, actions: torch.Tensor
    ) -> (torch.Tensor, torch.Tensor, torch.Tensor, dict):
        if actions.dtype not in (torch.short, torch.int, torch.long):
            raise TypeError(
                'actions Tensor must be an integer type i.e. '
                '{torch.ShortTensor, torch.IntTensor, torch.LongTensor}')

        if actions.shape[0] != self.num_envs:
            raise RuntimeError(
                'Must have the same number of actions as environments.')

        reward = torch.zeros(
            (self.num_envs, )).float().to(self.device).requires_grad_(False)
        done = torch.zeros(
            (self.num_envs, )).byte().to(self.device).requires_grad_(False)
        info = dict()

        t0 = time()
        # Create head position deltas
        head_deltas = F.conv2d(head(self.envs),
                               ORIENTATION_FILTERS.to(self.device),
                               padding=1)
        # Select the head position delta corresponding to the correct action
        actions_onehot = torch.FloatTensor(self.num_envs, 4).to(self.device)
        actions_onehot.zero_()
        actions_onehot.scatter_(1, actions.unsqueeze(-1), 1)
        head_deltas = torch.einsum('bchw,bc->bhw',
                                   [head_deltas, actions_onehot]).unsqueeze(1)

        # Move head position by applying delta
        self.envs[:, HEAD_CHANNEL:HEAD_CHANNEL +
                  1, :, :].add_(head_deltas).round_()
        if self.verbose:
            print(f'Head movement: {time() - t0}s')

        ################
        # Apply update #
        ################

        t0 = time()
        # Remove food and give reward
        # `food_removal` is 0 except where a snake head is at the same location as food where it is -1
        food_removal = head(self.envs) * food(self.envs) * -1
        reward.sub_(food_removal.view(self.num_envs, -1).sum(dim=-1).float())
        self.envs[:, FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_removal
        if self.verbose:
            print(f'Food removal: {time() - t0}s')

        # Add new food if necessary.
        if food_removal.sum() < 0:
            t0 = time()
            food_addition_env_indices = (food_removal * -1).view(
                self.num_envs, -1).sum(dim=-1).byte()
            add_food_envs = self.envs[food_addition_env_indices, :, :, :]
            food_addition = self._get_food_addition(add_food_envs)
            self.envs[food_addition_env_indices,
                      FOOD_CHANNEL:FOOD_CHANNEL + 1, :, :] += food_addition
            if self.verbose:
                print(
                    f'Food addition ({food_addition_env_indices.sum().item()} envs): {time() - t0}s'
                )

        t0 = time()
        # Check for boundary, Done by performing a convolution with no padding
        # If the head is at the edge then it will be cut off and the sum of the head
        # channel will be 0
        edge_collision = F.conv2d(
            head(self.envs),
            NO_CHANGE_FILTER.to(self.device),
        ).view(self.num_envs, -1).sum(dim=-1) < EPS
        done = done | edge_collision
        info.update({'edge_collision': edge_collision})
        if self.verbose:
            print(
                f'Edge collision ({edge_collision.sum().item()} envs): {time() - t0}s'
            )

        # Apply rounding to stop numerical errors accumulating
        self.envs.round_()

        self.done = done

        return self._observe(self.observation_mode), reward.unsqueeze(
            -1), done.unsqueeze(-1), info
Example #44
0
def aten_convolution(inputs, attributes, scope):
    inp, weight, bias = inputs[:3]
    stride, pad, dilation = inputs[3:6]
    transposed, output_padding, groups = inputs[6:9]
    ctx = current_context()
    net = ctx.network
    if transposed:
        I, O_groups, *ksize = weight.shape
        O = O_groups * groups
    else:
        O, I_groups, *ksize = weight.shape
        I = I_groups * groups
    if ctx.is_tensorrt and has_trt_tensor(inputs):
        assert all([e == 0 for e in output_padding
                    ]), "tensor rt don't support out padding"
        ndim = len(ksize)
        if ndim == 1:
            print(
                "WARNING: consider write conv2d because trt don't support conv2d, we need to change input shape (and output shape) and may cause error in following layers."
            )
            ksize = [ksize[0], 1]
            stride = [stride[0], 1]
            pad = [pad[0], 0]
            dilation = [dilation[0], 1]
        if len(inputs[0].shape) == 2:
            inputs[0] = _trt_reshape(net, inputs[0], [*inputs[0].shape, 1],
                                     scope + "/conv1d_reshape")

        assert ndim <= 2, "tensorrt only support 1d/2d conv"
        # trt weight format: GKCRS: [num_groups, O_groups, I, H, W]
        weight = weight.detach().cpu().numpy()
        if bias is not None:
            trt_bias = bias.detach().cpu().numpy()
        else:
            trt_bias = trt.Weights()
        if transposed:
            layer = net.add_deconvolution(inputs[0], O, tuple(ksize), weight,
                                          trt_bias)
        else:
            layer = net.add_convolution(inputs[0], O, tuple(ksize), weight,
                                        trt_bias)
            layer.dilation = tuple(dilation)
        layer.stride = tuple(stride)
        layer.padding = tuple(pad)
        layer.num_groups = groups
        output = layer.get_output(0)
        output.name = scope
        layer.name = scope
        ctx.refit_weight_dict[layer.name] = {
            "type": "Convolution",
            "weight": inputs[1].__torch2trt_weight_name,
        }
        if bias is not None:
            ctx.refit_weight_dict[
                layer.name]["bias"] = bias.__torch2trt_weight_name
        return [output]
    elif ctx.is_tvm and has_tvm_tensor(inputs):
        weight = weight.detach().cpu().numpy()
        weight_t = _expr.var(scope + "/weight",
                             shape=weight.shape,
                             dtype="float32")
        ctx.tvm_weight_dict[weight_t] = weight
        ctx.refit_weight_dict[
            weight_t.name_hint] = inputs[1].__torch2trt_weight_name
        if bias is not None:
            bias = bias.detach().cpu().numpy()
            bias_t = _expr.var(scope + "/bias",
                               shape=bias.shape,
                               dtype="float32")
            ctx.tvm_weight_dict[bias_t] = bias
            ctx.refit_weight_dict[
                bias_t.name_hint] = bias.__torch2trt_weight_name
        new_attrs = {}
        new_attrs["channels"] = O
        new_attrs["kernel_size"] = ksize
        new_attrs["strides"] = stride
        new_attrs["padding"] = pad
        new_attrs["dilation"] = dilation
        new_attrs["groups"] = groups
        new_attrs["data_layout"] = "NCHW"
        new_attrs["kernel_layout"] = "OIHW"
        use_bias = bias is not None
        if transposed:
            new_attrs["output_padding"] = output_padding
            res = _op.nn.conv2d_transpose(inputs[0], weight_t, **new_attrs)
        else:
            res = _op.nn.conv2d(inputs[0], weight_t, **new_attrs)
        if use_bias:
            res = _op.nn.bias_add(res, bias_t, axis=1)
        return [res]
    ndim = len(inputs[3])
    assert ndim <= 2
    if ndim == 1:
        if transposed:
            res = F.conv_transpose1d(inp, weight, bias, stride, pad,
                                     output_padding, groups, dilation)
        else:
            res = F.conv1d(inp, weight, bias, stride, pad, dilation, groups)
    else:
        if transposed:
            res = F.conv_transpose2d(inp, weight, bias, stride, pad,
                                     output_padding, groups, dilation)
        else:
            res = F.conv2d(inp, weight, bias, stride, pad, dilation, groups)

    return [res]
Example #45
0
 def forward(self, input):
     mask, penalty = self._get_mask()
     conv = F.conv2d(input, self._origin.weight * mask, self._origin.bias, stride=self._origin.stride,
                     padding=self._origin.padding, dilation=self._origin.dilation, groups=self._origin.groups)
     return conv, penalty
    def forward(self, foreground, mask, background="same"):
        ###assume the masked area has value 1
        bz, nc, w, h = foreground.size()
        if background == "same":
            background = foreground.clone()
        mask = F.interpolate(mask, size=(h, w), mode='nearest')
        background = background * (1 - mask)
        foreground = self.feature_attention(foreground, background, mask)
        background = F.pad(background, [
            self.patch_size // 2, self.patch_size // 2, self.patch_size // 2,
            self.patch_size // 2
        ])
        conv_kernels_all = background.unfold(
            2, self.patch_size,
            self.stride).unfold(3, self.patch_size,
                                self.stride).contiguous().view(
                                    bz, nc, -1, self.patch_size,
                                    self.patch_size)

        mask_resized = mask.repeat(1, self.in_dim, 1, 1)
        mask_resized = F.pad(mask_resized, [
            self.patch_size // 2, self.patch_size // 2, self.patch_size // 2,
            self.patch_size // 2
        ])
        mask_kernels_all = mask_resized.unfold(
            2, self.patch_size,
            self.stride).unfold(3, self.patch_size,
                                self.stride).contiguous().view(
                                    bz, nc, -1, self.patch_size,
                                    self.patch_size)
        conv_kernels_all = conv_kernels_all.transpose(2, 1)
        mask_kernels_all = mask_kernels_all.transpose(2, 1)
        output_tensor = []
        for i in range(bz):
            feature_map = foreground[i:i + 1]

            # form convolutional kernels
            conv_kernels = conv_kernels_all[i] + 0.0000001
            mask_kernels = mask_kernels_all[i]
            conv_kernels = self.patch_attention(conv_kernels, conv_kernels,
                                                mask_kernels)
            norm_factor = torch.sum(conv_kernels**2, [1, 2, 3],
                                    keepdim=True)**0.5
            conv_kernels = conv_kernels / norm_factor

            conv_result = F.conv2d(feature_map,
                                   conv_kernels,
                                   padding=self.patch_size // 2)
            #             print(conv_result.shape)
            if self.propagate_size != 1:
                if self.prop_kernels is None:
                    self.prop_kernels = torch.ones([
                        conv_result.size(1), 1, self.propagate_size,
                        self.propagate_size
                    ])
                    self.prop_kernels.requires_grad = False
                    self.prop_kernels = self.prop_kernels.cuda()
                conv_result = F.conv2d(conv_result,
                                       self.prop_kernels,
                                       stride=1,
                                       padding=1,
                                       groups=conv_result.size(1))
            mm = (torch.mean(mask_kernels_all[i], dim=[1, 2, 3],
                             keepdim=True) == 0.0).to(torch.float32)
            mm = mm.permute(1, 0, 2, 3).cuda()
            conv_result = conv_result * mm
            attention_scores = F.softmax(conv_result, dim=1)
            attention_scores = attention_scores * mm

            ##propagate the scores
            recovered_foreground = F.conv_transpose2d(
                attention_scores,
                conv_kernels,
                stride=1,
                padding=self.patch_size // 2)
            output_tensor.append(recovered_foreground)
        return torch.cat(output_tensor, dim=0)
Example #47
0
 def forward(self, x):
     return F.conv2d(x, self.weight, None, self.stride, self.padding)
Example #48
0
def snip_forward_conv2d(self, x):
    return F.conv2d(x, self.weight * self.weight_mask, self.bias, self.stride,
                    self.padding, self.dilation, self.groups)
    def fea_gen_forward(self, input, batch_size, warmup_t, pred_t, len_wind):
        '''Generate future feature from residual
        Annotation:
            1. Extract Intermediate Layer Features;
            2. Find Feature Diff;
            3. Loop for Generate New Feature Diff and form New feature:
                3.1 This time, try generating feature diff with only one input;
                3.2 Assume take in 14 frames, 13 fea diff, then use the first fea diff to generate the rest;
                3.3 Warm_up and pred_time is going to be different for different datasets; Need to be adaptive;
                3.4 Fix warm_up_t as 1 and pred_t as 12 now, for JHMDDB dataset;
                3.5 Generate motion vectors for next time feature maps, for now it is [3, 3] that works for all channels;
            4. Return both Generated and Origin Feature, to calculate the loss;

        @param: input: video input with shape [batch, num_seg, 3, 224, 224]
        @param: batch_size: batch
        @param: warmup_t: warm up time-step observation to compute residual for later prediction, >= 2;
        @param: pred-t: number of time-step to predict, >=1;
        @param: len_wind: length of sliding window for motion vector inference; 
        
        @return: 
            1. list of generated feature;
            2. list of origin feature;
            3. list of generated residual;
            4. list of origin residual
            5. list of generated diff grad;
            6. list of org diff grad;
        '''

        sample_len = (3 if self.modality == "RGB" else 2) * self.new_length
        input = input.view((-1, sample_len) + input.size()[-2:])
        ''' ops 1 '''
        org_fea = self.base_model.extract_feature(input)

        if self.modality == 'Flow':
            org_fea = list(
                torch.unbind((org_fea.view(batch_size, self.new_length * 2, -1,
                                           28, 28)), 1))
        else:
            org_fea = list(
                torch.unbind(
                    (org_fea.view(batch_size, self.num_segments, -1, 28, 28)),
                    1))
        ''' ops 2 '''
        fea_diff = []
        for x, y in zip(org_fea[1:], org_fea[:-1]):
            fea_diff.append(x - y)
        ''' ops 3 '''
        gen_fea = org_fea[:
                          warmup_t]  # start from first two, going to be extended every time step;
        warm_diff = fea_diff[:warmup_t -
                             1]  # start from first one, going to be extended every time step;

        for i in range(pred_t):

            # generate kernels for motion_vector
            k_3x3, k_5x5, k_7x7 = self.kernelCNN(
                torch.cat(warm_diff[-len_wind:], 1))

            # normalize tensor
            norm = k_3x3.norm(2)
            k_3x3 = k_3x3.div(norm)
            norm = k_5x5.norm(2)
            k_5x5 = k_5x5.div(norm)
            norm = k_7x7.norm(2)
            k_7x7 = k_7x7.div(norm)

            # k_3x3 = F.normalize(k_3x3, p=2, dim=1)
            # k_5x5 = F.normalize(k_5x5, p=2, dim=1)
            # k_7x7 = F.normalize(k_7x7, p=2, dim=1)
            # print(torch.max(k_3x3))

            # set weight value to temporaladaptive kernel
            # setattr(self.TemporalAdaptiveCNN, 'conv2d_3x3.weight', torch.unsqueeze(k_3x3, dim=1))
            # setattr(self.TemporalAdaptiveCNN, 'conv2d_5x5.weight', torch.unsqueeze(k_5x5, dim=1))
            # setattr(self.TemporalAdaptiveCNN, 'conv2d_7x7.weight', torch.unsqueeze(k_7x7, dim=1))
            '''schedule sampling '''
            if self.coin_flip:
                flip_coin = ((pred_t - i) / pred_t) * 0.4
                rand_p = random.random()

                if rand_p > flip_coin:
                    # new_diff = self.res_gen(fea_diff[i + warmup_t - 2])
                    x = F.conv2d(torch.transpose(fea_diff[i + warmup_t - 2], 1,
                                                 0),
                                 torch.unsqueeze(k_3x3, dim=1),
                                 stride=1,
                                 padding=1,
                                 groups=batch_size)
                    y = F.conv2d(torch.transpose(fea_diff[i + warmup_t - 2], 1,
                                                 0),
                                 torch.unsqueeze(k_5x5, dim=1),
                                 stride=1,
                                 padding=2,
                                 groups=batch_size)
                    z = F.conv2d(torch.transpose(fea_diff[i + warmup_t - 2], 1,
                                                 0),
                                 torch.unsqueeze(k_7x7, dim=1),
                                 stride=1,
                                 padding=3,
                                 groups=batch_size)
                else:
                    x = F.conv2d(torch.transpose(warm_diff[-1], 1, 0),
                                 torch.unsqueeze(k_3x3, dim=1),
                                 stride=1,
                                 padding=1,
                                 groups=batch_size)
                    y = F.conv2d(torch.transpose(warm_diff[-1], 1, 0),
                                 torch.unsqueeze(k_5x5, dim=1),
                                 stride=1,
                                 padding=2,
                                 groups=batch_size)
                    z = F.conv2d(torch.transpose(warm_diff[-1], 1, 0),
                                 torch.unsqueeze(k_7x7, dim=1),
                                 stride=1,
                                 padding=3,
                                 groups=batch_size)
            else:
                # set weight value to temporal-adaptive kernel
                x = F.conv2d(torch.transpose(warm_diff[-1], 1, 0),
                             torch.unsqueeze(k_3x3, dim=1),
                             stride=1,
                             padding=1,
                             groups=batch_size)
                y = F.conv2d(torch.transpose(warm_diff[-1], 1, 0),
                             torch.unsqueeze(k_5x5, dim=1),
                             stride=1,
                             padding=2,
                             groups=batch_size)
                z = F.conv2d(torch.transpose(warm_diff[-1], 1, 0),
                             torch.unsqueeze(k_7x7, dim=1),
                             stride=1,
                             padding=3,
                             groups=batch_size)

            new_diff = (x + y + z) / 3.0

            # new_diff = self.TemporalAdaptiveCNN(torch.transpose(warm_diff[-1], 1, 0))
            new_diff = torch.transpose(new_diff, 1, 0)
            new_fea = gen_fea[-1] + new_diff

            # Update
            gen_fea.append(new_fea)
            warm_diff.append(new_diff)

        gen_dif_grad = [self.diff_grad(x) for x in warm_diff]
        org_dif_grad = [self.diff_grad(x) for x in fea_diff]
        ''' ops 4 '''
        return gen_fea, org_fea, warm_diff, fea_diff, gen_dif_grad, org_dif_grad
Example #50
0
    def forward(self, x, visualize=False):

        N, C, H, W = x.shape
        out_h = (H + 2 * self.padding[0] - self.kernel_size[0] +
                 1) // self.stride[0]
        out_w = (W + 2 * self.padding[0] - self.kernel_size[0] +
                 1) // self.stride[1]

        if self.mode == 1:
            x = self.channel_deconv(x)

        if self.mode == 3:
            x = self.channel_deconv(x)
            return F.conv2d(x, self.weight, self.bias, self.stride,
                            self.padding, self.dilation, 1)

        if self.mode != 3:
            #1. im2col, reshape

            # N * cols * pixels
            inp_unf = torch.nn.functional.unfold(x, self.kernel_size,
                                                 self.dilation, self.padding,
                                                 self.stride)

            #(k*k, C*N*H*W) for pixel deconv
            #(k*k*G, C//G*N*H*W) for grouped pixel deconv
            X = inp_unf.permute(1, 0,
                                2).contiguous().view(self.num_features, -1)

            #2.subtract mean
            X_mean = X.mean(-1, keepdim=True)

            #track stats for evaluation
            if self.num_batches_tracked == 0:
                self.running_mean.copy_(X_mean.detach())
            if self.training:
                self.running_mean.mul_(1 - self.momentum)
                self.running_mean.add_(X_mean.detach() * self.momentum)
            else:
                X_mean = self.running_mean

            X = X - X_mean

            #3. calculate COV, COV^(-0.5), then deconv
            if self.training:
                Cov = X / X.shape[1] @ X.t() + self.eps * torch.eye(
                    X.shape[0], dtype=X.dtype, device=X.device)
                deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)

            #track stats for evaluation
            if self.num_batches_tracked == 0:
                #self.running_cov.copy_(Cov.detach())
                self.running_deconv.copy_(deconv.detach())
            if self.training:
                #self.running_cov.mul_(1-self.momentum)
                #self.running_cov.add_(Cov.detach()*self.momentum)
                self.running_deconv.mul_(1 - self.momentum)
                self.running_deconv.add_(deconv.detach() * self.momentum)
            else:
                #Cov = self.running_cov
                deconv = self.running_deconv

            #deconv
            X_deconv = deconv @ X

            #reshape
            X_deconv = X_deconv.view(-1, N,
                                     out_h * out_w).contiguous().permute(
                                         1, 2, 0)

            #4. convolve

            if visualize:
                w = torch.zeros(self.weight.shape[1],
                                self.weight.shape[1],
                                self.weight.shape[2],
                                self.weight.shape[3],
                                dtype=x.dtype,
                                device=x.device)
                c = self.weight.shape[1]
                w[torch.arange(c).long(),
                  torch.arange(c).long(), self.weight.shape[2] // 2,
                  self.weight.shape[3] // 2] = 1.
                out_unf = X_deconv.matmul(w.view(w.size(0), -1).t()).transpose(
                    1, 2).view(N, -1, out_h, out_w)
                return out_unf

            w = self.weight
            out_unf = X_deconv.matmul(w.view(w.size(0), -1).t()).transpose(
                1, 2).view(N, -1, out_h, out_w)
            if self.bias is not None:
                out_unf = out_unf + self.bias.view(1, -1, 1, 1)

            if self.training:
                self.num_batches_tracked.add_(1)

            return out_unf  #.contiguous()
            """
Example #51
0
 def _conv_forward(self, input, weight):
     return F.conv2d(
         _pad_symmetric(input,
                        (self.pw, self.pw, self.ph, self.ph)), weight,
         self.bias, self.stride, _pair(0), self.dilation, self.groups)
Example #52
0
 def _apply_sobel(self, channel):
     g_x = F.conv2d(channel, self.kernel_g_x, stride=1, padding=1)
     g_y = F.conv2d(channel, self.kernel_g_y, stride=1, padding=1)
     return torch.sqrt(torch.pow(g_x, 2) + torch.pow(g_y, 2))
 def forward(self, input):
     if self.binary:
         input = input.sign()
     return F.conv2d(input, self.weight, self.bias, self.stride,
                     self.padding, self.dilation, self.groups)
Example #54
0
def conv2d_with_reflection_pad(x, window):
    x = reflection_pad(x, window_size=window.size(-1))
    x = F.conv2d(x, window, padding=0, groups=x.size(1))

    return x
 def forward(self, x):
     x = self.static_padding(x)
     x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
                  self.dilation, self.groups)
     return x
Example #56
0
def ssim(img1,
         img2,
         window_size=11,
         window=None,
         size_average=True,
         full=False,
         val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range

    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
    # mu1 = F.conv2d(F.pad(img1, (0, 0, 5, 5), 'replicate'), window, padding=padd, groups=channel)
    # mu2 = F.conv2d(F.pad(img2, (0, 0, 5, 5), 'replicate'), window, padding=padd, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd,
                         groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd,
                         groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd,
                       groups=channel) - mu1_mu2

    # sigma1_sq = F.conv2d(F.pad(img1 * img1, (0, 0, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
    # sigma2_sq = F.conv2d(F.pad(img2 * img2, (0, 0, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
    # sigma12 = F.conv2d(F.pad(img1 * img2, (0, 0, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2

    C1 = (0.01 * L)**2
    C2 = (0.03 * L)**2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = torch.mean(v1 / v2)  # contrast sensitivity

    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    if size_average:
        ret = ssim_map.mean()
    else:
        ret = ssim_map.mean(1).mean(1).mean(1)

    if full:
        return ret, cs
    return ret
    def blur_frame(self,frame):


        aaa = 1

        if aaa ==0:

            # print('THISSSS')

            if torch.max(frame) > 1.:
                print ('DDDDDDDD')
                print (torch.max(frame).data.cpu().numpy())
                fasdf

            K = 21 #11 #21
            padding = 10# 5
            filter_weights = torch.ones(1,1,K,K).cuda()

            filter_weights = filter_weights / K**2
            # print (torch.sum(filter_weights1).data.cpu().numpy())

            # filter_weights2 = filter_weights / torch.sum(filter_weights)
            # print (torch.sum(filter_weights2).data.cpu().numpy())


            # fdsa
            # # print (torch.sum(filter_weights, dim=1).data.cpu().numpy())
            # print (torch.sum(filter_weights[0][0]).data.cpu().numpy())
            # print (torch.sum(filter_weights[1][0]).data.cpu().numpy())
            # print (torch.sum(filter_weights[2][0]).data.cpu().numpy())
            # fsfas

            frame_c0 = frame[:,0].unsqueeze(1)
            # print (frame_c0.shape)
            frame_c0 = F.conv2d(input=frame_c0, weight=filter_weights, bias=None, padding=padding, stride=1, dilation=1)
            # print (frame_c0.size())
            # print ('Output: [B,outC,outH,outW]')
            # print ()

            # print (torch.max(frame_c0).data.cpu().numpy())

            frame_c1 = frame[:,1].unsqueeze(1)
            frame_c1 = F.conv2d(input=frame_c1, weight=filter_weights, bias=None, padding=padding, stride=1, dilation=1)

            # print (torch.max(frame_c1).data.cpu().numpy())


            frame_c2 = frame[:,2].unsqueeze(1)
            frame_c2 = F.conv2d(input=frame_c2, weight=filter_weights, bias=None, padding=padding, stride=1, dilation=1)

            # print (torch.max(frame_c2).data.cpu().numpy())
            # fdsfa

            blurred_image = [frame_c0, frame_c1, frame_c2]
            blurred_image = torch.stack(blurred_image, dim=1)

            # print (blurred_image.shape)

            blurred_image = blurred_image.squeeze(dim=2)  #[B,3,480,640]

            # blurred_image = blurred_image / torch.max(blurred_image)
            blurred_image = torch.clamp(blurred_image, max=1.0)

            # print (torch.max(blurred_image).data.cpu().numpy())
            # fas

        else:
            # print('THAT')
            # blurred_image = torch.zeros(frame.size()[0],3,480,640).cuda()
            # blurred_image = F.avg_pool2d(frame, kernel_size=21, stride=None, padding=0, ceil_mode=False, count_include_pad=True)
            blurred_image = F.avg_pool2d(frame, kernel_size=100, stride=None, padding=0, ceil_mode=False, count_include_pad=True)
            # print(blurred_image.shape)
            blurred_image = F.upsample(input=blurred_image, size=(480,640), align_corners=False, mode='bilinear')

        return blurred_image