コード例 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%d_%s' % (args.model, args.setting, os.getpid(), time_string))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    # 输出重定向
    # sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    print(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = 1000
    num_parts = setting.num_parts
    label_weights_list = setting.label_weights
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))

    data_train, data_num_train, label_train = data_utils.load_seg(
        args.filelist)
    data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]  # 2048
    point_num = data_train.shape[1]  # 6501
    num_val = data_val.shape[0]

    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3),
                            name="xforms")  # 缩放形变
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")  # 旋转形变
    jitter_range = tf.placeholder(tf.float32, shape=(1),
                                  name="jitter_range")  # 数据扰动
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int32,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')
    loss_val_avg = tf.placeholder(tf.float32)
    t_1_acc_val_avg = tf.placeholder(tf.float32)
    t_1_acc_val_instance_avg = tf.placeholder(tf.float32)
    t_1_acc_val_others_avg = tf.placeholder(tf.float32)

    ######################################################################

    # Set Inputs(points,features_sampled)
    features_sampled = None

    if setting.data_dim > 3:
        #   [3 pts_xyz , 2 img_xy , 4 extra_features]
        points, _, features = tf.split(pts_fts, [
            setting.data_format["pts_xyz"], setting.data_format["img_xy"],
            setting.data_format["extra_features"]
        ],
                                       axis=-1,
                                       name='split_points_xy_features')
        # r, g, b, i
        if setting.use_extra_features:
            # 选取extra特征
            features_sampled = tf.gather_nd(features,
                                            indices=indices,
                                            name='features_sampled')
    else:
        points = pts_fts

    points_sampled = tf.gather_nd(points,
                                  indices=indices,
                                  name='points_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)
    # N * 6501
    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    # Build net
    net = model.Net(points_augmented, features_sampled, None, None, num_parts,
                    is_training, setting)
    logits, probs = net.logits, net.probs

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)
    t_1_acc_op = pf.top_1_accuracy(probs, labels_sampled)
    t_1_acc_instance_op = pf.top_1_accuracy(probs, labels_sampled,
                                            labels_weights_sampled, 0.6)
    t_1_acc_others_op = pf.top_1_accuracy(probs, labels_sampled,
                                          labels_weights_sampled, 0.6, "less")

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()

    # lr decay
    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)

    # Optimizer
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=0.9,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    saver = tf.train.Saver(max_to_keep=1)

    # backup this file, model and setting
    if not os.path.exists(os.path.join(root_folder, args.model)):
        os.makedirs(os.path.join(root_folder, args.model))
    shutil.copy(__file__, os.path.join(root_folder,
                                       os.path.basename(__file__)))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'),
                os.path.join(root_folder, args.model + '.py'))
    shutil.copy(
        os.path.join(os.path.dirname(__file__),
                     args.model.split("_")[0] + "_kitti" + '.py'),
        os.path.join(root_folder,
                     args.model.split("_")[0] + "_kitti" + '.py'))
    shutil.copy(os.path.join(setting_path, args.setting + '.py'),
                os.path.join(root_folder, args.model, args.setting + '.py'))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    _ = tf.summary.scalar('loss/train_seg',
                          tensor=loss_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_seg',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_seg_instance',
                          tensor=t_1_acc_instance_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_seg_others',
                          tensor=t_1_acc_others_op,
                          collections=['train'])
    _ = tf.summary.scalar('loss/val_seg',
                          tensor=loss_val_avg,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_seg_instance',
                          tensor=t_1_acc_val_instance_avg,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_seg',
                          tensor=t_1_acc_val_avg,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_seg_others',
                          tensor=t_1_acc_val_others_avg,
                          collections=['val'])
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])

    # Session Run
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))

        for batch_idx in range(batch_num):
            if (batch_idx != 0 and batch_idx % step_val
                    == 0) or batch_idx == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                losses_val = []
                t_1_accs = []
                t_1_accs_instance = []
                t_1_accs_others = []

                for batch_val_idx in range(math.ceil(num_val / batch_size)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[label_val[
                        start_idx:end_idx, ...]]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val, scaling_range=scaling_range_val)
                    sess_op_list = [
                        loss_op, t_1_acc_op, t_1_acc_instance_op,
                        t_1_acc_others_op
                    ]
                    sess_feed_dict = {
                        pts_fts:
                        points_batch,
                        indices:
                        pf.get_indices(batch_size_val, sample_num,
                                       points_num_batch),
                        xforms:
                        xforms_np,
                        rotations:
                        rotations_np,
                        jitter_range:
                        np.array([jitter_val]),  # data jitter
                        labels_seg:
                        labels_batch,
                        labels_weights:
                        weights_batch,
                        is_training:
                        False
                    }

                    loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others = sess.run(
                        sess_op_list, feed_dict=sess_feed_dict)
                    print(
                        '{}-[Val  ]-Iter: {:06d}  Loss: {:.4f} T-1 Acc: {:.4f}'
                        .format(datetime.now(), batch_val_idx, loss_val,
                                t_1_acc_val))

                    sys.stdout.flush()
                    losses_val.append(loss_val * batch_size_val)
                    t_1_accs.append(t_1_acc_val * batch_size_val)
                    t_1_accs_instance.append(t_1_acc_val_instance *
                                             batch_size_val)
                    t_1_accs_others.append(t_1_acc_val_others * batch_size_val)

                loss_avg = sum(losses_val) / num_val
                t_1_acc_avg = sum(t_1_accs) / num_val
                t_1_acc_instance_avg = sum(t_1_accs_instance) / num_val
                t_1_acc_others_avg = sum(t_1_accs_others) / num_val

                summaries_feed_dict = {
                    loss_val_avg: loss_avg,
                    t_1_acc_val_avg: t_1_acc_avg,
                    t_1_acc_val_instance_avg: t_1_acc_instance_avg,
                    t_1_acc_val_others_avg: t_1_acc_others_avg
                }

                summaries_val = sess.run(summaries_val_op,
                                         feed_dict=summaries_feed_dict)
                summary_writer.add_summary(summaries_val, batch_idx)

                print('{}-[Val  ]-Average:      Loss: {:.4f} T-1 Acc: {:.4f}'.
                      format(datetime.now(), loss_avg, t_1_acc_avg))

                sys.stdout.flush()
                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]  # 2048 6501 9

            points_num_batch = data_num_train[start_idx:end_idx, ...]  #2048
            labels_batch = label_train[start_idx:end_idx, ...]  #2048 6501
            weights_batch = np.array(label_weights_list)[labels_batch]  # 值为零

            if start_idx + batch_size_train == num_train:
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(random.gauss(0, sample_num // 8))  # 均值 方差
            offset = max(offset, -sample_num // 4)
            offset = min(offset, sample_num // 4)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train, scaling_range=scaling_range)  #数据增强

            sess_op_list = [
                train_op, loss_op, t_1_acc_op, t_1_acc_instance_op,
                t_1_acc_others_op, summaries_op
            ]

            sess_feed_dict = {
                pts_fts:
                points_batch,
                indices:
                pf.get_indices(batch_size_train, sample_num_train,
                               points_num_batch),
                xforms:
                xforms_np,
                rotations:
                rotations_np,
                jitter_range:
                np.array([jitter]),
                labels_seg:
                labels_batch,
                labels_weights:
                weights_batch,
                is_training:
                True
            }

            _, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries = sess.run(
                sess_op_list, feed_dict=sess_feed_dict)
            print('{}-[Train]-Iter: {:06d}  Loss_seg: {:.4f} T-1 Acc: {:.4f}'.
                  format(datetime.now(), batch_idx, loss, t_1_acc))

            summary_writer.add_summary(summaries, batch_idx)
            sys.stdout.flush()

            ######################################################################
        print('{}-Done!'.format(datetime.now()))
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    parser.add_argument(
        '--epochs',
        help='Number of training epochs (default defined in setting)',
        type=int)
    parser.add_argument('--batch_size',
                        help='Batch size (default defined in setting)',
                        type=int)
    parser.add_argument(
        '--log',
        help=
        'Log to FILE in save folder; use - for stdout (default is log.txt)',
        metavar='FILE',
        default='log.txt')
    parser.add_argument('--no_timestamp_folder',
                        help='Dont save to timestamp folder',
                        action='store_true')
    parser.add_argument('--no_code_backup',
                        help='Dont backup code',
                        action='store_true')
    args = parser.parse_args()

    if not args.no_timestamp_folder:
        time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
        root_folder = os.path.join(
            args.save_folder, '%s_%s_%s_%d' %
            (args.model, args.setting, time_string, os.getpid()))
    else:
        root_folder = args.save_folder
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    if args.log != '-':
        sys.stdout = open(os.path.join(root_folder, args.log), 'w')

    #print('PID:', os.getpid())

    #print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = args.epochs or setting.num_epochs
    batch_size = args.batch_size or setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    #print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    # shuffle
    # data_num_train is the number of points in each point cloud
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    #print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    #print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    #print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')
    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup all code
    if not args.no_code_backup:
        code_folder = os.path.abspath(os.path.dirname(__file__))
        shutil.copytree(
            code_folder,
            os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        else:
            latest_ckpt = tf.train.latest_checkpoint(folder_ckpt)
            if latest_ckpt:
                print('{}-Found checkpoint {}'.format(datetime.now(),
                                                      latest_ckpt))
                saver.restore(sess, latest_ckpt)
                print('{}-Checkpoint loaded from {} (Iter {})'.format(
                    datetime.now(), latest_ckpt, sess.run(global_step)))

        for batch_idx_train in tqdm(range(batch_num)):
            ######################################################################
            # Validation

            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                for batch_val_idx in range(batch_num_val):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)
                    sess.run(
                        [
                            loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            indices:
                            pf.get_indices(batch_size_val, sample_num,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            labels_seg:
                            labels_batch,
                            labels_weights:
                            weights_batch,
                            is_training:
                            False,
                        })
                loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        summaries_val_op, global_step
                    ])
                summary_writer.add_summary(summaries_val, step)
                print(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            t_1_per_class_acc_val))
                sys.stdout.flush()
            ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                if is_list_of_h5_list:
                    filelist_train_prev = seg_list[(seg_list_idx - 1) %
                                                   len(seg_list)]
                    filelist_train = seg_list[seg_list_idx % len(seg_list)]
                    if filelist_train != filelist_train_prev:
                        data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
                            filelist_train)
                        num_train = data_train.shape[0]
                    seg_list_idx = seg_list_idx + 1
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    train_op, loss_mean_update_op, t_1_acc_update_op,
                    t_1_per_class_acc_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    True,
                })
            if batch_idx_train % 10 == 0:
                loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([
                    loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                    summaries_op, global_step
                ])
                summary_writer.add_summary(summaries, step)
                print(
                    '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), step, loss, t_1_acc,
                            t_1_per_class_acc))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--category',
                        '-c',
                        help='category name',
                        required=True)
    parser.add_argument('--level',
                        '-l',
                        type=int,
                        help='level id',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-k',
                        help='Path to a check point file for load',
                        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    parser.add_argument('--batch_size',
                        '-b',
                        help='Batch size during testing',
                        default=8,
                        type=int)
    parser.add_argument('--save_ply',
                        '-s',
                        help='Save results as ply',
                        action='store_true')
    parser.add_argument('--save_dir',
                        '-o',
                        help='The output directory',
                        type=str,
                        default=None)
    parser.add_argument('--save_num_shapes',
                        '-u',
                        help='how many shapes to visualize',
                        default=20,
                        type=int)
    args = parser.parse_args()
    print(args)

    if args.save_ply:
        if os.path.exists(args.save_dir):
            print('ERROR: folder %s exists! Please check and delete!' %
                  args.save_dir)
            exit(1)
        os.mkdir(args.save_dir)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    sample_num = setting.sample_num
    batch_size = args.batch_size

    args.data_folder = '../../data/sem_seg_h5/'

    # Load all test data
    args.filelist = os.path.join(args.data_folder,
                                 '%s-%d' % (args.category, args.level),
                                 'test_files.txt')
    data_test, _, label_gt = data_utils.load_seg(args.filelist)
    num_shape = data_test.shape[0]
    print('Loaded data: %s shapes in total to test.' % num_shape)

    # Load current category + level statistics
    with open(
            '../../stats/after_merging_label_ids/%s-level-%d.txt' %
        (args.category, args.level), 'r') as fin:
        setting.num_class = len(fin.readlines()) + 1  # with "other"
        print('NUM CLASS: %d' % setting.num_class)

    ######################################################################
    # Placeholders
    is_training = tf.placeholder(tf.bool, name='is_training')
    pts_fts = tf.placeholder(tf.float32,
                             shape=(batch_size, sample_num, setting.data_dim),
                             name='points')
    ######################################################################

    ######################################################################
    pts_fts_sampled = pts_fts
    points_sampled = pts_fts_sampled
    features_sampled = None

    net = model.Net(points_sampled, features_sampled, is_training, setting)
    seg_probs_op = tf.nn.softmax(net.logits, name='seg_probs')

    # for restore model
    saver = tf.train.Saver()

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    sess = tf.Session(config=config)

    # Load the model
    ckptstate = tf.train.get_checkpoint_state(args.load_ckpt)
    if ckptstate is not None:
        LOAD_MODEL_FILE = os.path.join(
            args.load_ckpt, os.path.basename(ckptstate.model_checkpoint_path))
        saver.restore(sess, LOAD_MODEL_FILE)
        print("Model loaded in file: %s" % LOAD_MODEL_FILE)
    else:
        print("Fail to load modelfile: %s" % args.load_ckpt)

    # Start the testing
    print('{}-Testing...'.format(datetime.now()))

    num_batch = (num_shape - 1) // batch_size + 1
    pts_batch = np.zeros((batch_size, sample_num, 3), dtype=np.float32)

    avg_acc = 0.0
    avg_cnt = 0

    shape_iou_tot = 0.0
    shape_iou_cnt = 0

    part_intersect = np.zeros((setting.num_class), dtype=np.float32)
    part_union = np.zeros((setting.num_class), dtype=np.float32)

    bar = ProgressBar()
    all_seg_probs = []
    for batch_idx in bar(range(num_batch)):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, num_shape)

        pts_batch[:end_idx - start_idx, ...] = data_test[start_idx:end_idx]

        seg_probs = sess.run(seg_probs_op,
                             feed_dict={
                                 pts_fts: pts_batch,
                                 is_training: False
                             })
        seg_probs = seg_probs[:end_idx - start_idx]
        all_seg_probs.append(seg_probs)

        seg_res = np.argmax(seg_probs[:, :, 1:], axis=-1) + 1

        avg_acc += np.sum(
            np.mean((seg_res == label_gt[start_idx:end_idx]) |
                    (label_gt[start_idx:end_idx] == 0),
                    axis=-1))
        avg_cnt += end_idx - start_idx

        seg_gt = label_gt[start_idx:end_idx]
        seg_res[seg_gt == 0] = 0

        for i in range(end_idx - start_idx):
            cur_pred = seg_res[i]
            cur_gt = seg_gt[i]

            cur_shape_iou_tot = 0.0
            cur_shape_iou_cnt = 0
            for j in range(1, setting.num_class):
                cur_gt_mask = (cur_gt == j)
                cur_pred_mask = (cur_pred == j)

                has_gt = (np.sum(cur_gt_mask) > 0)
                has_pred = (np.sum(cur_pred_mask) > 0)

                if has_gt or has_pred:
                    intersect = np.sum(cur_gt_mask & cur_pred_mask)
                    union = np.sum(cur_gt_mask | cur_pred_mask)
                    iou = intersect / union

                    cur_shape_iou_tot += iou
                    cur_shape_iou_cnt += 1

                    part_intersect[j] += intersect
                    part_union[j] += union

            if cur_shape_iou_cnt > 0:
                cur_shape_miou = cur_shape_iou_tot / cur_shape_iou_cnt
                shape_iou_tot += cur_shape_miou
                shape_iou_cnt += 1

        if args.save_ply and start_idx < args.save_num_shapes:
            for i in range(start_idx, min(end_idx, args.save_num_shapes)):
                out_fn = os.path.join(args.save_dir, 'shape-%02d-pred.ply' % i)
                data_utils.save_ply_property(data_test[i],
                                             seg_res[i - start_idx],
                                             setting.num_class, out_fn)
                out_fn = os.path.join(args.save_dir, 'shape-%02d-gt.ply' % i)
                data_utils.save_ply_property(data_test[i], label_gt[i],
                                             setting.num_class, out_fn)

    all_seg_probs = np.vstack(all_seg_probs)
    print('{}-Done!'.format(datetime.now()))

    print('Average Accuracy: %f' % (avg_acc / avg_cnt))
    print('Shape mean IoU: %f' % (shape_iou_tot / shape_iou_cnt))

    part_iou = np.divide(part_intersect[1:], part_union[1:])
    mean_part_iou = np.mean(part_iou)
    print('Category mean IoU: %f, %s' % (mean_part_iou, str(part_iou)))

    out_list = ['%3.1f' % (item * 100) for item in part_iou.tolist()]
    print('%3.1f;%3.1f;%3.1f;%s' %
          (avg_acc * 100 / avg_cnt, shape_iou_tot * 100 / shape_iou_cnt,
           mean_part_iou * 100, '[' + ', '.join(out_list) + ']'))
