示例#1
0
    def create_loss(self, coarse, fine, gt, alpha):
        gt_ds = gt[:, :coarse.shape[1], :]
        loss_coarse = earth_mover(coarse, gt_ds)
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = chamfer(fine, gt)
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
示例#2
0
    def create_loss(self, coarse, fine, gt, alpha):
        gt_ds = gt[:, :coarse.shape[1], :]
        loss_coarse = 10 * earth_mover(coarse[:, :, 0:3], gt_ds[:, :, 0:3])
        _, retb, _, retd = tf_nndistance.nn_distance(coarse[:, :, 0:3],
                                                     gt_ds[:, :, 0:3])
        for i in range(np.shape(gt_ds)[0]):
            index = tf.expand_dims(retb[i], -1)
            sem_feat = tf.nn.softmax(coarse[i, :, 3:], -1)
            sem_gt = tf.cast(
                tf.one_hot(
                    tf.gather_nd(tf.cast(gt_ds[i, :, 3] * 80 * 12, tf.int32),
                                 index), 12), tf.float32)
            loss_sem_coarse = tf.reduce_mean(-tf.reduce_sum(
                0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) *
                (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1]))
            loss_coarse += loss_sem_coarse
        add_train_summary('train/coarse_loss', loss_coarse)
        update_coarse = add_valid_summary('valid/coarse_loss', loss_coarse)

        loss_fine = 10 * chamfer(fine[:, :, 0:3], gt[:, :, 0:3])
        _, retb, _, retd = tf_nndistance.nn_distance(fine[:, :, 0:3], gt[:, :,
                                                                         0:3])
        for i in range(np.shape(gt)[0]):
            index = tf.expand_dims(retb[i], -1)
            sem_feat = tf.nn.softmax(fine[i, :, 3:], -1)
            sem_gt = tf.cast(
                tf.one_hot(
                    tf.gather_nd(tf.cast(gt[i, :, 3] * 80 * 12, tf.int32),
                                 index), 12), tf.float32)
            loss_sem_fine = tf.reduce_mean(-tf.reduce_sum(
                0.9 * sem_gt * tf.log(1e-6 + sem_feat) + (1 - 0.9) *
                (1 - sem_gt) * tf.log(1e-6 + 1 - sem_feat), [1]))
            loss_fine += loss_sem_fine
        add_train_summary('train/fine_loss', loss_fine)
        update_fine = add_valid_summary('valid/fine_loss', loss_fine)

        loss = loss_coarse + alpha * loss_fine
        add_train_summary('train/loss', loss)
        update_loss = add_valid_summary('valid/loss', loss)

        return loss, [update_coarse, update_fine, update_loss]
