Example #1
0
    def forward(self, inputs):
        assert len(inputs) == len(self.in_channels)

        # build laterals
        laterals = [
            lateral_conv(inputs[i + self.start_level])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        # build top-down path
        used_backbone_levels = len(laterals)
        # import pdb;pdb.set_trace()
        for i in range(used_backbone_levels - 1, 0, -1):
            laterals[i - 1] += F.upsample_nearest(laterals[i], scale_factor=2)
            # laterals[i - 1] += interpolate(
            #     laterals[i], scale_factor=2, mode='nearest')

        # build outputs
        # part 1: from original levels
        outs = [
            self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
        ]
        # part 2: add extra levels
        if self.num_outs > len(outs):
            # use max pool to get more levels on top of outputs
            # (e.g., Faster R-CNN, Mask R-CNN)
            if not self.add_extra_convs:
                for i in range(self.num_outs - used_backbone_levels):
                    outs.append(F.max_pool2d(outs[-1], 1, stride=2))
            # add conv layers on top of original feature maps (RetinaNet)
            else:
                orig = inputs[self.backbone_end_level - 1]
                outs.append(self.fpn_convs[used_backbone_levels](orig))
                for i in range(used_backbone_levels + 1, self.num_outs):
                    # BUG: we should add relu before each extra conv
                    outs.append(self.fpn_convs[i](outs[-1]))
        return tuple(outs)
Example #2
0
    def _forward(self, level, inp):
        # Upper branch
        up1 = inp
        up1 = self._modules['b1_' + str(level)](up1)

        # Lower branch
        low1 = F.avg_pool2d(inp, 2, stride=2)
        low1 = self._modules['b2_' + str(level)](low1)

        if level > 1:
            low2, klow2 = self._forward(level - 1, low1)
            low3 = low2
            low2 = klow2
        else:
            low2 = low1
            low2 = self._modules['b2_plus_' + str(level)](low2)
            low3 = low2

        low3 = self._modules['b3_' + str(level)](low3)

        #up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
        up2 = F.upsample_nearest(low3, scale_factor=2)

        return up1 + up2, low2
Example #3
0
 def forward(self, x):
     y = self.conv3(x)
     x = self.conv1(torch.cat([x, y], 1))
     return F.upsample_nearest(x, scale_factor=2)
Example #4
0
    def forward(self, x, vars=None, bn_training=True, DEBUG=False):
        """

        This function can be called by finetunning (task specific parameters), 
        however, in finetunning, we don't wish to update running_mean/running_var.
        Thought weights/bias of bn is updated, it has been separated by task specific parameters.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via task spefici parameters.

        :param x: [b, 1, 28, 28]

        :param vars: model 

        :param bn_training: set False to not update

        :return: x, 

        """

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        hidden = x

        for name, param in self.config:

            # Weights
            if name == 'conv2d':
                w, b = vars[idx], vars[idx + 1]

                # remember to keep synchrozied of forward_encoder and forward_decoder
                hidden = F.conv2d(hidden,
                                  w,
                                  b,
                                  stride=param[4],
                                  padding=param[5])
                idx += 2

                if DEBUG == True:
                    print(name, param, "shape: ", hidden.shape)

            elif name == 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder
                hidden = F.conv_transpose2d(hidden,
                                            w,
                                            b,
                                            stride=param[4],
                                            padding=param[5])
                idx += 2

                if DEBUG == True:
                    print(name, param, "shape: ", hidden.shape)

            elif name == 'fc':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder
                hidden = F.linear(hidden, w, b)
                idx += 2

                if DEBUG == True:
                    print(name, param, "shape: ", hidden.shape)

            elif name == 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]

                hidden = F.batch_norm(hidden,
                                      running_mean,
                                      running_var,
                                      weight=w,
                                      bias=b,
                                      training=bn_training)
                idx += 2
                bn_idx += 2

                if DEBUG == True:
                    print(name, param, "shape: ", hidden.shape)

            # Activations
            elif name == 'flatten':
                if DEBUG == True:
                    print(name, param, "Before flatten shape: ", hidden.shape)

                hidden = hidden.view(
                    hidden.size(0), -1)  # synchronize the number of data point

                if DEBUG == True:
                    print(name, param, "shape: ", hidden.shape)

            elif name == 'reshape':
                hidden = hidden.view(hidden.size(0), *param)

            elif name == 'relu':
                hidden = F.relu(hidden, inplace=param[0])

            elif name == 'reakyrelu':
                hidden = F.leaky_relu(hidden,
                                      negative_slope=param[0],
                                      inplace=param[1])

            elif name == 'tanh':
                hidden = F.tanh(hidden)

            elif name == 'sigmoid':
                hidden = F.sigmoid(hidden)

            elif name == 'upsample':
                hidden = F.upsample_nearest(hidden, scale_factor=param[0])

            elif name == 'max_pool2d':
                hidden = F.max_pool2d(hidden, param[0], param[1], param[2])

            elif name == 'avg_pool2d':
                hidden = F.avg_pool2d(hidden, param[0], param[1], param[2])

            else:
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return hidden
Example #5
0
    def rrn_preprocess(self, rgb_img, initial_masks):
        """ Pad, Crop, and Resize to prepare input to RRN.

            @param rgb_img: a [3 x H x W] torch.FloatTensor
            @param initial_masks: a [H x W] torch tensor

            @return: a dictionary: {'rgb' : rgb_crops, 'initial_masks' : mask_crops}
                     a dictionary: {str(mask_id) : [x_min, y_min, x_max, y_max] for each mask_id}
        """
        _, H, W = rgb_img.shape

        # Dictionary to save crop indices
        crop_indices = {}

        mask_ids = torch.unique(initial_masks)
        mask_ids = mask_ids[mask_ids >= OBJECTS_LABEL]

        rgb_crops = torch.zeros((mask_ids.shape[0], 3, 224, 224),
                                device=self.device)
        mask_crops = torch.zeros((mask_ids.shape[0], 224, 224),
                                 device=self.device)

        for index, mask_id in enumerate(mask_ids):
            mask = (initial_masks == mask_id).float()  # Shape: [H x W]

            # crop the masks/rgb to 224x224 with some padding, save it as "initial_masks"
            x_min, y_min, x_max, y_max = util_.mask_to_tight_box(mask)
            x_padding = int(
                torch.round((x_max - x_min).float() *
                            self.config['padding_percentage']).item())
            y_padding = int(
                torch.round((y_max - y_min).float() *
                            self.config['padding_percentage']).item())

            # Pad and be careful of boundaries
            x_min = max(x_min - x_padding, 0)
            x_max = min(x_max + x_padding, W - 1)
            y_min = max(y_min - y_padding, 0)
            y_max = min(y_max + y_padding, H - 1)
            crop_indices[mask_id.item()] = [x_min, y_min, x_max,
                                            y_max]  # save crop indices

            # Crop
            rgb_crop = rgb_img[:, y_min:y_max + 1,
                               x_min:x_max + 1]  # [3 x crop_H x crop_W]
            mask_crop = mask[y_min:y_max + 1,
                             x_min:x_max + 1]  # [crop_H x crop_W]

            # Resize
            new_size = (224, 224)
            rgb_crop = F.upsample_bilinear(
                rgb_crop.unsqueeze(0),
                new_size)[0]  # Shape: [3 x new_H x new_W]
            rgb_crops[index] = rgb_crop
            mask_crop = F.upsample_nearest(
                mask_crop.unsqueeze(0).unsqueeze(0),
                new_size)[0, 0]  # Shape: [new_H, new_W]
            mask_crops[index] = mask_crop

        batch = {'rgb': rgb_crops, 'initial_masks': mask_crops}
        return batch, crop_indices
Example #6
0
def concat_skip(inputs, skip, scale):
    upscaled = F.upsample_nearest(skip, scale_factor=scale)
    upscaled = centre_crop(upscaled, inputs.size())

    return torch.cat([inputs, upscaled], 1)
