Ejemplo n.º 1
0
def interpolate(network, interval, experiment_directory, checkpoint,
                split_file, epoch, resolution, uniform_grid):

    with open(split_file, "r") as f:
        split = json.load(f)

    ds = utils.get_class(conf.get_string('train.dataset'))(
        split=split,
        dataset_path=conf.get_string('train.dataset_path'),
        with_normals=True)

    points_1, normals_1, index_1 = ds[0]
    points_2, normals_2, index_2 = ds[1]

    pnts = torch.cat([points_1, points_2], dim=0).cuda()

    name_1 = str.join('_', ds.get_info(0))
    name_2 = str.join('_', ds.get_info(0))

    name = name_1 + '_and_' + name_2

    utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate'))
    utils.mkdir_ifnotexists(
        os.path.join(experiment_directory, 'interpolate', str(checkpoint)))
    utils.mkdir_ifnotexists(
        os.path.join(experiment_directory, 'interpolate', str(checkpoint),
                     name))

    my_path = os.path.join(experiment_directory, 'interpolate',
                           str(checkpoint), name)

    latent_1 = optimize_latent(points_1.cuda(), normals_1.cuda(), conf, 800,
                               network, 5e-3)
    latent_2 = optimize_latent(points_2.cuda(), normals_2.cuda(), conf, 800,
                               network, 5e-3)

    pnts = torch.cat([latent_1.repeat(pnts.shape[0], 1), pnts], dim=-1)

    with torch.no_grad():
        network.eval()

        for alpha in np.linspace(0, 1, interval):

            latent = (latent_1 * (1 - alpha)) + (latent_2 * alpha)

            plt.plot_surface(with_points=False,
                             points=pnts,
                             decoder=network,
                             latent=latent,
                             path=my_path,
                             epoch=epoch,
                             shapename=str(alpha),
                             resolution=resolution,
                             mc_value=0,
                             is_uniform_grid=uniform_grid,
                             verbose=True,
                             save_html=False,
                             save_ply=True,
                             overwrite=True,
                             connected=True)
Ejemplo n.º 2
0
def evaluate(network, experiment_directory, conf, checkpoint, split_file, epoch, resolution, uniform_grid):

    my_path = os.path.join(experiment_directory, 'evaluation', str(checkpoint))

    utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'evaluation'))
    utils.mkdir_ifnotexists(my_path)

    with open(split_file, "r") as f:
        split = json.load(f)

    ds = utils.get_class(conf.get_string('train.dataset'))(split=split, dataset_path=conf.get_string('train.dataset_path'), with_normals=True)

    total_files = len(ds)
    print("total files : {0}".format(total_files))
    counter = 0
    dataloader = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=True, num_workers=1, drop_last=False, pin_memory=True)

    for (input_pc, normals, index) in dataloader:

        input_pc = input_pc.cuda().squeeze()
        normals = normals.cuda().squeeze()

        print(counter)
        counter = counter + 1

        network.train()

        latent = optimize_latent(input_pc, normals, conf, 800, network, lr=5e-3)

        all_latent = latent.repeat(input_pc.shape[0], 1)

        points = torch.cat([all_latent,input_pc], dim=-1)

        shapename = str.join('_', ds.get_info(index))

        with torch.no_grad():

            network.eval()

            plt.plot_surface(with_points=True,
                             points=points,
                             decoder=network,
                             latent=latent,
                             path=my_path,
                             epoch=epoch,
                             shapename=shapename,
                             resolution=resolution,
                             mc_value=0,
                             is_uniform_grid=uniform_grid,
                             verbose=True,
                             save_html=True,
                             save_ply=True,
                             overwrite=True,
                             connected=True)
