Пример #1
0
    def get_fake_pool(self, fool, gt_train, img, layout_boxes, layout_masks,
                      obj_vecs, objs):
        objs_batch_all = []
        objs_repr_all = []
        for b in range(obj_vecs.size(0)):
            mask = remove_dummy_objects(objs[b], self.opt.vocab)
            objs_batch = objs[b][mask]
            objs_vecs_batch = obj_vecs[b][mask]
            layout_boxes_batch = layout_boxes[b][mask]
            layout_masks_batch = layout_masks[b][mask]
            O = objs_vecs_batch.size(0)
            img_exp = img[b].repeat(O, 1, 1, 1)
            if gt_train:
                # create encoding
                crops = crop_bbox(img_exp, layout_boxes_batch, 64)
                obj_repr = self.repr_net(self.image_encoder(crops))
            else:
                obj_repr = self.repr_net(layout_masks_batch)

            objs_repr_all.append(obj_repr)
            objs_batch_all.append(objs_batch)
        objs_batch_all = torch.cat(objs_batch_all, dim=0)
        objs_repr_all = torch.cat(objs_repr_all, dim=0)

        # Create fool layout
        fake_pool = None
        if fool:
            fake_pool = self.fake_pool.query(objs_batch_all, objs_repr_all)
        return fake_pool, objs_repr_all
Пример #2
0
def crop_bbox_batch(imgs,
                    objs,
                    bbox,
                    HH,
                    WW=None,
                    vocab=None,
                    backend='cudnn'):
    """
    Inputs:
    - imgs: FloatTensor of shape (N, C, H, W)
    - bbox: FloatTensor of shape (B, 4) giving bounding box coordinates
    - bbox_to_feats: LongTensor of shape (B,) mapping boxes to feature maps;
      each element is in the range [0, N) and bbox_to_feats[b] = i means that
      bbox[b] will be cropped from feats[i].
    - HH, WW: Size of the output crops

    Returns:
    - crops: FloatTensor of shape (B, C, HH, WW) where crops[i] uses bbox[i] to
      crop from feats[bbox_to_feats[i]].
    """
    if backend == 'cudnn':
        return crop_bbox_batch_cudnn(imgs, objs, bbox, HH, WW, vocab=vocab)
    N, C, H, W = imgs.size()
    B = bbox.size(0)
    if WW is None: WW = HH
    dtype, device = imgs.dtype, imgs.device
    crops = torch.zeros(B, C, HH, WW, dtype=dtype, device=device)
    for i in range(N):
        mask = remove_dummy_objects(objs[i], vocab)
        cur_bbox = bbox[i][mask]
        n = cur_bbox.size(0)
        cur_feats = imgs[i].view(1, C, H, W).expand(n, C, H, W).contiguous()
        cur_crops = crop_bbox(cur_feats, cur_bbox, HH, WW)
        crops[i] = cur_crops
    return crops
Пример #3
0
    def forward(self, objs, layout_boxes, layout_masks, test_mode=False):
        obj_vecs = self.attribute_embedding.forward(objs)  # [B, N, d']
        seg_batches = []
        for b in range(obj_vecs.size(0)):
            mask = remove_dummy_objects(objs[b], self.opt.vocab)
            objs_vecs_batch = obj_vecs[b][mask]
            layout_boxes_batch = layout_boxes[b][mask]
            # Masks Layout
            if layout_masks is not None:
                layout_masks_batch = layout_masks[b][mask]
                seg = masks_to_layout(objs_vecs_batch,
                                      layout_boxes_batch,
                                      layout_masks_batch,
                                      self.opt.image_size[0],
                                      self.opt.image_size[0],
                                      test_mode=test_mode)
            else:
                # Boxes Layout
                seg = boxes_to_layout(objs_vecs_batch, layout_boxes_batch,
                                      self.opt.image_size[0],
                                      self.opt.image_size[0])
            seg_batches.append(seg)
        seg = torch.cat(seg_batches, dim=0)

        # we downsample segmap and run convolution
        x = F.interpolate(seg, size=(self.sh, self.sw))
        x = self.fc(x)
        x = self.head_0(x, seg)
        x = self.up(x)
        x = self.G_middle_0(x, seg)

        if self.opt.num_upsampling_layers == 'more' or \
                self.opt.num_upsampling_layers == 'most':
            x = self.up(x)

        x = self.G_middle_1(x, seg)

        x = self.up(x)
        x = self.up_0(x, seg)
        x = self.up(x)
        x = self.up_1(x, seg)
        x = self.up(x)
        x = self.up_2(x, seg)
        x = self.up(x)
        x = self.up_3(x, seg)

        if self.opt.num_upsampling_layers == 'most':
            x = self.up(x)
            x = self.up_4(x, seg)

        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)

        return x