示例#3
0
def test(args):
    inputs = tf.placeholder(tf.float32, (1, None, 3))
    gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 3))
    model_module = importlib.import_module('.%s' % args.model_type, 'models')
    model = model_module.Model(inputs, gt, tf.constant(1.0))

    output = tf.placeholder(tf.float32, (1, args.num_gt_points, 3))
    cd_op = chamfer(output, gt)
    emd_op = earth_mover(output, gt)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)

    saver = tf.train.Saver()
    saver.restore(sess, args.checkpoint)

    os.makedirs(args.results_dir, exist_ok=True)
    csv_file = open(os.path.join(args.results_dir, 'results.csv'), 'w')
    writer = csv.writer(csv_file)
    writer.writerow(['id', 'cd', 'emd'])

    with open(args.list_path) as file:
        model_list = file.read().splitlines()
    total_time = 0
    total_cd = 0
    total_emd = 0
    cd_per_cat = {}
    emd_per_cat = {}
    for i, model_id in enumerate(model_list):
        partial = read_pcd(
            os.path.join(args.data_dir, 'partial', '%s.pcd' % model_id))
        complete = read_pcd(
            os.path.join(args.data_dir, 'complete', '%s.pcd' % model_id))
        start = time.time()
        completion = sess.run(model.outputs, feed_dict={inputs: [partial]})
        total_time += time.time() - start
        cd, emd = sess.run([cd_op, emd_op],
                           feed_dict={
                               output: completion,
                               gt: [complete]
                           })
        total_cd += cd
        total_emd += emd
        writer.writerow([model_id, cd, emd])

        synset_id, model_id = model_id.split('/')
        if not cd_per_cat.get(synset_id):
            cd_per_cat[synset_id] = []
        if not emd_per_cat.get(synset_id):
            emd_per_cat[synset_id] = []
        cd_per_cat[synset_id].append(cd)
        emd_per_cat[synset_id].append(emd)

        if i % args.plot_freq == 0:
            os.makedirs(os.path.join(args.results_dir, 'plots', synset_id),
                        exist_ok=True)
            plot_path = os.path.join(args.results_dir, 'plots', synset_id,
                                     '%s.png' % model_id)
            plot_pcd_three_views(plot_path, [partial, completion[0], complete],
                                 ['input', 'output', 'ground truth'],
                                 'CD %.4f  EMD %.4f' % (cd, emd),
                                 [5, 0.5, 0.5])
        if args.save_pcd:
            os.makedirs(os.path.join(args.results_dir, 'pcds', synset_id),
                        exist_ok=True)
            save_pcd(
                os.path.join(args.results_dir, 'pcds', '%s.pcd' % model_id),
                completion[0])
    csv_file.close()
    sess.close()

    print('Average time: %f' % (total_time / len(model_list)))
    print('Average Chamfer distance: %f' % (total_cd / len(model_list)))
    print('Average Earth mover distance: %f' % (total_emd / len(model_list)))
    print('Chamfer distance per category')
    for synset_id in cd_per_cat.keys():
        print(synset_id, '%f' % np.mean(cd_per_cat[synset_id]))
    print('Earth mover distance per category')
    for synset_id in emd_per_cat.keys():
        print(synset_id, '%f' % np.mean(emd_per_cat[synset_id]))
