Exemple #1
0
def covariance(x, y):
    x_centered = x - reduce_mean(x, axis=(1, 2, 3), keepdim=True)
    y_centered = y - reduce_mean(y, axis=(1, 2, 3), keepdim=True)

    x_centered = torch.flatten(x_centered, start_dim=1)
    y_centered = torch.flatten(y_centered, start_dim=1)
    cov = torch.mean(x_centered.mm(torch.transpose(y_centered, -1, -2)), dim=1)
    return cov
Exemple #2
0
 def update(self, embeddings, labels):
     if not self.training:
         pass
     centers = self.centers
     # calculate gradient (in the case of L2 loss) and decay it by momentum factor
     residual = torch.sub(embeddings,
                          centers[labels]).mul(self.update_factor)
     # get classes whose centers should be updated
     labels_unique, labels_count = labels.unique(return_counts=True)
     # preprocess current centers for the subsequent averaging
     mul_inplace(centers, labels_unique, labels_count[:, None])
     # add gradient to centers of corresponding class
     centers.index_add_(0, labels.long(), residual)
     # average over the number of samples in each updated class
     div_inplace(centers, labels_unique, labels_count[:, None])
     # synchronize across all ranks in case of distributed training
     self.centers = reduce_mean(centers)
Exemple #3
0
def train(model, epoch, dataloader, optimizer):
    model.train()
    model.to(device)

    if epoch > 2000:
        for g in optimizer.param_groups:
            g['lr'] = 1e-5

    running_loss = 0.0
    count = 0
    for (in_img, gt_img, train_ids, ratios) in dataloader:
        # Twice as big because model upsamples.
        gt_img = gt_img.view(
            (BATCH_SIZE, 3, ps * 2, ps * 2)).to(device).cpu().float()
        in_img = in_img.view((BATCH_SIZE, 4, ps, ps)).to(device).float()

        # Zero gradients
        optimizer.zero_grad()

        # Get model outputs
        out_img = model(in_img)

        # Calculate loss
        loss = reduce_mean(out_img, gt_img)
        running_loss += loss.item()

        # Compute gradients and take step
        loss.backward()
        optimizer.step()

        if count % save_freq == 0 and count == 0:
            save_current_model(model, epoch, out_img[0], gt_img[0],
                               train_ids[0].item(), ratios[0].item())
            count += 1

    return running_loss