Example #7
0
    def loss_ml(self,
                cls_scores,
                mask_preds,
                gt_bboxes,
                gt_labels,
                gt_masks,
                category_targets,
                point_ins,
                img_metas,
                cfg,
                gt_bboxes_ignore=None):
        assert len(cls_scores) == len(mask_preds)
        loss_mask = 0
        _, _, b_h, b_w = mask_preds[0].shape
        bound = [self.grid_num[i]**2 for i in range(5)]
        bound = [0] + bound
        for i in range(1, 6):
            bound[i] += bound[i - 1]
        num_imgs = len(category_targets)
        for i in range(num_imgs):
            _, i_h, i_w = gt_masks[i].shape
            gt_masks[i] = nn.ConstantPad2d((0, b_w * self.strides[0] - i_w, 0,
                                            b_h * self.strides[0] - i_h),
                                           0)(torch.tensor(gt_masks[i]))

        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
                                                  self.cls_out_channels)
            for cls_score in cls_scores
        ]
        # need to check images first or dimensions first: image first and then row first
        flatten_cls_scores = torch.cat(
            [cls_score for cls_score in flatten_cls_scores],
            dim=1).reshape(-1, self.cls_out_channels)

        # calculate loss
        iter_all = 0
        for i in range(num_imgs):
            for j in range(len(self.grid_num)):
                ind = torch.nonzero(
                    category_targets[i][bound[j]:bound[j + 1]]).squeeze(-1)
                ins_ind = point_ins[i][bound[j]:bound[j + 1]][ind]
                _, b_h_i, b_w_i = mask_preds[j][i].shape
                gt_masks_ = F.upsample_nearest(
                    gt_masks[i].float().unsqueeze(0), (b_h_i, b_w_i))[0]
                ins_mask = gt_masks_[ins_ind].to(mask_preds[0].device)

                pdb.set_trace()

                if len(ins_mask) > 0:
                    pred_mask = mask_preds[j][i][ind]
                    pred_mask = F.sigmoid(pred_mask)
                    loss_mask += self.dice_loss(pred_mask, ins_mask)
                    iter_all += 1
        loss_mask = self.dict_weight * (loss_mask / (num_imgs * iter_all))
        category_targets = torch.cat(category_targets)
        num_pos = (category_targets > 0).sum()
        loss_cls = self.loss_cls(flatten_cls_scores,
                                 category_targets,
                                 avg_factor=num_pos + num_imgs)

        return dict(loss_cls=loss_cls, loss_mask=loss_mask)
Example #8
0
 def forward(self, x):
     from torch.nn import functional as F
     return F.upsample_nearest(x, scale_factor=2)
Example #9
0
    def forward(self, x, vars=None, bn_training=False, feature=False):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """

        cat_var = False
        cat_list = []

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        try:

            for (name, param, extra_name) in self.config:
                # assert(name == "conv2d")
                if name == 'conv2d':
                    w, b = vars[idx], vars[idx + 1]
                    x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                    idx += 2

                    # print(name, param, '\tout:', x.shape)
                elif name == 'convt2d':
                    w, b = vars[idx], vars[idx + 1]
                    x = F.conv_transpose2d(x,
                                           w,
                                           b,
                                           stride=param[4],
                                           padding=param[5])
                    idx += 2

                elif name == 'linear':

                    # ipdb.set_trace()
                    if extra_name == 'cosine':
                        w = F.normalize(vars[idx])
                        x = F.normalize(x)
                        x = F.linear(x, w)
                        idx += 1
                    else:
                        w, b = vars[idx], vars[idx + 1]
                        x = F.linear(x, w, b)
                        idx += 2

                    if cat_var:
                        cat_list.append(x)

                elif name == 'rep':
                    # print('rep')
                    # print(x.shape)
                    if feature:
                        return x

                elif name == "cat_start":
                    cat_var = True
                    cat_list = []

                elif name == "cat":
                    cat_var = False
                    x = torch.cat(cat_list, dim=1)

                elif name == 'bn':
                    w, b = vars[idx], vars[idx + 1]
                    running_mean, running_var = self.vars_bn[
                        bn_idx], self.vars_bn[bn_idx + 1]
                    x = F.batch_norm(x,
                                     running_mean,
                                     running_var,
                                     weight=w,
                                     bias=b,
                                     training=bn_training)
                    idx += 2
                    bn_idx += 2
                elif name == 'flatten':
                    # print('flatten')
                    # print(x.shape)

                    x = x.view(x.size(0), -1)

                elif name == 'reshape':
                    # [b, 8] => [b, 2, 2, 2]
                    x = x.view(x.size(0), *param)
                elif name == 'relu':
                    x = F.relu(x, inplace=param[0])
                elif name == 'leakyrelu':
                    x = F.leaky_relu(x,
                                     negative_slope=param[0],
                                     inplace=param[1])
                elif name == 'tanh':
                    x = F.tanh(x)
                elif name == 'sigmoid':
                    x = torch.sigmoid(x)
                elif name == 'upsample':
                    x = F.upsample_nearest(x, scale_factor=param[0])
                elif name == 'max_pool2d':
                    x = F.max_pool2d(x, param[0], param[1], param[2])
                elif name == 'avg_pool2d':
                    x = F.avg_pool2d(x, param[0], param[1], param[2])

                else:
                    print(name)
                    raise NotImplementedError

        except:
            traceback.print_exc(file=sys.stdout)
            # ipdb.set_trace()

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x
Example #10
0
    def forward(self, x, vars=None, bn_training=True, feature=False):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """

        cat_var = False
        cat_list = []

        if vars is None:
            vars = self.vars
        idx = 0
        bn_idx = 0

        if self.Neuromodulation:

            # =========== NEUROMODULATORY NETWORK ===========

            #'conv1_nm'
            #'bn1_nm'
            #'conv2_nm'
            #'bn2_nm'
            #'conv3_nm'
            #'bn3_nm'

            # Query the neuromodulatory network:

            for i in range(x.size(0)):

                data = x[i].view(1, 3, 28, 28)
                nm_data = x[i].view(1, 3, 28, 28)

                #input_mask = self.call_input_nm(data_, vars)
                #fc_mask = self.call_fc_nm(data_, vars)

                w, b = vars[0], vars[1]
                nm_data = conv2d(nm_data, w, b)
                w, b = vars[2], vars[3]
                running_mean, running_var = self.vars_bn[0], self.vars_bn[1]
                nm_data = F.batch_norm(nm_data,
                                       running_mean,
                                       running_var,
                                       weight=w,
                                       bias=b,
                                       training=True)

                nm_data = F.relu(nm_data)
                nm_data = maxpool(nm_data, kernel_size=2, stride=2)

                w, b = vars[4], vars[5]
                nm_data = conv2d(nm_data, w, b)
                w, b = vars[6], vars[7]
                running_mean, running_var = self.vars_bn[2], self.vars_bn[3]
                nm_data = F.batch_norm(nm_data,
                                       running_mean,
                                       running_var,
                                       weight=w,
                                       bias=b,
                                       training=True)

                nm_data = F.relu(nm_data)
                nm_data = maxpool(nm_data, kernel_size=2, stride=2)

                w, b = vars[8], vars[9]
                nm_data = conv2d(nm_data, w, b)
                w, b = vars[10], vars[11]
                running_mean, running_var = self.vars_bn[4], self.vars_bn[5]
                nm_data = F.batch_norm(nm_data,
                                       running_mean,
                                       running_var,
                                       weight=w,
                                       bias=b,
                                       training=True)
                nm_data = F.relu(nm_data)
                #nm_data = maxpool(nm_data, kernel_size=2, stride=2)

                nm_data = nm_data.view(nm_data.size(0), 1008)

                # NM Output

                w, b = vars[12], vars[13]
                fc_mask = F.sigmoid(F.linear(nm_data, w,
                                             b)).view(nm_data.size(0), 2304)

                # =========== PREDICTION NETWORK ===========

                #'conv1'
                #'bn1'
                #'conv2'
                #'bn2'
                #'conv3'
                #'bn3'
                #'fc'

                w, b = vars[14], vars[15]

                data = conv2d(data, w, b)

                w, b = vars[16], vars[17]
                running_mean, running_var = self.vars_bn[6], self.vars_bn[7]
                data = F.batch_norm(data,
                                    running_mean,
                                    running_var,
                                    weight=w,
                                    bias=b,
                                    training=True)
                data = F.relu(data)
                data = maxpool(data, kernel_size=2, stride=2)

                w, b = vars[18], vars[19]

                data = conv2d(data, w, b, stride=1)
                w, b = vars[20], vars[21]
                running_mean, running_var = self.vars_bn[8], self.vars_bn[9]
                data = F.batch_norm(data,
                                    running_mean,
                                    running_var,
                                    weight=w,
                                    bias=b,
                                    training=True)
                data = F.relu(data)
                data = maxpool(data, kernel_size=2, stride=2)

                w, b = vars[22], vars[23]

                data = conv2d(data, w, b, stride=1)
                w, b, = vars[24], vars[25]
                running_mean, running_var = self.vars_bn[10], self.vars_bn[11]
                data = F.batch_norm(data,
                                    running_mean,
                                    running_var,
                                    weight=w,
                                    bias=b,
                                    training=True)
                data = F.relu(data)
                #data = maxpool(data, kernel_size=2, stride=2)

                data = data.view(data.size(0), 2304)  #nothing-max-max
                data = data * fc_mask

                w, b = vars[26], vars[27]
                data = F.linear(data, w, b)

                try:
                    prediction = torch.cat([prediction, data], dim=0)
                except:
                    prediction = data

        else:

            for name, param in self.config:
                # assert(name == "conv2d")
                if name == 'conv2d':
                    w, b = vars[idx], vars[idx + 1]
                    x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                    idx += 2
                    # print(name, param, '\tout:', x.shape)
                elif name == 'convt2d':
                    w, b = vars[idx], vars[idx + 1]
                    x = F.conv_transpose2d(x,
                                           w,
                                           b,
                                           stride=param[4],
                                           padding=param[5])
                    idx += 2
                elif name == 'linear':

                    w, b = vars[idx], vars[idx + 1]
                    x = F.linear(x, w, b)
                    if cat_var:
                        cat_list.append(x)
                    idx += 2

                elif name == 'rep':
                    # print(x.shape)
                    if feature:
                        return x
                elif name == "cat_start":
                    cat_var = True
                    cat_list = []

                elif name == "cat":
                    cat_var = False
                    x = torch.cat(cat_list, dim=1)

                elif name == 'bn':
                    w, b = vars[idx], vars[idx + 1]
                    running_mean, running_var = self.vars_bn[
                        bn_idx], self.vars_bn[bn_idx + 1]
                    x = F.batch_norm(x,
                                     running_mean,
                                     running_var,
                                     weight=w,
                                     bias=b,
                                     training=bn_training)
                    idx += 2
                    bn_idx += 2
                elif name == 'flatten':
                    # print(x.shape)

                    x = x.view(x.size(0), -1)

                elif name == 'reshape':
                    # [b, 8] => [b, 2, 2, 2]
                    x = x.view(x.size(0), *param)
                elif name == 'relu':
                    x = F.relu(x, inplace=param[0])
                elif name == 'leakyrelu':
                    x = F.leaky_relu(x,
                                     negative_slope=param[0],
                                     inplace=param[1])
                elif name == 'tanh':
                    x = F.tanh(x)
                elif name == 'sigmoid':
                    x = torch.sigmoid(x)
                elif name == 'upsample':
                    x = F.upsample_nearest(x, scale_factor=param[0])
                elif name == 'max_pool2d':
                    x = F.max_pool2d(x, param[0], param[1], param[2])
                elif name == 'avg_pool2d':
                    x = F.avg_pool2d(x, param[0], param[1], param[2])

                else:
                    raise NotImplementedError

        if self.Neuromodulation:
            return (prediction)
        else:
            return (x)
