Esempio n. 1
0
class PSGNPreprocessor:
    ''' Point Set Generation Networks (PSGN) preprocessor class.

    Args:
        cfg_path (str): path to config file
        pointcloud_n (int): number of output points
        dataset (dataset): dataset
        device (device): pytorch device
        model_file (str): model file
    '''
    def __init__(self,
                 cfg_path,
                 pointcloud_n,
                 dataset=None,
                 device=None,
                 model_file=None):
        self.cfg = config.load_config(cfg_path, 'configs/default.yaml')
        self.pointcloud_n = pointcloud_n
        self.device = device
        self.dataset = dataset
        self.model = config.get_model(self.cfg, device, dataset)

        # Output directory of psgn model
        out_dir = self.cfg['training']['out_dir']
        # If model_file not specified, use the one from psgn model
        if model_file is None:
            model_file = self.cfg['test']['model_file']
        # Load model
        self.checkpoint_io = CheckpointIO(out_dir, model=self.model)
        self.checkpoint_io.load(model_file)

    def __call__(self, inputs):
        self.model.eval()
        with torch.no_grad():
            points = self.model(inputs)

        export_pointcloud(
            points.squeeze(0).cpu().numpy(),
            '/is/sg/lmescheder/Desktop/points.ply')

        batch_size = points.size(0)
        T = points.size(1)

        # Subsample points if necessary
        if T != self.pointcloud_n:
            idx = torch.randint(low=0,
                                high=T,
                                size=(batch_size, self.pointcloud_n),
                                device=self.device)
            idx = idx[:, :, None].expand(batch_size, self.pointcloud_n, 3)

            points = torch.gather(points, dim=1, index=idx)

        return points
Esempio n. 2
0
class PSGNPreprocessor:
    ''' Point Set Generation Networks (PSGN) preprocessor class.

  Args:
      cfg_path (str): path to config file
      pointcloud_n (int): number of output points
      dataset (dataset): dataset
      model_file (str): model file
  '''
    def __init__(self, cfg_path, pointcloud_n, dataset=None, model_file=None):
        self.cfg = config.load_config(cfg_path, 'configs/default.yaml')
        self.pointcloud_n = pointcloud_n
        self.dataset = dataset
        self.model = config.get_model(self.cfg, dataset)

        # Output directory of psgn model
        out_dir = self.cfg['training']['out_dir']
        # If model_file not specified, use the one from psgn model
        if model_file is None:
            model_file = self.cfg['test']['model_file']
        # Load model
        self.checkpoint_io = CheckpointIO(model=model, checkpoint_dir=out_dir)
        self.checkpoint_io.load(model_file)

    def __call__(self, inputs):
        points = self.model(inputs, training=False)

        batch_size = points.shape[0]
        t = points.shape[1]

        # Subsample points if necessary
        if t != self.pointcloud_n:
            idx = np.random.randint(low=0,
                                    high=t,
                                    size=(batch_size, self.pointcloud_n))
            idx = tf.convert_to_tensor(idx[:, :, None])
            idx = tf.broadcast_to(idx,
                                  shape=[batch_size, self.pointcloud_n, 3])
            points = tf.gather(points, indices=idx, axis=None, batch_dims=1)

        return points
                                         collate_fn=data.collate_remove_none,
                                         worker_init_fn=data.worker_init_fn)
data_vis = next(iter(vis_loader))

# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Intialize training
npoints = 1000
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
try:
    load_dict = checkpoint_io.load('model.pt')
except FileExistsError:
    load_dict = dict()
epoch_it = load_dict.get('epoch_it', -1)
it = load_dict.get('it', -1)
metric_val_best = load_dict.get('loss_val_best',
                                -model_selection_sign * np.inf)

# Hack because of previous bug in code
# TODO: remove, because shouldn't be necessary
if metric_val_best == np.inf or metric_val_best == -np.inf:
    metric_val_best = -model_selection_sign * np.inf