コード例 #4
0
ファイル: test_general_seg.py プロジェクト: GNETWVS/PointCNN
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to input .h5 filelist (.txt)',
                        required=True)
    parser.add_argument('--data_folder',
                        '-f',
                        help='Path to *.pts directory',
                        required=True)
    parser.add_argument('--category',
                        '-c',
                        help='Path to category list file (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load',
                        required=True)
    parser.add_argument('--repeat_num',
                        '-r',
                        help='Repeat number',
                        type=int,
                        default=1)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    parser.add_argument('--save_ply',
                        '-s',
                        help='Save results as ply',
                        action='store_true')
    args = parser.parse_args()
    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    sample_num = setting.sample_num
    num_class = setting.num_class

    # Prepare output folder
    output_folder = args.data_folder + 'pred_' + str(args.repeat_num)
    category_list = [
        (category, int(label_num))
        for (category,
             label_num) in [line.split() for line in open(args.category, 'r')]
    ]
    for category, _ in category_list:
        folder = os.path.join(output_folder, category)
        if not os.path.exists(folder):
            os.makedirs(folder)

    # prepare input pts path, output seg path, output ply path
    input_filelist = []
    output_filelist = []
    output_ply_filelist = []
    for category in sorted(os.listdir(args.data_folder)):
        data_category_folder = os.path.join(args.data_folder, category)
        for filename in sorted(os.listdir(data_category_folder)):
            input_filelist.append(
                os.path.join(args.data_folder, category, filename))
            output_filelist.append(
                os.path.join(output_folder, category, filename[0:-3] + 'seg'))
            output_ply_filelist.append(
                os.path.join(output_folder + '_ply', category,
                             filename[0:-3] + 'ply'))

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    data, _, data_num, _ = data_utils.load_seg(args.filelist)

    batch_num = data.shape[0]
    #point_num
    max_point_num = data.shape[1]
    batch_size = args.repeat_num * math.ceil(data.shape[1] / sample_num)

    print('{}-{:d} testing batches.'.format(datetime.now(), batch_num))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32,
                             shape=(batch_size, None, 2),
                             name="indices")
    is_training = tf.placeholder(tf.bool, name='is_training')
    pts_fts = tf.placeholder(tf.float32,
                             shape=(batch_size, max_point_num,
                                    setting.data_dim),
                             name='points')
    ######################################################################

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if not setting.use_extra_features:
            features_sampled = None
    else:
        points_sampled = pts_fts_sampled
        features_sampled = None

    net = model.Net(points_sampled, features_sampled, num_class, is_training,
                    setting)
    _, seg_probs_op = net.logits, net.probs

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    #for restore model
    saver = tf.train.Saver()

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        # Load the model
        saver.restore(sess, args.load_ckpt)
        print('{}-Checkpoint loaded from {}!'.format(datetime.now(),
                                                     args.load_ckpt))

        indices_batch_indices = np.tile(
            np.reshape(np.arange(batch_size), (batch_size, 1, 1)),
            (1, sample_num, 1))

        for batch_idx in range(batch_num):

            points_batch = data[[batch_idx] * batch_size, ...]
            point_num = data_num[batch_idx]

            coordinates = [[float(value) for value in xyz.split(' ')]
                           for xyz in open(input_filelist[batch_idx], 'r')
                           if len(xyz.split(' ')) == setting.data_dim]
            assert (point_num == len(coordinates))

            tile_num = math.ceil((sample_num * batch_size) / point_num)
            indices_shuffle = np.tile(np.arange(point_num),
                                      tile_num)[0:sample_num * batch_size]
            np.random.shuffle(indices_shuffle)
            indices_batch_shuffle = np.reshape(indices_shuffle,
                                               (batch_size, sample_num, 1))
            indices_batch = np.concatenate(
                (indices_batch_indices, indices_batch_shuffle), axis=2)

            _, seg_probs = \
                sess.run([update_ops, seg_probs_op],
                         feed_dict={
                             pts_fts: points_batch,
                             indices: indices_batch,
                             is_training: False,
                         })

            seg_probs_2d = np.reshape(seg_probs, (sample_num * batch_size, -1))

            predictions = [(-1, 0.0, [])] * point_num

            for idx in range(sample_num * batch_size):
                point_idx = indices_shuffle[idx]
                point_seg_probs = seg_probs_2d[idx, :]
                prob = np.amax(point_seg_probs)
                seg_idx = np.argmax(point_seg_probs)
                if prob > predictions[point_idx][1]:

                    predictions[point_idx] = [seg_idx, prob, point_seg_probs]

            labels = []
            with open(output_filelist[batch_idx], 'w') as file_seg:
                for seg_idx, prob, probs in predictions:

                    file_seg.write(str(int(seg_idx)))

                    file_seg.write("\n")

                    labels.append(seg_idx)

            if args.save_ply:
                data_utils.save_ply_property(np.array(coordinates),
                                             np.array(labels), 6,
                                             output_ply_filelist[batch_idx])

            print('{}-[Testing]-Iter: {:06d} saved to {}'.format(
                datetime.now(), batch_idx, output_filelist[batch_idx]))
            sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
