def test_reshape(self): a = jt.random([123, 456, 789]).name("a") b = jt.reshape(a, [123 * 2, int(789 * 456 / 2)]).name("b") c = jt.reshape(b, [123 * 456 * 789]).name("c") d = jt.reshape(c, [2, int(123 / 3), 789, int(456 / 2), 3]).name("d") e = jt.reshape(d, [2, int(123 / 3), 789, -1, 3]).name("e") assert b.shape == [123 * 2, int(789 * 456 / 2)] assert c.shape == [123 * 456 * 789] assert d.shape == [2, int(123 / 3), 789, int(456 / 2), 3] assert e.shape == [2, int(123 / 3), 789, int(456 / 2), 3] a_mean = a.mean().data b_mean = b.mean().data c_mean = c.mean().data d_mean = d.mean().data e_mean = e.mean().data a = (a + 1).name("new_a") new_a_mean = a.mean().data new_b_mean = b.mean().data node_dict = get_info(jt.dump_all_graphs()) assert check_equal(a_mean, b_mean), f"{a_mean} != {b_mean}" assert check_equal(a_mean, c_mean), f"{a_mean} != {c_mean}" assert check_equal(a_mean, d_mean), f"{a_mean} != {d_mean}" assert check_equal(a_mean, e_mean), f"{a_mean} != {e_mean}" assert check_equal(b_mean, new_b_mean), f"{b_mean} != {new_b_mean}" assert not check_equal(a_mean, new_a_mean), f"{a_mean} == {new_a_mean}" assert node_dict['a'] == node_dict['b'] assert node_dict['a'] == node_dict['c'] assert node_dict['a'] == node_dict['d'] assert node_dict['a'] == node_dict['e']
def channel_shuffle(x, groups): (batchsize, num_channels, height, width) = x.data.shape channels_per_group = (num_channels // groups) x = jt.reshape(x, [batchsize, groups, channels_per_group, height, width]) x = jt.transpose(x, (0, 2, 1, 3, 4)) x = jt.reshape(x, [batchsize, (-1), height, width]) return x
def execute(self, x): N, C, H, W = x.shape Kh, Kw = self.kernel_size oh = (H + self.padding[0] * 2 - Kh * self.dilation[0] + self.dilation[0] - 1) // self.stride[0] + 1 ow = (W + self.padding[1] * 2 - Kw * self.dilation[1] + self.dilation[1] - 1) // self.stride[1] + 1 x = jt.reshape(x, [N, self.groups, self.group_channel_in, H, W]) xx = x.reindex( [ N, self.groups, self.group_channel_out, self.group_channel_in, oh, ow, Kh, Kw ], [ 'i0', # Nid 'i1', # Group 'i3', # Cid f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid ]) ww = self.weight.broadcast(xx.shape, [0, 4, 5]) yy = xx * ww y = yy.sum([3, 6, 7]) # Kc, Kh, Kw y = jt.reshape(y, [N, self.out_channels, oh, ow]) if self.bias is not None: b = self.bias.broadcast(y.shape, [0, 2, 3]) y = y + b return y
def execute(self, predicted_locs, predicted_scores, boxes, labels): """ Forward propagation. Args: predicted_locs: predicted locations/boxes w.r.t the 8732 prior boxes, a tensor of dimensions (N, 8732, 4) predicted_scores: class scores for each of the encoded locations/boxes, a tensor of dimensions (N, 8732, n_classes) boxes: true object bounding boxes in boundary coordinates, a list of N tensors labels: true object labels, a list of N tensors Return: multibox loss, a scalar """ batch_size = predicted_locs.shape[0] n_priors = self.priors_cxcy.shape[0] n_classes = predicted_scores.shape[2] assert (n_priors == predicted_locs.shape[1] == predicted_scores.shape[1]) true_locs = np.zeros((batch_size, n_priors, 4)) true_classes = np.zeros((batch_size, n_priors)) for i in range(batch_size): # Step1: Select one object for every prior # Step2: Select one prior for every object, and set its iou 1 # Step3: Set priors as background, whose iou is lower than threshold, eg: 0.5 n_objects = boxes[i].shape[0] overlap = find_jaccard_overlap(boxes[i], self.priors_xy) object_for_each_prior, overlap_for_each_prior = argmax(overlap, axis=0) prior_for_each_object, _ = argmax(overlap, axis=1) object_for_each_prior[prior_for_each_object] = range(n_objects) overlap_for_each_prior[prior_for_each_object] = 1.0 label_for_each_prior = labels[i][object_for_each_prior] label_for_each_prior[overlap_for_each_prior < self.threshold] = 0 true_classes[i] = label_for_each_prior true_locs[i] = cxcy_to_gcxgcy( xy_to_cxcy(boxes[i][object_for_each_prior]), self.priors_cxcy) true_classes = jt.array(true_classes).float32().stop_grad() true_locs = jt.array(true_locs).float32().stop_grad() positive_priors = (true_classes != 0) loc_loss = self.smooth_l1( (predicted_locs * positive_priors.broadcast([1, 1, 4], [2])), (true_locs * positive_priors.broadcast([1, 1, 4], [2]))) loc_loss /= (positive_priors.float32().sum() * 4) n_positives = positive_priors.float32().sum(1) n_hard_negatives = self.neg_pos_ratio * n_positives conf_loss_all = self.cross_entropy( jt.reshape(predicted_scores, [-1, n_classes]), jt.reshape(true_classes, [ -1, ])) conf_loss_all = jt.reshape(conf_loss_all, [batch_size, n_priors]) conf_loss_pos = conf_loss_all * positive_priors conf_loss_neg = conf_loss_all * (1 - positive_priors) _, conf_loss_neg = conf_loss_neg.argsort(dim=1, descending=True) hardness_ranks = jt.array(range(n_priors)).broadcast( [conf_loss_neg.shape[0], conf_loss_neg.shape[1]], [0]) hard_negatives = hardness_ranks < n_hard_negatives.broadcast( hardness_ranks.shape, [1]) conf_loss_hard_neg = conf_loss_neg * hard_negatives conf_loss = ((conf_loss_hard_neg.sum() + conf_loss_pos.sum()) / n_positives.float32().sum()) return (conf_loss + (self.alpha * loc_loss)), conf_loss, loc_loss
def execute(self, x): ax = self.encoder(x) ax = jt.reshape(ax, [ax.shape[0], self.feat_dim]) f1 = self.fc1(ax) f1 = self.relu(f1) f2 = self.fc2(f1) f2 = jt.reshape(f2, [f2.shape[0], 512, self.rh, self.rw]) y = self.decoder(f2) return y
def execute(self, in_feat): z_img = self.model(in_feat) z = jt.reshape(z_img, [z_img.shape[0], (-1)]) zn = z[:, 0:self.latent_dim] zc_logits = z[:, self.latent_dim:] zc = nn.softmax(zc_logits, dim=1) return (zn, zc, zc_logits)
def execute(self, img): img_flat = jt.reshape(img, [img.shape[0], (-1)]) x = self.model(img_flat) mu = self.mu(x) logvar = self.logvar(x) z = reparameterization(mu, logvar) return z
def execute(self, img): out = self.feature_extractor(img) out = self.pooling(out) out = jt.reshape(out, [out.shape[0], (- 1)]) mu = self.fc_mu(out) logvar = self.fc_logvar(out) return mu, logvar
def _forward_impl(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # out = self.se(out) # x = self.conv2(x) # origin = x.size() # # print(out.size()) # x = x.flatten(2).transpose(0, 2, 1) # # print(out.size()) # x, out_weights = self.at(x) # # print(out.size()) # x = x.transpose(0, 2, 1).reshape(origin) # # print(out.size()) # x = self.conv3(x) x = self.avgpool(x) x = jt.reshape(x, (x.shape[0], -1)) x = self.fc(x) return x
def _forward(self, x): x = self.Conv2d_1a_3x3(x) x = self.Conv2d_2a_3x3(x) x = self.Conv2d_2b_3x3(x) x = nn.pool(x, 3, "maximum", stride=2) x = self.Conv2d_3b_1x1(x) x = self.Conv2d_4a_3x3(x) x = nn.pool(x, 3, "maximum", stride=2) x = self.Mixed_5b(x) x = self.Mixed_5c(x) x = self.Mixed_5d(x) x = self.Mixed_6a(x) x = self.Mixed_6b(x) x = self.Mixed_6c(x) x = self.Mixed_6d(x) x = self.Mixed_6e(x) aux_defined = self.aux_logits if aux_defined: aux = self.AuxLogits(x) else: aux = None x = self.Mixed_7a(x) x = self.Mixed_7b(x) x = self.Mixed_7c(x) x = nn.AdaptiveAvgPool2d(1)(x) x = nn.Dropout()(x) x = jt.reshape(x, (x.shape[0], (-1))) x = self.fc(x) return (x, aux)
def _forward(self, x): x = self.conv1(x) x = self.maxpool1(x) x = self.conv2(x) x = self.conv3(x) x = self.maxpool2(x) x = self.inception3a(x) x = self.inception3b(x) x = self.maxpool3(x) x = self.inception4a(x) if (self.aux1 is not None): aux1 = self.aux1(x) x = self.inception4b(x) x = self.inception4c(x) x = self.inception4d(x) if (self.aux2 is not None): aux2 = self.aux2(x) x = self.inception4e(x) x = self.maxpool4(x) x = self.inception5a(x) x = self.inception5b(x) x = self.avgpool(x) x = jt.reshape(x, (x.shape[0], (- 1))) x = self.dropout(x) x = self.fc(x) return (x, aux2, aux1)
def execute(self, x): x = nn.AdaptiveAvgPool2d(4)(x) x = self.conv(x) x = jt.reshape(x, (x.shape[0], (- 1))) x = nn.relu(self.fc1(x)) x = nn.Dropout(0.7)(x) x = self.fc2(x) return x
def execute(self, x): x = nn.pool(x, kernel_size=5, op="mean", stride=3) x = self.conv0(x) x = self.conv1(x) x = nn.AdaptiveAvgPool2d(1)(x) x = jt.reshape(x, (x.shape[0], (-1))) x = self.fc(x) return x
def execute(self, feature, input_part): # input2vector feature_vector = self.net_encoder(input_part) self.feature_vector = feature_vector fake_part = self.net_decoder(feature_vector) print(jt.reshape(self.feature_vector, (1, 512))) loss = self.criterion(fake_part, input_part.detach()) * 10 loss = loss.reshape(1) return fake_part, self.loss_filter(loss)
def execute(self, ten): # print("in DecoderGenerator, print some shape ") # print(ten.size()) ten = self.fc(ten) # print(ten.size()) ten = jt.reshape(ten,(ten.size()[0],512, self.latent_size, self.latent_size)) # print(ten.size()) ten = self.conv(ten) return ten
def execute(self, zn, zc): z = jt.contrib.concat([zn, zc], dim=1) x_gen = self.model0(z) x_gen = self.model1(x_gen) x_gen = self.model2(x_gen) x_gen = self.model3(x_gen) x_gen = self.model4(x_gen) x_gen = self.sigmoid(x_gen) x_gen = jt.reshape(x_gen, [x_gen.shape[0], *self.x_shape]) return x_gen
def execute(self, ten): # ten = ten[:,:,:] # ten2 = jt.reshape(ten,[ten.size()[0],-1]) # print(ten.shape, ten2.shape) ten = self.conv(ten) ten = jt.reshape(ten, [ten.size()[0], -1]) # print(ten.shape,self.longsize) mu = self.fc_mu(ten) # logvar = self.fc_var(ten) return mu # ,logvar
def execute(self, x): out = nn.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = nn.pool(out, size=4, op="mean", padding=0) out = jt.reshape(out, [out.shape[0], -1]) out = self.linear1(out) out = self.linear2(out) return out
def calc_gradient_penalty(netD, real_data, generated_data): LAMBDA = 10 b_size = real_data.shape[0] alpha = jt.random([b_size, 1, 1, 1]) alpha = alpha.broadcast(real_data) interpolated = ((alpha * real_data.data) + ((1 - alpha) * generated_data.data)) prob_interpolated = netD(interpolated) gradients = jt.grad(prob_interpolated, interpolated) gradients = jt.reshape(gradients, [b_size, -1]) gradients_norm = jt.sqrt((jt.sum((gradients**2), dim=1) + 1e-12)) return (LAMBDA * ((gradients_norm - 1)**2).mean())
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024 * 64): """Prepares inputs and applies network 'fn'. """ inputs_flat = jt.reshape(inputs, [-1, inputs.shape[-1]]) embedded = embed_fn(inputs_flat) if viewdirs is not None: input_dirs = viewdirs[:, None].expand(inputs.shape) input_dirs_flat = jt.reshape(input_dirs, [-1, input_dirs.shape[-1]]) embedded_dirs = embeddirs_fn(input_dirs_flat) embedded = jt.concat([embedded, embedded_dirs], -1) outputs_flat = batchify(fn, netchunk)(embedded) outputs = jt.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) return outputs
def _forward_impl(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = jt.reshape(x, (x.shape[0], -1)) x = self.fc(x) return x
def execute(self, x): x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.bn(x) x = self.relu(x) x = self.max_pool(x) x = jt.reshape(x, [x.shape[0], -1]) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x
def execute(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = jt.reshape(x, [x.shape[0], -1]) x = self.fc(x) return x
def execute(self, x, z): z = self.fc(z) z = jt.reshape(z, [z.shape[0], 1, self.h, self.w]) d1 = self.down1(jt.contrib.concat([x, z], dim=1)) d2 = self.down2(d1) d3 = self.down3(d2) d4 = self.down4(d3) d5 = self.down5(d4) d6 = self.down6(d5) d7 = self.down7(d6) u1 = self.up1(d7, d6) u2 = self.up2(u1, d5) u3 = self.up3(u2, d4) u4 = self.up4(u3, d3) u5 = self.up5(u4, d2) u6 = self.up6(u5, d1) return self.final(u6)
def execute(self, x): x = self.features(x) x = jt.reshape(x, [x.shape[0],-1]) x = self.classifier(x) return x
def render_rays(ray_batch, network_fn, network_query_fn, N_samples, retraw=False, lindisp=False, perturb=0., N_importance=0, network_fine=None, white_bkgd=False, raw_noise_std=0., verbose=False): """Volumetric rendering. Args: ray_batch: array of shape [batch_size, ...]. All information necessary for sampling along a ray, including: ray origin, ray direction, min dist, max dist, and unit-magnitude viewing direction. network_fn: function. Model for predicting RGB and density at each point in space. network_query_fn: function used for passing queries to network_fn. N_samples: int. Number of different times to sample along each ray. retraw: bool. If True, include model's raw, unprocessed predictions. lindisp: bool. If True, sample linearly in inverse depth rather than in depth. perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified random points in time. N_importance: int. Number of additional times to sample along each ray. These samples are only passed to network_fine. network_fine: "fine" network with same spec as network_fn. white_bkgd: bool. If True, assume a white background. raw_noise_std: ... verbose: bool. If True, print more debugging info. Returns: rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. disp_map: [num_rays]. Disparity map. 1 / depth. acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. raw: [num_rays, num_samples, 4]. Raw predictions from model. rgb0: See rgb_map. Output for coarse model. disp0: See disp_map. Output for coarse model. acc0: See acc_map. Output for coarse model. z_std: [num_rays]. Standard deviation of distances along ray for each sample. """ N_rays = ray_batch.shape[0] rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None bounds = jt.reshape(ray_batch[..., 6:8], [-1, 1, 2]) near, far = bounds[..., 0], bounds[..., 1] # [-1,1] z_vals = sample(N_rays, N_samples, lindisp, perturb, near, far) pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze( -1) # [N_rays, N_samples, 3] raw = network_query_fn(pts, viewdirs, network_fn) rgb_map, disp_map, acc_map, weights, depth_map = integrator( raw, z_vals, rays_d, raw_noise_std, white_bkgd) rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map #usefulRayIndex = jt.nonzero(acc_map > 0.1) if N_importance > 0: # importance sampling z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.)) z_samples = z_samples.detach() _, z_vals = jt.argsort(jt.concat([z_vals, z_samples], -1), -1) pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze( -1) # [N_rays, N_samples + N_importance, 3] run_fn = network_fn if network_fine is None else network_fine raw = network_query_fn(pts, viewdirs, run_fn) rgb_map, disp_map, acc_map, weights, depth_map = integrator( raw, z_vals, rays_d, raw_noise_std, white_bkgd) ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map} if retraw: ret['raw'] = raw if N_importance > 0: ret['rgb0'] = rgb_map_0 ret['disp0'] = disp_map_0 ret['acc0'] = acc_map_0 return ret
def render(H, W, focal, chunk=1024 * 32, rays=None, c2w=None, intrinsic=None, ndc=True, near=0., far=1., use_viewdirs=False, c2w_staticcam=None, **kwargs): """Render rays Args: H: int. Height of image in pixels. W: int. Width of image in pixels. focal: float. Focal length of pinhole camera. chunk: int. Maximum number of rays to process simultaneously. Used to control maximum memory usage. Does not affect final results. rays: array of shape [2, batch_size, 3]. Ray origin and direction for each example in batch. c2w: array of shape [3, 4]. Camera-to-world transformation matrix. ndc: bool. If True, represent ray origin, direction in NDC coordinates. near: float or array of shape [batch_size]. Nearest distance for a ray. far: float or array of shape [batch_size]. Farthest distance for a ray. use_viewdirs: bool. If True, use viewing direction of a point in space in model. c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for camera while using other c2w argument for viewing directions. Returns: rgb_map: [batch_size, 3]. Predicted RGB values for rays. disp_map: [batch_size]. Disparity map. Inverse of depth. acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. extras: dict with everything returned by render_rays(). """ if c2w is not None: # special case to render full image rays_o, rays_d = pinhole_get_rays(H, W, focal, c2w, intrinsic) else: # use provided ray batch rays_o, rays_d = rays if use_viewdirs: # provide ray directions as input viewdirs = rays_d if c2w_staticcam is not None: assert intrinsic is None rays_o, rays_d = pinhole_get_rays(H, W, focal, c2w_staticcam) viewdirs = viewdirs / jt.norm(viewdirs, p=2, dim=-1, keepdim=True) viewdirs = jt.reshape(viewdirs, [-1, 3]).float() sh = rays_d.shape # [..., 3] if ndc: # for forward facing scenes rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d) # Create ray batch rays_o = jt.reshape(rays_o, [-1, 3]).float() rays_d = jt.reshape(rays_d, [-1, 3]).float() near, far = near * jt.ones_like(rays_d[..., :1]), far * jt.ones_like( rays_d[..., :1]) rays = jt.concat([rays_o, rays_d, near, far], -1) if use_viewdirs: rays = jt.concat([rays, viewdirs], -1) # Render and reshape all_ret = batchify_rays(rays, chunk, **kwargs) for k in all_ret: k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) all_ret[k] = jt.reshape(all_ret[k], k_sh) k_extract = ['rgb_map', 'disp_map', 'acc_map'] ret_list = [all_ret[k] for k in k_extract] ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract} return ret_list + [ret_dict]
def execute(self, x): x = self.features(x) x = self.avgpool(x) x = jt.reshape(x, (x.shape[0], (-1))) x = self.classifier(x) return x
def execute(self, x): return jt.reshape(x, [x.shape[0], *self.shape])
def execute(self, conv4_3_feats, conv7_feats, conv8_2_feats, conv9_2_feats, conv10_2_feats, conv11_2_feats): """ Forward propagation. Args: conv4_3_feats: conv4_3 feature map, a array of dimensions (N, 512, 38, 38) conv7_feats: conv7 feature map, a array of dimensions (N, 1024, 19, 19) conv8_2_feats: conv8_2 feature map, a array of dimensions (N, 512, 10, 10) conv9_2_feats: conv9_2 feature map, a array of dimensions (N, 256, 5, 5) conv10_2_feats: conv10_2 feature map, a array of dimensions (N, 256, 3, 3) conv11_2_feats: conv11_2 feature map, a array of dimensions (N, 256, 1, 1) Return: 8732 locations and class scores (i.e. w.r.t each prior box) for each image """ batch_size = conv4_3_feats.shape[0] l_conv4_3 = self.loc_conv4_3(conv4_3_feats) l_conv4_3 = jt.transpose(l_conv4_3, [0, 2, 3, 1]) l_conv4_3 = jt.reshape(l_conv4_3, [batch_size, -1, 4]) l_conv7 = self.loc_conv7(conv7_feats) l_conv7 = jt.transpose(l_conv7, [0, 2, 3, 1]) l_conv7 = jt.reshape(l_conv7, [batch_size, -1, 4]) l_conv8_2 = self.loc_conv8_2(conv8_2_feats) l_conv8_2 = jt.transpose(l_conv8_2, [0, 2, 3, 1]) l_conv8_2 = jt.reshape(l_conv8_2, [batch_size, -1, 4]) l_conv9_2 = self.loc_conv9_2(conv9_2_feats) l_conv9_2 = jt.transpose(l_conv9_2, [0, 2, 3, 1]) l_conv9_2 = jt.reshape(l_conv9_2, [batch_size, -1, 4]) l_conv10_2 = self.loc_conv10_2(conv10_2_feats) l_conv10_2 = jt.transpose(l_conv10_2, [0, 2, 3, 1]) l_conv10_2 = jt.reshape(l_conv10_2, [batch_size, -1, 4]) l_conv11_2 = self.loc_conv11_2(conv11_2_feats) l_conv11_2 = jt.transpose(l_conv11_2, [0, 2, 3, 1]) l_conv11_2 = jt.reshape(l_conv11_2, [batch_size, -1, 4]) c_conv4_3 = self.cl_conv4_3(conv4_3_feats) c_conv4_3 = jt.transpose(c_conv4_3, [0, 2, 3, 1]) c_conv4_3 = jt.reshape(c_conv4_3, [batch_size, -1, self.n_classes]) c_conv7 = self.cl_conv7(conv7_feats) c_conv7 = jt.transpose(c_conv7, [0, 2, 3, 1]) c_conv7 = jt.reshape(c_conv7, [batch_size, -1, self.n_classes]) c_conv8_2 = self.cl_conv8_2(conv8_2_feats) c_conv8_2 = jt.transpose(c_conv8_2, [0, 2, 3, 1]) c_conv8_2 = jt.reshape(c_conv8_2, [batch_size, -1, self.n_classes]) c_conv9_2 = self.cl_conv9_2(conv9_2_feats) c_conv9_2 = jt.transpose(c_conv9_2, [0, 2, 3, 1]) c_conv9_2 = jt.reshape(c_conv9_2, [batch_size, -1, self.n_classes]) c_conv10_2 = self.cl_conv10_2(conv10_2_feats) c_conv10_2 = jt.transpose(c_conv10_2, [0, 2, 3, 1]) c_conv10_2 = jt.reshape(c_conv10_2, [batch_size, -1, self.n_classes]) c_conv11_2 = self.cl_conv11_2(conv11_2_feats) c_conv11_2 = jt.transpose(c_conv11_2, [0, 2, 3, 1]) c_conv11_2 = jt.reshape(c_conv11_2, [batch_size, -1, self.n_classes]) locs = jt.contrib.concat( [l_conv4_3, l_conv7, l_conv8_2, l_conv9_2, l_conv10_2, l_conv11_2], dim=1) classes_scores = jt.contrib.concat( [c_conv4_3, c_conv7, c_conv8_2, c_conv9_2, c_conv10_2, c_conv11_2], dim=1) return (locs, classes_scores)