Example #11
0
    def forward(self, x, y, vars=None, bn_training=True):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 512]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """

        batch_sz = x.size()[0]

        x_orig = x

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        # assert self.config[0][0] is 'random_proj'
        # need to start with the random projection
        for name, param in self.config:
            # print(name)
            if name is 'conv2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
                # print('forward:', idx, x.norm().item())
            elif name is 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
                x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
                idx += 2
                bn_idx += 2
            elif name is 'encode':
                x = x.view(x.size(0), -1)
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
            elif name is 'decode':
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                x = x.view(x.size(0), 64,28,28)
                idx += 2
            elif name is 'random_proj':
                # copying generator architecture from here: https://machinelearningmastery.com/how-to-develop-a-conditional-generative-adversarial-network-from-scratch/

                # latent_dim, latent_ch_out, emb_dim, emb_ch_out, hw_out = param
                latent_dim, hw_out, rand_ch_out = param
                cuda = torch.cuda.is_available()

                # send random tensor to linear layer, reshape into noise channels
                FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 
                rand = FloatTensor((x.size(0),latent_dim))
                torch.randn(x.size(0),latent_dim, out=rand, requires_grad=True)
                w_lat, b_lat = vars[idx], vars[idx + 1]
                rand = F.linear(rand, w_lat, b_lat)
                rand = F.leaky_relu(rand, 0.2)
                rand = rand.view(rand.size(0), rand_ch_out, hw_out, hw_out)
                x = torch.cat((x, rand), 1)

                # w_lat, b_lat = vars[idx], vars[idx + 1]

                # rand = F.linear(rand, w_lat, b_lat)
                # rand = F.leaky_relu(rand, 0.2)
                # rand = rand.view(rand.size(0), latent_ch_out, hw_out, hw_out)

                # send class embbeddings through a linear layer, reshape embeddings channels
                # w_emb, b_emb = vars[idx+2], vars[idx + 3]
                # x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                # idx += 2

                # x = F.linear(x, w_emb, b_emb)
                # x = F.leaky_relu(x, 0.2)
                # x = x.view(x.size(0), emb_ch_out, hw_out, hw_out)

                # concatenate embeddings and projections
                

                idx += 2

            elif name is 'update_identity':
                x_orig = x
            elif name is 'identity':
                # print(x.shape)
                x += x_orig
            elif name is 'flatten':
                # print(x.shape)
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])

            else:
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        # right now still returning y so that we can easilly extend to generating diff nums of examples by adjusting y in here
        return x, y
def get_masks_for_training(
        mask_shapes: List[Tuple] =
        [(64, 128, 128), (128, 64, 64), (256, 32, 32), (512, 16, 16), (512, 8, 8), (4096,), (365,)],
        device: str = 'cpu', add_batch_size: bool = False,
        p_random_mask: float = 0.3) -> List[torch.Tensor]:
    '''
    Method returns random masks similar to 3.2. of the paper
    :param mask_shapes: (List[Tuple]) Shapes of the features generated by the vgg16 model
    :param device: (str) Device to store tensor masks
    :param add_batch_size: (bool) If true a batch size is added to each mask
    :param p_random_mask: (float) Probability that a random mask is generated else no mask is utilized
    :return: (List[torch.Tensor]) Generated masks for each feature tensor
    '''
    # Select layer where no masking is used. Every output from the deeper layers get mapped out. Every higher layer gets
    # masked by a random shape
    selected_layer = np.random.choice(range(7))
    # Make masks
    masks = []
    random_mask = None
    random_mask_used = False
    for index, mask_shape in enumerate(reversed(mask_shapes)):
        # Full mask on case
        if index < selected_layer:
            if len(mask_shape) > 1:
                # Save mask to list
                masks.append(torch.zeros((1, mask_shape[1], mask_shape[2]), dtype=torch.float32, device=device))
            else:
                # Save mask to list
                masks.append(torch.zeros(mask_shape, dtype=torch.float32, device=device))
        # No mask case
        elif index == selected_layer:
            if len(mask_shape) > 1:
                # Save mask to list
                masks.append(torch.ones((1, mask_shape[1], mask_shape[2]), dtype=torch.float32, device=device))
            else:
                # Save mask to list
                masks.append(torch.ones(mask_shape, dtype=torch.float32, device=device))
        # Random mask cases
        elif index > selected_layer and random_mask is None:
            if len(mask_shape) > 2:
                # Get random mask
                if np.random.rand() < p_random_mask:
                    random_mask_used = True
                    random_mask = random_shapes(mask_shape[1:],
                                                min_shapes=1,
                                                max_shapes=4,
                                                min_size=min(8, mask_shape[1] // 2),
                                                allow_overlap=True)[0][:, :, 0]
                    # Random mask to torch tensor
                    random_mask = torch.tensor(random_mask, dtype=torch.float32, device=device)[None, :, :]
                    # Change range of mask to [0, 1]
                    random_mask = (random_mask == 255.0).float()
                else:
                    # Make no mask
                    random_mask = torch.ones(mask_shape[1:], dtype=torch.float32, device=device)[None, :, :]
                # Save mask to list
                masks.append(random_mask)
            else:
                # Save mask to list
                masks.append(torch.randint(low=0, high=2, size=mask_shape, dtype=torch.float32, device=device))
        else:
            # Save mask to list
            if random_mask_used:
                masks.append(F.upsample_nearest(random_mask[None, :, :, :], size=mask_shape[1:]).float().to(device)[0])
            else:
                masks.append(torch.ones(mask_shape[1:], dtype=torch.float32, device=device)[None, :, :])
    # Add batch size dimension
    if add_batch_size:
        for index in range(len(masks)):
            masks[index] = masks[index].unsqueeze(dim=0)
    # Reverse order of masks to match the features of the vgg16 model
    masks.reverse()
    return masks
Example #13
0
 def forward(self, x):
     x = F.upsample_nearest(x, size=(12, 12))
     x = F.upsample_nearest(x, scale_factor=2)
     return x
Example #14
0
    def forward_decoder(self, h, vars=None):
        """
        forward after hidden layer
        :param h:
        :param vars:
        :return:
        """
        if vars is None:
            vars = self.vars

        # get decoder network config
        decoder_config = self.config[self.hidden_config_idx + 1:]
        idx = self.hidden_var_idx
        bn_idx = self.hidden_var_bn_idx

        x = h
        for name, param in decoder_config:
            if name is 'conv2d':
                w, b = vars[idx:(idx + 2)]
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
            elif name is 'convt2d':
                w, b = vars[idx:(idx + 2)]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x,
                                       w,
                                       b,
                                       stride=param[4],
                                       padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx:(idx + 2)]
                x = F.linear(x, w, b)
                idx += 2
            elif name is 'bn':
                w, b = vars[idx:(idx + 2)]
                # TODO: can not be written as running_mean, running_var = self.vars_bn[bn_idx:(bn_idx+1)]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]
                x = F.batch_norm(x,
                                 running_mean,
                                 running_var,
                                 weight=w,
                                 bias=b,
                                 training=False)
                idx += 2
                bn_idx += 2
            elif name is 'usigma_layer':
                w1, b1 = vars[idx:(idx + 2)]
                w2, b2 = vars[idx + 2:(idx + 4)]
                # [b, h_dim]
                x1 = F.linear(x, w1, b1)
                # [b, h_dim]
                x2 = F.linear(x, w2, b2)
                # [b, 2*h_dim]
                x = torch.cat([x1, x2], dim=1)
                idx += 4
            elif name is 'flatten':
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'hidden':
                raise NotImplementedError
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'hidden':
                raise NotImplementedError
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])
            elif name is 'use_logits':
                continue
            else:
                raise NotImplementedError

        # print(self.hidden_var_idx, idx, len(self.vars))

        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        if self.use_logits:
            x = torch.sigmoid(x)

        return x
Example #15
0
    def forward_encoder(self, x, vars=None):
        """
        forward till hidden layer
        :param x:
        :param vars:
        :return:
        """
        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        for name, param in self.config:
            if name is 'conv2d':
                w, b = vars[idx:(idx + 2)]
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
            elif name is 'convt2d':
                w, b = vars[idx:(idx + 2)]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x,
                                       w,
                                       b,
                                       stride=param[4],
                                       padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx:(idx + 2)]
                x = F.linear(x, w, b)
                idx += 2
            elif name is 'bn':
                w, b = vars[idx:(idx + 2)]
                # TODO: can not be written as running_mean, running_var = self.vars_bn[bn_idx:(bn_idx+1)]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]
                x = F.batch_norm(x,
                                 running_mean,
                                 running_var,
                                 weight=w,
                                 bias=b,
                                 training=False)
                idx += 2
                bn_idx += 2
            elif name is 'usigma_layer':
                w1, b1 = vars[idx:(idx + 2)]
                w2, b2 = vars[idx + 2:(idx + 4)]
                # [b, h_dim]
                x1 = F.linear(x, w1, b1)
                # [b, h_dim]
                x2 = F.linear(x, w2, b2)
                # [b, 2*h_dim]
                x = torch.cat([x1, x2], dim=1)
                idx += 4
            elif name is 'flatten':
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'hidden':

                if self.is_vae:
                    # convert from h to q_h
                    # [b, 2*q_h_d]
                    assert len(x.shape) == 2
                    # splitting current x into mu and sigma
                    q_mu, q_sigma = x.chunk(2, dim=1)
                    # reparametrize trick
                    q_h = q_mu + q_sigma * torch.randn_like(q_sigma)
                    x = q_h

                break

            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])

            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])
            elif name is 'use_logits':
                raise NotImplementedError
            else:
                raise NotImplementedError

        assert idx == self.hidden_var_idx
        assert bn_idx == self.hidden_var_bn_idx

        return x
Example #16
0
    def forward(self, x, vars=None, update_bn_statistics=True):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param update_bn_statistics: set False to not update
        :return: x, loss, likelihood, kld
        """

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0
        input = x

        for name, param in self.config:
            if name is 'conv2d':
                w, b = vars[idx:(idx + 2)]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'convt2d':
                w, b = vars[idx:(idx + 2)]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x,
                                       w,
                                       b,
                                       stride=param[4],
                                       padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx:(idx + 2)]
                x = F.linear(x, w, b)
                idx += 2
                # print('forward:', idx, x.norm().item())
            elif name is 'bn':
                w, b = vars[idx:(idx + 2)]
                running_mean, running_var = self.vars_bn[bn_idx:(bn_idx + 2)]
                x = F.batch_norm(x,
                                 running_mean,
                                 running_var,
                                 weight=w,
                                 bias=b,
                                 training=update_bn_statistics)
                idx += 2
                bn_idx += 2

            elif name is 'usigma_layer':
                w1, b1 = vars[idx:(idx + 2)]
                w2, b2 = vars[idx + 2:(idx + 4)]
                # [b, h_dim]
                x1 = F.linear(x, w1, b1)
                # [b, h_dim]
                x2 = F.linear(x, w2, b2)
                # [b, 2*h_dim]
                x = torch.cat([x1, x2], dim=1)
                idx += 4
            elif name is 'flatten':
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'hidden':
                if self.is_vae:
                    # convert from h to q_h
                    # [b, 2*q_h_d]
                    assert len(x.shape) == 2
                    # splitting current x into mu and sigma
                    q_mu, q_sigma = x.chunk(2, dim=1)
                    # reparametrize trick
                    q_h = q_mu + q_sigma * torch.randn_like(q_sigma)
                    x = q_h
                else:
                    continue

            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])
            elif name is 'use_logits':
                continue
            else:
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        if self.is_vae:
            # assert not torch.isnan(x).any()

            # likelihood is the negative loss.
            likelihood = -self.criteon(x, input)

            # see Appendix B from VAE paper:
            # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
            # https://arxiv.org/abs/1312.6114
            # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
            kld = 0.5 * torch.sum(
                torch.pow(q_mu, 2) + torch.pow(q_sigma, 2) -
                torch.log(1e-8 + torch.pow(q_sigma, 2)) - 1) / np.prod(
                    input.shape)
            kld = self.beta * kld
            elbo = likelihood - kld
            loss = -elbo

            if self.use_logits:
                x = torch.sigmoid(x)

            return x, loss, likelihood, kld

        else:
            loss = self.criteon(x, input)
            # print(loss, input.shape)

            if self.use_logits:
                x = torch.sigmoid(x)

            return x, loss, None, None