コード例 #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-f',
                        help='Path to input .h5 filelist (.txt)',
                        required=True)
    parser.add_argument('--data_folder',
                        '-d',
                        help='Path to *.pts directory',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load',
                        required=True)
    parser.add_argument('--repeat_num',
                        '-r',
                        help='Repeat number',
                        type=int,
                        default=1)
    parser.add_argument('--sample_num',
                        help='Point sample num',
                        type=int,
                        default=1024)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)

    args = parser.parse_args()
    print(args)

    model = importlib.import_module(args.model)
    sys.path.append(os.path.dirname(args.setting))
    setting = importlib.import_module(os.path.basename(args.setting))

    sample_num = setting.sample_num
    num_parts = setting.num_parts

    output_folder = os.path.abspath(os.path.join(
        args.data_folder, "..")) + '/pred_' + str(args.repeat_num)

    output_folder_seg = output_folder + '/seg/'

    # check the path
    if not os.path.exists(output_folder_seg):
        print(output_folder_seg, "Not Exists! Create", output_folder_seg)
        os.makedirs(output_folder_seg)

    input_filelist = []
    output_seg_filelist = []

    for filename in sorted(os.listdir(args.data_folder)):
        input_filelist.append(os.path.join(args.data_folder, filename))
        output_seg_filelist.append(
            os.path.join(output_folder_seg, filename[0:-3] + 'seg'))

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))

    data, data_num, _label_seg = data_utils.load_seg(args.filelist)

    batch_num = data.shape[0]
    max_point_num = data.shape[1]
    batch_size = args.repeat_num * math.ceil(max_point_num / sample_num)

    print('{}-{:d} testing batches.'.format(datetime.now(), batch_num))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32,
                             shape=(batch_size, None, 2),
                             name="indices")
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(batch_size, max_point_num,
                                    setting.data_dim),
                             name='pts_fts')

    #######################################################################

    features_sampled = None

    if setting.data_dim > 3:

        points, _, features = tf.split(pts_fts, [
            setting.data_format["pts_xyz"], setting.data_format["img_xy"],
            setting.data_format["extra_features"]
        ],
                                       axis=-1,
                                       name='split_points_xy_features')

        if setting.use_extra_features:

            features_sampled = tf.gather_nd(features,
                                            indices=indices,
                                            name='features_sampled')

    else:
        points = pts_fts

    points_sampled = tf.gather_nd(points,
                                  indices=indices,
                                  name='points_sampled')

    net = model.Net(points_sampled, features_sampled, None, None, num_parts,
                    is_training, setting)

    probs_op = net.probs

    saver = tf.train.Saver()

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        # Load the model
        saver.restore(sess, args.load_ckpt)
        print('{}-Checkpoint loaded from {}!'.format(datetime.now(),
                                                     args.load_ckpt))

        indices_batch_indices = np.tile(
            np.reshape(np.arange(batch_size), (batch_size, 1, 1)),
            (1, sample_num, 1))

        for batch_idx in range(batch_num):

            points_batch = data[[batch_idx] * batch_size, ...]
            point_num = data_num[batch_idx]

            tile_num = math.ceil((sample_num * batch_size) / point_num)
            indices_shuffle = np.tile(np.arange(point_num),
                                      tile_num)[0:sample_num * batch_size]
            np.random.shuffle(indices_shuffle)
            indices_batch_shuffle = np.reshape(indices_shuffle,
                                               (batch_size, sample_num, 1))
            indices_batch = np.concatenate(
                (indices_batch_indices, indices_batch_shuffle), axis=2)

            sess_op_list = [probs_op]

            sess_feed_dict = {
                pts_fts: points_batch,
                indices: indices_batch,
                is_training: False
            }

            # sess run
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            probs = sess.run(sess_op_list, feed_dict=sess_feed_dict)

            # output seg probs
            probs_2d = np.reshape(probs, (sample_num * batch_size, -1))
            predictions = [(-1, 0.0, [])] * point_num

            for idx in range(sample_num * batch_size):
                point_idx = indices_shuffle[idx]
                point_probs = probs_2d[idx, :]
                prob = np.amax(point_probs)
                seg_idx = np.argmax(point_probs)
                if prob > predictions[point_idx][1]:
                    predictions[point_idx] = [seg_idx, prob, point_probs]

            with open(output_seg_filelist[batch_idx], 'w') as file_seg:
                for seg_idx, prob, probs in predictions:
                    file_seg.write(str(int(seg_idx)) + "\n")

            print('{}-[Testing]-Iter: {:06d} \nseg  saved to {}'.format(
                datetime.now(), batch_idx, output_seg_filelist[batch_idx]))

            sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