Exemple #4
0
    def forward(self, f, b, mask=None):
        """ Contextual attention layer implementation.
            Contextual attention is first introduced in publication:
            Generative Image Inpainting with Contextual Attention, Yu et al.
        Args:
            f: Input feature to match (foreground).
            b: Input feature for match (background).
            mask: Input mask for b, indicating patches not available.
            ksize: Kernel size for contextual attention.
            stride: Stride for extracting patches from b.
            rate: Dilation for matching.
            softmax_scale: Scaled softmax for attention.
        Returns:
            torch.tensor: output
        """
        # get shapes
        raw_int_fs = list(f.size())   # b*c*h*w
        raw_int_bs = list(b.size())   # b*c*h*w

        # extract patches from background with stride and rate
        kernel = 2 * self.rate
        # raw_w is extracted for reconstruction
        raw_w = utils.extract_image_patches(b, ksizes=[kernel, kernel],
                                      strides=[self.rate*self.stride,
                                               self.rate*self.stride],
                                      rates=[1, 1],
                                      padding='same') # [N, C*k*k, L]
        # raw_shape: [N, C, k, k, L] [4, 192, 4, 4, 1024]
        raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
        raw_w = raw_w.permute(0, 4, 1, 2, 3)    # raw_shape: [N, L, C, k, k]
        raw_w_groups = torch.split(raw_w, 1, dim=0)

        # downscaling foreground option: downscaling both foreground and
        # background for matching and use original background for reconstruction.
        f = F.interpolate(f, scale_factor=1./self.rate, mode='nearest')
        b = F.interpolate(b, scale_factor=1./self.rate, mode='nearest')
        int_fs = list(f.size())     # b*c*h*w
        int_bs = list(b.size())
        f_groups = torch.split(f, 1, dim=0)  # split tensors along the batch dimension
        # w shape: [N, C*k*k, L]
        w = utils.extract_image_patches(b, ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride, self.stride],
                                  rates=[1, 1],
                                  padding='same')
        # w shape: [N, C, k, k, L]
        w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
        w = w.permute(0, 4, 1, 2, 3)    # w shape: [N, L, C, k, k]
        w_groups = torch.split(w, 1, dim=0)

        # process mask
        mask = F.interpolate(mask, scale_factor=1./self.rate, mode='nearest')
        int_ms = list(mask.size())
        # m shape: [N, C*k*k, L]
        m = utils.extract_image_patches(mask, ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride, self.stride],
                                  rates=[1, 1],
                                  padding='same')

        # m shape: [N, C, k, k, L]
        m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
        m = m.permute(0, 4, 1, 2, 3)    # m shape: [N, L, C, k, k]
        m = m[0]    # m shape: [L, C, k, k]
        # mm shape: [L, 1, 1, 1]
        mm = (utils.reduce_mean(m, axis=[1, 2, 3], keepdim=True)==0.).to(torch.float32)
        mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]

        y = []
        offsets = []
        k = self.fuse_k
        scale = self.softmax_scale    # to fit the PyTorch tensor image value range
        fuse_weight = torch.eye(k).view(1, 1, k, k)  # 1*1*k*k
        if self.use_cuda:
            fuse_weight = fuse_weight.cuda()

        for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
            '''
            O => output channel as a conv filter
            I => input channel as a conv filter
            xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
            wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
            raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
            '''
            # conv for compare
            escape_NaN = torch.FloatTensor([1e-4])
            if self.use_cuda:
                escape_NaN = escape_NaN.cuda()
            wi = wi[0]  # [L, C, k, k]
            max_wi = torch.sqrt(utils.reduce_sum(torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True))
            wi_normed = wi / max_wi
            # xi shape: [1, C, H, W], yi shape: [1, L, H, W]
            xi = utils.same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(xi, wi_normed, stride=1)   # [1, L, H, W]
            # conv implementation for fuse scores to encourage large patches
            if self.fuse:
                # make all of depth to spatial resolution
                yi = yi.view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])  # (B=1, I=1, H=32*32, W=32*32)
                yi = utils.same_padding(yi, [k, k], [1, 1], [1, 1])
                yi = F.conv2d(yi, fuse_weight, stride=1)  # (B=1, C=1, H=32*32, W=32*32)
                yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2], int_fs[3])  # (B=1, 32, 32, 32, 32)
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])
                yi = utils.same_padding(yi, [k, k], [1, 1], [1, 1])
                yi = F.conv2d(yi, fuse_weight, stride=1)
                yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3], int_fs[2])
                yi = yi.permute(0, 2, 1, 4, 3).contiguous()
            yi = yi.view(1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3])  # (B=1, C=32*32, H=32, W=32)
            # softmax to match
            yi = yi * mm
            yi = F.softmax(yi*scale, dim=1)
            yi = yi * mm  # [1, L, H, W]

            offset = torch.argmax(yi, dim=1, keepdim=True)  # 1*1*H*W

            if int_bs != int_fs:
                # Normalize the offset value to match foreground dimension
                times = float(int_fs[2] * int_fs[3]) / float(int_bs[2] * int_bs[3])
                offset = ((offset + 1).float() * times - 1).to(torch.int64)
            offset = torch.cat([offset//int_fs[3], offset%int_fs[3]], dim=1)  # 1*2*H*W

            # deconv for patch pasting
            wi_center = raw_wi[0]
            # yi = F.pad(yi, [0, 1, 0, 1])    # here may need conv_transpose same padding
            yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4.  # (B=1, C=128, H=64, W=64)
            y.append(yi)
            offsets.append(offset)

        y = torch.cat(y, dim=0)  # back to the mini-batch
        y.contiguous().view(raw_int_fs)

        return y
    def forward(self, b, f, mask=None):
        """
        :param b: Input feature for match (background) - known region.
        :param f: Input feature to match (foreground) - missing region.
        :param mask: Input mask for b, indicating patches not available.
        :return:
        """

        # get shapes
        f_shape_raw = list(f.size())  # batch_size * c * h * w
        b_shape_raw = list(b.size())  # batch_size * c * h * w

        kernel_size = 2 * self.rate

        # extract patches from background with stride, padding and dilation
        # raw_w is extracted for reconstruction
        raw_w = self.extract_patches(b,
                                     kernel_size,
                                     self.rate * self.stride,
                                     self.dilation,
                                     padding='valid')  # [batch_size, C*k*k, L]

        raw_w = raw_w.view(b_shape_raw[0], b_shape_raw[1], kernel_size,
                           kernel_size, -1)
        raw_w = raw_w.permute(0, 4, 1, 2,
                              3)  # b_weights shape: [batch_size, L, C, k, k]

        # tuple of tensors with size [L, C, k, k] with len = batch_size
        raw_w_groups = torch.split(raw_w, 1, dim=0)

        # downscaling foreground option: downscaling both foreground and
        # background for matching and use original background for reconstruction.
        f = F.interpolate(f, scale_factor=1. / self.rate, mode='nearest')
        b = F.interpolate(b, scale_factor=1. / self.rate, mode='nearest')

        f_shape = list(f.size())  # b*c*h*w
        b_shape = list(b.size())

        # split tensors along the batch dimension
        # tuple of tensors with size [C, h, w] with len = batch_size
        f_groups = torch.split(
            f, 1, dim=0)  # split tensors along the batch dimension

        # w shape: [N, C*k*k, L]
        w = self.extract_patches(b, self.ksize, self.stride, 1, padding='same')

        # w shape: [N, C, k, k, L]
        w = w.view(b_shape[0], b_shape[1], self.ksize, self.ksize, -1)
        w = w.permute(0, 4, 1, 2, 3)  # w shape: [N, L, C, k, k]
        w_groups = torch.split(w, 1, dim=0)

        if mask is None:
            mask = torch.zeros(f_shape[0], 1, f_shape[2], f_shape[3])
            if self.device is not None:
                mask = mask.to(self.device)
        else:
            mask_scale = mask.size()[3] // f_shape[3]

            # downscale to match f shape
            mask = F.interpolate(mask,
                                 scale_factor=1 / mask_scale,
                                 mode='nearest')
            # mask = F.avg_pool2d(mask, kernel_size=4, stride=mask_scale)

        m_shape = list(mask.size())  # c * h * w
        m = self.extract_patches(mask,
                                 self.ksize,
                                 self.stride,
                                 1,
                                 padding='same')  # [batch_size, k*k, L]

        m = m.view(m_shape[0], m_shape[1], self.ksize, self.ksize,
                   -1)  # [batch_size, 1, k, k, L]
        m = m.permute(0, 4, 1, 2, 3)  # m shape: [batch_size, L, C, k, k]
        # m = m[0]  # m shape: [L, C, k, k]

        # 0 for patches where all values are 0
        # 1 for patches with non-zero mean
        # mm shape: [batch_size, L, 1, 1, 1]

        mm = (reduce_mean(m, axis=[2, 3, 4],
                          keepdim=True) == 1.).to(torch.float32)
        # mm shape: [batch_size, 1, L, 1, 1]
        mm = mm.permute(0, 2, 1, 3, 4)

        y = []
        offsets = []
        k = self.fuse_k
        scale = self.softmax_scale  # to fit the PyTorch tensor image value range
        # Diagonal matrix with shape k * k
        fuse_weight = torch.eye(k).view(1, 1, k, k)  # 1*1*k*k
        if self.device:
            fuse_weight = fuse_weight.to(self.device)
        EPS = torch.FloatTensor([1e-4]).to(self.device)
        for xi, wi, raw_wi, mi in zip(f_groups, w_groups, raw_w_groups, mm):
            """
            O => output channel as a conv filter
            I => input channel as a conv filter
            xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
            wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
            raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
            """
            # Normalizing weight tensor

            wi = wi.squeeze(0)
            wi_norm = torch.sqrt(
                reduce_sum(torch.pow(wi, 2) + EPS,
                           axis=[1, 2, 3],
                           keepdim=True))
            wi_normed = wi / wi_norm

            # xi shape: [1, C, H, W], yi shape: [1, L, H, W]
            xi_pad = same_padding(xi.shape[0], xi.shape[1],
                                  [self.ksize, self.ksize], [1, 1], [1, 1])
            yi = F.conv2d(xi, wi_normed, stride=1,
                          padding=xi_pad)  # [1, L, H, W]

            # conv implementation for fuse scores to encourage large patches
            if self.fuse:
                # make all of depth to spatial resolution
                # Convolution with diagonal shaped kernel №1
                yi = yi.view(1, 1, b_shape[2] * b_shape[3], f_shape[2] *
                             f_shape[3])  # (B=1, I=1, H=32*32, W=32*32)
                yi_pad = same_padding(yi.shape[0], yi.shape[1], [k, k], [1, 1],
                                      [1, 1])
                yi = F.conv2d(yi, fuse_weight, stride=1,
                              padding=yi_pad)  # (B=1, C=1, H=32*32, W=32*32)

                # Convolution with diagonal shaped kernel №2
                yi = yi.contiguous().view(1, b_shape[2], b_shape[3],
                                          f_shape[2],
                                          f_shape[3])  # (B=1, 32, 32, 32, 32)
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(1, 1, b_shape[2] * b_shape[3],
                                          f_shape[2] * f_shape[3])
                yi_pad = same_padding(yi.shape[0], yi.shape[1], [k, k], [1, 1],
                                      [1, 1])
                yi = F.conv2d(yi, fuse_weight, stride=1, padding=yi_pad)

                yi = yi.contiguous().view(1, b_shape[3], b_shape[2],
                                          f_shape[3], f_shape[2])
                yi = yi.permute(0, 2, 1, 4, 3).contiguous()

            yi = yi.view(1, b_shape[2] * b_shape[3], f_shape[2],
                         f_shape[3])  # (B=1, C=32*32, H=32, W=32)
            # softmax to match
            yi = yi * mi
            yi = F.softmax(yi * scale, dim=1)
            yi = yi * mi  # [1, L, H, W]
            offset = torch.argmax(yi, dim=1, keepdim=True)  # 1*1*H*W
            if b_shape != f_shape:
                # Normalize the offset value to match foreground dimension
                times = float(f_shape[2] * f_shape[3]) / float(
                    b_shape[2] * b_shape[3])
                offset = ((offset + 1).float() * times - 1).to(torch.int64)
            offset = torch.cat([offset // f_shape[3], offset % b_shape[3]],
                               dim=1)  # 1*2*H*W
            # deconv for patch pasting
            wi_center = raw_wi[0]

            # yi = F.pad(yi, [0, 1, 0, 1])    # here may need conv_transpose same padding
            yi = F.conv_transpose2d(yi, wi_center, stride=self.rate,
                                    padding=1) / 4.  # (B=1, C=128, H=64, W=64)
            y.append(yi)
            offsets.append(offset)

        y = torch.cat(y, dim=0)  # back to the mini-batch
        y.contiguous().view(f_shape_raw)

        offsets = torch.cat(offsets, dim=0)
        offsets = offsets.view(f_shape[0], 2, *f_shape[2:])

        # case1: visualize optical flow: minus current position
        h_add = torch.arange(f_shape[2]).view([1, 1, f_shape[2], 1]).expand(
            f_shape[0], -1, -1, f_shape[3])
        w_add = torch.arange(f_shape[3]).view([1, 1, 1, f_shape[3]]).expand(
            f_shape[0], -1, f_shape[2], -1)
        ref_coordinate = torch.cat([h_add, w_add], dim=1)
        ref_coordinate = ref_coordinate.to(self.device)

        offsets = offsets - ref_coordinate
        flow = torch.from_numpy(
            self.flow_to_image(offsets.permute(0, 2, 3,
                                               1).cpu().data.numpy())) / 255.
        flow = flow.permute(0, 3, 1, 2)
        flow = flow.to(self.device)

        if self.rate != 1:
            flow = F.interpolate(flow,
                                 scale_factor=self.rate * 4,
                                 mode='nearest')

        return y, flow
Exemple #6
0
    def forward(self, cls_preds, reg_preds, cen_preds, gt_bboxes, gt_labels,
                fin_img_shape):
        batch_size = cls_preds[0].size(0)
        scales = [cls_pred.shape[-2:] for cls_pred in cls_preds]
        num_levels = len(scales)

        cls_targets, reg_targets, anchors, batch_valid_flags, num_anchors_per_level = self.compute_targets(
            scales, gt_bboxes, gt_labels, fin_img_shape)

        anchors = [anchors for _ in range(batch_size)]
        anchors = [anchor.split(num_anchors_per_level) for anchor in anchors]
        cls_targets = [
            cls_target.split(num_anchors_per_level)
            for cls_target in cls_targets
        ]
        reg_targets = [
            reg_target.split(num_anchors_per_level)
            for reg_target in reg_targets
        ]
        batch_valid_flags = [
            batch_valid_flag.split(num_anchors_per_level)
            for batch_valid_flag in batch_valid_flags
        ]

        anchors_level_first = []
        cls_targets_level_first = []
        reg_targets_level_first = []
        valid_flags_level_first = []
        for i in range(num_levels):
            anchors_level_first.append(
                torch.cat([anchor[i] for anchor in anchors]))
            cls_targets_level_first.append(
                torch.cat([cls_target[i] for cls_target in cls_targets]))
            reg_targets_level_first.append(
                torch.cat([reg_target[i] for reg_target in reg_targets]))
            valid_flags_level_first.append(
                torch.cat([
                    batch_valid_flag[i]
                    for batch_valid_flag in batch_valid_flags
                ]))

        anchors_level_first = torch.cat(anchors_level_first)
        cls_targets_level_first = torch.cat(cls_targets_level_first)
        reg_targets_level_first = torch.cat(reg_targets_level_first)
        valid_flags_level_first = torch.cat(valid_flags_level_first)

        cls_preds = [
            cls_pred.permute(0, 2, 3, 1).reshape(-1, cfg.num_classes)
            for cls_pred in cls_preds
        ]
        reg_preds = [
            reg_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for reg_pred in reg_preds
        ]
        cen_preds = [
            cen_pred.permute(0, 2, 3, 1).reshape(-1) for cen_pred in cen_preds
        ]

        cls_preds_all = torch.cat(cls_preds)
        reg_preds_all = torch.cat(reg_preds)
        cen_preds_all = torch.cat(cen_preds)

        pos_inds = torch.nonzero((cls_targets_level_first >= 0) &
                                 (cls_targets_level_first != cfg.num_classes),
                                 as_tuple=False).reshape(-1)
        num_pos = utils.reduce_mean(
            torch.tensor(pos_inds.size(0)).float().cuda()).item()
        cls_loss = self.cls_loss_func(cls_preds_all, cls_targets_level_first,
                                      valid_flags_level_first) / num_pos

        anchors_pos = anchors_level_first[pos_inds]
        reg_preds_pos = reg_preds_all[pos_inds]
        cen_preds_pos = cen_preds_all[pos_inds]
        reg_targets_pos = reg_targets_level_first[pos_inds]

        if (num_pos > 0):
            bboxes_predict = utils.reg_decode(anchors_pos, reg_preds_pos)
            bboxes_target = utils.reg_decode(anchors_pos, reg_targets_pos)
            cen_targets_pos = self.compute_centerness_targets(
                anchors_pos, bboxes_target)

            sum_cen = utils.reduce_mean(cen_targets_pos.sum()).item()

            reg_loss = self.reg_loss_func(bboxes_predict, bboxes_target,
                                          cen_targets_pos) / sum_cen
            cen_loss = self.cen_loss_func(cen_preds_pos,
                                          cen_targets_pos) / num_pos
        else:
            reg_loss = cls_loss.new_tensor(0, requires_grad=True)
            cen_loss = cls_loss.new_tensor(0, requires_grad=True)

        return dict(cls_loss=cls_loss, reg_loss=reg_loss, cen_loss=cen_loss)
Exemple #7
0
    def forward(self,
                f,
                b,
                mask=None,
                ksize=3,
                stride=1,
                rate=1,
                fuse_k=3,
                softmax_scale=10.,
                training=True,
                fuse=True):
        """ Contextual attention layer implementation.

        Contextual attention is first introduced in publication:
        Generative Image Inpainting with Contextual Attention, Yu et al.

        Args:
        f: Input feature to match (foreground).
        b: Input feature for match (background).
        mask: Input mask for b, indicating patches not available.
        ksize: Kernel size for contextual attention.
        stride: Stride for extracting patches from t.
        rate: Dilation for matching.
        softmax_scale: Scaled softmax for attention.
        training: Indicating if current graph is training or inference.

        Returns:
        tf.Tensor: output
        """

        # get shapes of foreground (f) and background (b)
        raw_fs = f.shape
        # print("RAW FS: " + str(raw_fs))
        raw_int_fs = list(f.shape)
        raw_int_bs = list(b.shape)

        # extract 3x3 patches from background with stride and rate
        kernel = 2 * rate
        raw_w = self.extract_image_patches(b, kernel, rate * stride)

        # Reshape raw_w to match pytorch conv weights shape
        raw_w = torch.reshape(
            raw_w, [raw_int_bs[0], -1, raw_int_bs[1], kernel, kernel
                    ])  # b x in_ch (h * w) x out_ch (c) x k x k

        # downscaling foreground option: downscaling both foreground and
        # background for matching and use original background for reconstruction.
        f = F.interpolate(f, scale_factor=1. / rate, mode='nearest')
        b = F.interpolate(
            b,
            size=[int(raw_int_bs[2] / rate),
                  int(raw_int_bs[3] / rate)],
            mode='nearest')

        # get shape of foreground then split on the batch dimension
        fs = f.shape
        int_fs = list(f.shape)
        f_groups = torch.split(f, 1, dim=0)

        # print("F GROUPS: " + str(f_groups[0].shape))

        bs = b.shape
        int_bs = list(b.shape)

        # extract w then reshape to weight shape of functional conv2d of pytorch
        w = self.extract_image_patches(b, ksize, stride)
        # reshape to b x in_ch (h * w) x out_ch (c) x k x k
        # print("INT FS: " + str(int_fs))
        w = torch.reshape(w, [int_fs[0], -1, int_fs[1], ksize, ksize])

        # print("W: " + str(w.shape))
        # process mask
        if mask is None:
            mask = torch.zeros([bs[0], 1, bs[2], bs[3]]).cuda()
        else:
            # print("DOWNSAMPLE MEN")
            mask = F.interpolate(mask, scale_factor=1. / rate, mode='nearest')

        m = self.extract_image_patches(mask, ksize, stride)

        # make mask have the shape of (b x c x hw x k x k)
        # print("m = " + str(mask.shape))
        if (mask.shape[0] > 1):
            m = torch.reshape(m, [mask.shape[0], 1, -1, ksize, ksize])
        else:
            m = torch.reshape(m, [1, 1, -1, ksize, ksize])
        # m = m[0]
        # print("MY M: " + str(m.shape))
        # create batch for mm
        mm = []
        for i in range(m.shape[0]):
            mm.append(utils.reduce_mean(m[i], axis=[0, 2, 3], keep_dims=True))

        mm = torch.cat(mm)

        # print("mm: " + str(mm.shape))
        w_groups = torch.split(w, 1, dim=0)
        raw_w_groups = torch.split(raw_w, 1, dim=0)
        y = []
        offsets = []
        k = fuse_k
        scale = softmax_scale
        fuse_weight = utils.to_var(torch.reshape(torch.eye(k), [1, 1, k, k]))

        for xi, wi, raw_wi, mi in zip(f_groups, w_groups, raw_w_groups, mm):
            """
            # Conv per batch
            # VARIABLES:
            # - xi: input to the conv; tensors from foreground (f_groups)
            # - wi: weights for training; image patches from the background (w_groups): 
            # - raw_wi: patches from the background (raw_w_groups)
            """
            # conv for compare
            wi = wi[0]  #

            wi_normed = wi / \
                torch.max(torch.sqrt(utils.reduce_sum(
                    wi ** 2, axis=[0, 2, 3])), torch.FloatTensor([1e-4]).cuda())

            # print("wi_normed: " + str(wi_normed.shape))
            # print("xi:" + str(xi.shape))
            yi = F.conv2d(xi, wi_normed, stride=1, padding=1)
            # print("yi: " + str(yi.shape))
            # wi_normed = wi / torch.max(torch.sqrt(torch.sum(torch.square()))) #l2 norm
            # conv implementation for fuse scores to encourage large patches
            if fuse:
                # b x c x f(hw) x b(hw)
                yi = torch.reshape(yi, [1, 1, fs[2] * fs[3], bs[2] * bs[3]])
                # print("yi: " + str(yi.shape))
                yi = F.conv2d(yi, fuse_weight, stride=1, padding=1)
                yi = torch.reshape(yi, [1, fs[2], fs[3], bs[2], bs[3]])
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = torch.reshape(yi, [1, 1, fs[2] * fs[3], bs[2] * bs[3]])
                # print("yi: " + str(yi.shape))
                yi = F.conv2d(yi, fuse_weight, stride=1, padding=1)
                yi = torch.reshape(yi, [1, fs[3], fs[2], bs[3], bs[2]])
                yi = yi.permute(0, 2, 1, 4, 3)
                # print("yi inside fuse: " + str(yi.shape))
                # print("yi: " + str(yi.shape))

            yi = torch.reshape(yi, [1, bs[2] * bs[3], fs[2], fs[3]])
            # print("yi: " + str(yi.shape))
            # softmax to match
            yi = yi * mi
            # print("hey")
            yi = F.softmax(yi * scale, dim=1)
            yi = yi * mi  # mask

            _, offset = torch.max(yi, dim=1)
            offset = torch.stack([offset // fs[3], offset % fs[3]], dim=-1)

            # deconv for patch pasting
            # 3.1 paste center
            wi_center = raw_wi[0]
            yi = F.conv_transpose2d(yi, wi_center, stride=rate, padding=1) / 4.
            y.append(yi)
            offsets.append(offset)

        y = torch.cat(y, dim=0)

        offsets = torch.cat(offsets, dim=0)
        offsets = torch.reshape(offsets,
                                [int_bs[0]] + [2] + int_bs[2:])  # skip channel

        # case1: visualize optical flow: minus current position
        # height
        h_add = utils.to_var(
            torch.reshape(torch.arange(bs[2]), [1, 1, bs[2], 1]))
        h_add = h_add.expand([bs[0], 1, bs[2], bs[3]])

        # width
        w_add = utils.to_var(
            torch.reshape(torch.arange(bs[3]), [1, 1, 1, bs[3]]))
        w_add = w_add.expand([bs[0], 1, bs[2], bs[3]])

        # concat on channel
        offsets = offsets - torch.cat([h_add, w_add], dim=1)

        # to flow image
        flow = helper.flow_to_image(
            offsets.permute(0, 2, 3, 1).data.cpu().numpy())
        flow = torch.from_numpy(flow).permute(0, 3, 1, 2)

        # case2: visualize which pixels are attended
        # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
        if rate != 1:
            flow = F.interpolate(flow, scale_factor=rate, mode='nearest')

        out = self.final_layers(y)
        return out, flow