Ejemplo n.º 3
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    eval_rendering = kwargs['eval_rendering']
    eval_animation = kwargs['eval_animation']

    expname = conf.get_string('train.expname') + kwargs['expname']
    scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int(
        'dataset.scan_id', default=-1)
    if scan_id != -1:
        expname = expname + '_{0}'.format(scan_id)

    if kwargs['timestamp'] == 'latest':
        if os.path.exists(
                os.path.join('../', kwargs['exps_folder_name'], expname)):
            timestamps = os.listdir(
                os.path.join('../', kwargs['exps_folder_name'], expname))
            if (len(timestamps)) == 0:
                print('WRONG EXP FOLDER')
                exit()
            else:
                timestamp = sorted(timestamps)[-1]
        else:
            print('WRONG EXP FOLDER')
            exit()
    else:
        timestamp = kwargs['timestamp']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir = os.path.join('../', exps_folder_name, expname)
    evaldir = os.path.join('../', evals_folder_name, expname)
    utils.mkdir_ifnotexists(evaldir)

    dataset_conf = conf.get_config('dataset')
    model = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'),\
                                                                  id=scan_id, datadir=dataset_conf['data_dir'])
    if torch.cuda.is_available():
        model.cuda()

    if kwargs['scan_id'] != -1:
        dataset_conf['scan_id'] = kwargs['scan_id']
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(
        False, **dataset_conf)

    if eval_rendering:
        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=eval_dataset.collate_fn)
        total_pixels = eval_dataset.total_pixels
        img_res = eval_dataset.img_res

    old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints')

    saved_model_state = torch.load(
        os.path.join(old_checkpnts_dir, 'ModelParameters',
                     str(kwargs['checkpoint']) + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])
    epoch = saved_model_state['epoch']

    ####################################################################################################################
    print("evaluating...")

    model.eval()

    detail_3dmm, detail_3dmm_subdivision_full = plt.get_displacement_mesh(
        model)
    detail_3dmm.export('{0}/Detailed_3dmm_{1}.obj'.format(evaldir, epoch),
                       'obj')
    detail_3dmm_subdivision_full.export(
        '{0}/Subdivide_full_{1}.obj'.format(evaldir, epoch), 'obj')

    if eval_animation:
        sdf_np0, sdf_np1 = plt.get_displacement_animation(model)
        np.save('{0}/Cropped_Detailed_sdf_{1}.npy'.format(evaldir, epoch),
                sdf_np0)
        np.save('{0}/Cropped_Subdivide_full_{1}.npy'.format(evaldir, epoch),
                sdf_np1)

    if eval_rendering:
        images_dir = '{0}/rendering'.format(evaldir)
        utils.mkdir_ifnotexists(images_dir)

        psnrs = []
        for data_index, (indices, model_input,
                         ground_truth) in enumerate(eval_dataloader):
            model_input["intrinsics"] = model_input["intrinsics"].cuda()
            model_input["uv"] = model_input["uv"].cuda()
            model_input["object_mask"] = model_input["object_mask"].cuda()
            model_input['pose'] = model_input['pose'].cuda()

            split = utils.split_input(model_input, total_pixels)
            res = []
            for s in split:
                out = model(s)
                res.append({
                    'rgb_values': out['rgb_values'].detach(),
                    'diffuse_values': out['diffuse_values'].detach(),
                    'specular_values': out['specular_values'].detach(),
                    'albedo_values': out['albedo_values'].detach(),
                })

            batch_size = ground_truth['rgb'].shape[0]
            model_outputs = utils.merge_output(res, total_pixels, batch_size)
            rgb_eval = model_outputs['rgb_values']
            rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)
            rgb_eval = (rgb_eval + 1.) / 2.
            rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
            rgb_eval = rgb_eval.transpose(1, 2, 0)
            img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}.png'.format(images_dir,
                                               '%03d' % indices[0]))

            diffuse_eval = model_outputs['diffuse_values']
            diffuse_eval = diffuse_eval.reshape(batch_size, total_pixels, 3)
            diffuse_eval = (diffuse_eval + 1.) / 2.
            diffuse_eval = plt.lin2img(diffuse_eval,
                                       img_res).detach().cpu().numpy()[0]
            diffuse_eval = diffuse_eval.transpose(1, 2, 0)
            img = Image.fromarray((diffuse_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}_diffuse.png'.format(images_dir,
                                                       '%03d' % indices[0]))

            specular_eval = model_outputs['specular_values']
            specular_eval = specular_eval.reshape(batch_size, total_pixels, 3)
            specular_eval = (specular_eval + 1.) / 2.
            specular_eval = plt.lin2img(specular_eval,
                                        img_res).detach().cpu().numpy()[0]
            specular_eval = specular_eval.transpose(1, 2, 0)
            img = Image.fromarray((specular_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}_specular.png'.format(
                images_dir, '%03d' % indices[0]))

            albedo_eval = model_outputs['albedo_values']
            albedo_eval = albedo_eval.reshape(batch_size, total_pixels, 3)
            albedo_eval = (albedo_eval + 1.) / 2.
            albedo_eval = plt.lin2img(albedo_eval,
                                      img_res).detach().cpu().numpy()[0]
            albedo_eval = albedo_eval.transpose(1, 2, 0)
            img = Image.fromarray((albedo_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}_albedo.png'.format(images_dir,
                                                      '%03d' % indices[0]))

            rgb_gt = ground_truth['rgb']
            rgb_gt = (rgb_gt + 1.) / 2.
            rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0]
            rgb_gt = rgb_gt.transpose(1, 2, 0)

            mask = model_input['object_mask']
            mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0]
            mask = mask.transpose(1, 2, 0)

            rgb_eval_masked = rgb_eval * mask
            rgb_gt_masked = rgb_gt * mask

            psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask)
            psnrs.append(psnr)

        psnrs = np.array(psnrs).astype(np.float64)
        print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".
              format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id))