Пример #4
0
    def forward(self,
                img,
                objs,
                layout_boxes,
                layout_masks=None,
                gt_train=True,
                fool=False):
        obj_vecs = self.attribute_embedding.forward(objs)  # [B, N, d']

        # Masks Layout
        seg_batches = []
        for b in range(obj_vecs.size(0)):
            mask = remove_dummy_objects(objs[b], self.opt.vocab)
            objs_vecs_batch = obj_vecs[b][mask]
            layout_boxes_batch = layout_boxes[b][mask]

            # Masks Layout
            if layout_masks is not None:
                layout_masks_batch = layout_masks[b][mask]
                seg = masks_to_layout(
                    objs_vecs_batch,
                    layout_boxes_batch,
                    layout_masks_batch,
                    self.opt.image_size[0],
                    self.opt.image_size[0],
                    test_mode=False)  # test mode always false in disc.
            else:
                # Boxes Layout
                seg = boxes_to_layout(objs_vecs_batch, layout_boxes_batch,
                                      self.opt.image_size[0],
                                      self.opt.image_size[0])
            seg_batches.append(seg)

        # layout = torch.cat(layout_batches, dim=0)  # [B, N, d']
        seg = torch.cat(seg_batches, dim=0)
        input = torch.cat([img, seg], dim=1)

        result = []
        get_intermediate_features = not self.opt.no_ganFeat_loss
        for name, D in self.named_children():
            if name.startswith('discriminator'):
                out = D(input)
                if not get_intermediate_features:
                    out = [out]
                result.append(out)
                input = self.downsample(input)
        return result
Пример #5
0
    def forward(self, imgs, objs, boxes):
        crops = crop_bbox_batch(imgs,
                                objs,
                                boxes,
                                self.object_size,
                                vocab=self.vocab)

        N = objs.size(0)
        new_objs = []
        for i in range(N):
            mask = remove_dummy_objects(objs[i], self.vocab)

            curr_objs = objs[i][mask]
            new_objs.append(curr_objs)

        objs = torch.cat(new_objs, dim=0).squeeze(1)  # [N]
        real_scores, ac_loss = self.discriminator(crops, objs)
        return real_scores, ac_loss, crops
Пример #6
0
def crop_bbox_batch_cudnn(imgs, objs, bbox, HH, WW=None, vocab=None):
    N, C, H, W = imgs.size()
    if WW is None:
        WW = HH

    feats_flat, bbox_flat = [], []
    for i in range(N):
        mask = remove_dummy_objects(objs[i], vocab)
        cur_bbox = bbox[i][mask]
        n = cur_bbox.size(0)
        cur_feats = imgs[i].view(1, C, H, W).expand(n, C, H, W).contiguous()

        feats_flat.append(cur_feats)
        bbox_flat.append(cur_bbox)

    feats_flat = torch.cat(feats_flat, dim=0)
    bbox_flat = torch.cat(bbox_flat, dim=0)
    crops = crop_bbox(feats_flat, bbox_flat, HH, WW, backend='cudnn')
    return crops
Пример #7
0
    def forward(self, objs, layout_masks, gt_train=True):
        layout_batches = []
        for b in range(layout_masks.size(0)):
            mask = remove_dummy_objects(objs[b], self.opt.vocab)
            objs_batch = objs[b][mask]
            # Masks Layout
            layout_masks_batch = layout_masks[b][mask]  # [N, 32, 32]
            new_layout_masks = layout_masks_batch.unsqueeze(
                1)  # [N, 1, 32, 32]
            O = objs_batch.size(0)
            M = layout_masks_batch.size(1)
            # create one-hot vector for label map
            one_hot_size = (
                O, max(self.opt.vocab['object_name_to_idx'].values()) + 1)
            one_hot_obj = torch.zeros(one_hot_size,
                                      dtype=layout_masks_batch.dtype,
                                      device=layout_masks_batch.device)
            one_hot_obj = one_hot_obj.scatter_(1,
                                               objs_batch.view(-1, 1).long(),
                                               1.0)
            one_hot_obj = one_hot_obj.view(O, -1, 1, 1).expand(-1, -1, M, M)

            layout_vecs = torch.cat([one_hot_obj, new_layout_masks], dim=1)
            layout_batches.append(layout_vecs)

        input = torch.cat(layout_batches, dim=0).float()  # [N, d', M, M]

        result = []
        get_intermediate_features = not self.opt.no_ganFeat_loss
        for name, D in self.named_children():
            if name.startswith('discriminator'):
                out = D(input)
                if not get_intermediate_features:
                    out = [out]
                result.append(out)
                input = self.downsample(input)
        return result