# TODO: remove this switch
# metric_val_best = -model_selection_sign * np.inf
input_type = cfg['data']['input_type']
vis_n_outputs = cfg['generation']['vis_n_outputs']
mesh_extension = cfg['generation']['mesh_extension']

# Dataset
# This is for DTU when we parallelise over images
# we do not want to treat different images from same object as
# different objects
cfg['data']['split_model_for_images'] = False
dataset = config.get_dataset(cfg, mode='test', return_idx=True)

# Model
model = config.get_model(cfg, device=device, len_dataset=len(dataset))

checkpoint_io = CheckpointIO(out_dir, model=model)
checkpoint_io.load(cfg['test']['model_file'])

# Generator
generator = config.get_generator(model, cfg, device=device)

torch.manual_seed(0)
# Loader
test_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=1,
                                          num_workers=0,
                                          shuffle=True)

# Statistics
time_dicts = []
vis_file_dict = {}
Esempio n. 5
0
vis_loader = vis_dataset.loader()

data_vis = next(iter(vis_loader))

# Model
model = config.get_model(cfg, dataset=train_dataset)

# Intialize training
npoints = 1000
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-08)
# optimizer = tf.keras.optimizers.SGD(learning_rate=1e-4, momentum=0.9)

checkpoint_io = CheckpointIO(model, optimizer, model_selection_sign, out_dir)

try:
    checkpoint_io.load('model')
except FileExistsError:
    print("start from scratch")

epoch_it = checkpoint_io.ckpt.epoch_it
it = checkpoint_io.ckpt.it
metric_val_best = checkpoint_io.ckpt.metric_val_best

trainer = config.get_trainer(model, optimizer, cfg)

# Hack because of previous bug in code
if metric_val_best == np.inf or metric_val_best == -np.inf:
    metric_val_best = -model_selection_sign * np.inf

print('Current best validation metric (%s): %.8f' %
      (model_selection_metric, metric_val_best))
                             device=device,
                             len_dataset=len(train_dataset))

    # Initialize training
    optimizer = optim.Adam(model.parameters(), lr=lr)

    generator = config.get_generator(model, cfg, device=device)

    trainer = config.get_trainer(model,
                                 optimizer,
                                 cfg,
                                 device=device,
                                 generator=generator)
    checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
    try:
        load_dict = checkpoint_io.load('model.pt', device=device)
    except FileExistsError:
        load_dict = dict()

    epoch_it = load_dict.get('epoch_it', -1)
    it = load_dict.get('it', -1)
    metric_val_best = load_dict.get('loss_val_best',
                                    -model_selection_sign * np.inf)

    if metric_val_best == np.inf or metric_val_best == -np.inf:
        metric_val_best = -model_selection_sign * np.inf

    print('Current best validation metric (%s): %.8f' %
          (model_selection_metric, metric_val_best))

    scheduler = optim.lr_scheduler.MultiStepLR(
    input_type = cfg['data']['input_type']
    vis_n_outputs = cfg['generation']['vis_n_outputs']
    mesh_extension = cfg['generation']['mesh_extension']

    # Dataset
    # This is for DTU when we parallelise over images
    # we do not want to treat different images from same object as
    # different objects
    cfg['data']['split_model_for_images'] = False
    dataset = config.get_dataset(cfg, mode='test', return_idx=True)

    # Model
    model = config.get_model(cfg, device=device, len_dataset=len(dataset))

    checkpoint_io = CheckpointIO(out_dir, model=model)
    checkpoint_io.load(cfg['test']['model_file'], device=device)
    
    # Generator
    generator = config.get_generator(model, cfg, device=device)

    torch.manual_seed(0)
    # Loader
    test_loader = torch.utils.data.DataLoader(
        dataset, batch_size=1, num_workers=0, shuffle=True)

    # Statistics
    time_dicts = []
    vis_file_dict = {}

    # Generate
    model.eval()
Esempio n. 8
0
vis_n_outputs = cfg['generation']['vis_n_outputs']
if vis_n_outputs is None:
    vis_n_outputs = -1

# Dataset
dataset = config.get_dataset('test',
                             cfg,
                             return_idx=True,
                             use_target_domain=args.da)

# Model
model = config.get_model(cfg, device=device, dataset=dataset)

checkpoint_io = CheckpointIO(out_dir, model=model)
# load with 'cuda:0' because we set visible devices earlier
checkpoint_io.load(cfg['test']['model_file'], 'cuda:0')

# Generator
generator = config.get_generator(model, cfg, device=device)

# Determine what to generate
generate_mesh = cfg['generation']['generate_mesh']
generate_pointcloud = cfg['generation']['generate_pointcloud']

if generate_mesh and not hasattr(generator, 'generate_mesh'):
    generate_mesh = False
    print('Warning: generator does not support mesh generation.')

if generate_pointcloud and not hasattr(generator, 'generate_pointcloud'):
    generate_pointcloud = False
    print('Warning: generator does not support pointcloud generation.')
Esempio n. 9
0
    return fields

fields = get_fields()
train_dataset = data.Shapes3dDataset_AllImgs(dataset_folder, fields, split=None)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False,
    collate_fn=data.collate_remove_none,
    worker_init_fn=data.worker_init_fn)

# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

checkpoint_io = CheckpointIO(cfg['training']['out_dir'], model=model)
try:
    load_dict = checkpoint_io.load('model_best.pt', strict=True)
except FileExistsError:
    load_dict = dict()

it = 0
for batch in tqdm(train_loader):
    it += 1
    model.eval()

    encoder_inputs, _ = compose_inputs(batch, mode='train', device=device, input_type='depth_pointcloud',
                                                depth_pointcloud_transfer=depth_pointcloud_transfer)
    cur_batch_size = encoder_inputs.size(0)
    idxs = batch.get('idx')
    viewids = batch.get('viewid')

    with torch.no_grad():
Esempio n. 10
0
def run(pointcloud_path, out_dir, decoder_type='siren', resume=True, **kwargs):
    """
    test_implicit_siren_noisy_wNormals
    """
    device = torch.device('cuda:0')

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # data
    points, normals = np.split(read_ply(pointcloud_path).astype('float32'),
                               (3, ),
                               axis=1)

    pmax, pmin = points.max(axis=0), points.min(axis=0)
    scale = (pmax - pmin).max()
    pcenter = (pmax + pmin) / 2
    points = (points - pcenter) / scale * 1.5
    scale_mat = scale_mat_inv = np.identity(4)
    scale_mat[[0, 1, 2], [0, 1, 2]] = 1 / scale * 1.5
    scale_mat[[0, 1, 2], [3, 3, 3]] = -pcenter / scale * 1.5
    scale_mat_inv = np.linalg.inv(scale_mat)
    normals = normals @ np.linalg.inv(scale_mat[:3, :3].T)
    object_bounding_sphere = np.linalg.norm(points, axis=1).max()
    pcl = trimesh.Trimesh(vertices=points,
                          vertex_normals=normals,
                          process=False)
    pcl.export(os.path.join(out_dir, "input_pcl.ply"), vertex_normal=True)

    assert (np.abs(points).max() < 1)

    dataset = torch.utils.data.TensorDataset(torch.from_numpy(points),
                                             torch.from_numpy(normals))
    batch_size = 5000
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=1,
        shuffle=True,
        collate_fn=tolerating_collate,
    )
    gt_surface_pts_all = torch.from_numpy(points).unsqueeze(0).float()
    gt_surface_normals_all = torch.from_numpy(normals).unsqueeze(0).float()
    gt_surface_normals_all = F.normalize(gt_surface_normals_all, dim=-1)

    if kwargs['use_off_normal_loss']:
        # subsample from pointset
        sub_idx = torch.randperm(gt_surface_normals_all.shape[1])[:20000]
        gt_surface_pts_sub = torch.index_select(gt_surface_pts_all, 1,
                                                sub_idx).to(device=device)
        gt_surface_normals_sub = torch.index_select(gt_surface_normals_all, 1,
                                                    sub_idx).to(device=device)
        gt_surface_normals_sub = denoise_normals(gt_surface_pts_sub,
                                                 gt_surface_normals_sub,
                                                 neighborhood_size=30)

    if decoder_type == 'siren':
        decoder_params = {
            'dim': 3,
            "out_dims": {
                'sdf': 1
            },
            "c_dim": 0,
            "hidden_size": 256,
            'n_layers': 3,
            "first_omega_0": 30,
            "hidden_omega_0": 30,
            "outermost_linear": True,
        }
        decoder = Siren(**decoder_params)
        # pretrained_model_file = os.path.join('data', 'trained_model', 'siren_l{}_c{}_o{}.pt'.format(
        #                     decoder_params['n_layers'], decoder_params['hidden_size'], decoder_params['first_omega_0']))
        # loaded_state_dict = torch.load(pretrained_model_file)
        # decoder.load_state_dict(loaded_state_dict)
    elif decoder_type == 'sdf':
        decoder_params = {
            'dim': 3,
            "out_dims": {
                'sdf': 1
            },
            "c_dim": 0,
            "hidden_size": 512,
            'n_layers': 8,
            'bias': 1.0,
        }
        decoder = SDF(**decoder_params)
    else:
        raise ValueError
    print(decoder)
    decoder = decoder.to(device)

    # training
    total_iter = 30000
    optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10000, 20000],
                                                     gamma=0.5)

    shape = Shape(gt_surface_pts_all.cuda(),
                  n_points=gt_surface_pts_all.shape[1] // 16,
                  normals=gt_surface_normals_all.cuda())
    # initialize siren with sphere_initialization
    checkpoint_io = CheckpointIO(out_dir, model=decoder, optimizer=optimizer)
    load_dict = dict()
    if resume:
        models_avail = [f for f in os.listdir(out_dir) if f[-3:] == '.pt']
        if len(models_avail) > 0:
            models_avail.sort()
            load_dict = checkpoint_io.load(models_avail[-1])

    it = load_dict.get('it', 0)
    if it > 0:
        try:
            iso_point_files = [
                f for f in os.listdir(out_dir) if f[-7:] == 'iso.ply'
            ]
            iso_point_iters = [
                int(os.path.basename(f[:-len('_iso.ply')]))
                for f in iso_point_files
            ]
            iso_point_iters = np.array(iso_point_iters)
            idx = np.argmax(iso_point_iters[(iso_point_iters - it) <= 0])
            iso_point_file = np.array(iso_point_files)[(iso_point_iters -
                                                        it) <= 0][idx]
            iso_points = torch.from_numpy(
                read_ply(os.path.join(out_dir, iso_point_file))[..., :3])
            shape.points = iso_points.to(device=shape.points.device).view(
                1, -1, 3)
            print('Loaded iso-points from %s' % iso_point_file)
        except Exception as e:
            pass

    # loss
    eikonal_loss = NormalLengthLoss(reduction='mean')

    # start training
    # save_ply(os.path.join(out_dir, 'in_iso_points.ply'), (to_homogen(shape.points).cpu().detach().numpy() @ scale_mat_inv.T)[...,:3].reshape(-1,3))
    save_ply(os.path.join(out_dir, 'in_iso_points.ply'),
             shape.points.cpu().view(-1, 3))
    # autograd.set_detect_anomaly(True)
    iso_points = shape.points
    iso_points_normal = None
    while True:
        if (it > total_iter):
            checkpoint_io.save('model_{:04d}.pt'.format(it), it=it)
            mesh = get_surface_high_res_mesh(
                lambda x: decoder(x).sdf.squeeze(), resolution=512)
            mesh.apply_transform(scale_mat_inv)
            mesh.export(os.path.join(out_dir, "final.ply"))
            break

        for batch in data_loader:

            gt_surface_pts, gt_surface_normals = batch
            gt_surface_pts.unsqueeze_(0)
            gt_surface_normals.unsqueeze_(0)
            gt_surface_pts = gt_surface_pts.to(device=device).detach()
            gt_surface_normals = gt_surface_normals.to(device=device).detach()

            optimizer.zero_grad()
            decoder.train()
            loss = defaultdict(float)

            lambda_surface_sdf = 1e3
            lambda_surface_normal = 1e2
            if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up']:
                lambda_surface_sdf = kwargs['lambda_surface_sdf']
                lambda_surface_normal = kwargs['lambda_surface_normal']

            # debug
            if (it - kwargs['warm_up']) % 1000 == 0:
                # generate iso surface
                with torch.autograd.no_grad():
                    box_size = (object_bounding_sphere * 2 + 0.2, ) * 3
                    imgs = plot_cuts(
                        lambda x: decoder(x).sdf.squeeze().detach(),
                        box_size=box_size,
                        max_n_eval_pts=10000,
                        thres=0.0,
                        imgs_per_cut=1,
                        save_path=os.path.join(out_dir, '%010d_iso.html' % it))
                    mesh = get_surface_high_res_mesh(
                        lambda x: decoder(x).sdf.squeeze(), resolution=200)
                    mesh.apply_transform(scale_mat_inv)
                    mesh.export(os.path.join(out_dir, '%010d_mesh.ply' % it))

            if it % 2000 == 0:
                checkpoint_io.save('model.pt', it=it)

            pred_surface_grad = gradient(gt_surface_pts.clone(),
                                         lambda x: decoder(x).sdf)

            # every once in a while update shape and points
            # sample points in space and on the shape
            # use iso points to weigh data points loss
            weights = 1.0
            if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up']:
                if it == kwargs['warm_up'] or kwargs['resample_every'] > 0 and (
                        it -
                        kwargs['warm_up']) % kwargs['resample_every'] == 0:
                    # if shape.points.shape[1]/iso_points.shape[1] < 1.0:
                    #     idx = fps(iso_points.view(-1,3), torch.zeros(iso_points.shape[1], dtype=torch.long, device=iso_points.device), shape.points.shape[1]/iso_points.shape[1])
                    #     iso_points = iso_points.view(-1,3)[idx].view(1,-1,3)

                    iso_points = shape.get_iso_points(
                        iso_points + 0.1 * (torch.rand_like(iso_points) - 0.5),
                        decoder,
                        ear=kwargs['ear'],
                        outlier_tolerance=kwargs['outlier_tolerance'])
                    # iso_points = shape.get_iso_points(shape.points, decoder, ear=kwargs['ear'], outlier_tolerance=kwargs['outlier_tolerance'])
                    iso_points_normal = estimate_pointcloud_normals(
                        iso_points.view(1, -1, 3), 8, False)
                    if kwargs['denoise_normal']:
                        iso_points_normal = denoise_normals(iso_points,
                                                            iso_points_normal,
                                                            num_points=None)
                        iso_points_normal = iso_points_normal.view_as(
                            iso_points)
                elif iso_points_normal is None:
                    iso_points_normal = estimate_pointcloud_normals(
                        iso_points.view(1, -1, 3), 8, False)

                # iso_points = resample_uniformly(iso_points.view(1,-1,3))
                # TODO: use gradient from network or neighborhood?
                iso_points_g = gradient(iso_points.clone(),
                                        lambda x: decoder(x).sdf)
                if it == kwargs['warm_up'] or kwargs['resample_every'] > 0 and (
                        it -
                        kwargs['warm_up']) % kwargs['resample_every'] == 0:
                    # save_ply(os.path.join(out_dir, '%010d_iso.ply' % it), (to_homogen(iso_points).cpu().detach().numpy() @ scale_mat_inv.T)[...,:3].reshape(-1,3), normals=iso_points_g.view(-1,3).detach().cpu())
                    save_ply(os.path.join(out_dir, '%010d_iso.ply' % it),
                             iso_points.cpu().detach().view(-1, 3),
                             normals=iso_points_g.view(-1, 3).detach().cpu())

                if kwargs['weight_mode'] == 1:
                    weights = get_iso_bilateral_weights(
                        gt_surface_pts, gt_surface_normals, iso_points,
                        iso_points_g).detach()
                elif kwargs['weight_mode'] == 2:
                    weights = get_laplacian_weights(gt_surface_pts,
                                                    gt_surface_normals,
                                                    iso_points,
                                                    iso_points_g).detach()
                elif kwargs['weight_mode'] == 3:
                    weights = get_heat_kernel_weights(gt_surface_pts,
                                                      gt_surface_normals,
                                                      iso_points,
                                                      iso_points_g).detach()

                if (it - kwargs['warm_up']
                    ) % 1000 == 0 and kwargs['weight_mode'] != -1:
                    print("min {:.4g}, max {:.4g}, std {:.4g}, mean {:.4g}".
                          format(weights.min(), weights.max(), weights.std(),
                                 weights.mean()))
                    colors = scaler_to_color(1 -
                                             weights.view(-1).cpu().numpy(),
                                             cmap='Reds')
                    save_ply(
                        os.path.join(out_dir, '%010d_batch_weight.ply' % it),
                        (to_homogen(gt_surface_pts).cpu().detach().numpy()
                         @ scale_mat_inv.T)[..., :3].reshape(-1, 3),
                        colors=colors)

                sample_idx = torch.randperm(
                    iso_points.shape[1])[:min(batch_size, iso_points.shape[1])]
                iso_points_sampled = iso_points.detach()[:, sample_idx, :]
                # iso_points_sampled = iso_points.detach()
                iso_points_sdf = decoder(iso_points_sampled.detach()).sdf
                loss_iso_points_sdf = iso_points_sdf.abs().mean(
                ) * kwargs['lambda_iso_sdf'] * iso_points_sdf.nelement() / (
                    iso_points_sdf.nelement() + 8000)
                loss['loss_sdf_iso'] = loss_iso_points_sdf.detach()
                loss['loss'] += loss_iso_points_sdf

                # TODO: predict iso_normals from local_frame
                iso_normals_sampled = iso_points_normal.detach()[:,
                                                                 sample_idx, :]
                iso_g_sampled = iso_points_g[:, sample_idx, :]
                loss_normals = torch.mean(
                    (1 - F.cosine_similarity(
                        iso_normals_sampled, iso_g_sampled, dim=-1).abs())
                ) * kwargs['lambda_iso_normal'] * iso_points_sdf.nelement() / (
                    iso_points_sdf.nelement() + 8000)
                # loss_normals = torch.mean((1 - F.cosine_similarity(iso_points_normal, iso_points_g, dim=-1).abs())) * kwargs['lambda_iso_normal']
                loss['loss_normal_iso'] = loss_normals.detach()
                loss['loss'] += loss_normals

            idx = torch.randperm(gt_surface_pts.shape[1]).to(
                device=gt_surface_pts.device)[:(gt_surface_pts.shape[1] // 2)]
            tmp = torch.index_select(gt_surface_pts, 1, idx)
            space_pts = torch.cat([
                torch.rand_like(tmp) * 2 - 1,
                torch.randn_like(tmp, device=tmp.device, dtype=tmp.dtype) * 0.1
                + tmp
            ],
                                  dim=1)

            space_pts.requires_grad_(True)
            pred_space_sdf = decoder(space_pts).sdf
            pred_space_grad = torch.autograd.grad(
                pred_space_sdf, [space_pts], [torch.ones_like(pred_space_sdf)],
                create_graph=True)[0]

            # 1. eikonal term
            loss_eikonal = (
                eikonal_loss(pred_surface_grad) +
                eikonal_loss(pred_space_grad)) * kwargs['lambda_eikonal']
            loss['loss_eikonal'] = loss_eikonal.detach()
            loss['loss'] += loss_eikonal

            # 2. SDF loss
            # loss on iso points
            pred_surface_sdf = decoder(gt_surface_pts).sdf

            loss_sdf = torch.mean(
                weights * pred_surface_sdf.abs()) * lambda_surface_sdf
            if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up'] and kwargs[
                    'lambda_iso_sdf'] != 0:
                # loss_sdf = 0.5 * loss_sdf
                loss_sdf = loss_sdf * pred_surface_sdf.nelement() / (
                    pred_surface_sdf.nelement() + iso_points_sdf.nelement())

            if kwargs['use_sal_loss'] and iso_points is not None:
                dists, idxs, _ = knn_points(space_pts.view(1, -1, 3),
                                            iso_points.view(1, -1, 3).detach(),
                                            K=1)
                dists = dists.view_as(pred_space_sdf)
                idxs = idxs.view_as(pred_space_sdf)
                loss_inter = ((eps_sqrt(dists).sqrt() - pred_space_sdf.abs())**
                              2).mean() * kwargs['lambda_inter_sal']
            else:
                alpha = (it / total_iter + 1) * 100
                loss_inter = torch.exp(
                    -alpha *
                    pred_space_sdf.abs()).mean() * kwargs['lambda_inter_sdf']

            loss_sald = torch.tensor(0.0).cuda()
            # prevent wrong closing for open mesh
            if kwargs['use_off_normal_loss'] and it < 1000:
                dists, idxs, _ = knn_points(space_pts.view(1, -1, 3),
                                            gt_surface_pts_sub.view(1, -1,
                                                                    3).cuda(),
                                            K=1)
                knn_normal = knn_gather(
                    gt_surface_normals_sub.cuda().view(1, -1, 3),
                    idxs).view(1, -1, 3)
                direction_correctness = -F.cosine_similarity(
                    knn_normal, pred_space_grad, dim=-1)
                direction_correctness[direction_correctness < 0] = 0
                loss_sald = torch.mean(
                    direction_correctness * torch.exp(-2 * dists)) * 2

            # 3. normal direction
            loss_normals = torch.mean(weights * (1 - F.cosine_similarity(
                gt_surface_normals, pred_surface_grad, dim=-1))
                                      ) * lambda_surface_normal
            if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up'] and kwargs[
                    'lambda_iso_normal'] != 0:
                # loss_normals = 0.5 * loss_normals
                loss_normals = loss_normals * gt_surface_normals.nelement() / (
                    gt_surface_normals.nelement() +
                    iso_normals_sampled.nelement())

            loss['loss_sdf'] = loss_sdf.detach()
            loss['loss_inter'] = loss_inter.detach()
            loss['loss_normals'] = loss_normals.detach()
            loss['loss_sald'] = loss_sald
            loss['loss'] += loss_sdf
            loss['loss'] += loss_inter
            loss['loss'] += loss_sald
            loss['loss'] += loss_normals

            loss['loss'].backward()
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.)

            optimizer.step()
            scheduler.step()
            if it % 20 == 0:
                print("iter {:05d} {}".format(
                    it, ', '.join([
                        '{}: {}'.format(k, v.item()) for k, v in loss.items()
                    ])))

            it += 1
Esempio n. 11
0
                                         worker_init_fn=data.worker_init_fn)
data_vis = next(iter(vis_loader))

# Model
model = config.get_model(cfg, device=device, dataset=train_dataset)

# Intialize training
npoints = 1000
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
trainer = config.get_trainer(model, optimizer, cfg, device=device)

checkpoint_io = CheckpointIO(out_dir, model=model, optimizer=optimizer)
try:
    load_dict = checkpoint_io.load('model.pt',
                                   strict=not args.load_no_strict,
                                   load_optimizer=not args.restart)
except FileExistsError:
    load_dict = dict()

if args.restart:
    epoch_it = -1
    it = -1
    metric_val_best = -model_selection_sign * np.inf
else:
    epoch_it = load_dict.get('epoch_it', -1)
    it = load_dict.get('it', -1)
    metric_val_best = load_dict.get('loss_val_best',
                                    -model_selection_sign * np.inf)

# Hack because of previous bug in code