Ejemplo n.º 4
0
    def __init__(self, **kwargs):

        self.home_dir = os.path.abspath(os.pardir)

        # config setting

        if type(kwargs['conf']) == str:
            self.conf_filename = './reconstruction/' + kwargs['conf']
            self.conf = ConfigFactory.parse_file(self.conf_filename)
        else:
            self.conf = kwargs['conf']

        self.expname = kwargs['expname']

        # GPU settings

        self.GPU_INDEX = kwargs['gpu_index']

        if not self.GPU_INDEX == 'ignore':
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        self.num_of_gpus = torch.cuda.device_count()

        self.eval = kwargs['eval']

        # settings for loading an existing experiment

        if (kwargs['is_continue'] or self.eval) and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join(self.home_dir, 'exps', self.expname)):
                timestamps = os.listdir(os.path.join(self.home_dir, 'exps', self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue'] or self.eval

        self.exps_folder_name = 'exps'

        utils.mkdir_ifnotexists(utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name)))

        self.input_file = self.conf.get_string('train.input_path')
        self.data = utils.load_point_cloud_by_file_extension(self.input_file)

        sigma_set = []
        ptree = cKDTree(self.data)

        for p in np.array_split(self.data, 100, axis=0):
            d = ptree.query(p, 50 + 1)
            sigma_set.append(d[0][:, -1])

        sigmas = np.concatenate(sigma_set)
        self.local_sigma = torch.from_numpy(sigmas).float().cuda()

        self.expdir = utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        if is_continue:
            self.timestamp = timestamp
        else:
            self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())

        self.cur_exp_dir = os.path.join(self.expdir, self.timestamp)
        utils.mkdir_ifnotexists(self.cur_exp_dir)

        self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"

        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))

        self.nepochs = kwargs['nepochs']

        self.points_batch = kwargs['points_batch']

        self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
        self.sampler = Sampler.get_sampler(self.conf.get_string('network.sampler.sampler_type'))(self.global_sigma,
                                                                                                 self.local_sigma)
        self.grad_lambda = self.conf.get_float('network.loss.lambda')
        self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')

        self.with_normals = self.normals_lambda > 0

        self.d_in = self.conf.get_int('train.d_in')

        self.network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=self.d_in,
                                                                                    **self.conf.get_config(
                                                                                        'network.inputs'))

        if torch.cuda.is_available():
            self.network.cuda()

        self.lr_schedules = self.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule'))
        self.weight_decay = self.conf.get_float('train.weight_decay')

        self.startepoch = 0

        self.optimizer = torch.optim.Adam(
            [
                {
                    "params": self.network.parameters(),
                    "lr": self.lr_schedules[0].get_learning_rate(0),
                    "weight_decay": self.weight_decay
                },
            ])

        # if continue load checkpoints

        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.startepoch = saved_model_state['epoch']
Ejemplo n.º 5
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    timestamp = '2020'
    checkpoint = '2000'

    expname = conf.get_string('train.expname')

    geometry_id = kwargs['geometry_id']
    appearance_id = kwargs['appearance_id']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir_geometry = os.path.join('../', exps_folder_name, expname + '_{0}'.format(geometry_id))
    expdir_appearance = os.path.join('../', exps_folder_name, expname + '_{0}'.format(appearance_id))
    evaldir = os.path.join('../', evals_folder_name, expname + '_{0}_{1}'.format(geometry_id, appearance_id))
    utils.mkdir_ifnotexists(evaldir)

    model = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model.cuda()

    # Load geometry network model
    old_checkpnts_dir = os.path.join(expdir_geometry, timestamp, 'checkpoints')
    saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])

    # Load rendering network model
    model_fake = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model_fake.cuda()
    old_checkpnts_dir = os.path.join(expdir_appearance, timestamp, 'checkpoints')
    saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth"))
    model_fake.load_state_dict(saved_model_state["model_state_dict"])

    model.rendering_network = model_fake.rendering_network

    dataset_conf = conf.get_config('dataset')
    dataset_conf['scan_id'] = geometry_id
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(False, **dataset_conf)

    eval_dataloader = torch.utils.data.DataLoader(eval_dataset,
                                                  batch_size=1,
                                                  shuffle=True,
                                                  collate_fn=eval_dataset.collate_fn
                                                  )
    total_pixels = eval_dataset.total_pixels
    img_res = eval_dataset.img_res

    ####################################################################################################################
    print("evaluating...")

    model.eval()

    gt_pose = eval_dataset.get_gt_pose(scaled=True).cuda()
    gt_quat = rend_util.rot_to_quat(gt_pose[:, :3, :3])
    gt_pose_vec = torch.cat([gt_quat, gt_pose[:, :3, 3]], 1)

    indices_all = [11, 16, 34, 28, 11]
    pose = gt_pose_vec[indices_all, :]
    t_in = np.array([0, 2, 3, 5, 6]).astype(np.float32)

    n_inter = 5
    t_out = np.linspace(t_in[0], t_in[-1], n_inter * t_in[-1]).astype(np.float32)

    scales = np.array([4.2, 4.2, 3.8, 3.8, 4.2]).astype(np.float32)

    s_new = CubicSpline(t_in, scales, bc_type='periodic')
    s_new = s_new(t_out)

    q_new = CubicSpline(t_in, pose[:, :4].detach().cpu().numpy(), bc_type='periodic')
    q_new = q_new(t_out)
    q_new = q_new / np.linalg.norm(q_new, 2, 1)[:, None]
    q_new = torch.from_numpy(q_new).cuda().float()

    images_dir = '{0}/novel_views_rendering'.format(evaldir)
    utils.mkdir_ifnotexists(images_dir)

    indices, model_input, ground_truth = next(iter(eval_dataloader))

    for i, (new_q, scale) in enumerate(zip(q_new, s_new)):
        torch.cuda.empty_cache()

        new_q = new_q.unsqueeze(0)
        new_t = -rend_util.quat_to_rot(new_q)[:, :, 2] * scale

        new_p = torch.eye(4).float().cuda().unsqueeze(0)
        new_p[:, :3, :3] = rend_util.quat_to_rot(new_q)
        new_p[:, :3, 3] = new_t

        sample = {
            "object_mask": torch.zeros_like(model_input['object_mask']).cuda().bool(),
            "uv": model_input['uv'].cuda(),
            "intrinsics": model_input['intrinsics'].cuda(),
            "pose": new_p
        }

        split = utils.split_input(sample, total_pixels)
        res = []
        for s in split:
            out = model(s)
            res.append({
                'rgb_values': out['rgb_values'].detach(),
            })

        batch_size = 1
        model_outputs = utils.merge_output(res, total_pixels, batch_size)
        rgb_eval = model_outputs['rgb_values']
        rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)

        rgb_eval = (rgb_eval + 1.) / 2.
        rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
        rgb_eval = rgb_eval.transpose(1, 2, 0)
        img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
        img.save('{0}/eval_{1}.png'.format(images_dir,'%03d' % i))
