def make_base_grid_5D(theta,N,C,D,H,W,align_corners): base_grid = jt.zeros((N, D, H, W, 4), dtype=theta.dtype) base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners) base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1) base_grid[...,2] = jt.unsqueeze(jt.unsqueeze(linspace_from_neg_one(theta, D, align_corners),-1),-1) base_grid[...,-1] = 1 return base_grid
def get_shadow_image(self, mus_mouth, weight, nearnN): fakes = 0 for i in range(nearnN): w_i = weight[i] if w_i <= 0: continue elif w_i > 0.5: w_i = 0.5 # print(i) mus_vec = jt.unsqueeze(mus_mouth[[i], :], 1) fake_image = self.net_decoder(jt.array(mus_vec)) # fake_image = fake_image[[0],:,:,:] if i == 0: fakes = (1 - fake_image) / 2 * w_i else: fakes = fakes + (1 - fake_image) / 2 * w_i fakes = 1 - fakes fakes = fakes[0, :, :, :].detach().numpy() fakes = np.transpose(fakes, (1, 2, 0)) * 255.0 fakes = np.clip(fakes, 0, 255) return fakes.astype(np.uint8)
def execute(self, x): # (N, K, C_out) """ Applies XConv to the input data. :param x: (rep_pt, pts, fts) where - rep_pt: Representative point. - pts: Regional point cloud such that fts[:,p_idx,:] is the feature associated with pts[:,p_idx,:]. - fts: Regional features such that pts[:,p_idx,:] is the feature associated with fts[:,p_idx,:]. :return: Features aggregated into point rep_pt. """ rep_pt, pts, fts = x # b, n, c // b ,n k, c // b, n, k, d if fts is not None: assert(rep_pt.size()[0] == pts.size()[0] == fts.size()[0]) # Check N is equal. assert(rep_pt.size()[1] == pts.size()[1] == fts.size()[1]) # Check P is equal. assert(pts.size()[2] == fts.size()[2] == self.K) # Check K is equal. assert(fts.size()[3] == self.cin) # Check C_in is equal. else: assert(rep_pt.size()[0] == pts.size()[0]) # Check N is equal. assert(rep_pt.size()[1] == pts.size()[1]) # Check P is equal. assert(pts.size()[2] == self.K) # Check K is equal. assert(rep_pt.size()[2] == pts.size()[3] == self.dims) # Check dims is equal. N = pts.size()[0] P = rep_pt.size()[1] # (N, P, K, dims) p_center = jt.unsqueeze(rep_pt, dim = 2) # (N, P, 1, dims) # print (p_center.size()) # # Move pts to local coordinate system of rep_pt. pts_local = pts - p_center.repeat(1, 1, self.K, 1) # (N, P, K, dims) # pts_local = self.pts_layernorm(pts - p_center) # Individually lift each point into C_mid space. # print (pts_local.size(), 'before size') pts_local = pts_local.permute(0, 3, 1, 2) # N, dim, P, K fts_lifted0 = self.dense1(pts_local) # ? # print (.size(), 'after size') fts_lifted = self.dense2(fts_lifted0) # N, C_mid, P, K fts = fts.permute(0, 3, 1, 2) if fts is None: fts_cat = fts_lifted else: fts_cat = concat((fts_lifted, fts), 1) # (N, C_mid + C_in, P, K) # Learn the (N, K, K) X-transformation matrix. X_shape = (N, P, self.K, self.K) # X = self.x_trans(pts_local) # N, K*K, 1, P x = self.x_trans_0(pts_local) x = self.x_trans_1(x) X = self.x_trans_2(x) # print ('X size ', X.size()) X = X.permute(0, 2, 3, 1) # n p 1 k X = X.view(X_shape) # N, P, K, K # print (fts_cat.shape) fts_cat = fts_cat.permute(0, 2, 3, 1) fts_X = jt.matmul(X, fts_cat) # # print ('fts X size =', fts_X.shape) fts_p = self.end_conv(fts_X).squeeze(dim = 2) # print ('xxxxxxxxxxx') # print ('result size') # print (fts_X.size(), fts_p.size()) return fts_p
def show_jitor(tensor): data = np.array(tensor.data) img = data[0][0] plt.imshow(img, cmap='gray') plt.show() if __name__ == '__main__': jt.flags.use_cuda = 1 opt = TrainOptions().parse() feature = 'nose' data_loader = SingelDataLoader() data_loader.initialize('dataset', True, 'nose') model = CE_Model(opt, feature) model.initialize(opt, feature) model.load_networ_from_file( '/home/loc/face_psych/Params/CE_model/nose_encoder.pkl', '/home/loc/face_psych/Params/CE_model/nose_decoder.pkl') # model.load_networ_from_file('/home/loc/Desktop/DeepFaceDrawing-Jittor/Params/AE_whole/latest_net_encoder_nose.pkl', # '/home/loc/Desktop/DeepFaceDrawing-Jittor/Params/AE_whole/latest_net_decoder_nose_image.pkl') in_img = data_loader.get_data(10) in_img = jt.unsqueeze(in_img, 0) show_jitor(in_img) generated, losses = model(feature, in_img) show_jitor(generated) print(losses)
def getonehot(outputs, classes, batch_size): index = jt.argmax(outputs,1) y = jt.unsqueeze(index[0],1) onehot = jt.zeros([batch_size, classes]) onehot.scatter_(1, y, jt.array(1.)) return onehot