def concat_we(self, x, we, only_we=False, only_grid=False): """ Convenience function to concat we Expects x in the form B x C x H x W (one feature map) we: B x wdim (the language vector) Output: concatenated word embedding and grid centers """ # Both cannot be true assert not (only_we and only_grid) # Create the grid grid = create_grid((x.size(2), x.size(3)), flatten=False).to(self.device) grid = grid.permute(2, 0, 1).contiguous() # TODO: Slightly cleaner implementation? grid_tile = grid.view(1, grid.size(0), grid.size(1), grid.size(2)).expand(we.size(0), grid.size(0), grid.size(1), grid.size(2)) # In case we only need the grid # Basically, don't use any image/language information if only_grid: return grid_tile # Expand word embeddings word_emb_tile = we.view(we.size(0), we.size(1), 1, 1).expand(we.size(0), we.size(1), x.size(2), x.size(3)) # In case performing image blind (requiring only language) if only_we: return word_emb_tile # Concatenate along the channel dimension return torch.cat((x, word_emb_tile, grid_tile), dim=1)
def concat_we(self, x, we, append_grid_centers=True): """ Convenience function to concat we Expects x in the form B x C x H x W we: B x wdim """ b, wdim = we.shape we = we / we.norm(dim=1).unsqueeze(1).expand(b, wdim) word_emb_tile = we.view(we.size(0), we.size(1), 1, 1).expand(we.size(0), we.size(1), x.size(2), x.size(3)) if append_grid_centers: grid = create_grid((x.size(2), x.size(3)), flatten=False).to(self.device) grid = grid.permute(2, 0, 1).contiguous() grid_tile = grid.view(1, grid.size(0), grid.size(1), grid.size(2)).expand(we.size(0), grid.size(0), grid.size(1), grid.size(2)) return torch.cat((x, word_emb_tile, grid_tile), dim=1) return torch.cat((x, word_emb_tile), dim=1)