def test_reshape(self):
     a = jt.random([123, 456, 789]).name("a")
     b = jt.reshape(a, [123 * 2, int(789 * 456 / 2)]).name("b")
     c = jt.reshape(b, [123 * 456 * 789]).name("c")
     d = jt.reshape(c, [2, int(123 / 3), 789, int(456 / 2), 3]).name("d")
     e = jt.reshape(d, [2, int(123 / 3), 789, -1, 3]).name("e")
     assert b.shape == [123 * 2, int(789 * 456 / 2)]
     assert c.shape == [123 * 456 * 789]
     assert d.shape == [2, int(123 / 3), 789, int(456 / 2), 3]
     assert e.shape == [2, int(123 / 3), 789, int(456 / 2), 3]
     a_mean = a.mean().data
     b_mean = b.mean().data
     c_mean = c.mean().data
     d_mean = d.mean().data
     e_mean = e.mean().data
     a = (a + 1).name("new_a")
     new_a_mean = a.mean().data
     new_b_mean = b.mean().data
     node_dict = get_info(jt.dump_all_graphs())
     assert check_equal(a_mean, b_mean), f"{a_mean} != {b_mean}"
     assert check_equal(a_mean, c_mean), f"{a_mean} != {c_mean}"
     assert check_equal(a_mean, d_mean), f"{a_mean} != {d_mean}"
     assert check_equal(a_mean, e_mean), f"{a_mean} != {e_mean}"
     assert check_equal(b_mean, new_b_mean), f"{b_mean} != {new_b_mean}"
     assert not check_equal(a_mean, new_a_mean), f"{a_mean} == {new_a_mean}"
     assert node_dict['a'] == node_dict['b']
     assert node_dict['a'] == node_dict['c']
     assert node_dict['a'] == node_dict['d']
     assert node_dict['a'] == node_dict['e']
Ejemplo n.º 2
def channel_shuffle(x, groups):
    (batchsize, num_channels, height, width) =
    channels_per_group = (num_channels // groups)
    x = jt.reshape(x, [batchsize, groups, channels_per_group, height, width])
    x = jt.transpose(x, (0, 2, 1, 3, 4))
    x = jt.reshape(x, [batchsize, (-1), height, width])
    return x
Ejemplo n.º 3
    def execute(self, x):
        N, C, H, W = x.shape
        Kh, Kw = self.kernel_size
        oh = (H + self.padding[0] * 2 - Kh * self.dilation[0] +
              self.dilation[0] - 1) // self.stride[0] + 1
        ow = (W + self.padding[1] * 2 - Kw * self.dilation[1] +
              self.dilation[1] - 1) // self.stride[1] + 1

        x = jt.reshape(x, [N, self.groups, self.group_channel_in, H, W])
        xx = x.reindex(
                N, self.groups, self.group_channel_out, self.group_channel_in,
                oh, ow, Kh, Kw
                'i0',  # Nid
                'i1',  # Group
                'i3',  # Cid
                f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}',  # Hid+Khid
                f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}',  # Wid+KWid
        ww = self.weight.broadcast(xx.shape, [0, 4, 5])
        yy = xx * ww
        y = yy.sum([3, 6, 7])  # Kc, Kh, Kw
        y = jt.reshape(y, [N, self.out_channels, oh, ow])
        if self.bias is not None:
            b = self.bias.broadcast(y.shape, [0, 2, 3])
            y = y + b
        return y