Ejemplo n.º 6
0
    def __init__(self, **kwargs):
        torch.set_default_dtype(torch.float32)
        torch.set_num_threads(1)

        self.conf = ConfigFactory.parse_file(kwargs['conf'])
        self.batch_size = kwargs['batch_size']
        self.nepochs = kwargs['nepochs']
        self.exps_folder_name = kwargs['exps_folder_name']
        self.GPU_INDEX = kwargs['gpu_index']
        self.train_cameras = kwargs['train_cameras']

        self.expname = self.conf.get_string(
            'train.expname') + kwargs['expname']
        scan_id = kwargs[
            'scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int(
                'dataset.scan_id', default=-1)
        if scan_id != -1:
            self.expname = self.expname + '_{0}'.format(scan_id)

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(
                    os.path.join('../', kwargs['exps_folder_name'],
                                 self.expname)):
                timestamps = os.listdir(
                    os.path.join('../', kwargs['exps_folder_name'],
                                 self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        utils.mkdir_ifnotexists(os.path.join('../', self.exps_folder_name))
        self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
        utils.mkdir_ifnotexists(self.expdir)
        self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))

        self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        # create checkpoints dirs
        self.checkpoints_path = os.path.join(self.expdir, self.timestamp,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)
        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"
        self.scheduler_params_subdir = "SchedulerParameters"

        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.scheduler_params_subdir))

        if self.train_cameras:
            self.optimizer_cam_params_subdir = "OptimizerCamParameters"
            self.cam_params_subdir = "CamParameters"

            utils.mkdir_ifnotexists(
                os.path.join(self.checkpoints_path,
                             self.optimizer_cam_params_subdir))
            utils.mkdir_ifnotexists(
                os.path.join(self.checkpoints_path, self.cam_params_subdir))

        os.system("""cp -r {0} "{1}" """.format(
            kwargs['conf'],
            os.path.join(self.expdir, self.timestamp, 'runconf.conf')))

        if (not self.GPU_INDEX == 'ignore'):
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        print('shell command : {0}'.format(' '.join(sys.argv)))

        print('Loading data ...')

        dataset_conf = self.conf.get_config('dataset')
        if kwargs['scan_id'] != -1:
            dataset_conf['scan_id'] = kwargs['scan_id']

        # add train_set instead of training on all image
        try:
            dataset_conf['train_set'] = self.conf.get_list('train.train_set')
        except:
            dataset_conf['train_set'] = None

        try:
            dataset_conf['test_set'] = self.conf.get_list('train.test_set')
        except:
            dataset_conf['test_set'] = None

        self.train_dataset = utils.get_class(
            self.conf.get_string('train.dataset_class'))(self.train_cameras,
                                                         **dataset_conf)

        print('Finish loading data ...')

        self.train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.train_dataset.collate_fn)
        self.plot_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.conf.get_int('plot.plot_nimgs'),
            shuffle=True,
            collate_fn=self.train_dataset.collate_fn)

        self.model = utils.get_class(
            self.conf.get_string('train.model_class'))(
                conf=self.conf.get_config('model'))
        if torch.cuda.is_available():
            self.model.cuda()

        self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(
            **self.conf.get_config('loss'))

        self.lr = self.conf.get_float('train.learning_rate')
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.sched_milestones = self.conf.get_list('train.sched_milestones',
                                                   default=[])
        self.sched_factor = self.conf.get_float('train.sched_factor',
                                                default=0.0)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, self.sched_milestones, gamma=self.sched_factor)

        # settings for camera optimization
        if self.train_cameras:
            num_images = len(self.train_dataset)
            self.pose_vecs = torch.nn.Embedding(num_images, 7,
                                                sparse=True).cuda()
            self.pose_vecs.weight.data.copy_(
                self.train_dataset.get_pose_init())

            self.optimizer_cam = torch.optim.SparseAdam(
                self.pose_vecs.parameters(),
                self.conf.get_float('train.learning_rate_cam'))

        self.start_epoch = 0
        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp,
                                             'checkpoints')

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.model.load_state_dict(saved_model_state["model_state_dict"])
            self.start_epoch = saved_model_state['epoch']

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, self.scheduler_params_subdir,
                             str(kwargs['checkpoint']) + ".pth"))
            self.scheduler.load_state_dict(data["scheduler_state_dict"])

            if self.train_cameras:
                data = torch.load(
                    os.path.join(old_checkpnts_dir,
                                 self.optimizer_cam_params_subdir,
                                 str(kwargs['checkpoint']) + ".pth"))
                self.optimizer_cam.load_state_dict(
                    data["optimizer_cam_state_dict"])

                data = torch.load(
                    os.path.join(old_checkpnts_dir, self.cam_params_subdir,
                                 str(kwargs['checkpoint']) + ".pth"))
                self.pose_vecs.load_state_dict(data["pose_vecs_state_dict"])

        self.num_pixels = self.conf.get_int('train.num_pixels')
        self.total_pixels = self.train_dataset.total_pixels
        self.img_res = self.train_dataset.img_res
        self.n_batches = len(self.train_dataloader)
        self.plot_freq = self.conf.get_int('train.plot_freq')
        self.plot_conf = self.conf.get_config('plot')

        self.alpha_milestones = self.conf.get_list('train.alpha_milestones',
                                                   default=[])
        self.alpha_factor = self.conf.get_float('train.alpha_factor',
                                                default=0.0)
        for acc in self.alpha_milestones:
            if self.start_epoch > acc:
                self.loss.alpha = self.loss.alpha * self.alpha_factor