Example #17
0
    def forward(self, x, w):
        x = F.upsample_nearest(x, size=(12, 12))
        x = F.upsample_nearest(x, scale_factor=2)

        w = F.upsample_nearest(w, scale_factor=(2.976744, 2.976744))
        return x, w
Example #18
0
    def get_embedded_vector(self, x, vars=None, bn_training=True, DEBUG=False):
        '''
            get_embedded_vector(self)

            return : embedded vectors (n_way * k_shot, 800)
        '''

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        hidden = x

        for name, param in self.config:
            # Weights
            if name == 'conv2d':
                w, b = vars[idx], vars[idx + 1]

                # remember to keep synchrozied of forward_encoder and forward_decoder
                hidden = F.conv2d(hidden,
                                  w,
                                  b,
                                  stride=param[4],
                                  padding=param[5])
                idx += 2

                if DEBUG == True:
                    print(name, param, "shape: ", hidden.shape)

            elif name == 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder
                hidden = F.conv_transpose2d(hidden,
                                            w,
                                            b,
                                            stride=param[4],
                                            padding=param[5])
                idx += 2

                if DEBUG == True:
                    print(name, param, "shape: ", hidden.shape)

            elif name == 'fc':
                break

            elif name == 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]

                hidden = F.batch_norm(hidden,
                                      running_mean,
                                      running_var,
                                      weight=w,
                                      bias=b,
                                      training=bn_training)
                idx += 2
                bn_idx += 2

                if DEBUG == True:
                    print(name, param, "shape: ", hidden.shape)

            # Activations
            elif name == 'flatten':
                if DEBUG == True:
                    print(name, param, "Before flatten shape: ", hidden.shape)

                hidden = hidden.view(
                    hidden.size(0), -1)  # synchronize the number of data point

                if DEBUG == True:
                    print(name, param, "shape: ", hidden.shape)

            elif name == 'reshape':
                hidden = hidden.view(hidden.size(0), *param)

            elif name == 'relu':
                hidden = F.relu(hidden, inplace=param[0])

            elif name == 'reakyrelu':
                hidden = F.leaky_relu(hidden,
                                      negative_slope=param[0],
                                      inplace=param[1])

            elif name == 'tanh':
                hidden = F.tanh(hidden)

            elif name == 'sigmoid':
                hidden = F.sigmoid(hidden)

            elif name == 'upsample':
                hidden = F.upsample_nearest(hidden, scale_factor=param[0])

            elif name == 'max_pool2d':
                hidden = F.max_pool2d(hidden, param[0], param[1], param[2])

            elif name == 'avg_pool2d':
                hidden = F.avg_pool2d(hidden, param[0], param[1], param[2])

            else:
                raise NotImplementedError

        return hidden