コード例 #6
0
ファイル: test_shapenet_seg.py プロジェクト: jtpils/BDCI2018
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-f',
                        help='Path to input .h5 filelist (.txt)',
                        required=True)
    parser.add_argument('--category',
                        '-c',
                        help='Path to category list file (.txt)',
                        required=True)
    parser.add_argument('--data_folder',
                        '-d',
                        help='Path to *.pts directory',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load',
                        required=True)
    parser.add_argument('--repeat_num',
                        '-r',
                        help='Repeat number',
                        type=int,
                        default=1)
    parser.add_argument('--sample_num',
                        help='Point sample num',
                        type=int,
                        default=2048)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    parser.add_argument('--save_ply',
                        '-s',
                        help='Save results as ply',
                        action='store_true')
    args = parser.parse_args()
    print(args)

    model = importlib.import_module(args.model)
    sys.path.append(os.path.dirname(args.setting))
    print(os.path.dirname(args.setting))
    setting = importlib.import_module(os.path.basename(args.setting))

    sample_num = setting.sample_num

    output_folder = args.data_folder + '_pred_nips_' + str(args.repeat_num)
    category_list = [
        (category, int(label_num))
        for (category,
             label_num) in [line.split() for line in open(args.category, 'r')]
    ]
    offset = 0
    category_range = dict()
    for category, category_label_seg_max in category_list:
        category_range[category] = (offset, offset + category_label_seg_max)
        offset = offset + category_label_seg_max
        folder = os.path.join(output_folder, category)
        if not os.path.exists(folder):
            os.makedirs(folder)

    input_filelist = []
    output_filelist = []
    output_ply_filelist = []
    for category in sorted(os.listdir(args.data_folder)):
        data_category_folder = os.path.join(args.data_folder, category)
        for filename in sorted(os.listdir(data_category_folder)):
            input_filelist.append(
                os.path.join(args.data_folder, category, filename))
            output_filelist.append(
                os.path.join(output_folder, category, filename[0:-3] + 'seg'))
            output_ply_filelist.append(
                os.path.join(output_folder + '_ply', category,
                             filename[0:-3] + 'ply'))

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    data, label, data_num, _, _ = data_utils.load_seg(args.filelist)

    batch_num = data.shape[0]
    max_point_num = data.shape[1]
    batch_size = args.repeat_num * math.ceil(data.shape[1] / sample_num)

    print('{}-{:d} testing batches.'.format(datetime.now(), batch_num))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32,
                             shape=(batch_size, None, 2),
                             name="indices")
    is_training = tf.placeholder(tf.bool, name='is_training')
    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, max_point_num, setting.data_dim),
                             name='pts_fts')
    ######################################################################

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if not setting.use_extra_features:
            features_sampled = None
    else:
        points_sampled = pts_fts_sampled
        features_sampled = None

    net = model.Net(points_sampled, features_sampled, is_training, setting)
    logits = net.logits
    probs_op = tf.nn.softmax(logits, name='probs')

    saver = tf.train.Saver()

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        # Load the model
        saver.restore(sess, args.load_ckpt)
        print('{}-Checkpoint loaded from {}!'.format(datetime.now(),
                                                     args.load_ckpt))

        indices_batch_indices = np.tile(
            np.reshape(np.arange(batch_size), (batch_size, 1, 1)),
            (1, sample_num, 1))
        for batch_idx in range(batch_num):
            points_batch = data[[batch_idx] * batch_size, ...]
            object_label = label[batch_idx]
            point_num = data_num[batch_idx]
            category = category_list[object_label][0]
            label_start, label_end = category_range[category]

            tile_num = math.ceil((sample_num * batch_size) / point_num)
            indices_shuffle = np.tile(np.arange(point_num),
                                      tile_num)[0:sample_num * batch_size]
            np.random.shuffle(indices_shuffle)
            indices_batch_shuffle = np.reshape(indices_shuffle,
                                               (batch_size, sample_num, 1))
            indices_batch = np.concatenate(
                (indices_batch_indices, indices_batch_shuffle), axis=2)

            probs = sess.run(
                [probs_op],
                feed_dict={
                    pts_fts: points_batch,
                    indices: indices_batch,
                    is_training: False,
                })
            probs_2d = np.reshape(probs, (sample_num * batch_size, -1))
            predictions = [(-1, 0.0)] * point_num
            for idx in range(sample_num * batch_size):
                point_idx = indices_shuffle[idx]
                probs = probs_2d[idx, label_start:label_end]
                confidence = np.amax(probs)
                seg_idx = np.argmax(probs)
                if confidence > predictions[point_idx][1]:
                    predictions[point_idx] = (seg_idx, confidence)

            labels = []
            with open(output_filelist[batch_idx], 'w') as file_seg:
                for seg_idx, _ in predictions:
                    file_seg.write('%d\n' % (seg_idx))
                    labels.append(seg_idx)

            # read the coordinates from the txt file for verification
            coordinates = [[float(value) for value in xyz.split(' ')]
                           for xyz in open(input_filelist[batch_idx], 'r')
                           if len(xyz.split(' ')) == 3]
            assert (point_num == len(coordinates))
            if args.save_ply:
                data_utils.save_ply_property(np.array(coordinates),
                                             np.array(labels), 6,
                                             output_ply_filelist[batch_idx])

            print('{}-[Testing]-Iter: {:06d} saved to {}'.format(
                datetime.now(), batch_idx, output_filelist[batch_idx]))
            sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