Ejemplo n.º 7
0
    conf = ConfigFactory.parse_file(os.path.join(code_path, 'shapespace', args.conf))

    experiment_directory = os.path.join(exps_path, args.exp_name)

    if args.timestamp == 'latest':
        timestamps = os.listdir(experiment_directory)
        timestamp = sorted(timestamps)[-1]
    else:
        timestamp = args.timestamp

    experiment_directory = os.path.join(experiment_directory, timestamp)
    saved_model_state = torch.load(os.path.join(experiment_directory, 'checkpoints', 'ModelParameters', args.epoch + ".pth"))
    saved_model_epoch = saved_model_state["epoch"]
    with_normals = conf.get_float('network.loss.normals_lambda') > 0
    network = utils.get_class(conf.get_string('train.network_class'))(d_in=conf.get_int('train.latent_size')+conf.get_int('train.d_in'), **conf.get_config('network.inputs'))

    network.load_state_dict({k.replace('module.', ''): v for k, v in saved_model_state["model_state_dict"].items()})
    split_file = os.path.join(code_path, 'splits', args.split)

    interpolate(
        network=network.cuda(),
        interval=args.interval,
        experiment_directory=experiment_directory,
        checkpoint=saved_model_epoch,
        split_file=split_file,
        epoch=saved_model_epoch,
        resolution=args.resolution,
        uniform_grid=args.uniform_grid
    )