Example #19
0
 def forward(self, x):
     out = self.conv1(self.relu(self.bn1(x)))
     if self.droprate > 0:
         out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
     return F.upsample_nearest(out, scale_factor=2)
Example #20
0
    def forward(self, x, vars=None, bn_training=True):

        if vars is None:
            vars = self.vars
        idx = 0
        bn_idx = 0
        self.kl_reg = 0.0
        self.conv_idx, self.lin_idx = 0, 0
        self.sp_layer_names = []
        self.sp_vals = []
        self.dropout_vals = {}

        for ii, (name, param) in enumerate(self.config):
            if name is 'Conv2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                # print(ii,x.shape)
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x,
                                       w,
                                       b,
                                       stride=param[4],
                                       padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)

            elif name is 'Conv2d_SVDO' or name is 'Linear_SVDO':
                mu, b, log_sigma = vars[idx], vars[idx + 1], vars[idx + 2]

                # print('MU NORM',mu.data.norm(2))
                log_alpha = log_sigma * 2.0 - 2.0 * torch.log(1e-8 +
                                                              torch.abs(mu))
                log_alpha = torch.clamp(log_alpha, -10, 10)

                if torch.isnan(mu).any():
                    print('mu', name)

                if torch.isnan(log_sigma).any():
                    print('log_sigma', name)

                if torch.isnan(log_alpha).any():
                    print('Log Alpha', name)
                if torch.isnan(x).any():
                    print('Name', name)
                    print(x)
                    break

                self.kl_reg += self.kl_reg_term(log_alpha)

                sp_val = (log_alpha.cpu().data.numpy() > self.threshold).mean()
                self.sp_vals.append(sp_val)
                if self.training:
                    if name is 'Conv2d_SVDO':
                        layer_name = 'Conv-' + str(self.conv_idx + 1)
                        self.sp_layer_names.append(layer_name)
                        self.dropout_vals[layer_name] = log_alpha.cpu(
                        ).data.numpy()

                        self.conv_idx += 1
                        lrt_mean = F.conv2d(x,
                                            mu,
                                            bias=None,
                                            stride=param[4],
                                            padding=param[5]) + b
                        lrt_std = torch.sqrt(
                            F.conv2d(x * x,
                                     torch.exp(log_sigma * 2.0) + 1e-8,
                                     bias=None,
                                     stride=param[4],
                                     padding=param[5]))
                    else:
                        layer_name = 'Linear-' + str(self.lin_idx + 1)
                        self.sp_layer_names.append(layer_name)
                        self.dropout_vals[layer_name] = log_alpha.cpu(
                        ).data.numpy()

                        self.lin_idx += 1
                        lrt_mean = F.linear(
                            x, mu) + b  # compute mean activation using LRT
                        lrt_std = torch.sqrt(
                            F.linear(x * x,
                                     torch.exp(log_sigma * 2.0) + 1e-18))

                    eps = lrt_std.data.new(lrt_std.size()).normal_()
                    x = lrt_mean + lrt_std * eps
                else:
                    if name is 'Conv2d_SVDO':
                        x = F.conv2d(x,
                                     mu * (log_alpha < self.threshold).float(),
                                     bias=None,
                                     stride=param[4],
                                     padding=param[5]) + b
                    else:
                        x = F.linear(x,
                                     mu *
                                     (log_alpha < self.threshold).float()) + b
                idx += 3

            elif name is 'Linear':
                w, b = vars[idx], vars[idx + 1]
                # print(w.shape,b.shape,x.shape)
                x = F.linear(x, w, b)
                idx += 2
                # print('forward:', idx, x.norm().item())
            elif name is 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]
                x = F.batch_norm(x,
                                 running_mean,
                                 running_var,
                                 weight=w,
                                 bias=b,
                                 training=bn_training)
                idx += 2
                bn_idx += 2

            elif name is 'flatten':
                # print('Flatten',x.shape)
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'MaxPool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])

            else:
                print(name)
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x
Example #21
0
    def run_on_batch(self, batch):
        """ Run algorithm on batch of images in eval mode

            @param batch: a dictionary with the following keys:
                            - rgb: a [N x 3 x H x W] torch.FloatTensor
                            - xyz: a [N x 3 x H x W] torch.FloatTensor
            @param final_close_morphology: If True, then run open/close morphology after refining mask.
                                           This typically helps a synthetically-trained RRN
        """
        N, _, H, W = batch['rgb'].shape

        # Run the Depth Seeding Network. Note: this will send "batch" to device (e.g. GPU)
        fg_masks, direction_predictions, object_centers, initial_masks = self.dsn.run_on_batch(
            batch)
        # fg_masks: a [N x H x W] torch.LongTensor with values in {0, 1, 2}
        # direction_predictions: a [N x 2 x H x W] torch.FloatTensor
        # object_centers: a list of [2 x num_objects] torch.IntTensor. This list has length N
        # initial_masks: a [N x H x W] torch.IntTensor. Note: Initial masks has values in [0, 2, 3, ...]. No table

        initial_masks = self.process_initial_masks(batch, initial_masks,
                                                   object_centers)

        # Data structure to hold everything at end
        refined_masks = torch.zeros_like(initial_masks)
        for i in range(N):

            # Dictionary to save crop indices
            crop_indices = {}

            mask_ids = torch.unique(initial_masks[i])
            if mask_ids[0] == 0:
                mask_ids = mask_ids[1:]
            rgb_crops = torch.zeros((mask_ids.shape[0], 3, 224, 224),
                                    device=self.device)
            mask_crops = torch.zeros((mask_ids.shape[0], 224, 224),
                                     device=self.device)

            for index, mask_id in enumerate(mask_ids):
                mask = (initial_masks[i] == mask_id).float()  # Shape: [H x W]

                # crop the masks/rgb to 224x224 with some padding, save it as "initial_masks"
                x_min, y_min, x_max, y_max = util_.mask_to_tight_box(mask)
                x_padding = int(
                    torch.round((x_max - x_min).float() *
                                self.params['padding_percentage']).item())
                y_padding = int(
                    torch.round((y_max - y_min).float() *
                                self.params['padding_percentage']).item())

                # Pad and be careful of boundaries
                x_min = max(x_min - x_padding, 0)
                x_max = min(x_max + x_padding, W - 1)
                y_min = max(y_min - y_padding, 0)
                y_max = min(y_max + y_padding, H - 1)
                crop_indices[mask_id.item()] = [x_min, y_min, x_max,
                                                y_max]  # save crop indices

                # Crop
                rgb_crop = batch['rgb'][i, :, y_min:y_max + 1, x_min:x_max +
                                        1]  # [3 x crop_H x crop_W]
                mask_crop = mask[y_min:y_max + 1,
                                 x_min:x_max + 1]  # [crop_H x crop_W]

                # Resize
                new_size = (224, 224)
                rgb_crop = F.upsample_bilinear(
                    rgb_crop.unsqueeze(0),
                    new_size)[0]  # Shape: [3 x new_H x new_W]
                rgb_crops[index] = rgb_crop
                mask_crop = F.upsample_nearest(
                    mask_crop.unsqueeze(0).unsqueeze(0),
                    new_size)[0, 0]  # Shape: [new_H, new_W]
                mask_crops[index] = mask_crop

            # Run the RGB Refinement Network
            if mask_ids.shape[
                    0] > 0:  # only run if you actually have masks to refine...

                new_batch = {'rgb': rgb_crops, 'initial_masks': mask_crops}
                refined_crops = self.rrn.run_on_batch(
                    new_batch)  # Shape: [num_masks x new_H x new_W]

            # resize the results to the original size. Order this by average depth (highest to lowest)
            sorted_mask_ids = []
            for index, mask_id in enumerate(mask_ids):

                # Resize back to original size
                x_min, y_min, x_max, y_max = crop_indices[mask_id.item()]
                orig_H = y_max - y_min + 1
                orig_W = x_max - x_min + 1
                mask = refined_crops[index].unsqueeze(0).unsqueeze(0).float()
                resized_mask = F.upsample_nearest(mask, (orig_H, orig_W))[0, 0]

                # Calculate average depth
                h_idx, w_idx = torch.nonzero(resized_mask).t()
                avg_depth = torch.mean(batch['xyz'][i, 2, y_min:y_max + 1,
                                                    x_min:x_max + 1][h_idx,
                                                                     w_idx])
                sorted_mask_ids.append((index, mask_id, avg_depth))

            sorted_mask_ids = sorted(sorted_mask_ids,
                                     key=lambda x: x[2],
                                     reverse=True)
            sorted_mask_ids = [x[:2] for x in sorted_mask_ids
                               ]  # list of tuples: (index, mask_id)

            for index, mask_id in sorted_mask_ids:

                # Resize back to original size
                x_min, y_min, x_max, y_max = crop_indices[mask_id.item()]
                orig_H = y_max - y_min + 1
                orig_W = x_max - x_min + 1
                mask = refined_crops[index].unsqueeze(0).unsqueeze(0).float()
                resized_mask = F.upsample_nearest(mask, (orig_H, orig_W))[0, 0]

                # Set refined mask
                h_idx, w_idx = torch.nonzero(resized_mask).t()
                refined_masks[i, y_min:y_max + 1,
                              x_min:x_max + 1][h_idx, w_idx] = mask_id

        # Open/close morphology stuff, for synthetically-trained RRN
        if self.params['final_close_morphology']:
            refined_masks = refined_masks.cpu().numpy()  # to CPU

            for i in range(N):

                # Get object ids. Remove background (0)
                obj_ids = np.unique(refined_masks[i])
                if obj_ids[0] == 0:
                    obj_ids = obj_ids[1:]

                # For each object id, open/close the masks
                for obj_id in obj_ids:
                    mask = (refined_masks[i] == obj_id)  # Shape: [H x W]

                    ksize = self.params['open_close_morphology_ksize']  # 9
                    opened_mask = cv2.morphologyEx(
                        mask.astype(np.uint8), cv2.MORPH_OPEN,
                        cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
                                                  (ksize, ksize)))
                    opened_closed_mask = cv2.morphologyEx(
                        opened_mask, cv2.MORPH_CLOSE,
                        cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
                                                  (ksize, ksize)))

                    h_idx, w_idx = np.nonzero(mask)
                    refined_masks[i, h_idx, w_idx] = 0
                    h_idx, w_idx = np.nonzero(opened_closed_mask)
                    refined_masks[i, h_idx, w_idx] = obj_id

            refined_masks = torch.from_numpy(refined_masks).to(
                self.device)  # back to GPU

        return fg_masks, direction_predictions, initial_masks, refined_masks
 def forward(self, x):
     x0 = self.branch0(x)
     x0 = F.upsample_nearest(x0, scale_factor=2)
     x1 = self.branch1(x)
     x0 = x0 + x1
     return x0