コード例 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--category',
                        '-c',
                        help='category name',
                        required=True)
    parser.add_argument('--level',
                        '-l',
                        type=int,
                        help='level id',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-k',
                        help='Path to a check point file for load')
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    args = parser.parse_args()

    args.save_folder = 'exps_results/'
    args.data_folder = '../../data/sem_seg_h5/'

    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, args.category, args.level))
    if os.path.exists(root_folder):
        print('ERROR: folder %s exist! Please check and delete!' % root_folder)
        exit(1)
    os.makedirs(root_folder)

    flog = open(os.path.join(root_folder, 'log.txt'), 'w')

    def printout(d):
        flog.write(str(d) + '\n')
        print(d)

    printout('PID: %s' % os.getpid())

    printout(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    args.filelist = os.path.join(args.data_folder,
                                 '%s-%d' % (args.category, args.level),
                                 'train_files.txt')
    args.filelist_val = os.path.join(args.data_folder,
                                     '%s-%d' % (args.category, args.level),
                                     'val_files.txt')

    # Load current category + level statistics
    with open(
            '../../stats/after_merging_label_ids/%s-level-%d.txt' %
        (args.category, args.level), 'r') as fin:
        setting.num_class = len(fin.readlines()) + 1  # with "other"
        printout('NUM CLASS: %d' % setting.num_class)

    label_weights_list = [1.0] * setting.num_class

    # Prepare inputs
    printout('{}-Preparing datasets...'.format(datetime.now()))
    data_train, data_num_train, label_train = data_utils.load_seg(
        args.filelist)
    data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    printout('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    train_batch = num_train // batch_size
    printout('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    printout('{}-{:d} testing batches per test.'.format(
        datetime.now(), batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    printout('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    printout('{}-Parameter number: {:d}.'.format(datetime.now(),
                                                 parameter_num))

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    sess = tf.Session(config=config)

    summaries_op = tf.summary.merge_all('train')
    summaries_val_op = tf.summary.merge_all('val')
    summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

    sess.run(init_op)

    # Load the model
    if args.load_ckpt is not None:
        saver.restore(sess, args.load_ckpt)
        printout('{}-Checkpoint loaded from {}!'.format(
            datetime.now(), args.load_ckpt))

    for batch_idx_train in range(batch_num):
        if (batch_idx_train % (10 * train_batch) == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                or batch_idx_train == batch_num - 1:
            ######################################################################
            # Validation
            filename_ckpt = os.path.join(folder_ckpt, 'iter')
            saver.save(sess, filename_ckpt, global_step=global_step)
            printout('{}-Checkpoint saved to {}!'.format(
                datetime.now(), filename_ckpt))

            sess.run(reset_metrics_op)
            for batch_val_idx in range(batch_num_val):
                start_idx = batch_size * batch_val_idx
                end_idx = min(start_idx + batch_size, num_val)
                batch_size_val = end_idx - start_idx
                points_batch = data_val[start_idx:end_idx, ...]
                points_num_batch = data_num_val[start_idx:end_idx, ...]
                labels_batch = label_val[start_idx:end_idx, ...]
                weights_batch = np.array(label_weights_list)[labels_batch]

                xforms_np, rotations_np = pf.get_xforms(
                    batch_size_val,
                    rotation_range=rotation_range_val,
                    scaling_range=scaling_range_val,
                    order=setting.rotation_order)
                sess.run(
                    [
                        loss_mean_update_op, t_1_acc_update_op,
                        t_1_per_class_acc_update_op
                    ],
                    feed_dict={
                        pts_fts:
                        points_batch,
                        indices:
                        pf.get_indices(batch_size_val, sample_num,
                                       points_num_batch),
                        xforms:
                        xforms_np,
                        rotations:
                        rotations_np,
                        jitter_range:
                        np.array([jitter_val]),
                        labels_seg:
                        labels_batch,
                        labels_weights:
                        weights_batch,
                        is_training:
                        False,
                    })

            loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val = sess.run(
                [
                    loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                    summaries_val_op
                ])
            summary_writer.add_summary(summaries_val, batch_idx_train)
            printout(
                '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                .format(datetime.now(), loss_val, t_1_acc_val,
                        t_1_per_class_acc_val))
            flog.flush()
            ######################################################################

        ######################################################################
        # Training
        start_idx = (batch_size * batch_idx_train) % num_train
        end_idx = min(start_idx + batch_size, num_train)
        batch_size_train = end_idx - start_idx
        points_batch = data_train[start_idx:end_idx, ...]
        points_num_batch = data_num_train[start_idx:end_idx, ...]
        labels_batch = label_train[start_idx:end_idx, ...]
        weights_batch = np.array(label_weights_list)[labels_batch]

        if start_idx + batch_size_train == num_train:
            data_train, data_num_train, label_train = \
                data_utils.grouped_shuffle([data_train, data_num_train, label_train])

        offset = int(random.gauss(0, sample_num * setting.sample_num_variance))
        offset = max(offset, -sample_num * setting.sample_num_clip)
        offset = min(offset, sample_num * setting.sample_num_clip)
        sample_num_train = sample_num + offset
        xforms_np, rotations_np = pf.get_xforms(batch_size_train,
                                                rotation_range=rotation_range,
                                                scaling_range=scaling_range,
                                                order=setting.rotation_order)
        sess.run(reset_metrics_op)
        sess.run(
            [
                train_op, loss_mean_update_op, t_1_acc_update_op,
                t_1_per_class_acc_update_op
            ],
            feed_dict={
                pts_fts:
                points_batch,
                indices:
                pf.get_indices(batch_size_train, sample_num_train,
                               points_num_batch),
                xforms:
                xforms_np,
                rotations:
                rotations_np,
                jitter_range:
                np.array([jitter]),
                labels_seg:
                labels_batch,
                labels_weights:
                weights_batch,
                is_training:
                True,
            })
        if batch_idx_train % 10 == 0:
            loss, t_1_acc, t_1_per_class_acc, summaries = sess.run(
                [loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_op])
            summary_writer.add_summary(summaries, batch_idx_train)
            printout(
                '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                .format(datetime.now(), batch_idx_train, loss, t_1_acc,
                        t_1_per_class_acc))
            flog.flush()
        ######################################################################
    printout('{}-Done!'.format(datetime.now()))
    flog.close()
コード例 #8
0
ファイル: train_val_seg.py プロジェクト: zzzzdz/shellnet
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load')
    parser.add_argument('--save_folder', '-s', default='log/seg', help='Path to folder for saving check points and summary')
    parser.add_argument('--model', '-m', default='shellconv', help='Model to use')
    parser.add_argument('--setting', '-x', default='seg_s3dis', help='Setting to use')
    parser.add_argument('--log', help='Log to FILE in save folder; use - for stdout (default is log.txt)', metavar='FILE', default='log.txt')
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(args.save_folder, '%s_%s_%s' % (args.model, args.setting, time_string))
    
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    global LOG_FOUT
    if args.log != '-':
        LOG_FOUT = open(os.path.join(root_folder, args.log), 'w')

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val    

    is_list_of_h5_list = data_utils.is_h5_list(setting.filelist)
    if is_list_of_h5_list:
        seg_list = [setting.filelist] # for train
    else:
        seg_list = data_utils.load_seg_list(setting.filelist)  # for train
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(setting.filelist_val)
    if data_val.shape[-1] > 3:    
        data_val = data_val[:,:,:3]  # only use the xyz coordinates
    point_num = data_val.shape[1]
    num_val = data_val.shape[0]
    batch_num_val = num_val // batch_size

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, sample_num, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts')
    labels_seg = tf.placeholder(tf.int64, shape=(None, point_num), name='labels_seg')
    labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights')

    ######################################################################
    points_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled')

    bn_decay_exp_op = tf.train.exponential_decay(0.5, global_step, setting.decay_steps,
                                           0.5, staircase=True)
    bn_decay_op = tf.minimum(0.99, 1 - bn_decay_exp_op)

    logits_op = model.get_model(points_augmented, is_training, setting.sconv_params, setting.sdconv_params, setting.fc_params, 
                            sampling=setting.sampling,
                            weight_decay=setting.weight_decay, 
                            bn_decay = bn_decay_op, 
                            part_num=setting.num_class)

    predictions = tf.argmax(logits_op, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_sampled, logits=logits_op,
                                                     weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([var for var in tf.local_variables()
                                                 if var.name.split('/')[0] == 'metrics'])


    _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,
                                           setting.decay_rate, staircase=True)
                                           
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

    saver = tf.train.Saver(max_to_keep=None)

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    with tf.Session(config=config) as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(tf.global_variables_initializer())

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))
        else:
            latest_ckpt = tf.train.latest_checkpoint(folder_ckpt)
            if latest_ckpt:
                print('{}-Found checkpoint {}'.format(datetime.now(), latest_ckpt))
                saver.restore(sess, latest_ckpt)
                print('{}-Checkpoint loaded from {} (Iter {})'.format(
                    datetime.now(), latest_ckpt, sess.run(global_step)))

        best_acc = 0
        best_epoch = 0
        for epoch in range(num_epochs):
            ############################### train #######################################
            # Shuffle train files
            np.random.shuffle(seg_list)
            for file_idx_train in range(len(seg_list)):
                print('----epoch:'+str(epoch) + '--train file:' + str(file_idx_train) + '-----')
                filelist_train = seg_list[file_idx_train]
                data_train, _, data_num_train, label_train, _ = data_utils.load_seg(filelist_train)
                num_train = data_train.shape[0]
                if data_train.shape[-1] > 3:    
                    data_train = data_train[:,:,:3]
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])
                # data_train, label_train, _ = provider.shuffle_data_seg(data_train, label_train) 

                batch_num = (num_train + batch_size - 1) // batch_size

                for batch_idx_train in range(batch_num):
                    # Training
                    start_idx = (batch_size * batch_idx_train) % num_train
                    end_idx = min(start_idx + batch_size, num_train)
                    batch_size_train = end_idx - start_idx
                    points_batch = data_train[start_idx:end_idx, ...]
                    points_num_batch = data_num_train[start_idx:end_idx, ...]
                    labels_batch = label_train[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[labels_batch]

                    offset = int(random.gauss(0, sample_num * setting.sample_num_variance))
                    offset = max(offset, -sample_num * setting.sample_num_clip)
                    offset = min(offset, sample_num * setting.sample_num_clip)
                    sample_num_train = sample_num + offset
                    xforms_np, rotations_np = pf.get_xforms(batch_size_train,
                                                            rotation_range=rotation_range,
                                                            scaling_range=scaling_range,
                                                            order=setting.rotation_order)
                    sess.run(reset_metrics_op)
                    sess.run([train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op],
                            feed_dict={
                                pts_fts: points_batch,
                                indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch),
                                xforms: xforms_np,
                                rotations: rotations_np,
                                jitter_range: np.array([jitter]),
                                labels_seg: labels_batch,
                                labels_weights: weights_batch,
                                is_training: True,
                            })
                
                loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([loss_mean_op,
                                                                        t_1_acc_op,
                                                                        t_1_per_class_acc_op,
                                                                        summaries_op,
                                                                        global_step])
                summary_writer.add_summary(summaries, step)
                log_string('{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), step, loss, t_1_acc, t_1_per_class_acc))
                sys.stdout.flush()
                ######################################################################
        
            filename_ckpt = os.path.join(folder_ckpt, 'epoch')
            saver.save(sess, filename_ckpt, global_step=epoch)
            print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))

            sess.run(reset_metrics_op)
            for batch_val_idx in range(batch_num_val):
                start_idx = batch_size * batch_val_idx
                end_idx = min(start_idx + batch_size, num_val)
                batch_size_val = end_idx - start_idx
                points_batch = data_val[start_idx:end_idx, ...]
                points_num_batch = data_num_val[start_idx:end_idx, ...]
                labels_batch = label_val[start_idx:end_idx, ...]
                weights_batch = np.array(label_weights_list)[labels_batch]

                xforms_np, rotations_np = pf.get_xforms(batch_size_val,
                                                            rotation_range=rotation_range_val,
                                                            scaling_range=scaling_range_val,
                                                            order=setting.rotation_order)
                sess.run([loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op],
                        feed_dict={
                            pts_fts: points_batch,
                            indices: pf.get_indices(batch_size_val, sample_num, points_num_batch),
                            xforms: xforms_np,
                            rotations: rotations_np,
                            jitter_range: np.array([jitter_val]),
                            labels_seg: labels_batch,
                            labels_weights: weights_batch,
                            is_training: False,
                        })
            loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run(
                [loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_val_op, global_step])
            summary_writer.add_summary(summaries_val, step)

            if t_1_per_class_acc_val > best_acc:
                best_acc = t_1_per_class_acc_val
                best_epoch = epoch

            log_string('{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f} best epoch: {} Current epoch: {}'
                  .format(datetime.now(), loss_val, t_1_acc_val, t_1_per_class_acc_val, best_epoch, epoch))
            sys.stdout.flush()
            ######################################################################
            
        print('{}-Done!'.format(datetime.now()))