Ejemplo n.º 8
0
    def __init__(self, **kwargs):

        # config setting

        self.home_dir = os.path.abspath(os.pardir)

        if type(kwargs['conf']) == str:
            self.conf_filename = os.path.abspath(kwargs['conf'])
            self.conf = ConfigFactory.parse_file(self.conf_filename)
        else:
            self.conf = kwargs['conf']

        self.expname = kwargs['expname']

        # GPU settings

        self.GPU_INDEX = kwargs['gpu_index']

        if not self.GPU_INDEX == 'ignore':
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        self.num_of_gpus = torch.cuda.device_count()

        # settings for loading an existing experiment

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join(self.home_dir, 'exps',
                                           self.expname)):
                timestamps = os.listdir(
                    os.path.join(self.home_dir, 'exps', self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        self.exps_folder_name = 'exps'

        utils.mkdir_ifnotexists(
            utils.concat_home_dir(
                os.path.join(self.home_dir, self.exps_folder_name)))

        self.expdir = utils.concat_home_dir(
            os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        if is_continue:
            self.timestamp = timestamp
        else:
            self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())

        self.cur_exp_dir = self.timestamp
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.cur_exp_dir))

        self.plots_dir = os.path.join(self.expdir, self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"
        self.latent_codes_subdir = "LatentCodes"

        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.latent_codes_subdir))

        self.nepochs = kwargs['nepochs']

        self.batch_size = kwargs['batch_size']

        if self.num_of_gpus > 0:
            self.batch_size *= self.num_of_gpus

        self.parallel = self.num_of_gpus > 1

        self.global_sigma = self.conf.get_float(
            'network.sampler.properties.global_sigma')
        self.local_sigma = self.conf.get_float(
            'network.sampler.properties.local_sigma')
        self.sampler = Sampler.get_sampler(
            self.conf.get_string('network.sampler.sampler_type'))(
                self.global_sigma, self.local_sigma)

        train_split_file = os.path.abspath(kwargs['split_file'])
        print(f'Loading split file {train_split_file}')
        with open(train_split_file, "r") as f:
            train_split = json.load(f)
        print(f'Size of the split: {len(train_split)} samples')

        self.d_in = self.conf.get_int('train.d_in')

        # latent preprocessing

        self.latent_size = self.conf.get_int('train.latent_size')

        self.latent_lambda = self.conf.get_float('network.loss.latent_lambda')
        self.grad_lambda = self.conf.get_float('network.loss.lambda')
        self.normals_lambda = self.conf.get_float(
            'network.loss.normals_lambda')

        self.with_normals = self.normals_lambda > 0

        self.ds = utils.get_class(self.conf.get_string('train.dataset'))(
            split=train_split,
            with_normals=self.with_normals,
            dataset_path=self.conf.get_string('train.dataset_path'),
            points_batch=kwargs['points_batch'],
        )

        self.num_scenes = len(self.ds)

        self.train_dataloader = torch.utils.data.DataLoader(
            self.ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=kwargs['threads'],
            drop_last=True,
            pin_memory=True)
        self.eval_dataloader = torch.utils.data.DataLoader(self.ds,
                                                           batch_size=1,
                                                           shuffle=False,
                                                           num_workers=0,
                                                           drop_last=True)

        self.network = utils.get_class(
            self.conf.get_string('train.network_class'))(
                d_in=(self.d_in + self.latent_size),
                **self.conf.get_config('network.inputs'))

        if self.parallel:
            self.network = torch.nn.DataParallel(self.network)

        if torch.cuda.is_available():
            self.network.cuda()

        self.lr_schedules = self.get_learning_rate_schedules(
            self.conf.get_list('train.learning_rate_schedule'))
        self.weight_decay = self.conf.get_float('train.weight_decay')

        # optimizer and latent settings

        self.startepoch = 0

        self.lat_vecs = utils.to_cuda(
            torch.zeros(self.num_scenes, self.latent_size))
        self.lat_vecs.requires_grad_()

        self.optimizer = torch.optim.Adam([
            {
                "params": self.network.parameters(),
                "lr": self.lr_schedules[0].get_learning_rate(0),
                "weight_decay": self.weight_decay
            },
            {
                "params": self.lat_vecs,
                "lr": self.lr_schedules[1].get_learning_rate(0)
            },
        ])

        # if continue load checkpoints

        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp,
                                             'checkpoints')

            data = torch.load(
                os.path.join(old_checkpnts_dir, self.latent_codes_subdir,
                             str(kwargs['checkpoint']) + '.pth'))

            self.lat_vecs = utils.to_cuda(data["latent_codes"])

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.startepoch = saved_model_state['epoch']
Ejemplo n.º 9
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    eval_cameras = kwargs['eval_cameras']
    eval_rendering = kwargs['eval_rendering']

    expname = conf.get_string('train.expname') + kwargs['expname']
    scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int(
        'dataset.scan_id', default=-1)
    if scan_id != -1:
        expname = expname + '_{0}'.format(scan_id)

    if kwargs['timestamp'] == 'latest':
        if os.path.exists(
                os.path.join('../', kwargs['exps_folder_name'], expname)):
            timestamps = os.listdir(
                os.path.join('../', kwargs['exps_folder_name'], expname))
            if (len(timestamps)) == 0:
                print('WRONG EXP FOLDER')
                exit()
            else:
                timestamp = sorted(timestamps)[-1]
        else:
            print('WRONG EXP FOLDER')
            exit()
    else:
        timestamp = kwargs['timestamp']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir = os.path.join('../', exps_folder_name, expname)
    evaldir = os.path.join('../', evals_folder_name, expname)
    utils.mkdir_ifnotexists(evaldir)

    model = utils.get_class(
        conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model.cuda()

    dataset_conf = conf.get_config('dataset')
    if kwargs['scan_id'] != -1:
        dataset_conf['scan_id'] = kwargs['scan_id']
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(
        eval_cameras, **dataset_conf)

    # settings for camera optimization
    scale_mat = eval_dataset.get_scale_mat()
    if eval_cameras:
        num_images = len(eval_dataset)
        pose_vecs = torch.nn.Embedding(num_images, 7, sparse=True).cuda()
        pose_vecs.weight.data.copy_(eval_dataset.get_pose_init())

        gt_pose = eval_dataset.get_gt_pose()

    if eval_rendering:
        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=eval_dataset.collate_fn)
        total_pixels = eval_dataset.total_pixels
        img_res = eval_dataset.img_res

    old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints')

    saved_model_state = torch.load(
        os.path.join(old_checkpnts_dir, 'ModelParameters',
                     str(kwargs['checkpoint']) + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])
    epoch = saved_model_state['epoch']

    if eval_cameras:
        data = torch.load(
            os.path.join(old_checkpnts_dir, 'CamParameters',
                         str(kwargs['checkpoint']) + ".pth"))
        pose_vecs.load_state_dict(data["pose_vecs_state_dict"])

    ####################################################################################################################
    print("evaluating...")

    model.eval()
    if eval_cameras:
        pose_vecs.eval()

    with torch.no_grad():
        if eval_cameras:
            gt_Rs = gt_pose[:, :3, :3].double()
            gt_ts = gt_pose[:, :3, 3].double()

            pred_Rs = rend_util.quat_to_rot(
                pose_vecs.weight.data[:, :4]).cpu().double()
            pred_ts = pose_vecs.weight.data[:, 4:].cpu().double()

            R_opt, t_opt, c_opt, R_fixed, t_fixed = get_cameras_accuracy(
                pred_Rs, gt_Rs, pred_ts, gt_ts)

            cams_transformation = np.eye(4, dtype=np.double)
            cams_transformation[:3, :3] = c_opt * R_opt
            cams_transformation[:3, 3] = t_opt

        mesh = plt.get_surface_high_res_mesh(
            sdf=lambda x: model.implicit_network(x)[:, 0],
            resolution=kwargs['resolution'])

        # Transform to world coordinates
        if eval_cameras:
            mesh.apply_transform(cams_transformation)
        else:
            mesh.apply_transform(scale_mat)

        # Taking the biggest connected component
        components = mesh.split(only_watertight=False)
        areas = np.array([c.area for c in components], dtype=np.float)
        mesh_clean = components[areas.argmax()]
        mesh_clean.export(
            '{0}/surface_world_coordinates_{1}.ply'.format(evaldir, epoch),
            'ply')

    if eval_rendering:
        images_dir = '{0}/rendering'.format(evaldir)
        utils.mkdir_ifnotexists(images_dir)

        psnrs = []
        for data_index, (indices, model_input,
                         ground_truth) in enumerate(eval_dataloader):
            model_input["intrinsics"] = model_input["intrinsics"].cuda()
            model_input["uv"] = model_input["uv"].cuda()
            model_input["object_mask"] = model_input["object_mask"].cuda()

            if eval_cameras:
                pose_input = pose_vecs(indices.cuda())
                model_input['pose'] = pose_input
            else:
                model_input['pose'] = model_input['pose'].cuda()

            split = utils.split_input(model_input, total_pixels)
            res = []
            for s in split:
                out = model(s)
                res.append({
                    'rgb_values': out['rgb_values'].detach(),
                })

            batch_size = ground_truth['rgb'].shape[0]
            model_outputs = utils.merge_output(res, total_pixels, batch_size)
            rgb_eval = model_outputs['rgb_values']
            rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)

            rgb_eval = (rgb_eval + 1.) / 2.
            rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
            rgb_eval = rgb_eval.transpose(1, 2, 0)
            img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}.png'.format(images_dir,
                                               '%03d' % indices[0]))

            rgb_gt = ground_truth['rgb']
            rgb_gt = (rgb_gt + 1.) / 2.
            rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0]
            rgb_gt = rgb_gt.transpose(1, 2, 0)

            mask = model_input['object_mask']
            mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0]
            mask = mask.transpose(1, 2, 0)

            rgb_eval_masked = rgb_eval * mask
            rgb_gt_masked = rgb_gt * mask

            psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask)
            psnrs.append(psnr)

        psnrs = np.array(psnrs).astype(np.float64)
        print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".
              format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id))