Example #23
0
    def forwardaux(self, x):
        x = x / 255
        x = torch.stack(
            [
                x[:, 0, :, :] - self.transform[0][0],
                x[:, 1, :, :] - self.transform[0][1],
                x[:, 2, :, :] - self.transform[0][2],
            ],
            dim=1,
        )
        x = torch.stack(
            [
                x[:, 0, :, :] / self.transform[1][0],
                x[:, 1, :, :] / self.transform[1][1],
                x[:, 2, :, :] / self.transform[1][2],
            ],
            dim=1,
        )

        x1 = F.leaky_relu(self.conv11(x))
        x1 = F.leaky_relu(self.conv12(x1))
        x1p = F.max_pool2d(x1, kernel_size=2, stride=2)

        x2 = F.leaky_relu(self.conv21(x1p))
        x2 = F.leaky_relu(self.conv22(x2))
        x2p = F.max_pool2d(x2, kernel_size=2, stride=2)

        x3 = F.leaky_relu(self.conv31(x2p))
        x3 = F.leaky_relu(self.conv32(x3))
        x3 = F.leaky_relu(self.conv33(x3))
        x3p = F.max_pool2d(x3, kernel_size=2, stride=2)

        x4 = F.leaky_relu(self.conv41(x3p))
        x4 = F.leaky_relu(self.conv42(x4))
        x4 = F.leaky_relu(self.conv43(x4))
        x4p = F.max_pool2d(x4, kernel_size=2, stride=2)

        x5 = F.leaky_relu(self.conv51(x4p))
        x5 = F.leaky_relu(self.conv52(x5))
        x5 = F.leaky_relu(self.conv53(x5))

        x_grad_16 = self.gradientdoor16(x5)

        x5u = F.upsample_nearest(x5, scale_factor=2)
        x4 = torch.cat((x5u, x4), 1)

        x4 = F.leaky_relu(self.conv43d(x4))
        x4 = F.leaky_relu(self.conv42d(x4))
        x4 = F.leaky_relu(self.conv41d(x4))

        x_grad_8 = self.gradientdoor8(x4)

        x4u = F.upsample_nearest(x4, scale_factor=2)
        x3 = torch.cat((x4u, x3), 1)

        x3 = F.leaky_relu(self.conv33d(x3))
        x3 = F.leaky_relu(self.conv32d(x3))
        x3 = F.leaky_relu(self.conv31d(x3))

        x_grad_4 = self.gradientdoor4(x3)

        x3u = F.upsample_nearest(x3, scale_factor=2)
        x2 = torch.cat((x3u, x2), 1)

        x2 = F.leaky_relu(self.conv22d(x2))
        x2 = F.leaky_relu(self.conv21d(x2))

        x_grad_2 = self.gradientdoor2(x2)

        x2u = F.upsample_nearest(x2, scale_factor=2)
        x1 = torch.cat((x2u, x1), 1)

        x1 = F.leaky_relu(self.conv12d(x1))
        x = self.conv11d(x1)
        return x, (x_grad_2, x_grad_4, x_grad_8, x_grad_16)