コード例 #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)

    parser.add_argument('--startpoint',
                        '-b',
                        help='Setting to use',
                        required=True)

    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    #sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    label_weights_val = [1.0] * 1 + [1.0] * (setting.num_class - 1)

    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    #data_train, _, data_num_train, label_train, _ = data_utils.load_seg(filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    data_train = data_val
    data_num_train = data_num_val
    label_train = label_val

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = int(math.ceil(num_val / batch_size))
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
        t_1_per_mean_iou_op, t_1_per_mean_iou_op_update_op = \
            tf.metrics.mean_iou(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)

    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_mean_iou/val',
                          tensor=t_1_per_mean_iou_op,
                          collections=['val'])

    #_ = tf.summary.histogram('summary/Add_F2', Add_F2, collections=['summary_values'])
    #_ = tf.summary.histogram('summary/Add_F3', Add_F3, collections=['summary_values'])
    #_ = tf.summary.histogram('summary/Z', Z, collections=['summary_values'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=5)

    saver_best = tf.train.Saver(max_to_keep=5)
    # backup all code
    code_folder = os.path.abspath(os.path.dirname(__file__))
    shutil.copytree(code_folder,
                    os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_ckpt_best = os.path.join(root_folder, 'ckpt-best')
    if not os.path.exists(folder_ckpt_best):
        os.makedirs(folder_ckpt_best)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    folder_cm_matrix = os.path.join(root_folder, 'cm-matrix')
    if not os.path.exists(folder_cm_matrix):
        os.makedirs(folder_cm_matrix)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    _highest_val = 0.0
    max_val = 10
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_values_op = tf.summary.merge_all('summary_values')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        #img_d_summary_writer = tf.summary.FileWriter(folder_cm_matrix, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        batch_num = 1
        batch_idx_train = 1

        ######################################################################
        # Validation
        filename_ckpt = os.path.join(folder_ckpt, 'iter')
        saver.save(sess, filename_ckpt, global_step=global_step)
        print('{}-Checkpoint saved to {}!'.format(datetime.now(),
                                                  filename_ckpt))

        sess.run(reset_metrics_op)
        summary_hist = None
        _idxVal = np.arange(num_val)
        np.random.shuffle(_idxVal)

        _dataX = []
        _dataY = []
        _dataZ = []
        _dataD = []
        _pred = []
        _label = []
        for batch_val_idx in tqdm(range(batch_num_val)):
            start_idx = batch_size * batch_val_idx
            end_idx = min(start_idx + batch_size, num_val)
            batch_size_val = end_idx - start_idx

            points_batch = data_val[_idxVal[start_idx:end_idx], ...]
            points_num_batch = data_num_val[_idxVal[start_idx:end_idx], ...]
            labels_batch = label_val[_idxVal[start_idx:end_idx], ...]

            weights_batch = np.array(label_weights_val)[labels_batch]

            xforms_np, rotations_np = pf.get_xforms(
                batch_size_val,
                rotation_range=rotation_range_val,
                scaling_range=scaling_range_val,
                order=setting.rotation_order)

            _labels_sampled, _predictions, _, _, _, _ = sess.run(
                [
                    labels_sampled, predictions, loss_mean_update_op,
                    t_1_acc_update_op, t_1_per_class_acc_update_op,
                    t_1_per_mean_iou_op_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    indices:
                    pf.get_indices(batch_size_val, sample_num,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter_val]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    False,
                })
            _dataX.append(points_batch[:, :, 0].flatten())
            _dataY.append(points_batch[:, :, 1].flatten())
            _dataZ.append(points_batch[:, :, 2].flatten())
            _pred.append(_predictions.flatten())
            _label.append(_labels_sampled.flatten())

        loss_val, t_1_acc_val, t_1_per_class_acc_val, t1__mean_iou, summaries_val = sess.run(
            [
                loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                t_1_per_mean_iou_op, summaries_val_op
            ])

        img_d_summary = pf.plot_confusion_matrix(
            _label,
            _pred, ["Environment", "Pedestrian", "Car", "Cyclist"],
            tensor_name='confusion_matrix',
            normalize=False)

        _dataX = np.concatenate(_dataX, axis=0).flatten()
        _dataY = np.concatenate(_dataY, axis=0).flatten()
        _dataZ = np.concatenate(_dataZ, axis=0).flatten()
        correct_labels = np.concatenate(_label, axis=0).flatten()
        predict_labels = np.concatenate(_pred, axis=0).flatten()

        filename_pred = args.filelist_val + '_cm.h5'
        print('{}-Saving {}...'.format(datetime.now(), filename_pred))
        file = h5py.File(filename_pred, 'w')
        file.create_dataset('dataX', data=_dataX)
        file.create_dataset('dataY', data=_dataY)
        file.create_dataset('dataZ', data=_dataZ)
        file.create_dataset('correct_labels', data=correct_labels)
        file.create_dataset('predict_labels', data=predict_labels)
        file.close()

        summary_writer.add_summary(img_d_summary, batch_idx_train)
        summary_writer.add_summary(summaries_val, batch_idx_train)

        y_test = correct_labels
        prediction = predict_labels

        print('Accuracy:', accuracy_score(y_test, prediction))
        print('F1 score:', f1_score(y_test, prediction, average=None))
        print('Recall:', recall_score(y_test, prediction, average=None))
        print('Precision:', precision_score(y_test, prediction, average=None))
        print('\n clasification report:\n',
              classification_report(y_test, prediction))
        print('\n confussion matrix:\n', confusion_matrix(y_test, prediction))

        print(
            '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  Diff-Best: {:.4f} T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
            .format(datetime.now(), loss_val, t_1_acc_val,
                    _highest_val - t_1_per_class_acc_val,
                    t_1_per_class_acc_val, t1__mean_iou))
        sys.stdout.flush()
コード例 #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_unseen',
                        '-u',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    label_weights_val = [1.0] * 1 + [1.0] * (setting.num_class - 1)

    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    data_unseen, _, data_num_unseen, label_unseen, _ = data_utils.load_seg(
        args.filelist_unseen)

    # shuffle
    data_unseen, data_num_unseen, label_unseen = \
        data_utils.grouped_shuffle([data_unseen, data_num_unseen, label_unseen])

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    num_unseen = data_unseen.shape[0]

    batch_num_unseen = int(math.ceil(num_unseen / batch_size))

    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = int(math.ceil(num_val / batch_size))
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    is_labelled_data = tf.placeholder(tf.bool, name='is_labelled_data')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    net = model.Net(points_augmented, features_augmented, is_training, setting)

    logits = net.logits
    probs = tf.nn.softmax(logits, name='prob')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    with tf.variable_scope('xcrf_ker_weights'):

        _, point_indices = pf.knn_indices_general(points_augmented,
                                                  points_augmented, 1024, True)

        xcrf = learningBlock(num_points=sample_num,
                             num_classes=setting.num_class,
                             theta_alpha=float(5),
                             theta_beta=float(2),
                             theta_gamma=float(1),
                             num_iterations=5,
                             name='xcrf',
                             point_indices=point_indices)

        _logits1 = xcrf.call(net.logits, points_augmented, features_augmented,
                             setting.data_dim)
        _logits2 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=2)
        _logits3 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=3)
        _logits4 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=4)
        _logits5 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=8)
        _logits6 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=16)

        _logits = _logits1 + _logits2 + _logits3 + _logits4 + _logits5 + _logits6
        _probs = tf.nn.softmax(_logits, name='probs_crf')
        _predictions = tf.argmax(_probs, axis=-1, name='predictions_crf')

    logits = tf.cond(
        is_training,
        lambda: tf.cond(is_labelled_data, lambda: _logits, lambda: logits),
        lambda: _logits)

    predictions = tf.cond(
        is_training, lambda: tf.cond(is_labelled_data, lambda: _predictions,
                                     lambda: predictions),
        lambda: _predictions)

    labels_sampled = tf.cond(
        is_training, lambda: tf.cond(is_labelled_data, lambda: labels_sampled,
                                     lambda: _predictions),
        lambda: labels_sampled)

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
        t_1_per_mean_iou_op, t_1_per_mean_iou_op_update_op = \
            tf.metrics.mean_iou(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)

    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_mean_iou/val',
                          tensor=t_1_per_mean_iou_op,
                          collections=['val'])

    all_variable = tf.global_variables()

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    xcrf_ker_weights = [
        var for var in tf.global_variables() if 'xcrf_ker_weights' in var.name
    ]
    no_xcrf_ker_weights = [
        var for var in tf.global_variables()
        if 'xcrf_ker_weights' not in var.name
    ]

    #print(restore_values)
    #train_op = optimizer.minimize( loss_op+reg_loss, global_step=global_step)

    with tf.control_dependencies(update_ops):
        train_op_xcrf = optimizer.minimize(loss_op + reg_loss,
                                           var_list=no_xcrf_ker_weights,
                                           global_step=global_step)
        train_op_all = optimizer.minimize(loss_op + reg_loss,
                                          var_list=all_variable,
                                          global_step=global_step)

    train_op = tf.cond(is_labelled_data, lambda: train_op_all,
                       lambda: train_op_xcrf)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(var_list=no_xcrf_ker_weights, max_to_keep=5)

    saver_best = tf.train.Saver(max_to_keep=5)
    # backup all code
    code_folder = os.path.abspath(os.path.dirname(__file__))
    shutil.copytree(code_folder,
                    os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_ckpt_best = os.path.join(root_folder, 'ckpt-best')
    if not os.path.exists(folder_ckpt_best):
        os.makedirs(folder_ckpt_best)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(),
                                              int(parameter_num)))

    _highest_val = 0.0
    max_val = 10
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_values_op = tf.summary.merge_all('summary_values')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        batch_num = 50000
        for batch_idx_train in range(batch_num):
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                ######################################################################

                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                summary_hist = None
                _idxVal = np.arange(num_val)
                np.random.shuffle(_idxVal)

                _pred = []
                _label = []
                for batch_val_idx in tqdm(range(batch_num_val)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx

                    points_batch = data_val[_idxVal[start_idx:end_idx], ...]
                    points_num_batch = data_num_val[_idxVal[start_idx:end_idx],
                                                    ...]
                    labels_batch = label_val[_idxVal[start_idx:end_idx], ...]

                    weights_batch = np.array(label_weights_val)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)

                    _labels_sampled, _predictions, _, _, _, _ = sess.run(
                        [
                            labels_sampled, predictions, loss_mean_update_op,
                            t_1_acc_update_op, t_1_per_class_acc_update_op,
                            t_1_per_mean_iou_op_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            indices:
                            pf.get_indices(batch_size_val, sample_num,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            labels_seg:
                            labels_batch,
                            is_labelled_data:
                            True,
                            labels_weights:
                            weights_batch,
                            is_training:
                            False,
                        })
                    _pred.append(_predictions.flatten())
                    _label.append(_labels_sampled.flatten())

                loss_val, t_1_acc_val, t_1_per_class_acc_val, t1__mean_iou, summaries_val = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        t_1_per_mean_iou_op, summaries_val_op
                    ])

                img_d_summary = pf.plot_confusion_matrix(
                    _label,
                    _pred, ["Environment", "Pedestrian", "Car", "Cyclist"],
                    tensor_name='confusion_matrix')

                max_val = max_val - 1

                if (t_1_per_class_acc_val > _highest_val):

                    max_val = 10

                    _highest_val = t_1_per_class_acc_val

                    filename_ckpt = os.path.join(folder_ckpt_best,
                                                 str(_highest_val) + "-iter-")
                    saver_best.save(sess,
                                    filename_ckpt,
                                    global_step=global_step)

                if (max_val < 0):
                    sys.exit(0)

                summary_writer.add_summary(summaries_val, batch_idx_train)

                summary_writer.add_summary(img_d_summary, batch_idx_train)

                #summary_writer.add_summary(summary_hist, batch_idx_train)
                print(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  Diff-Best: {:.4f} T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            _highest_val - t_1_per_class_acc_val,
                            t_1_per_class_acc_val, t1__mean_iou))
                sys.stdout.flush()
                ######################################################################

                # Unseen Data
                sess.run(reset_metrics_op)
                summary_hist = None
                _idxunseen = np.arange(num_unseen)
                np.random.shuffle(_idxunseen)

                _pred = []
                _label = []

                unseenIndices = np.arange(batch_num_unseen)
                np.random.shuffle(unseenIndices)
                for batch_val_idx in tqdm(unseenIndices[:500]):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_unseen)
                    batch_size_train = end_idx - start_idx

                    points_batch = data_unseen[_idxunseen[start_idx:end_idx],
                                               ...]
                    points_num_batch = data_num_unseen[
                        _idxunseen[start_idx:end_idx], ...]
                    labels_batch = np.zeros(points_batch.shape[0:2],
                                            dtype=np.int32)

                    weights_batch = np.array(label_weights_list)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_train,
                        rotation_range=rotation_range,
                        scaling_range=scaling_range,
                        order=setting.rotation_order)

                    offset = int(
                        random.gauss(0,
                                     sample_num * setting.sample_num_variance))
                    offset = max(offset, -sample_num * setting.sample_num_clip)
                    offset = min(offset, sample_num * setting.sample_num_clip)
                    sample_num_train = sample_num + offset

                    sess.run(
                        [
                            train_op, loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op,
                            t_1_per_mean_iou_op_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            is_labelled_data:
                            False,
                            indices:
                            pf.get_indices(batch_size_train, sample_num_train,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter]),
                            labels_seg:
                            labels_batch,
                            labels_weights:
                            weights_batch,
                            is_training:
                            True,
                        })
                    if batch_val_idx % 100 == 0:
                        loss, t_1_acc, t_1_per_class_acc, t_1__mean_iou = sess.run(
                            [
                                loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                                t_1_per_mean_iou_op
                            ],
                            feed_dict={
                                pts_fts:
                                points_batch,
                                indices:
                                pf.get_indices(batch_size_train,
                                               sample_num_train,
                                               points_num_batch),
                                xforms:
                                xforms_np,
                                is_labelled_data:
                                False,
                                rotations:
                                rotations_np,
                                jitter_range:
                                np.array([jitter]),
                                labels_seg:
                                labels_batch,
                                labels_weights:
                                weights_batch,
                                is_training:
                                True,
                            })
                        print(
                            '{}-[Train]-Unseen: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
                            .format(datetime.now(), batch_val_idx, loss,
                                    t_1_acc, t_1_per_class_acc, t_1__mean_iou))
                        sys.stdout.flush()

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                if is_list_of_h5_list:
                    filelist_train_prev = seg_list[(seg_list_idx - 1) %
                                                   len(seg_list)]
                    filelist_train = seg_list[seg_list_idx % len(seg_list)]
                    if filelist_train != filelist_train_prev:
                        data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
                            filelist_train)
                        num_train = data_train.shape[0]
                    seg_list_idx = seg_list_idx + 1
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    train_op, loss_mean_update_op, t_1_acc_update_op,
                    t_1_per_class_acc_update_op, t_1_per_mean_iou_op_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    is_labelled_data:
                    True,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    True,
                })
            if batch_idx_train % 100 == 0:
                loss, t_1_acc, t_1_per_class_acc, t_1__mean_iou, summaries = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        t_1_per_mean_iou_op, summaries_op
                    ],
                    feed_dict={
                        pts_fts:
                        points_batch,
                        indices:
                        pf.get_indices(batch_size_train, sample_num_train,
                                       points_num_batch),
                        xforms:
                        xforms_np,
                        is_labelled_data:
                        True,
                        rotations:
                        rotations_np,
                        jitter_range:
                        np.array([jitter]),
                        labels_seg:
                        labels_batch,
                        labels_weights:
                        weights_batch,
                        is_training:
                        True,
                    })
                summary_writer.add_summary(summaries, batch_idx_train)
                print(
                    '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
                    .format(datetime.now(), batch_idx_train, loss, t_1_acc,
                            t_1_per_class_acc, t_1__mean_iou))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