示例#4
0
def test(args):
    inputs = tf.placeholder(tf.float32, (1, None, 3))
    npts = tf.placeholder(tf.int32, (1, ))
    gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 6))
    model_module = importlib.import_module('.%s' % args.model_type, 'models')
    model = model_module.Model(inputs, npts, gt, tf.constant(1.0),
                               args.num_channel)

    output = tf.placeholder(tf.float32,
                            (1, args.num_gt_points, 3 + args.num_channel))
    cd_op = chamfer(output[:, :, 0:3], gt[:, :, 0:3])
    emd_op = earth_mover(output[:, :, 0:3], gt[:, :, 0:3])

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    sess = tf.Session(config=config)

    saver = tf.train.Saver()
    saver.restore(sess, args.checkpoint)

    os.makedirs(args.results_dir, exist_ok=True)
    csv_file = open(os.path.join(args.results_dir, 'results.csv'), 'w')
    writer = csv.writer(csv_file)
    writer.writerow(['id', 'cd', 'emd'])

    with open(args.list_path) as file:
        model_list = file.read().splitlines()
    total_time = 0
    total_cd = 0
    total_emd = 0
    cd_per_cat = {}
    emd_per_cat = {}
    np.random.seed(1)
    for i, model_id in enumerate(model_list):
        if args.experiment == 'shapenet':
            synset_id, model_id = model_id.split('/')
            partial = read_pcd(
                os.path.join(args.data_dir, 'partial', synset_id,
                             '%s.pcd' % model_id))
            complete = read_pcd(
                os.path.join(args.data_dir, 'complete', synset_id,
                             '%s.pcd' % model_id))
        elif args.experiment == 'suncg':
            synset_id = 'all_rooms'
            partial = read_pcd(
                os.path.join(args.data_dir, 'pcd_partial',
                             '%s.pcd' % model_id))
            complete = read_pcd(
                os.path.join(args.data_dir, 'pcd_complete',
                             '%s.pcd' % model_id))
        if args.rotate:
            angle = np.random.rand(1) * 2 * np.pi
            partial = np.stack([
                np.cos(angle) * partial[:, 0] - np.sin(angle) * partial[:, 2],
                partial[:, 1],
                np.sin(angle) * partial[:, 0] + np.cos(angle) * partial[:, 2]
            ],
                               axis=-1)
            complete = np.stack([
                np.cos(angle) * complete[:, 0] -
                np.sin(angle) * complete[:, 2], complete[:, 1],
                np.sin(angle) * complete[:, 0] +
                np.cos(angle) * complete[:, 2], complete[:, 3], complete[:, 4],
                complete[:, 5]
            ],
                                axis=-1)
        partial = partial[:, :3]
        complete = resample_pcd(complete, 16384)
        start = time.time()
        completion1, completion2, mesh_out = sess.run(
            [model.outputs1, model.outputs2, model.gt_can],
            feed_dict={
                inputs: [partial],
                npts: [partial.shape[0]],
                gt: [complete]
            })
        completion1[0][:, (3 + args.num_channel):] *= 0
        completion2[0][:, (3 + args.num_channel):] *= 0
        mesh_out[0][:, (3 + args.num_channel):] *= 0
        total_time += time.time() - start
        # cd, emd = sess.run([cd_op, emd_op],
        cd, emd = sess.run([cd_op, cd_op],
                           feed_dict={
                               output: completion2,
                               gt: [complete]
                           })
        total_cd += cd
        total_emd += emd
        if not cd_per_cat.get(synset_id):
            cd_per_cat[synset_id] = []
        if not emd_per_cat.get(synset_id):
            emd_per_cat[synset_id] = []
        cd_per_cat[synset_id].append(cd)
        emd_per_cat[synset_id].append(emd)
        writer.writerow([model_id, cd, emd])

        if i % args.plot_freq == 0:
            os.makedirs(
                os.path.join(args.results_dir, 'plots', synset_id),
                exist_ok=True)
            plot_path = os.path.join(args.results_dir, 'plots', synset_id,
                                     '%s.png' % model_id)
            plot_pcd_three_views(
                plot_path, [
                    partial, completion1[0], completion2[0], mesh_out[0],
                    complete
                ], ['input', 'coarse', 'fine', 'mesh', 'ground truth'],
                'CD %.4f  EMD %.4f' % (cd, emd), [5, 0.5, 0.5, 0.5, 0.5],
                num_channel=args.num_channel)
        if args.save_pcd:
            os.makedirs(
                os.path.join(args.results_dir, 'input', synset_id),
                exist_ok=True)
            pts_coord = partial[:, 0:3]
            pts_color = matplotlib.cm.cool((partial[:, 1]))[:, 0:3]
            # save_pcd(os.path.join(args.results_dir, 'input', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1))
            pcd = PointCloud()
            pcd.points = Vector3dVector(pts_coord)
            pcd.colors = Vector3dVector(pts_color)
            write_point_cloud(
                os.path.join(args.results_dir, 'input', synset_id,
                             '%s.ply' % model_id),
                pcd,
                write_ascii=True)
            os.makedirs(
                os.path.join(args.results_dir, 'output1', synset_id),
                exist_ok=True)
            pts_coord = completion1[0][:, 0:3]
            pts_color = matplotlib.cm.Set1(
                (np.argmax(completion1[0][:, 3:3 + args.num_channel], -1) +
                 1) / args.num_channel - 0.5 / args.num_channel)[:, 0:3]
            # pts_color = matplotlib.cm.tab20((np.argmax(completion1[0][:, 3:3+args.num_channel], -1) + 1)/args.num_channel - 0.5/args.num_channel)[:,0:3]
            # save_pcd(os.path.join(args.results_dir, 'output1', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1))
            pcd.points = Vector3dVector(pts_coord)
            pcd.colors = Vector3dVector(pts_color)
            write_point_cloud(
                os.path.join(args.results_dir, 'output1', synset_id,
                             '%s.ply' % model_id),
                pcd,
                write_ascii=True)
            os.makedirs(
                os.path.join(args.results_dir, 'output2', synset_id),
                exist_ok=True)
            pts_coord = completion2[0][:, 0:3]
            pts_color = matplotlib.cm.Set1(
                (np.argmax(completion2[0][:, 3:3 + args.num_channel], -1) +
                 1) / args.num_channel - 0.5 / args.num_channel)[:, 0:3]
            # pts_color = matplotlib.cm.tab20((np.argmax(completion2[0][:, 3:3+args.num_channel], -1) + 1)/args.num_channel - 0.5/args.num_channel)[:,0:3]
            # save_pcd(os.path.join(args.results_dir, 'output2', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1))
            pcd.points = Vector3dVector(pts_coord)
            pcd.colors = Vector3dVector(pts_color)
            write_point_cloud(
                os.path.join(args.results_dir, 'output2', synset_id,
                             '%s.ply' % model_id),
                pcd,
                write_ascii=True)
            #######
            os.makedirs(
                os.path.join(args.results_dir, 'regions', synset_id),
                exist_ok=True)
            for idx in range(3, 3 + args.num_channel):
                val_min = np.min(completion2[0][:, idx])
                val_max = np.max(completion2[0][:, idx])
                pts_color = 0.8 * matplotlib.cm.Reds(
                    (completion2[0][:, idx] - val_min) /
                    (val_max - val_min))[:, 0:3]
                pts_color += 0.2 * matplotlib.cm.gist_gray(
                    (completion2[0][:, idx] - val_min) /
                    (val_max - val_min))[:, 0:3]
                pcd.colors = Vector3dVector(pts_color)
                write_point_cloud(
                    os.path.join(args.results_dir, 'regions', synset_id,
                                 '%s_%s.ply' % (model_id, idx - 3)),
                    pcd,
                    write_ascii=True)
            os.makedirs(
                os.path.join(args.results_dir, 'gt', synset_id), exist_ok=True)
            pts_coord = complete[:, 0:3]
            if args.experiment == 'shapenet':
                pts_color = matplotlib.cm.cool(complete[:, 1])[:, 0:3]
            elif args.experiment == 'suncg':
                pts_color = matplotlib.cm.Set1(complete[:, 3] -
                                               0.5 / args.num_channel)[:, 0:3]
            # save_pcd(os.path.join(args.results_dir, 'gt', synset_id, '%s.ply' % model_id), np.concatenate((pts_coord, pts_color), -1))
            pcd.points = Vector3dVector(pts_coord)
            pcd.colors = Vector3dVector(pts_color)
            write_point_cloud(
                os.path.join(args.results_dir, 'gt', synset_id,
                             '%s.ply' % model_id),
                pcd,
                write_ascii=True)
    sess.close()

    print('Average time: %f' % (total_time / len(model_list)))
    print('Average Chamfer distance: %f' % (total_cd / len(model_list)))
    print('Average Earth mover distance: %f' % (total_emd / len(model_list)))
    writer.writerow([
        total_time / len(model_list), total_cd / len(model_list),
        total_emd / len(model_list)
    ])
    print('Chamfer distance per category')
    for synset_id in cd_per_cat.keys():
        print(synset_id, '%f' % np.mean(cd_per_cat[synset_id]))
        writer.writerow([synset_id, np.mean(cd_per_cat[synset_id])])
    print('Earth mover distance per category')
    for synset_id in emd_per_cat.keys():
        print(synset_id, '%f' % np.mean(emd_per_cat[synset_id]))
        writer.writerow([synset_id, np.mean(emd_per_cat[synset_id])])
    csv_file.close()