def match_label_crop(initial_masks, labels_crop, out_label_crop, rois,
                     depth_crop):
    num = labels_crop.shape[0]
    for i in range(num):
        mask_ids = torch.unique(labels_crop[i])
        for index, mask_id in enumerate(mask_ids):
            mask = (labels_crop[i] == mask_id).float()
            overlap = mask * out_label_crop[i]
            percentage = torch.sum(overlap) / torch.sum(mask)
            if percentage < 0.5:
                labels_crop[i][labels_crop[i] == mask_id] = -1

    # sort the local labels
    sorted_ids = []
    for i in range(num):
        if depth_crop is not None:
            if torch.sum(labels_crop[i] > -1) > 0:
                roi_depth = depth_crop[i, 2][labels_crop[i] > -1]
            else:
                roi_depth = depth_crop[i, 2]
            avg_depth = torch.mean(roi_depth[roi_depth > 0])
            sorted_ids.append((i, avg_depth))
        else:
            x_min = rois[i, 0]
            y_min = rois[i, 1]
            x_max = rois[i, 2]
            y_max = rois[i, 3]
            orig_H = y_max - y_min + 1
            orig_W = x_max - x_min + 1
            roi_size = orig_H * orig_W
            sorted_ids.append((i, roi_size))

    sorted_ids = sorted(sorted_ids, key=lambda x: x[1], reverse=True)
    sorted_ids = [x[0] for x in sorted_ids]

    # combine the local labels
    refined_masks = torch.zeros_like(initial_masks).float()
    count = 0
    for index in sorted_ids:

        mask_ids = torch.unique(labels_crop[index])
        if mask_ids[0] == -1:
            mask_ids = mask_ids[1:]

        # mapping
        label_crop = torch.zeros_like(labels_crop[index])
        for mask_id in mask_ids:
            count += 1
            label_crop[labels_crop[index] == mask_id] = count

        # resize back to original size
        x_min = int(rois[index, 0].item())
        y_min = int(rois[index, 1].item())
        x_max = int(rois[index, 2].item())
        y_max = int(rois[index, 3].item())
        orig_H = int(y_max - y_min + 1)
        orig_W = int(x_max - x_min + 1)
        mask = label_crop.unsqueeze(0).unsqueeze(0).float()
        resized_mask = F.upsample_nearest(mask, (orig_H, orig_W))[0, 0]

        # Set refined mask
        h_idx, w_idx = torch.nonzero(resized_mask).t()
        refined_masks[0, y_min:y_max + 1,
                      x_min:x_max + 1][h_idx,
                                       w_idx] = resized_mask[h_idx,
                                                             w_idx].cpu()

    return refined_masks, labels_crop
Example #25
0
    def forward(self, boxes, labels, masks):

        n = len(boxes)
        _, img_h, img_w = masks[0].shape
        # n, _, img_h, img_w = img.size()
        category_targets_batch, point_ins_batch = [], []

        for bn in range(n):
            boxes_current = boxes[bn]
            label_current = labels[bn]
            mask_current = masks[bn]

            category_targets, point_ins = [], []
            for i in range(5):
                category_targets.append(
                    torch.zeros((self.fpn_size[i],
                                 self.fpn_size[i])).type(torch.int64))
                point_ins.append(
                    torch.ones((self.fpn_size[i], self.fpn_size[i])) * -1)
            category_targets_current = category_targets
            point_ins_current = point_ins

            obj_num = boxes_current.shape[0]
            x1, y1, x2, y2 = boxes_current[:,
                                           0], boxes_current[:,
                                                             1], boxes_current[:,
                                                                               2], boxes_current[:,
                                                                                                 3]
            hl = (y2 - y1) + 1
            wl = (x2 - x1) + 1

            gt_areas = torch.sqrt(hl * wl)
            masks_center = torch.zeros(
                (mask_current.shape[0], 2)).to(boxes[0].device)

            for i in range(mask_current.shape[0]):
                cent_ = torch.nonzero(mask_current[i], as_tuple=True)
                cent_ = torch.mean(torch.stack(cent_).float(), dim=1)

                if len(cent_) > 0:
                    masks_center[i, :] = cent_
                else:
                    masks_center[i, :] = [
                        0.5 * (y1[i] + y2[i]), 0.5 * (x1[i] + x2[i])
                    ]

            x_mean = masks_center[:, 1]
            y_mean = masks_center[:, 0]

            left_raw_l = (x_mean - self.sigma * wl).clamp(0, img_w - 1)
            right_raw_l = (x_mean + self.sigma * wl).clamp(0, img_w - 1)
            top_raw_l = (y_mean - self.sigma * hl).clamp(0, img_h - 1)
            bottom_raw_l = (y_mean + self.sigma * hl).clamp(0, img_h - 1)

            self.sigma /= 2.0
            left_raw = (x_mean - self.sigma * wl).clamp(0, img_w - 1)
            right_raw = (x_mean + self.sigma * wl).clamp(0, img_w - 1)
            top_raw = (y_mean - self.sigma * hl).clamp(0, img_h - 1)
            bottom_raw = (y_mean + self.sigma * hl).clamp(0, img_h - 1)

            ins_list = torch.range(0, obj_num - 1)

            for i in range(len(self.scale_ranges)):

                # indice of instances, scale for each instance
                hit_indices = (
                    (gt_areas >= self.scale_ranges[i][0]) &
                    (gt_areas <= self.scale_ranges[i][1])).nonzero()

                if len(hit_indices) > 0:
                    hit_indices = hit_indices[:, 0]
                    hit_indices_order = torch.sort(-gt_areas[hit_indices])[-1]
                    hit_indices = hit_indices[hit_indices_order]
                    h, w = img_h / self.fpn_size[i], img_w / self.fpn_size[i]

                    pos_category = label_current[hit_indices]
                    pos_mask = mask_current[hit_indices].float()
                    # pdb.set_trace()
                    for j in range(len(hit_indices)):
                        pos_mask[j][:top_raw_l[j].long()] = 0
                        pos_mask[j][bottom_raw_l[j].long():] = 0
                        pos_mask[j][:, :left_raw_l[j].long()] = 0
                        pos_mask[j][:, right_raw_l[j].long():] = 0
                    pos_mask2cat = F.upsample_nearest(
                        pos_mask.unsqueeze(0),
                        (self.fpn_size[i], self.fpn_size[i]))[0].long()

                    pos_instance = ins_list[hit_indices].tolist()

                    pos_left = (torch.floor(left_raw[hit_indices] / w)).clamp(
                        0, self.fpn_size[i] - 1).type(torch.int)
                    pos_right = (torch.floor(
                        right_raw[hit_indices] / w)).clamp(
                            0, self.fpn_size[i] - 1).type(torch.int)
                    pos_top = (torch.floor(top_raw[hit_indices] / h)).clamp(
                        0, self.fpn_size[i] - 1).type(torch.int)
                    pos_bottom = (torch.floor(
                        bottom_raw[hit_indices] / h)).clamp(
                            0, self.fpn_size[i] - 1).type(torch.int)

                    for j in range(len(hit_indices)):
                        mask_vindex = torch.nonzero(pos_mask2cat[j],
                                                    as_tuple=True)

                        pos_left_ = pos_left[j]
                        pos_right_ = pos_right[j]
                        pos_top_ = pos_top[j]
                        pos_bottom_ = pos_bottom[j]

                        row_ = np.array(range(pos_top_,
                                              pos_bottom_ + 1)).reshape(-1, 1)
                        row_num = row_.shape[0]

                        col_ = np.array(range(pos_left_,
                                              pos_right_ + 1)).reshape(1, -1)
                        col_num = col_.shape[1]

                        row_grid = np.tile(row_, (1, col_num)).reshape(
                            row_num * col_num).tolist()
                        col_grid = np.tile(col_, (row_num, 1)).reshape(
                            row_num * col_num).tolist()
                        try:
                            category_targets_current[i][
                                mask_vindex] = pos_category[j]
                            # in case small object vanishes
                            category_targets_current[i][
                                row_grid, col_grid] = pos_category[j]
                        except:
                            print(masks_center)
                        point_ins_current[i][mask_vindex] = pos_instance[j]
                        point_ins_current[i][row_grid,
                                             col_grid] = pos_instance[j]

            category_targets_current = torch.cat(
                (category_targets_current[0].flatten(),
                 category_targets_current[1].flatten(),
                 category_targets_current[2].flatten(),
                 category_targets_current[3].flatten(),
                 category_targets_current[4].flatten()),
                dim=0).type(torch.int64).to(boxes[0].device)
            # points for one instance larger than 3: need to be fixed
            point_ins_current = torch.cat(
                (point_ins_current[0].flatten(),
                 point_ins_current[1].flatten(),
                 point_ins_current[2].flatten(),
                 point_ins_current[3].flatten(),
                 point_ins_current[4].flatten()),
                dim=0).type(torch.int64).to(boxes[0].device)

            category_targets_batch.append(category_targets_current)
            point_ins_batch.append(point_ins_current)

        return category_targets_batch, point_ins_batch