コード例 #11
0
def main():
    args = AttrDict()
    # Path of training set ground truth file list (.txt)
    args.filelist = os.path.join(SCENENN_DIR, 'train_files.txt')
    # Path of validation set ground truth file list (.txt)
    args.filelist_val = os.path.join(SCENENN_DIR, 'test_files.txt')
    # Path of a check point file to load
    args.load_ckpt = os.path.join(ROOT_DIR, '..', 'models',
                                  'pretrained_scannet', 'ckpts', 'iter-354000')
    # Base directory where model checkpoint and summary files get saved in separate subdirectories
    args.save_folder = os.path.join(ROOT_DIR, '..', 'models')
    # PointCNN model to use
    args.model = 'pointcnn_seg'
    # Model setting to use
    args.setting = 'scenenn_x8_2048_fps'

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    model_save_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(model_save_folder):
        os.makedirs(model_save_folder)

    # sys.stdout = open(os.path.join(model_save_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())
    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(ROOT_DIR, args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)
    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    print('{}-Initializing TF-placeholders...'.format(datetime.now()))
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    print('{}-Initializing net...'.format(datetime.now()))
    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    print('{}-Setting up optimizer...'.format(datetime.now()))
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        last_layer_train_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'logits')
        last_layer_train_op = optimizer.minimize(
            loss_op + reg_loss,
            global_step=global_step,
            var_list=last_layer_train_vars)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    variables_to_restore = get_variables_to_restore(exclude=['logits'])
    restorer = tf.train.Saver(var_list=variables_to_restore, max_to_keep=None)

    # backup all code
    # code_folder = os.path.abspath(os.path.dirname(__file__))
    # shutil.copytree(code_folder, os.path.join(model_save_folder, os.path.basename(code_folder)), symlinks=True)

    folder_ckpt = os.path.join(model_save_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(model_save_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Number of model parameters: {:d}.'.format(
        datetime.now(), parameter_num))

    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        print('{}-Initializing variables...'.format(datetime.now()))
        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            print('{}-Loading checkpoint from {}...'.format(
                datetime.now(), args.load_ckpt))
            restorer.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded.'.format(datetime.now()))

        for batch_idx_train in tqdm(range(batch_num), ncols=60):
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                tqdm.write('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                for batch_val_idx in range(batch_num_val):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)
                    sess.run(
                        [
                            loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            indices:
                            pf.get_indices(batch_size_val, sample_num,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            labels_seg:
                            labels_batch,
                            labels_weights:
                            weights_batch,
                            is_training:
                            False,
                        })

                loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        summaries_val_op
                    ])
                summary_writer.add_summary(summaries_val, batch_idx_train)
                tqdm.write(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            t_1_per_class_acc_val))
                sys.stdout.flush()
                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                if is_list_of_h5_list:
                    filelist_train_prev = seg_list[(seg_list_idx - 1) %
                                                   len(seg_list)]
                    filelist_train = seg_list[seg_list_idx % len(seg_list)]
                    if filelist_train != filelist_train_prev:
                        data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
                            filelist_train)
                        num_train = data_train.shape[0]
                    seg_list_idx = seg_list_idx + 1
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    last_layer_train_op, loss_mean_update_op,
                    t_1_acc_update_op, t_1_per_class_acc_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    True,
                })
            if batch_idx_train % 10 == 0:
                loss, t_1_acc, t_1_per_class_acc, summaries = sess.run([
                    loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                    summaries_op
                ])
                summary_writer.add_summary(summaries, batch_idx_train)
                # tqdm.write('{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                #            .format(datetime.now(), batch_idx_train, loss, t_1_acc, t_1_per_class_acc))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))