def build_network(self): """ Create network architecture. Refer to auxiliary.model :return: """ network = model.AE_AtlasNet_Humans(point_translation=self.opt.point_translation, dim_template=self.opt.dim_template, patch_deformation=self.opt.patch_deformation, dim_out_patch=self.opt.dim_out_patch, start_from=self.opt.start_from, dataset_train=self.dataset_train) network.cuda() # put network on GPU network.apply(my_utils.weights_init) # initialization of the weight if self.opt.model != "": try: network.load_state_dict(torch.load(self.opt.model)) print(" Previous network weights loaded! From ", self.opt.model) except: print("Failed to reload ", self.opt.model) if self.opt.reload: print(f"reload model frow : {self.opt.dir_name}/network.pth") self.opt.model = os.path.join(self.opt.dir_name, "network.pth") network.load_state_dict(torch.load(self.opt.model)) self.network = network self.network.eval() self.network.save_template_png(self.opt.dir_name)
def __init__(self, HR=0, nepoch=3000, model_path='trained_models/sup_human_network_last.pth', num_points=6890, num_angles=100, clean=1, scale=1, project_on_target=0, save_path=None, LR_input=True): self.LR_input = LR_input self.HR = HR self.nepoch = nepoch self.model_path = model_path self.num_points = num_points self.num_angles = num_angles self.clean = clean self.scale = scale self.project_on_target = project_on_target self.distChamfer = ext.chamferDist() # load network self.network = model.AE_AtlasNet_Humans(num_points=self.num_points) self.network.cuda() self.network.apply(my_utils.weights_init) if self.model_path != '': print("Reload weights from : ", self.model_path) self.network.load_state_dict(torch.load(self.model_path)) self.network.eval() self.neigh = NearestNeighbors(1, 0.4) self.mesh_ref = trimesh.load("./data/template/template_dense.ply", process=False) self.mesh_ref_LR = trimesh.load("./data/template/template.ply", process=False) # load colors self.red_LR = np.load("./data/template/red_LR.npy").astype("uint8") self.green_LR = np.load("./data/template/green_LR.npy").astype("uint8") self.blue_LR = np.load("./data/template/blue_LR.npy").astype("uint8") self.red_HR = np.load("./data/template/red_HR.npy").astype("uint8") self.green_HR = np.load("./data/template/green_HR.npy").astype("uint8") self.blue_HR = np.load("./data/template/blue_HR.npy").astype("uint8") self.save_path ="./results_HR_truth" #save_path