def interpolate(network, interval, experiment_directory, checkpoint, split_file, epoch, resolution, uniform_grid): with open(split_file, "r") as f: split = json.load(f) ds = utils.get_class(conf.get_string('train.dataset'))( split=split, dataset_path=conf.get_string('train.dataset_path'), with_normals=True) points_1, normals_1, index_1 = ds[0] points_2, normals_2, index_2 = ds[1] pnts = torch.cat([points_1, points_2], dim=0).cuda() name_1 = str.join('_', ds.get_info(0)) name_2 = str.join('_', ds.get_info(0)) name = name_1 + '_and_' + name_2 utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate')) utils.mkdir_ifnotexists( os.path.join(experiment_directory, 'interpolate', str(checkpoint))) utils.mkdir_ifnotexists( os.path.join(experiment_directory, 'interpolate', str(checkpoint), name)) my_path = os.path.join(experiment_directory, 'interpolate', str(checkpoint), name) latent_1 = optimize_latent(points_1.cuda(), normals_1.cuda(), conf, 800, network, 5e-3) latent_2 = optimize_latent(points_2.cuda(), normals_2.cuda(), conf, 800, network, 5e-3) pnts = torch.cat([latent_1.repeat(pnts.shape[0], 1), pnts], dim=-1) with torch.no_grad(): network.eval() for alpha in np.linspace(0, 1, interval): latent = (latent_1 * (1 - alpha)) + (latent_2 * alpha) plt.plot_surface(with_points=False, points=pnts, decoder=network, latent=latent, path=my_path, epoch=epoch, shapename=str(alpha), resolution=resolution, mc_value=0, is_uniform_grid=uniform_grid, verbose=True, save_html=False, save_ply=True, overwrite=True, connected=True)
def evaluate(network, experiment_directory, conf, checkpoint, split_file, epoch, resolution, uniform_grid): my_path = os.path.join(experiment_directory, 'evaluation', str(checkpoint)) utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'evaluation')) utils.mkdir_ifnotexists(my_path) with open(split_file, "r") as f: split = json.load(f) ds = utils.get_class(conf.get_string('train.dataset'))(split=split, dataset_path=conf.get_string('train.dataset_path'), with_normals=True) total_files = len(ds) print("total files : {0}".format(total_files)) counter = 0 dataloader = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=True, num_workers=1, drop_last=False, pin_memory=True) for (input_pc, normals, index) in dataloader: input_pc = input_pc.cuda().squeeze() normals = normals.cuda().squeeze() print(counter) counter = counter + 1 network.train() latent = optimize_latent(input_pc, normals, conf, 800, network, lr=5e-3) all_latent = latent.repeat(input_pc.shape[0], 1) points = torch.cat([all_latent,input_pc], dim=-1) shapename = str.join('_', ds.get_info(index)) with torch.no_grad(): network.eval() plt.plot_surface(with_points=True, points=points, decoder=network, latent=latent, path=my_path, epoch=epoch, shapename=shapename, resolution=resolution, mc_value=0, is_uniform_grid=uniform_grid, verbose=True, save_html=True, save_ply=True, overwrite=True, connected=True)
def evaluate(**kwargs): torch.set_default_dtype(torch.float32) conf = ConfigFactory.parse_file(kwargs['conf']) exps_folder_name = kwargs['exps_folder_name'] evals_folder_name = kwargs['evals_folder_name'] eval_rendering = kwargs['eval_rendering'] eval_animation = kwargs['eval_animation'] expname = conf.get_string('train.expname') + kwargs['expname'] scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int( 'dataset.scan_id', default=-1) if scan_id != -1: expname = expname + '_{0}'.format(scan_id) if kwargs['timestamp'] == 'latest': if os.path.exists( os.path.join('../', kwargs['exps_folder_name'], expname)): timestamps = os.listdir( os.path.join('../', kwargs['exps_folder_name'], expname)) if (len(timestamps)) == 0: print('WRONG EXP FOLDER') exit() else: timestamp = sorted(timestamps)[-1] else: print('WRONG EXP FOLDER') exit() else: timestamp = kwargs['timestamp'] utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name)) expdir = os.path.join('../', exps_folder_name, expname) evaldir = os.path.join('../', evals_folder_name, expname) utils.mkdir_ifnotexists(evaldir) dataset_conf = conf.get_config('dataset') model = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'),\ id=scan_id, datadir=dataset_conf['data_dir']) if torch.cuda.is_available(): model.cuda() if kwargs['scan_id'] != -1: dataset_conf['scan_id'] = kwargs['scan_id'] eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))( False, **dataset_conf) if eval_rendering: eval_dataloader = torch.utils.data.DataLoader( eval_dataset, batch_size=1, shuffle=False, collate_fn=eval_dataset.collate_fn) total_pixels = eval_dataset.total_pixels img_res = eval_dataset.img_res old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints') saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) model.load_state_dict(saved_model_state["model_state_dict"]) epoch = saved_model_state['epoch'] #################################################################################################################### print("evaluating...") model.eval() detail_3dmm, detail_3dmm_subdivision_full = plt.get_displacement_mesh( model) detail_3dmm.export('{0}/Detailed_3dmm_{1}.obj'.format(evaldir, epoch), 'obj') detail_3dmm_subdivision_full.export( '{0}/Subdivide_full_{1}.obj'.format(evaldir, epoch), 'obj') if eval_animation: sdf_np0, sdf_np1 = plt.get_displacement_animation(model) np.save('{0}/Cropped_Detailed_sdf_{1}.npy'.format(evaldir, epoch), sdf_np0) np.save('{0}/Cropped_Subdivide_full_{1}.npy'.format(evaldir, epoch), sdf_np1) if eval_rendering: images_dir = '{0}/rendering'.format(evaldir) utils.mkdir_ifnotexists(images_dir) psnrs = [] for data_index, (indices, model_input, ground_truth) in enumerate(eval_dataloader): model_input["intrinsics"] = model_input["intrinsics"].cuda() model_input["uv"] = model_input["uv"].cuda() model_input["object_mask"] = model_input["object_mask"].cuda() model_input['pose'] = model_input['pose'].cuda() split = utils.split_input(model_input, total_pixels) res = [] for s in split: out = model(s) res.append({ 'rgb_values': out['rgb_values'].detach(), 'diffuse_values': out['diffuse_values'].detach(), 'specular_values': out['specular_values'].detach(), 'albedo_values': out['albedo_values'].detach(), }) batch_size = ground_truth['rgb'].shape[0] model_outputs = utils.merge_output(res, total_pixels, batch_size) rgb_eval = model_outputs['rgb_values'] rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3) rgb_eval = (rgb_eval + 1.) / 2. rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0] rgb_eval = rgb_eval.transpose(1, 2, 0) img = Image.fromarray((rgb_eval * 255).astype(np.uint8)) img.save('{0}/eval_{1}.png'.format(images_dir, '%03d' % indices[0])) diffuse_eval = model_outputs['diffuse_values'] diffuse_eval = diffuse_eval.reshape(batch_size, total_pixels, 3) diffuse_eval = (diffuse_eval + 1.) / 2. diffuse_eval = plt.lin2img(diffuse_eval, img_res).detach().cpu().numpy()[0] diffuse_eval = diffuse_eval.transpose(1, 2, 0) img = Image.fromarray((diffuse_eval * 255).astype(np.uint8)) img.save('{0}/eval_{1}_diffuse.png'.format(images_dir, '%03d' % indices[0])) specular_eval = model_outputs['specular_values'] specular_eval = specular_eval.reshape(batch_size, total_pixels, 3) specular_eval = (specular_eval + 1.) / 2. specular_eval = plt.lin2img(specular_eval, img_res).detach().cpu().numpy()[0] specular_eval = specular_eval.transpose(1, 2, 0) img = Image.fromarray((specular_eval * 255).astype(np.uint8)) img.save('{0}/eval_{1}_specular.png'.format( images_dir, '%03d' % indices[0])) albedo_eval = model_outputs['albedo_values'] albedo_eval = albedo_eval.reshape(batch_size, total_pixels, 3) albedo_eval = (albedo_eval + 1.) / 2. albedo_eval = plt.lin2img(albedo_eval, img_res).detach().cpu().numpy()[0] albedo_eval = albedo_eval.transpose(1, 2, 0) img = Image.fromarray((albedo_eval * 255).astype(np.uint8)) img.save('{0}/eval_{1}_albedo.png'.format(images_dir, '%03d' % indices[0])) rgb_gt = ground_truth['rgb'] rgb_gt = (rgb_gt + 1.) / 2. rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0] rgb_gt = rgb_gt.transpose(1, 2, 0) mask = model_input['object_mask'] mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0] mask = mask.transpose(1, 2, 0) rgb_eval_masked = rgb_eval * mask rgb_gt_masked = rgb_gt * mask psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask) psnrs.append(psnr) psnrs = np.array(psnrs).astype(np.float64) print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}". format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id))
def __init__(self, **kwargs): self.home_dir = os.path.abspath(os.pardir) # config setting if type(kwargs['conf']) == str: self.conf_filename = './reconstruction/' + kwargs['conf'] self.conf = ConfigFactory.parse_file(self.conf_filename) else: self.conf = kwargs['conf'] self.expname = kwargs['expname'] # GPU settings self.GPU_INDEX = kwargs['gpu_index'] if not self.GPU_INDEX == 'ignore': os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX) self.num_of_gpus = torch.cuda.device_count() self.eval = kwargs['eval'] # settings for loading an existing experiment if (kwargs['is_continue'] or self.eval) and kwargs['timestamp'] == 'latest': if os.path.exists(os.path.join(self.home_dir, 'exps', self.expname)): timestamps = os.listdir(os.path.join(self.home_dir, 'exps', self.expname)) if (len(timestamps)) == 0: is_continue = False timestamp = None else: timestamp = sorted(timestamps)[-1] is_continue = True else: is_continue = False timestamp = None else: timestamp = kwargs['timestamp'] is_continue = kwargs['is_continue'] or self.eval self.exps_folder_name = 'exps' utils.mkdir_ifnotexists(utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name))) self.input_file = self.conf.get_string('train.input_path') self.data = utils.load_point_cloud_by_file_extension(self.input_file) sigma_set = [] ptree = cKDTree(self.data) for p in np.array_split(self.data, 100, axis=0): d = ptree.query(p, 50 + 1) sigma_set.append(d[0][:, -1]) sigmas = np.concatenate(sigma_set) self.local_sigma = torch.from_numpy(sigmas).float().cuda() self.expdir = utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name, self.expname)) utils.mkdir_ifnotexists(self.expdir) if is_continue: self.timestamp = timestamp else: self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) self.cur_exp_dir = os.path.join(self.expdir, self.timestamp) utils.mkdir_ifnotexists(self.cur_exp_dir) self.plots_dir = os.path.join(self.cur_exp_dir, 'plots') utils.mkdir_ifnotexists(self.plots_dir) self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.model_params_subdir = "ModelParameters" self.optimizer_params_subdir = "OptimizerParameters" utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir)) utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) self.nepochs = kwargs['nepochs'] self.points_batch = kwargs['points_batch'] self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma') self.sampler = Sampler.get_sampler(self.conf.get_string('network.sampler.sampler_type'))(self.global_sigma, self.local_sigma) self.grad_lambda = self.conf.get_float('network.loss.lambda') self.normals_lambda = self.conf.get_float('network.loss.normals_lambda') self.with_normals = self.normals_lambda > 0 self.d_in = self.conf.get_int('train.d_in') self.network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=self.d_in, **self.conf.get_config( 'network.inputs')) if torch.cuda.is_available(): self.network.cuda() self.lr_schedules = self.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule')) self.weight_decay = self.conf.get_float('train.weight_decay') self.startepoch = 0 self.optimizer = torch.optim.Adam( [ { "params": self.network.parameters(), "lr": self.lr_schedules[0].get_learning_rate(0), "weight_decay": self.weight_decay }, ]) # if continue load checkpoints if is_continue: old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) self.network.load_state_dict(saved_model_state["model_state_dict"]) data = torch.load( os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) self.optimizer.load_state_dict(data["optimizer_state_dict"]) self.startepoch = saved_model_state['epoch']
def evaluate(**kwargs): torch.set_default_dtype(torch.float32) conf = ConfigFactory.parse_file(kwargs['conf']) exps_folder_name = kwargs['exps_folder_name'] evals_folder_name = kwargs['evals_folder_name'] timestamp = '2020' checkpoint = '2000' expname = conf.get_string('train.expname') geometry_id = kwargs['geometry_id'] appearance_id = kwargs['appearance_id'] utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name)) expdir_geometry = os.path.join('../', exps_folder_name, expname + '_{0}'.format(geometry_id)) expdir_appearance = os.path.join('../', exps_folder_name, expname + '_{0}'.format(appearance_id)) evaldir = os.path.join('../', evals_folder_name, expname + '_{0}_{1}'.format(geometry_id, appearance_id)) utils.mkdir_ifnotexists(evaldir) model = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model')) if torch.cuda.is_available(): model.cuda() # Load geometry network model old_checkpnts_dir = os.path.join(expdir_geometry, timestamp, 'checkpoints') saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth")) model.load_state_dict(saved_model_state["model_state_dict"]) # Load rendering network model model_fake = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model')) if torch.cuda.is_available(): model_fake.cuda() old_checkpnts_dir = os.path.join(expdir_appearance, timestamp, 'checkpoints') saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth")) model_fake.load_state_dict(saved_model_state["model_state_dict"]) model.rendering_network = model_fake.rendering_network dataset_conf = conf.get_config('dataset') dataset_conf['scan_id'] = geometry_id eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(False, **dataset_conf) eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=1, shuffle=True, collate_fn=eval_dataset.collate_fn ) total_pixels = eval_dataset.total_pixels img_res = eval_dataset.img_res #################################################################################################################### print("evaluating...") model.eval() gt_pose = eval_dataset.get_gt_pose(scaled=True).cuda() gt_quat = rend_util.rot_to_quat(gt_pose[:, :3, :3]) gt_pose_vec = torch.cat([gt_quat, gt_pose[:, :3, 3]], 1) indices_all = [11, 16, 34, 28, 11] pose = gt_pose_vec[indices_all, :] t_in = np.array([0, 2, 3, 5, 6]).astype(np.float32) n_inter = 5 t_out = np.linspace(t_in[0], t_in[-1], n_inter * t_in[-1]).astype(np.float32) scales = np.array([4.2, 4.2, 3.8, 3.8, 4.2]).astype(np.float32) s_new = CubicSpline(t_in, scales, bc_type='periodic') s_new = s_new(t_out) q_new = CubicSpline(t_in, pose[:, :4].detach().cpu().numpy(), bc_type='periodic') q_new = q_new(t_out) q_new = q_new / np.linalg.norm(q_new, 2, 1)[:, None] q_new = torch.from_numpy(q_new).cuda().float() images_dir = '{0}/novel_views_rendering'.format(evaldir) utils.mkdir_ifnotexists(images_dir) indices, model_input, ground_truth = next(iter(eval_dataloader)) for i, (new_q, scale) in enumerate(zip(q_new, s_new)): torch.cuda.empty_cache() new_q = new_q.unsqueeze(0) new_t = -rend_util.quat_to_rot(new_q)[:, :, 2] * scale new_p = torch.eye(4).float().cuda().unsqueeze(0) new_p[:, :3, :3] = rend_util.quat_to_rot(new_q) new_p[:, :3, 3] = new_t sample = { "object_mask": torch.zeros_like(model_input['object_mask']).cuda().bool(), "uv": model_input['uv'].cuda(), "intrinsics": model_input['intrinsics'].cuda(), "pose": new_p } split = utils.split_input(sample, total_pixels) res = [] for s in split: out = model(s) res.append({ 'rgb_values': out['rgb_values'].detach(), }) batch_size = 1 model_outputs = utils.merge_output(res, total_pixels, batch_size) rgb_eval = model_outputs['rgb_values'] rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3) rgb_eval = (rgb_eval + 1.) / 2. rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0] rgb_eval = rgb_eval.transpose(1, 2, 0) img = Image.fromarray((rgb_eval * 255).astype(np.uint8)) img.save('{0}/eval_{1}.png'.format(images_dir,'%03d' % i))
def __init__(self, **kwargs): torch.set_default_dtype(torch.float32) torch.set_num_threads(1) self.conf = ConfigFactory.parse_file(kwargs['conf']) self.batch_size = kwargs['batch_size'] self.nepochs = kwargs['nepochs'] self.exps_folder_name = kwargs['exps_folder_name'] self.GPU_INDEX = kwargs['gpu_index'] self.train_cameras = kwargs['train_cameras'] self.expname = self.conf.get_string( 'train.expname') + kwargs['expname'] scan_id = kwargs[ 'scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int( 'dataset.scan_id', default=-1) if scan_id != -1: self.expname = self.expname + '_{0}'.format(scan_id) if kwargs['is_continue'] and kwargs['timestamp'] == 'latest': if os.path.exists( os.path.join('../', kwargs['exps_folder_name'], self.expname)): timestamps = os.listdir( os.path.join('../', kwargs['exps_folder_name'], self.expname)) if (len(timestamps)) == 0: is_continue = False timestamp = None else: timestamp = sorted(timestamps)[-1] is_continue = True else: is_continue = False timestamp = None else: timestamp = kwargs['timestamp'] is_continue = kwargs['is_continue'] utils.mkdir_ifnotexists(os.path.join('../', self.exps_folder_name)) self.expdir = os.path.join('../', self.exps_folder_name, self.expname) utils.mkdir_ifnotexists(self.expdir) self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp)) self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots') utils.mkdir_ifnotexists(self.plots_dir) # create checkpoints dirs self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.model_params_subdir = "ModelParameters" self.optimizer_params_subdir = "OptimizerParameters" self.scheduler_params_subdir = "SchedulerParameters" utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.model_params_subdir)) utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.scheduler_params_subdir)) if self.train_cameras: self.optimizer_cam_params_subdir = "OptimizerCamParameters" self.cam_params_subdir = "CamParameters" utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir)) utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.cam_params_subdir)) os.system("""cp -r {0} "{1}" """.format( kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf'))) if (not self.GPU_INDEX == 'ignore'): os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX) print('shell command : {0}'.format(' '.join(sys.argv))) print('Loading data ...') dataset_conf = self.conf.get_config('dataset') if kwargs['scan_id'] != -1: dataset_conf['scan_id'] = kwargs['scan_id'] # add train_set instead of training on all image try: dataset_conf['train_set'] = self.conf.get_list('train.train_set') except: dataset_conf['train_set'] = None try: dataset_conf['test_set'] = self.conf.get_list('train.test_set') except: dataset_conf['test_set'] = None self.train_dataset = utils.get_class( self.conf.get_string('train.dataset_class'))(self.train_cameras, **dataset_conf) print('Finish loading data ...') self.train_dataloader = torch.utils.data.DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.train_dataset.collate_fn) self.plot_dataloader = torch.utils.data.DataLoader( self.train_dataset, batch_size=self.conf.get_int('plot.plot_nimgs'), shuffle=True, collate_fn=self.train_dataset.collate_fn) self.model = utils.get_class( self.conf.get_string('train.model_class'))( conf=self.conf.get_config('model')) if torch.cuda.is_available(): self.model.cuda() self.loss = utils.get_class(self.conf.get_string('train.loss_class'))( **self.conf.get_config('loss')) self.lr = self.conf.get_float('train.learning_rate') self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.sched_milestones = self.conf.get_list('train.sched_milestones', default=[]) self.sched_factor = self.conf.get_float('train.sched_factor', default=0.0) self.scheduler = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, self.sched_milestones, gamma=self.sched_factor) # settings for camera optimization if self.train_cameras: num_images = len(self.train_dataset) self.pose_vecs = torch.nn.Embedding(num_images, 7, sparse=True).cuda() self.pose_vecs.weight.data.copy_( self.train_dataset.get_pose_init()) self.optimizer_cam = torch.optim.SparseAdam( self.pose_vecs.parameters(), self.conf.get_float('train.learning_rate_cam')) self.start_epoch = 0 if is_continue: old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) self.model.load_state_dict(saved_model_state["model_state_dict"]) self.start_epoch = saved_model_state['epoch'] data = torch.load( os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) self.optimizer.load_state_dict(data["optimizer_state_dict"]) data = torch.load( os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth")) self.scheduler.load_state_dict(data["scheduler_state_dict"]) if self.train_cameras: data = torch.load( os.path.join(old_checkpnts_dir, self.optimizer_cam_params_subdir, str(kwargs['checkpoint']) + ".pth")) self.optimizer_cam.load_state_dict( data["optimizer_cam_state_dict"]) data = torch.load( os.path.join(old_checkpnts_dir, self.cam_params_subdir, str(kwargs['checkpoint']) + ".pth")) self.pose_vecs.load_state_dict(data["pose_vecs_state_dict"]) self.num_pixels = self.conf.get_int('train.num_pixels') self.total_pixels = self.train_dataset.total_pixels self.img_res = self.train_dataset.img_res self.n_batches = len(self.train_dataloader) self.plot_freq = self.conf.get_int('train.plot_freq') self.plot_conf = self.conf.get_config('plot') self.alpha_milestones = self.conf.get_list('train.alpha_milestones', default=[]) self.alpha_factor = self.conf.get_float('train.alpha_factor', default=0.0) for acc in self.alpha_milestones: if self.start_epoch > acc: self.loss.alpha = self.loss.alpha * self.alpha_factor
conf = ConfigFactory.parse_file(os.path.join(code_path, 'shapespace', args.conf)) experiment_directory = os.path.join(exps_path, args.exp_name) if args.timestamp == 'latest': timestamps = os.listdir(experiment_directory) timestamp = sorted(timestamps)[-1] else: timestamp = args.timestamp experiment_directory = os.path.join(experiment_directory, timestamp) saved_model_state = torch.load(os.path.join(experiment_directory, 'checkpoints', 'ModelParameters', args.epoch + ".pth")) saved_model_epoch = saved_model_state["epoch"] with_normals = conf.get_float('network.loss.normals_lambda') > 0 network = utils.get_class(conf.get_string('train.network_class'))(d_in=conf.get_int('train.latent_size')+conf.get_int('train.d_in'), **conf.get_config('network.inputs')) network.load_state_dict({k.replace('module.', ''): v for k, v in saved_model_state["model_state_dict"].items()}) split_file = os.path.join(code_path, 'splits', args.split) interpolate( network=network.cuda(), interval=args.interval, experiment_directory=experiment_directory, checkpoint=saved_model_epoch, split_file=split_file, epoch=saved_model_epoch, resolution=args.resolution, uniform_grid=args.uniform_grid )
def __init__(self, **kwargs): # config setting self.home_dir = os.path.abspath(os.pardir) if type(kwargs['conf']) == str: self.conf_filename = os.path.abspath(kwargs['conf']) self.conf = ConfigFactory.parse_file(self.conf_filename) else: self.conf = kwargs['conf'] self.expname = kwargs['expname'] # GPU settings self.GPU_INDEX = kwargs['gpu_index'] if not self.GPU_INDEX == 'ignore': os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX) self.num_of_gpus = torch.cuda.device_count() # settings for loading an existing experiment if kwargs['is_continue'] and kwargs['timestamp'] == 'latest': if os.path.exists(os.path.join(self.home_dir, 'exps', self.expname)): timestamps = os.listdir( os.path.join(self.home_dir, 'exps', self.expname)) if (len(timestamps)) == 0: is_continue = False timestamp = None else: timestamp = sorted(timestamps)[-1] is_continue = True else: is_continue = False timestamp = None else: timestamp = kwargs['timestamp'] is_continue = kwargs['is_continue'] self.exps_folder_name = 'exps' utils.mkdir_ifnotexists( utils.concat_home_dir( os.path.join(self.home_dir, self.exps_folder_name))) self.expdir = utils.concat_home_dir( os.path.join(self.home_dir, self.exps_folder_name, self.expname)) utils.mkdir_ifnotexists(self.expdir) if is_continue: self.timestamp = timestamp else: self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) self.cur_exp_dir = self.timestamp utils.mkdir_ifnotexists(os.path.join(self.expdir, self.cur_exp_dir)) self.plots_dir = os.path.join(self.expdir, self.cur_exp_dir, 'plots') utils.mkdir_ifnotexists(self.plots_dir) self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.model_params_subdir = "ModelParameters" self.optimizer_params_subdir = "OptimizerParameters" self.latent_codes_subdir = "LatentCodes" utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.model_params_subdir)) utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.latent_codes_subdir)) self.nepochs = kwargs['nepochs'] self.batch_size = kwargs['batch_size'] if self.num_of_gpus > 0: self.batch_size *= self.num_of_gpus self.parallel = self.num_of_gpus > 1 self.global_sigma = self.conf.get_float( 'network.sampler.properties.global_sigma') self.local_sigma = self.conf.get_float( 'network.sampler.properties.local_sigma') self.sampler = Sampler.get_sampler( self.conf.get_string('network.sampler.sampler_type'))( self.global_sigma, self.local_sigma) train_split_file = os.path.abspath(kwargs['split_file']) print(f'Loading split file {train_split_file}') with open(train_split_file, "r") as f: train_split = json.load(f) print(f'Size of the split: {len(train_split)} samples') self.d_in = self.conf.get_int('train.d_in') # latent preprocessing self.latent_size = self.conf.get_int('train.latent_size') self.latent_lambda = self.conf.get_float('network.loss.latent_lambda') self.grad_lambda = self.conf.get_float('network.loss.lambda') self.normals_lambda = self.conf.get_float( 'network.loss.normals_lambda') self.with_normals = self.normals_lambda > 0 self.ds = utils.get_class(self.conf.get_string('train.dataset'))( split=train_split, with_normals=self.with_normals, dataset_path=self.conf.get_string('train.dataset_path'), points_batch=kwargs['points_batch'], ) self.num_scenes = len(self.ds) self.train_dataloader = torch.utils.data.DataLoader( self.ds, batch_size=self.batch_size, shuffle=True, num_workers=kwargs['threads'], drop_last=True, pin_memory=True) self.eval_dataloader = torch.utils.data.DataLoader(self.ds, batch_size=1, shuffle=False, num_workers=0, drop_last=True) self.network = utils.get_class( self.conf.get_string('train.network_class'))( d_in=(self.d_in + self.latent_size), **self.conf.get_config('network.inputs')) if self.parallel: self.network = torch.nn.DataParallel(self.network) if torch.cuda.is_available(): self.network.cuda() self.lr_schedules = self.get_learning_rate_schedules( self.conf.get_list('train.learning_rate_schedule')) self.weight_decay = self.conf.get_float('train.weight_decay') # optimizer and latent settings self.startepoch = 0 self.lat_vecs = utils.to_cuda( torch.zeros(self.num_scenes, self.latent_size)) self.lat_vecs.requires_grad_() self.optimizer = torch.optim.Adam([ { "params": self.network.parameters(), "lr": self.lr_schedules[0].get_learning_rate(0), "weight_decay": self.weight_decay }, { "params": self.lat_vecs, "lr": self.lr_schedules[1].get_learning_rate(0) }, ]) # if continue load checkpoints if is_continue: old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') data = torch.load( os.path.join(old_checkpnts_dir, self.latent_codes_subdir, str(kwargs['checkpoint']) + '.pth')) self.lat_vecs = utils.to_cuda(data["latent_codes"]) saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) self.network.load_state_dict(saved_model_state["model_state_dict"]) data = torch.load( os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) self.optimizer.load_state_dict(data["optimizer_state_dict"]) self.startepoch = saved_model_state['epoch']
def evaluate(**kwargs): torch.set_default_dtype(torch.float32) conf = ConfigFactory.parse_file(kwargs['conf']) exps_folder_name = kwargs['exps_folder_name'] evals_folder_name = kwargs['evals_folder_name'] eval_cameras = kwargs['eval_cameras'] eval_rendering = kwargs['eval_rendering'] expname = conf.get_string('train.expname') + kwargs['expname'] scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int( 'dataset.scan_id', default=-1) if scan_id != -1: expname = expname + '_{0}'.format(scan_id) if kwargs['timestamp'] == 'latest': if os.path.exists( os.path.join('../', kwargs['exps_folder_name'], expname)): timestamps = os.listdir( os.path.join('../', kwargs['exps_folder_name'], expname)) if (len(timestamps)) == 0: print('WRONG EXP FOLDER') exit() else: timestamp = sorted(timestamps)[-1] else: print('WRONG EXP FOLDER') exit() else: timestamp = kwargs['timestamp'] utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name)) expdir = os.path.join('../', exps_folder_name, expname) evaldir = os.path.join('../', evals_folder_name, expname) utils.mkdir_ifnotexists(evaldir) model = utils.get_class( conf.get_string('train.model_class'))(conf=conf.get_config('model')) if torch.cuda.is_available(): model.cuda() dataset_conf = conf.get_config('dataset') if kwargs['scan_id'] != -1: dataset_conf['scan_id'] = kwargs['scan_id'] eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))( eval_cameras, **dataset_conf) # settings for camera optimization scale_mat = eval_dataset.get_scale_mat() if eval_cameras: num_images = len(eval_dataset) pose_vecs = torch.nn.Embedding(num_images, 7, sparse=True).cuda() pose_vecs.weight.data.copy_(eval_dataset.get_pose_init()) gt_pose = eval_dataset.get_gt_pose() if eval_rendering: eval_dataloader = torch.utils.data.DataLoader( eval_dataset, batch_size=1, shuffle=False, collate_fn=eval_dataset.collate_fn) total_pixels = eval_dataset.total_pixels img_res = eval_dataset.img_res old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints') saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) model.load_state_dict(saved_model_state["model_state_dict"]) epoch = saved_model_state['epoch'] if eval_cameras: data = torch.load( os.path.join(old_checkpnts_dir, 'CamParameters', str(kwargs['checkpoint']) + ".pth")) pose_vecs.load_state_dict(data["pose_vecs_state_dict"]) #################################################################################################################### print("evaluating...") model.eval() if eval_cameras: pose_vecs.eval() with torch.no_grad(): if eval_cameras: gt_Rs = gt_pose[:, :3, :3].double() gt_ts = gt_pose[:, :3, 3].double() pred_Rs = rend_util.quat_to_rot( pose_vecs.weight.data[:, :4]).cpu().double() pred_ts = pose_vecs.weight.data[:, 4:].cpu().double() R_opt, t_opt, c_opt, R_fixed, t_fixed = get_cameras_accuracy( pred_Rs, gt_Rs, pred_ts, gt_ts) cams_transformation = np.eye(4, dtype=np.double) cams_transformation[:3, :3] = c_opt * R_opt cams_transformation[:3, 3] = t_opt mesh = plt.get_surface_high_res_mesh( sdf=lambda x: model.implicit_network(x)[:, 0], resolution=kwargs['resolution']) # Transform to world coordinates if eval_cameras: mesh.apply_transform(cams_transformation) else: mesh.apply_transform(scale_mat) # Taking the biggest connected component components = mesh.split(only_watertight=False) areas = np.array([c.area for c in components], dtype=np.float) mesh_clean = components[areas.argmax()] mesh_clean.export( '{0}/surface_world_coordinates_{1}.ply'.format(evaldir, epoch), 'ply') if eval_rendering: images_dir = '{0}/rendering'.format(evaldir) utils.mkdir_ifnotexists(images_dir) psnrs = [] for data_index, (indices, model_input, ground_truth) in enumerate(eval_dataloader): model_input["intrinsics"] = model_input["intrinsics"].cuda() model_input["uv"] = model_input["uv"].cuda() model_input["object_mask"] = model_input["object_mask"].cuda() if eval_cameras: pose_input = pose_vecs(indices.cuda()) model_input['pose'] = pose_input else: model_input['pose'] = model_input['pose'].cuda() split = utils.split_input(model_input, total_pixels) res = [] for s in split: out = model(s) res.append({ 'rgb_values': out['rgb_values'].detach(), }) batch_size = ground_truth['rgb'].shape[0] model_outputs = utils.merge_output(res, total_pixels, batch_size) rgb_eval = model_outputs['rgb_values'] rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3) rgb_eval = (rgb_eval + 1.) / 2. rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0] rgb_eval = rgb_eval.transpose(1, 2, 0) img = Image.fromarray((rgb_eval * 255).astype(np.uint8)) img.save('{0}/eval_{1}.png'.format(images_dir, '%03d' % indices[0])) rgb_gt = ground_truth['rgb'] rgb_gt = (rgb_gt + 1.) / 2. rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0] rgb_gt = rgb_gt.transpose(1, 2, 0) mask = model_input['object_mask'] mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0] mask = mask.transpose(1, 2, 0) rgb_eval_masked = rgb_eval * mask rgb_gt_masked = rgb_gt * mask psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask) psnrs.append(psnr) psnrs = np.array(psnrs).astype(np.float64) print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}". format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id))
def __init__( self, latent_size, dims, dropout=None, dropout_prob=0.0, norm_layers=(), latent_in=(), weight_norm=False, xyz_in_all=None, activation=None, latent_dropout=False, ): super().__init__() dims = [latent_size + 3] + dims + [1] self.num_layers = len(dims) self.norm_layers = norm_layers self.latent_in = latent_in self.latent_dropout = latent_dropout if self.latent_dropout: self.lat_dp = nn.Dropout(0.2) self.xyz_in_all = xyz_in_all self.weight_norm = weight_norm for l in range(0, self.num_layers - 1): if l + 1 in latent_in: out_dim = dims[l + 1] - dims[0] else: out_dim = dims[l + 1] if self.xyz_in_all and l != self.num_layers - 2: out_dim -= 3 lin = nn.Linear(dims[l], out_dim) if (l in dropout): p = 1 - dropout_prob else: p = 1.0 if l == self.num_layers - 2: torch.nn.init.normal_(lin.weight, mean=2 * np.sqrt(np.pi) / np.sqrt(p * dims[l]), std=0.000001) torch.nn.init.constant_(lin.bias, -1.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(p * out_dim)) if weight_norm and l in self.norm_layers: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.use_activation = not activation == 'None' if self.use_activation: self.last_activation = utils.get_class(activation)() self.relu = nn.ReLU() self.dropout_prob = dropout_prob self.dropout = dropout
def optimize_latent(latent, ds, itemindex, decoder, path, epoch, resolution, conf): latent.detach_() latent.requires_grad_() lr = 1.0e-3 optimizer = torch.optim.Adam([latent], lr=lr) loss_func = utils.get_class(conf.get_string('network.loss.loss_type'))( **conf.get_config('network.loss.properties')) num_iterations = 800 decreased_by = 10 adjust_lr_every = int(num_iterations / 2) for e in range(num_iterations): input_pc, sample_nonmnfld, _ = ds[itemindex] input_pc = utils.get_cuda_ifavailable(input_pc).unsqueeze(0) sample_nonmnfld = utils.get_cuda_ifavailable( sample_nonmnfld).unsqueeze(0) non_mnfld_pnts = sample_nonmnfld[:, :, :3] dist_nonmnfld = sample_nonmnfld[:, :, 3].reshape(-1) adjust_learning_rate(lr, optimizer, e, decreased_by, adjust_lr_every) optimizer.zero_grad() non_mnfld_pnts_with_latent = torch.cat([ latent.unsqueeze(1).repeat(1, non_mnfld_pnts.shape[1], 1), non_mnfld_pnts ], dim=-1) nonmanifold_pnts_pred = decoder( non_mnfld_pnts_with_latent.view( -1, non_mnfld_pnts_with_latent.shape[-1])) loss_res = loss_func(manifold_pnts_pred=None, nonmanifold_pnts_pred=nonmanifold_pnts_pred, nonmanifold_gt=dist_nonmnfld, weight=None) loss = loss_res["loss"] loss.backward() optimizer.step() print("iteration : {0} , loss {1}".format(e, loss.item())) print("mean {0} , std {1}".format(latent.mean().item(), latent.std().item())) with torch.no_grad(): reconstruction = plt.plot_surface( with_points=False, points=torch.cat([ latent.unsqueeze(1).repeat(1, input_pc.shape[1], 1), input_pc ], dim=-1)[0], decoder=network.decoder, latent=latent, path=path, epoch=epoch, in_epoch=ds.npyfiles_mnfld[itemindex].split('/')[-3] + '_' + ds.npyfiles_mnfld[itemindex].split('/')[-1].split('.npy')[0] + '_after', shapefile=ds.npyfiles_mnfld[itemindex], resolution=resolution, mc_value=0, is_uniform_grid=True, verbose=True, save_html=False, save_ply=True, overwrite=True) return reconstruction
includeNan=False, excludeID=[], excludeUUID=[]) gpu = deviceIDs[0] else: gpu = args.gpu os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(gpu) base_dir = os.path.join('../', args.exps_dir, args.exp_name, timestamp) saved_model_state = torch.load( os.path.join(base_dir, 'checkpoints', 'ModelParameters', args.checkpoint + ".pth")) saved_model_epoch = saved_model_state["epoch"] conf = ConfigFactory.parse_file(args.conf) network = utils.get_class(conf.get_string('train.network_class'))( conf=conf.get_config('network'), latent_size=conf.get_int('train.latent_size')) if (args.parallel): network.load_state_dict({ '.'.join(k.split('.')[1:]): v for k, v in saved_model_state["model_state_dict"].items() }) else: network.load_state_dict(saved_model_state["model_state_dict"]) evaluate(network=network.cuda(), exps_dir=args.exps_dir, experiment_name=args.exp_name, timestamp=timestamp, split_filename=args.split,
def get_sampler(sampler_type): return utils.get_class("model.sample.{0}".format(sampler_type))
def __init__(self,**kwargs): if (type(kwargs['conf']) == str): self.conf = ConfigFactory.parse_file(kwargs['conf']) self.conf_filename = kwargs['conf'] else: self.conf = kwargs['conf'] self.batch_size = kwargs['batch_size'] self.nepochs = kwargs['nepochs'] self.expnameraw = self.conf.get_string('train.expname') self.expname = self.conf.get_string('train.expname') + kwargs['expname'] if kwargs['is_continue'] and kwargs['timestamp'] == 'latest': if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)): timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname)) if (len(timestamps)) == 0: is_continue = False timestamp = None else: timestamp = sorted(timestamps)[-1] is_continue = True else: is_continue = False timestamp = None else: timestamp = kwargs['timestamp'] is_continue = kwargs['is_continue'] self.adjust_lr = self.conf.get_bool('train.adjust_lr') self.GPU_INDEX = kwargs['gpu_index'] self.exps_folder_name = kwargs['exps_folder_name'] utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name)) self.expdir = os.path.join('../', self.exps_folder_name, self.expname) utils.mkdir_ifnotexists(self.expdir) self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp)) log_dir = os.path.join(self.expdir, self.timestamp, 'log') self.log_dir = log_dir utils.mkdir_ifnotexists(log_dir) utils.configure_logging(kwargs['debug'],kwargs['quiet'],os.path.join(self.log_dir,'log.txt')) self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots') utils.mkdir_ifnotexists(self.plots_dir) self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.model_params_subdir = "ModelParameters" self.optimizer_params_subdir = "OptimizerParameters" utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path,self.model_params_subdir)) utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) if (not self.GPU_INDEX == 'all'): os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX) # Backup code self.code_path = os.path.join(self.expdir, self.timestamp, 'code') utils.mkdir_ifnotexists(self.code_path) for folder in ['training','preprocess','utils','model','datasets','confs']: utils.mkdir_ifnotexists(os.path.join(self.code_path, folder)) os.system("""cp -r ./{0}/* "{1}" """.format(folder,os.path.join(self.code_path, folder))) os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.code_path, 'confs/runconf.conf'))) logging.info('shell command : {0}'.format(' '.join(sys.argv))) if (self.conf.get_string('train.data_split') == 'none'): self.ds = utils.get_class(self.conf.get_string('train.dataset'))(split=None, dataset_path=self.conf.get_string('train.dataset_path'), dist_file_name=None) else: train_split_file = './confs/splits/{0}'.format(self.conf.get_string('train.data_split')) with open(train_split_file, "r") as f: train_split = json.load(f) self.ds = utils.get_class(self.conf.get_string('train.dataset'))(split=train_split, dataset_path=self.conf.get_string('train.dataset_path'), dist_file_name=self.conf.get_string('train.dist_file_name')) self.dataloader = torch.utils.data.DataLoader(self.ds, batch_size=self.batch_size, shuffle=True, num_workers=kwargs['workers'],drop_last=True,pin_memory=True) self.eval_dataloader = torch.utils.data.DataLoader(self.ds, batch_size=1, shuffle=True, num_workers=0, drop_last=True) self.latent_size = self.conf.get_int('train.latent_size') self.network = utils.get_class(self.conf.get_string('train.network_class'))(conf=self.conf.get_config('network'), latent_size=self.latent_size) if kwargs['parallel']: self.network = torch.nn.DataParallel(self.network) if torch.cuda.is_available(): self.network.cuda() self.parallel = kwargs['parallel'] self.loss = utils.get_class(self.conf.get_string('network.loss.loss_type'))(**self.conf.get_config('network.loss.properties')) self.lr_schedules = BaseTrainRunner.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule')) self.optimizer = torch.optim.Adam( [ { "params": self.network.parameters(), "lr": self.lr_schedules[0].get_learning_rate(0), } ]) self.start_epoch = 0 if is_continue: old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) self.network.load_state_dict(saved_model_state["model_state_dict"]) data = torch.load(os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) self.optimizer.load_state_dict(data["optimizer_state_dict"]) self.start_epoch = saved_model_state['epoch']