def crop_rois(rgb, initial_masks, depth):

    N, H, W = initial_masks.shape
    crop_size = cfg.TRAIN.SYN_CROP_SIZE
    padding_percentage = 0.25

    mask_ids = torch.unique(initial_masks[0])
    if mask_ids[0] == 0:
        mask_ids = mask_ids[1:]
    num = mask_ids.shape[0]
    rgb_crops = torch.zeros((num, 3, crop_size, crop_size), device=cfg.device)
    rois = torch.zeros((num, 4), device=cfg.device)
    mask_crops = torch.zeros((num, crop_size, crop_size), device=cfg.device)
    if depth is not None:
        depth_crops = torch.zeros((num, 3, crop_size, crop_size),
                                  device=cfg.device)
    else:
        depth_crops = None

    for index, mask_id in enumerate(mask_ids):
        mask = (initial_masks[0] == mask_id).float()  # Shape: [H x W]
        x_min, y_min, x_max, y_max = util_.mask_to_tight_box(mask)
        x_padding = int(
            torch.round((x_max - x_min).float() * padding_percentage).item())
        y_padding = int(
            torch.round((y_max - y_min).float() * padding_percentage).item())

        # pad and be careful of boundaries
        x_min = max(x_min - x_padding, 0)
        x_max = min(x_max + x_padding, W - 1)
        y_min = max(y_min - y_padding, 0)
        y_max = min(y_max + y_padding, H - 1)
        rois[index, 0] = x_min
        rois[index, 1] = y_min
        rois[index, 2] = x_max
        rois[index, 3] = y_max

        # crop
        rgb_crop = rgb[0, :, y_min:y_max + 1,
                       x_min:x_max + 1]  # [3 x crop_H x crop_W]
        mask_crop = mask[y_min:y_max + 1, x_min:x_max + 1]  # [crop_H x crop_W]
        if depth is not None:
            depth_crop = depth[0, :, y_min:y_max + 1,
                               x_min:x_max + 1]  # [3 x crop_H x crop_W]

        # resize
        new_size = (crop_size, crop_size)
        rgb_crop = F.upsample_bilinear(
            rgb_crop.unsqueeze(0), new_size)[0]  # Shape: [3 x new_H x new_W]
        rgb_crops[index] = rgb_crop
        mask_crop = F.upsample_nearest(
            mask_crop.unsqueeze(0).unsqueeze(0),
            new_size)[0, 0]  # Shape: [new_H, new_W]
        mask_crops[index] = mask_crop
        if depth is not None:
            depth_crop = F.upsample_bilinear(
                depth_crop.unsqueeze(0),
                new_size)[0]  # Shape: [3 x new_H x new_W]
            depth_crops[index] = depth_crop

    return rgb_crops, mask_crops, rois, depth_crops
    def forward(self, x, vars=None, bn_training=True):

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        for name, param in self.config:
            if name is 'conv2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'lstm':
                w, b = vars[idx], vars[idx + 1]
                x = F.lstm()
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x,
                                       w,
                                       b,
                                       stride=param[4],
                                       padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
                # print('forward:', idx, x.norm().item())
            elif name is 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]
                x = F.batch_norm(x,
                                 running_mean,
                                 running_var,
                                 weight=w,
                                 bias=b,
                                 training=bn_training)
                idx += 2
                bn_idx += 2

            elif name is 'flatten':
                # print(x.shape)
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])

            else:
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x
Example #28
0
    def forward(self, x, vars=None, bn_training=True):
        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        for name, param in self.config:
            if name is 'conv2d':
                w, b = vars[idx], vars[idx + 1]
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
            elif name is 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                x = F.conv_transpose2d(x,
                                       w,
                                       b,
                                       stride=param[4],
                                       padding=param[5])
                idx += 2
            elif name is 'linear':
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
            elif name is 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]
                x = F.batch_norm(x,
                                 running_mean,
                                 running_var,
                                 weight=w,
                                 bias=b,
                                 training=bn_training)
                idx += 2
                bn_idx += 2

            elif name is 'flatten':
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])

            else:
                raise NotImplementedError

        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x
Example #29
0
    def forward(self, x, vars=None, bn_training=True):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        for name, param in self.config:
            if name is 'conv2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x,
                                       w,
                                       b,
                                       stride=param[4],
                                       padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
                # print('forward:', idx, x.norm().item())
            elif name is 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]
                x = F.batch_norm(x,
                                 running_mean,
                                 running_var,
                                 weight=w,
                                 bias=b,
                                 training=bn_training)
                idx += 2
                bn_idx += 2
            elif name is 'flatten':
                # print(x.shape)
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])
            else:
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)
        return x
Example #30
0
    def forward(self, x):
        size = x.size()
        x1 = self.features1(x)
        x2 = self.features2(x1)
        x3 = self.features3(x2)
        x4 = self.features4(x3)
        x5 = self.features5(x4)
        # begining of decoding
        x = self.de_pred5(x5)
        x5_out = x
        x = F.upsample_bilinear(x, size=x4.size()[2:])

        x = torch.cat([x4, x], 1)
        x = self.de_pred4(x)
        x4_out = x
        x = F.upsample_bilinear(x, size=x3.size()[2:])

        x = torch.cat([x3, x], 1)
        x = self.de_pred3(x)
        x3_out = x
        x = F.upsample_bilinear(x, size=x2.size()[2:])

        x = torch.cat([x2, x], 1)
        x = self.de_pred2(x)
        x2_out = x
        x = F.upsample_bilinear(x, size=x1.size()[2:])

        x = torch.cat([x1, x], 1)
        x = self.de_pred1(x)
        x1_out = x
        # density prediction
        x5_density = self.density_head5(x5_out)
        x4_density = self.density_head4(x4_out)
        x3_density = self.density_head3(x3_out)
        x2_density = self.density_head2(x2_out)
        x1_density = self.density_head1(x1_out)
        # get patch features for confidence prediction
        x5_confi = F.adaptive_avg_pool2d(x5_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        x4_confi = F.adaptive_avg_pool2d(x4_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        x3_confi = F.adaptive_avg_pool2d(x3_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        x2_confi = F.adaptive_avg_pool2d(x2_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        x1_confi = F.adaptive_avg_pool2d(x1_out, output_size=(size[-2] // self.block_size, size[-1] // self.block_size))
        # confidence prediction
        x5_confi = self.confidence_head5(x5_confi)
        x4_confi = self.confidence_head4(x4_confi)
        x3_confi = self.confidence_head3(x3_confi)
        x2_confi = self.confidence_head2(x2_confi)
        x1_confi = self.confidence_head1(x1_confi)
        # upsample the density prediction to be the same with the input size
        x5_density = F.upsample_nearest(x5_density, size=x1.size()[2:])
        x4_density = F.upsample_nearest(x4_density, size=x1.size()[2:])
        x3_density = F.upsample_nearest(x3_density, size=x1.size()[2:])
        x2_density = F.upsample_nearest(x2_density, size=x1.size()[2:])
        x1_density = F.upsample_nearest(x1_density, size=x1.size()[2:])
        # upsample the confidence prediction to be the same with the input size
        x5_confi_upsample = F.upsample_nearest(x5_confi, size=x1.size()[2:])
        x4_confi_upsample = F.upsample_nearest(x4_confi, size=x1.size()[2:])
        x3_confi_upsample = F.upsample_nearest(x3_confi, size=x1.size()[2:])
        x2_confi_upsample = F.upsample_nearest(x2_confi, size=x1.size()[2:])
        x1_confi_upsample = F.upsample_nearest(x1_confi, size=x1.size()[2:])

        # =============================================================================================================
        # soft √
        confidence_map = torch.cat([x5_confi_upsample, x4_confi_upsample,
                                    x3_confi_upsample, x2_confi_upsample, x1_confi_upsample], 1)
        confidence_map = torch.nn.functional.sigmoid(confidence_map)

        # use softmax to normalize
        confidence_map = torch.nn.functional.softmax(confidence_map, 1)

        density_map = torch.cat([x5_density, x4_density, x3_density, x2_density, x1_density], 1)
        # soft selection
        density_map *= confidence_map
        density = torch.sum(density_map, 1, keepdim=True)

        return density