Ejemplo n.º 4
    def execute(self, predicted_locs, predicted_scores, boxes, labels):
        """ Forward propagation.

            predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4)
            predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes)
            boxes: true  object bounding boxes in boundary coordinates, a list of N tensors
            labels: true object labels, a list of N tensors
        Return: multibox loss, a scalar
        batch_size = predicted_locs.shape[0]
        n_priors = self.priors_cxcy.shape[0]
        n_classes = predicted_scores.shape[2]
        assert (n_priors == predicted_locs.shape[1] ==
        true_locs = np.zeros((batch_size, n_priors, 4))
        true_classes = np.zeros((batch_size, n_priors))
        for i in range(batch_size):
            # Step1: Select one object for every prior
            # Step2: Select one prior for every object, and set its iou 1
            # Step3: Set priors as background, whose iou is lower than threshold, eg: 0.5
            n_objects = boxes[i].shape[0]
            overlap = find_jaccard_overlap(boxes[i], self.priors_xy)
            object_for_each_prior, overlap_for_each_prior = argmax(overlap,
            prior_for_each_object, _ = argmax(overlap, axis=1)
            object_for_each_prior[prior_for_each_object] = range(n_objects)
            overlap_for_each_prior[prior_for_each_object] = 1.0
            label_for_each_prior = labels[i][object_for_each_prior]
            label_for_each_prior[overlap_for_each_prior < self.threshold] = 0
            true_classes[i] = label_for_each_prior
            true_locs[i] = cxcy_to_gcxgcy(
                xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy)
        true_classes = jt.array(true_classes).float32().stop_grad()
        true_locs = jt.array(true_locs).float32().stop_grad()
        positive_priors = (true_classes != 0)
        loc_loss = self.smooth_l1(
            (predicted_locs * positive_priors.broadcast([1, 1, 4], [2])),
            (true_locs * positive_priors.broadcast([1, 1, 4], [2])))
        loc_loss /= (positive_priors.float32().sum() * 4)
        n_positives = positive_priors.float32().sum(1)
        n_hard_negatives = self.neg_pos_ratio * n_positives
        conf_loss_all = self.cross_entropy(
            jt.reshape(predicted_scores, [-1, n_classes]),
            jt.reshape(true_classes, [
        conf_loss_all = jt.reshape(conf_loss_all, [batch_size, n_priors])
        conf_loss_pos = conf_loss_all * positive_priors
        conf_loss_neg = conf_loss_all * (1 - positive_priors)
        _, conf_loss_neg = conf_loss_neg.argsort(dim=1, descending=True)
        hardness_ranks = jt.array(range(n_priors)).broadcast(
            [conf_loss_neg.shape[0], conf_loss_neg.shape[1]], [0])
        hard_negatives = hardness_ranks < n_hard_negatives.broadcast(
            hardness_ranks.shape, [1])
        conf_loss_hard_neg = conf_loss_neg * hard_negatives
        conf_loss = ((conf_loss_hard_neg.sum() + conf_loss_pos.sum()) /
        return (conf_loss + (self.alpha * loc_loss)), conf_loss, loc_loss
Ejemplo n.º 5
 def execute(self, x):
     ax = self.encoder(x)
     ax = jt.reshape(ax, [ax.shape[0], self.feat_dim])
     f1 = self.fc1(ax)
     f1 = self.relu(f1)
     f2 = self.fc2(f1)
     f2 = jt.reshape(f2, [f2.shape[0], 512, self.rh,])
     y = self.decoder(f2)
     return y
Ejemplo n.º 6
 def execute(self, in_feat):
     z_img = self.model(in_feat)
     z = jt.reshape(z_img, [z_img.shape[0], (-1)])
     zn = z[:, 0:self.latent_dim]
     zc_logits = z[:, self.latent_dim:]
     zc = nn.softmax(zc_logits, dim=1)
     return (zn, zc, zc_logits)
Ejemplo n.º 7
 def execute(self, img):
     img_flat = jt.reshape(img, [img.shape[0], (-1)])
     x = self.model(img_flat)
     mu =
     logvar = self.logvar(x)
     z = reparameterization(mu, logvar)
     return z
Ejemplo n.º 8
 def execute(self, img):
     out = self.feature_extractor(img)
     out = self.pooling(out)
     out = jt.reshape(out, [out.shape[0], (- 1)])
     mu = self.fc_mu(out)
     logvar = self.fc_logvar(out)
     return mu, logvar
Ejemplo n.º 9
    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # out =
        # x = self.conv2(x)
        # origin = x.size()
        # # print(out.size())
        # x = x.flatten(2).transpose(0, 2, 1)
        # # print(out.size())
        # x, out_weights =
        # # print(out.size())
        # x = x.transpose(0, 2, 1).reshape(origin)
        # # print(out.size())
        # x = self.conv3(x)

        x = self.avgpool(x)
        x = jt.reshape(x, (x.shape[0], -1))

        x = self.fc(x)
        return x
 def _forward(self, x):
     x = self.Conv2d_1a_3x3(x)
     x = self.Conv2d_2a_3x3(x)
     x = self.Conv2d_2b_3x3(x)
     x = nn.pool(x, 3, "maximum", stride=2)
     x = self.Conv2d_3b_1x1(x)
     x = self.Conv2d_4a_3x3(x)
     x = nn.pool(x, 3, "maximum", stride=2)
     x = self.Mixed_5b(x)
     x = self.Mixed_5c(x)
     x = self.Mixed_5d(x)
     x = self.Mixed_6a(x)
     x = self.Mixed_6b(x)
     x = self.Mixed_6c(x)
     x = self.Mixed_6d(x)
     x = self.Mixed_6e(x)
     aux_defined = self.aux_logits
     if aux_defined:
         aux = self.AuxLogits(x)
         aux = None
     x = self.Mixed_7a(x)
     x = self.Mixed_7b(x)
     x = self.Mixed_7c(x)
     x = nn.AdaptiveAvgPool2d(1)(x)
     x = nn.Dropout()(x)
     x = jt.reshape(x, (x.shape[0], (-1)))
     x = self.fc(x)
     return (x, aux)
Ejemplo n.º 11
    def _forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)
        x = self.inception4a(x)
        if (self.aux1 is not None):
            aux1 = self.aux1(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        if (self.aux2 is not None):
            aux2 = self.aux2(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)
        x = self.inception5a(x)
        x = self.inception5b(x)
        x = self.avgpool(x)

        x = jt.reshape(x, (x.shape[0], (- 1)))
        x = self.dropout(x)
        x = self.fc(x)
        return (x, aux2, aux1)
Ejemplo n.º 12
 def execute(self, x):
     x = nn.AdaptiveAvgPool2d(4)(x)
     x = self.conv(x)
     x = jt.reshape(x, (x.shape[0], (- 1)))
     x = nn.relu(self.fc1(x))
     x = nn.Dropout(0.7)(x)
     x = self.fc2(x)
     return x
    def execute(self, x):
        x = nn.pool(x, kernel_size=5, op="mean", stride=3)
        x = self.conv0(x)
        x = self.conv1(x)

        x = nn.AdaptiveAvgPool2d(1)(x)
        x = jt.reshape(x, (x.shape[0], (-1)))
        x = self.fc(x)
        return x
Ejemplo n.º 14
 def execute(self, feature, input_part):
     # input2vector
     feature_vector = self.net_encoder(input_part)
     self.feature_vector = feature_vector
     fake_part = self.net_decoder(feature_vector)
     print(jt.reshape(self.feature_vector, (1, 512)))
     loss = self.criterion(fake_part, input_part.detach()) * 10
     loss = loss.reshape(1)
     return fake_part, self.loss_filter(loss)
Ejemplo n.º 15
    def execute(self, ten):
        # print("in DecoderGenerator, print some shape ")
        # print(ten.size())
        ten = self.fc(ten)
        # print(ten.size())
        ten = jt.reshape(ten,(ten.size()[0],512, self.latent_size, self.latent_size))
        # print(ten.size())
        ten = self.conv(ten)

        return ten
Ejemplo n.º 16
 def execute(self, zn, zc):
     z = jt.contrib.concat([zn, zc], dim=1)
     x_gen = self.model0(z)
     x_gen = self.model1(x_gen)
     x_gen = self.model2(x_gen)
     x_gen = self.model3(x_gen)
     x_gen = self.model4(x_gen)
     x_gen = self.sigmoid(x_gen)
     x_gen = jt.reshape(x_gen, [x_gen.shape[0], *self.x_shape])
     return x_gen
Ejemplo n.º 17
 def execute(self, ten):
     # ten = ten[:,:,:]
     # ten2 = jt.reshape(ten,[ten.size()[0],-1])
     # print(ten.shape, ten2.shape)
     ten = self.conv(ten)
     ten = jt.reshape(ten, [ten.size()[0], -1])
     # print(ten.shape,self.longsize)
     mu = self.fc_mu(ten)
     # logvar = self.fc_var(ten)
     return mu  # ,logvar
 def execute(self, x):
     out = nn.relu(self.bn1(self.conv1(x)))
     out = self.layer1(out)
     out = self.layer2(out)
     out = self.layer3(out)
     out = self.layer4(out)
     out = nn.pool(out, size=4, op="mean", padding=0)
     out = jt.reshape(out, [out.shape[0], -1])
     out = self.linear1(out)
     out = self.linear2(out)
     return out
Ejemplo n.º 19
def calc_gradient_penalty(netD, real_data, generated_data):
    LAMBDA = 10
    b_size = real_data.shape[0]
    alpha = jt.random([b_size, 1, 1, 1])
    alpha = alpha.broadcast(real_data)
    interpolated = ((alpha * +
                    ((1 - alpha) *
    prob_interpolated = netD(interpolated)
    gradients = jt.grad(prob_interpolated, interpolated)
    gradients = jt.reshape(gradients, [b_size, -1])
    gradients_norm = jt.sqrt((jt.sum((gradients**2), dim=1) + 1e-12))
    return (LAMBDA * ((gradients_norm - 1)**2).mean())
Ejemplo n.º 20
def run_network(inputs,
                netchunk=1024 * 64):
    """Prepares inputs and applies network 'fn'.
    inputs_flat = jt.reshape(inputs, [-1, inputs.shape[-1]])
    embedded = embed_fn(inputs_flat)

    if viewdirs is not None:
        input_dirs = viewdirs[:, None].expand(inputs.shape)
        input_dirs_flat = jt.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        embedded = jt.concat([embedded, embedded_dirs], -1)

    outputs_flat = batchify(fn, netchunk)(embedded)
    outputs = jt.reshape(outputs_flat,
                         list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs
Ejemplo n.º 21
 def _forward_impl(self, x):
     x = self.conv1(x)
     x = self.bn1(x)
     x = self.relu(x)
     x = self.maxpool(x)
     x = self.layer1(x)
     x = self.layer2(x)
     x = self.layer3(x)
     x = self.layer4(x)
     x = self.avgpool(x)
     x = jt.reshape(x, (x.shape[0], -1))
     x = self.fc(x)
     return x
Ejemplo n.º 22
    def execute(self, x):
        x = self.conv1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x =
        x = self.relu(x)

        x = self.max_pool(x)
        x = jt.reshape(x, [x.shape[0], -1])
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
Ejemplo n.º 23
    def execute(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = jt.reshape(x, [x.shape[0], -1])
        x = self.fc(x)

        return x
Ejemplo n.º 24
 def execute(self, x, z):
     z = self.fc(z)
     z = jt.reshape(z, [z.shape[0], 1, self.h, self.w])
     d1 = self.down1(jt.contrib.concat([x, z], dim=1))
     d2 = self.down2(d1)
     d3 = self.down3(d2)
     d4 = self.down4(d3)
     d5 = self.down5(d4)
     d6 = self.down6(d5)
     d7 = self.down7(d6)
     u1 = self.up1(d7, d6)
     u2 = self.up2(u1, d5)
     u3 = self.up3(u2, d4)
     u4 = self.up4(u3, d3)
     u5 = self.up5(u4, d2)
     u6 = self.up6(u5, d1)
Ejemplo n.º 25
 def execute(self, x):
     x = self.features(x)
     x = jt.reshape(x, [x.shape[0],-1])
     x = self.classifier(x)
     return x
Ejemplo n.º 26
def render_rays(ray_batch,
    """Volumetric rendering.
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      network_fn: function. Model for predicting RGB and density at each point
        in space.
      network_query_fn: function used for passing queries to network_fn.
      N_samples: int. Number of different times to sample along each ray.
      retraw: bool. If True, include model's raw, unprocessed predictions.
      lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
      N_importance: int. Number of additional times to sample along each ray.
        These samples are only passed to network_fine.
      network_fine: "fine" network with same spec as network_fn.
      white_bkgd: bool. If True, assume a white background.
      raw_noise_std: ...
      verbose: bool. If True, print more debugging info.
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
      disp_map: [num_rays]. Disparity map. 1 / depth.
      acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.
      rgb0: See rgb_map. Output for coarse model.
      disp0: See disp_map. Output for coarse model.
      acc0: See acc_map. Output for coarse model.
      z_std: [num_rays]. Standard deviation of distances along ray for each
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6]  # [N_rays, 3] each
    viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
    bounds = jt.reshape(ray_batch[..., 6:8], [-1, 1, 2])
    near, far = bounds[..., 0], bounds[..., 1]  # [-1,1]

    z_vals = sample(N_rays, N_samples, lindisp, perturb, near, far)
    pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(
        -1)  # [N_rays, N_samples, 3]

    raw = network_query_fn(pts, viewdirs, network_fn)
    rgb_map, disp_map, acc_map, weights, depth_map = integrator(
        raw, z_vals, rays_d, raw_noise_std, white_bkgd)

    rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
    #usefulRayIndex = jt.nonzero(acc_map > 0.1)
    if N_importance > 0:
        # importance sampling
        z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
        z_samples = sample_pdf(z_vals_mid,
                               weights[..., 1:-1],
                               det=(perturb == 0.))
        z_samples = z_samples.detach()

        _, z_vals = jt.argsort(jt.concat([z_vals, z_samples], -1), -1)
        pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(
            -1)  # [N_rays, N_samples + N_importance, 3]

        run_fn = network_fn if network_fine is None else network_fine
        raw = network_query_fn(pts, viewdirs, run_fn)
        rgb_map, disp_map, acc_map, weights, depth_map = integrator(
            raw, z_vals, rays_d, raw_noise_std, white_bkgd)

    ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map}
    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0

    return ret
Ejemplo n.º 27
def render(H,
           chunk=1024 * 32,
    """Render rays
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
       camera while using other c2w argument for viewing directions.
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = pinhole_get_rays(H, W, focal, c2w, intrinsic)
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            assert intrinsic is None
            rays_o, rays_d = pinhole_get_rays(H, W, focal, c2w_staticcam)
        viewdirs = viewdirs / jt.norm(viewdirs, p=2, dim=-1, keepdim=True)
        viewdirs = jt.reshape(viewdirs, [-1, 3]).float()

    sh = rays_d.shape  # [..., 3]
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)

    # Create ray batch
    rays_o = jt.reshape(rays_o, [-1, 3]).float()
    rays_d = jt.reshape(rays_d, [-1, 3]).float()

    near, far = near * jt.ones_like(rays_d[..., :1]), far * jt.ones_like(
        rays_d[..., :1])
    rays = jt.concat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = jt.concat([rays, viewdirs], -1)

    # Render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = jt.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]
Ejemplo n.º 28
 def execute(self, x):
     x = self.features(x)
     x = self.avgpool(x)
     x = jt.reshape(x, (x.shape[0], (-1)))
     x = self.classifier(x)
     return x
Ejemplo n.º 29
 def execute(self, x):
     return jt.reshape(x, [x.shape[0], *self.shape])
Ejemplo n.º 30
    def execute(self, conv4_3_feats, conv7_feats, conv8_2_feats, conv9_2_feats,
                conv10_2_feats, conv11_2_feats):
        """ Forward propagation.

            conv4_3_feats: conv4_3 feature map, a array of dimensions (N, 512, 38, 38)
            conv7_feats: conv7 feature map, a array of dimensions (N, 1024, 19, 19)
            conv8_2_feats: conv8_2 feature map, a array of dimensions (N, 512, 10, 10)
            conv9_2_feats: conv9_2 feature map, a array of dimensions (N, 256, 5, 5)
            conv10_2_feats: conv10_2 feature map, a array of dimensions (N, 256, 3, 3)
            conv11_2_feats: conv11_2 feature map, a array of dimensions (N, 256, 1, 1)
            8732 locations and class scores (i.e. w.r.t each prior box) for each image
        batch_size = conv4_3_feats.shape[0]
        l_conv4_3 = self.loc_conv4_3(conv4_3_feats)
        l_conv4_3 = jt.transpose(l_conv4_3, [0, 2, 3, 1])
        l_conv4_3 = jt.reshape(l_conv4_3, [batch_size, -1, 4])
        l_conv7 = self.loc_conv7(conv7_feats)
        l_conv7 = jt.transpose(l_conv7, [0, 2, 3, 1])
        l_conv7 = jt.reshape(l_conv7, [batch_size, -1, 4])
        l_conv8_2 = self.loc_conv8_2(conv8_2_feats)
        l_conv8_2 = jt.transpose(l_conv8_2, [0, 2, 3, 1])
        l_conv8_2 = jt.reshape(l_conv8_2, [batch_size, -1, 4])
        l_conv9_2 = self.loc_conv9_2(conv9_2_feats)
        l_conv9_2 = jt.transpose(l_conv9_2, [0, 2, 3, 1])
        l_conv9_2 = jt.reshape(l_conv9_2, [batch_size, -1, 4])
        l_conv10_2 = self.loc_conv10_2(conv10_2_feats)
        l_conv10_2 = jt.transpose(l_conv10_2, [0, 2, 3, 1])
        l_conv10_2 = jt.reshape(l_conv10_2, [batch_size, -1, 4])
        l_conv11_2 = self.loc_conv11_2(conv11_2_feats)
        l_conv11_2 = jt.transpose(l_conv11_2, [0, 2, 3, 1])
        l_conv11_2 = jt.reshape(l_conv11_2, [batch_size, -1, 4])
        c_conv4_3 = self.cl_conv4_3(conv4_3_feats)
        c_conv4_3 = jt.transpose(c_conv4_3, [0, 2, 3, 1])
        c_conv4_3 = jt.reshape(c_conv4_3, [batch_size, -1, self.n_classes])
        c_conv7 = self.cl_conv7(conv7_feats)
        c_conv7 = jt.transpose(c_conv7, [0, 2, 3, 1])
        c_conv7 = jt.reshape(c_conv7, [batch_size, -1, self.n_classes])
        c_conv8_2 = self.cl_conv8_2(conv8_2_feats)
        c_conv8_2 = jt.transpose(c_conv8_2, [0, 2, 3, 1])
        c_conv8_2 = jt.reshape(c_conv8_2, [batch_size, -1, self.n_classes])
        c_conv9_2 = self.cl_conv9_2(conv9_2_feats)
        c_conv9_2 = jt.transpose(c_conv9_2, [0, 2, 3, 1])
        c_conv9_2 = jt.reshape(c_conv9_2, [batch_size, -1, self.n_classes])
        c_conv10_2 = self.cl_conv10_2(conv10_2_feats)
        c_conv10_2 = jt.transpose(c_conv10_2, [0, 2, 3, 1])
        c_conv10_2 = jt.reshape(c_conv10_2, [batch_size, -1, self.n_classes])
        c_conv11_2 = self.cl_conv11_2(conv11_2_feats)
        c_conv11_2 = jt.transpose(c_conv11_2, [0, 2, 3, 1])
        c_conv11_2 = jt.reshape(c_conv11_2, [batch_size, -1, self.n_classes])
        locs = jt.contrib.concat(
            [l_conv4_3, l_conv7, l_conv8_2, l_conv9_2, l_conv10_2, l_conv11_2],
        classes_scores = jt.contrib.concat(
            [c_conv4_3, c_conv7, c_conv8_2, c_conv9_2, c_conv10_2, c_conv11_2],
        return (locs, classes_scores)