Ejemplo n.º 10
0
    def __init__(
        self,
        latent_size,
        dims,
        dropout=None,
        dropout_prob=0.0,
        norm_layers=(),
        latent_in=(),
        weight_norm=False,
        xyz_in_all=None,
        activation=None,
        latent_dropout=False,
    ):
        super().__init__()

        dims = [latent_size + 3] + dims + [1]

        self.num_layers = len(dims)
        self.norm_layers = norm_layers
        self.latent_in = latent_in
        self.latent_dropout = latent_dropout
        if self.latent_dropout:
            self.lat_dp = nn.Dropout(0.2)

        self.xyz_in_all = xyz_in_all
        self.weight_norm = weight_norm

        for l in range(0, self.num_layers - 1):
            if l + 1 in latent_in:
                out_dim = dims[l + 1] - dims[0]
            else:
                out_dim = dims[l + 1]
                if self.xyz_in_all and l != self.num_layers - 2:
                    out_dim -= 3
            lin = nn.Linear(dims[l], out_dim)

            if (l in dropout):
                p = 1 - dropout_prob
            else:
                p = 1.0

            if l == self.num_layers - 2:
                torch.nn.init.normal_(lin.weight,
                                      mean=2 * np.sqrt(np.pi) /
                                      np.sqrt(p * dims[l]),
                                      std=0.000001)
                torch.nn.init.constant_(lin.bias, -1.0)
            else:
                torch.nn.init.constant_(lin.bias, 0.0)
                torch.nn.init.normal_(lin.weight, 0.0,
                                      np.sqrt(2) / np.sqrt(p * out_dim))

            if weight_norm and l in self.norm_layers:
                lin = nn.utils.weight_norm(lin)

            setattr(self, "lin" + str(l), lin)
        self.use_activation = not activation == 'None'

        if self.use_activation:
            self.last_activation = utils.get_class(activation)()
        self.relu = nn.ReLU()

        self.dropout_prob = dropout_prob
        self.dropout = dropout
Ejemplo n.º 11
0
def optimize_latent(latent, ds, itemindex, decoder, path, epoch, resolution,
                    conf):
    latent.detach_()
    latent.requires_grad_()
    lr = 1.0e-3
    optimizer = torch.optim.Adam([latent], lr=lr)
    loss_func = utils.get_class(conf.get_string('network.loss.loss_type'))(
        **conf.get_config('network.loss.properties'))

    num_iterations = 800

    decreased_by = 10
    adjust_lr_every = int(num_iterations / 2)
    for e in range(num_iterations):
        input_pc, sample_nonmnfld, _ = ds[itemindex]
        input_pc = utils.get_cuda_ifavailable(input_pc).unsqueeze(0)
        sample_nonmnfld = utils.get_cuda_ifavailable(
            sample_nonmnfld).unsqueeze(0)

        non_mnfld_pnts = sample_nonmnfld[:, :, :3]
        dist_nonmnfld = sample_nonmnfld[:, :, 3].reshape(-1)

        adjust_learning_rate(lr, optimizer, e, decreased_by, adjust_lr_every)

        optimizer.zero_grad()
        non_mnfld_pnts_with_latent = torch.cat([
            latent.unsqueeze(1).repeat(1, non_mnfld_pnts.shape[1], 1),
            non_mnfld_pnts
        ],
                                               dim=-1)
        nonmanifold_pnts_pred = decoder(
            non_mnfld_pnts_with_latent.view(
                -1, non_mnfld_pnts_with_latent.shape[-1]))

        loss_res = loss_func(manifold_pnts_pred=None,
                             nonmanifold_pnts_pred=nonmanifold_pnts_pred,
                             nonmanifold_gt=dist_nonmnfld,
                             weight=None)
        loss = loss_res["loss"]

        loss.backward()
        optimizer.step()
        print("iteration : {0} , loss {1}".format(e, loss.item()))
        print("mean {0} , std {1}".format(latent.mean().item(),
                                          latent.std().item()))

    with torch.no_grad():
        reconstruction = plt.plot_surface(
            with_points=False,
            points=torch.cat([
                latent.unsqueeze(1).repeat(1, input_pc.shape[1], 1), input_pc
            ],
                             dim=-1)[0],
            decoder=network.decoder,
            latent=latent,
            path=path,
            epoch=epoch,
            in_epoch=ds.npyfiles_mnfld[itemindex].split('/')[-3] + '_' +
            ds.npyfiles_mnfld[itemindex].split('/')[-1].split('.npy')[0] +
            '_after',
            shapefile=ds.npyfiles_mnfld[itemindex],
            resolution=resolution,
            mc_value=0,
            is_uniform_grid=True,
            verbose=True,
            save_html=False,
            save_ply=True,
            overwrite=True)
        return reconstruction
Ejemplo n.º 12
0
                                        includeNan=False,
                                        excludeID=[],
                                        excludeUUID=[])
        gpu = deviceIDs[0]
    else:
        gpu = args.gpu

    os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(gpu)
    base_dir = os.path.join('../', args.exps_dir, args.exp_name, timestamp)
    saved_model_state = torch.load(
        os.path.join(base_dir, 'checkpoints', 'ModelParameters',
                     args.checkpoint + ".pth"))
    saved_model_epoch = saved_model_state["epoch"]
    conf = ConfigFactory.parse_file(args.conf)
    network = utils.get_class(conf.get_string('train.network_class'))(
        conf=conf.get_config('network'),
        latent_size=conf.get_int('train.latent_size'))

    if (args.parallel):
        network.load_state_dict({
            '.'.join(k.split('.')[1:]): v
            for k, v in saved_model_state["model_state_dict"].items()
        })
    else:
        network.load_state_dict(saved_model_state["model_state_dict"])

    evaluate(network=network.cuda(),
             exps_dir=args.exps_dir,
             experiment_name=args.exp_name,
             timestamp=timestamp,
             split_filename=args.split,
Ejemplo n.º 13
0
    def get_sampler(sampler_type):

        return utils.get_class("model.sample.{0}".format(sampler_type))
Ejemplo n.º 14
0
    def __init__(self,**kwargs):

        if (type(kwargs['conf']) == str):
            self.conf = ConfigFactory.parse_file(kwargs['conf'])
            self.conf_filename = kwargs['conf']
        else:
            self.conf = kwargs['conf']
        self.batch_size = kwargs['batch_size']
        self.nepochs = kwargs['nepochs']
        self.expnameraw = self.conf.get_string('train.expname')
        self.expname = self.conf.get_string('train.expname') +  kwargs['expname']

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)):
                timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        self.adjust_lr = self.conf.get_bool('train.adjust_lr')
        self.GPU_INDEX = kwargs['gpu_index']
        self.exps_folder_name = kwargs['exps_folder_name']

        utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name))

        self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
        utils.mkdir_ifnotexists(self.expdir)
        self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))
        log_dir = os.path.join(self.expdir, self.timestamp, 'log')
        self.log_dir = log_dir
        utils.mkdir_ifnotexists(log_dir)
        utils.configure_logging(kwargs['debug'],kwargs['quiet'],os.path.join(self.log_dir,'log.txt'))

        self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"

        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path,self.model_params_subdir))
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))

        if (not self.GPU_INDEX == 'all'):
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        # Backup code
        self.code_path = os.path.join(self.expdir, self.timestamp, 'code')
        utils.mkdir_ifnotexists(self.code_path)
        for folder in ['training','preprocess','utils','model','datasets','confs']:
            utils.mkdir_ifnotexists(os.path.join(self.code_path, folder))
            os.system("""cp -r ./{0}/* "{1}" """.format(folder,os.path.join(self.code_path, folder)))

        os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.code_path, 'confs/runconf.conf')))

        logging.info('shell command : {0}'.format(' '.join(sys.argv)))

        if (self.conf.get_string('train.data_split') == 'none'):
            self.ds = utils.get_class(self.conf.get_string('train.dataset'))(split=None, dataset_path=self.conf.get_string('train.dataset_path'), dist_file_name=None)
        else:
            train_split_file = './confs/splits/{0}'.format(self.conf.get_string('train.data_split'))

            with open(train_split_file, "r") as f:
                train_split = json.load(f)

            self.ds = utils.get_class(self.conf.get_string('train.dataset'))(split=train_split,
                                                                             dataset_path=self.conf.get_string('train.dataset_path'),
                                                                             dist_file_name=self.conf.get_string('train.dist_file_name'))

        self.dataloader = torch.utils.data.DataLoader(self.ds,
                                                      batch_size=self.batch_size,
                                                      shuffle=True,
                                                      num_workers=kwargs['workers'],drop_last=True,pin_memory=True)
        self.eval_dataloader = torch.utils.data.DataLoader(self.ds,
                                                           batch_size=1,
                                                           shuffle=True,
                                                           num_workers=0, drop_last=True)

        self.latent_size = self.conf.get_int('train.latent_size')

        self.network = utils.get_class(self.conf.get_string('train.network_class'))(conf=self.conf.get_config('network'),
                                                                                    latent_size=self.latent_size)
        if kwargs['parallel']:
            self.network = torch.nn.DataParallel(self.network)

        if torch.cuda.is_available():
            self.network.cuda()

        self.parallel = kwargs['parallel']
        self.loss = utils.get_class(self.conf.get_string('network.loss.loss_type'))(**self.conf.get_config('network.loss.properties'))
        self.lr_schedules = BaseTrainRunner.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule'))

        self.optimizer = torch.optim.Adam(
        [
            {
                "params": self.network.parameters(),
                "lr": self.lr_schedules[0].get_learning_rate(0),
            }
        ])

        self.start_epoch = 0
        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')

            saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.start_epoch = saved_model_state['epoch']