def evaluate(num_votes):
    is_training = False
     
    with tf.device('/gpu:'+str(GPU_INDEX)):
        xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
        rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
        jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
        global_step = tf.Variable(0, trainable=False, name='global_step')
        is_training_pl = tf.placeholder(tf.bool, name='is_training')

        pointclouds_pl = tf.placeholder(tf.float32, shape=(BATCH_SIZE, NUM_POINT, 3), name='data_train')
        labels_pl = tf.placeholder(tf.int32, shape=(BATCH_SIZE), name='label_train')

        points_augmented = pf.augment(pointclouds_pl, xforms, jitter_range)
        net = MODEL.Net(points=points_augmented, features=None, is_training=is_training_pl, setting=setting)
        # net = MODEL.Net(points=pointclouds_pl, features=None, is_training=is_training_pl, setting=setting)
        logits = net.logits
        probs = tf.nn.softmax(logits, name='probs')
        labels_2d = tf.expand_dims(labels_pl, axis=-1, name='labels_2d')
        labels_tile = tf.tile(labels_2d, (1, tf.shape(logits)[1]), name='labels_tile')
        loss_op = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(labels=labels_tile, logits=logits))    
        # Add ops to save and restore all the variables.
        saver = tf.train.Saver()
        
    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    sess = tf.Session(config=config)

    # Restore variables from disk.
    saver.restore(sess, MODEL_PATH)
    log_string("Model restored.")

    ops = {'pointclouds_pl': pointclouds_pl,
           'labels_pl': labels_pl,
           'is_training_pl': is_training_pl,
           'pred': probs,
           'loss': loss_op,
           'xforms': xforms,
           'rotations': rotations,
           'jitter_range': jitter_range}

    eval_one_epoch(sess, ops, num_votes)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', '-t', help='Path to data', required=True)
    parser.add_argument('--path_val', '-v', help='Path to validation data')
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    point_num = 2048
    rotation_range = setting.rotation_range
    scaling_range = setting.scaling_range
    jitter = setting.jitter
    pool_setting_train = None if not hasattr(
        setting, 'pool_setting_train') else setting.pool_setting_train

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    sys.stdout.flush()
    read_path, write_path, len = data_utils.read_path(args.path)

    num_train = len

    print('{}-{:d} training samples.'.format(datetime.now(), len))
    sys.stdout.flush()

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    weight_train_placeholder = tf.placeholder(tf.float32,
                                              shape=(batch_size),
                                              name="weight")
    ### add weight
    data_train_placeholder = tf.placeholder(tf.float32,
                                            shape=(batch_size, point_num, 6),
                                            name='data_train')
    label_train_placeholder = tf.placeholder(tf.int64,
                                             shape=(batch_size),
                                             name='label_train')
    ########################################################################
    batch_num_per_epoch = math.floor(num_train / batch_size)

    print('{}-{:d} training batches per_epoch.'.format(datetime.now(),
                                                       batch_num_per_epoch))
    sys.stdout.flush()

    pts_fts_sampled = tf.gather_nd(data_train_placeholder,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    net = model.Net(points=points_augmented,
                    features=features_augmented,
                    is_training=is_training,
                    setting=setting)
    logits = net.logits
    feature = net.fc_layers[-1]
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs,
                            axis=-1,
                            name='predictions',
                            output_type=tf.int32)
    predictions = tf.squeeze(predictions)

    labels_2d = tf.expand_dims(label_train_placeholder,
                               axis=-1,
                               name='labels_2d')
    labels_tile = tf.tile(labels_2d, (1, tf.shape(logits)[1]),
                          name='labels_tile')
    # loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile, logits=logits)
    weights_2d = tf.expand_dims(weight_train_placeholder,
                                axis=-1,
                                name='weights_2d')
    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile,
                                                     logits=logits,
                                                     weights=weights_2d)

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)

    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)

    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op)

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()

    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)

    train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))
    sys.stdout.flush()
    with tf.Session() as sess:
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)
        sess.run(init_op)
        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        print('total-[Train]-Iter: ', num_epochs)
        sys.stdout.flush()
        cloud_features = []
        for batch_idx_train in range(batch_num_per_epoch):
            dataset_train = []
            for i in range(batch_size):
                k = batch_idx_train * batch_size + i
                count = 0
                data = []
                with open(read_path[k]) as fpts:
                    while 1:
                        line = fpts.readline()
                        if not line:
                            break
                        L = line.split(' ')
                        L = [float(i) for i in L]
                        data.append(np.array(L))
                        count = count + 1
                    data = np.array(data)
                    trans_x = (min(data[:, 0]) + max(data[:, 0])) / 2
                    trans_y = (min(data[:, 1]) + max(data[:, 1])) / 2
                    trans_z = (min(data[:, 2]) + max(data[:, 2])) / 2
                    data[:, 3] = data[:, 3] / 255
                    data[:, 4] = data[:, 4] / 255
                    data[:, 5] = data[:, 5] / 255
                    data = data - [trans_x, trans_y, trans_z, 0.5, 0.5, 0.5]
                    if (count >= 2048):
                        index = np.random.choice(count,
                                                 size=2048,
                                                 replace=False)
                        # index = random.sample(range(0, count), 2048)
                        dataset = data[index, :]
                    else:
                        # k = random.sample(range(0, count), count)
                        index = np.random.choice(count,
                                                 size=2048,
                                                 replace=True)
                        dataset = data[index, :]
                    dataset_train.append(dataset)
            data_batch = np.array(dataset_train)
            ######################################################################
            # TESting
            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            cloud_feature = sess.run(
                [feature],
                feed_dict={
                    data_train_placeholder:
                    data_batch,
                    indices:
                    pf.get_indices(batch_size, sample_num_train, point_num,
                                   pool_setting_train),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    is_training:
                    True,
                })
            cloud_feature = np.array(cloud_feature)
            cloud_feature = cloud_feature.reshape((batch_size, -1))
            for i in range(batch_size):
                wr_index = batch_idx_train * batch_size + i
                np.savetxt(write_path[wr_index],
                           cloud_feature[i],
                           fmt='%.6e',
                           newline=' ')
                print("{} has writed".format(write_path[wr_index]))

        sys.stdout.flush()

        print('{}-Done!'.format(datetime.now()))
Exemple #3
0
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:'+str(GPU_INDEX)):
            # Placeholders
            xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
            rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
            jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
            global_step = tf.Variable(0, trainable=False, name='global_step')
            is_training_pl = tf.placeholder(tf.bool, name='is_training')

            pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, sample_num, 3), name='data_train')
            labels_pl = tf.placeholder(tf.int32, shape=(batch_size), name='label_train')

            points_augmented = pf.augment(pointclouds_pl, xforms, jitter_range)
            net = model.Net(points=points_augmented, features=None, is_training=is_training_pl, setting=setting)
            logits = net.logits
            probs = tf.nn.softmax(logits, name='probs')
            predictions = tf.argmax(probs, axis=-1, name='predictions')

            labels_2d = tf.expand_dims(labels_pl, axis=-1, name='labels_2d')
            labels_tile = tf.tile(labels_2d, (1, tf.shape(logits)[1]), name='labels_tile')
            loss_op = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(labels=labels_tile, logits=logits))

            tf.summary.scalar('loss', loss_op)
            # with tf.name_scope('metrics'):
            #     loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
            #     t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(labels_tile, predictions)
            #     t_1_per_class_acc_op, t_1_per_class_acc_update_op = tf.metrics.mean_per_class_accuracy(labels_tile,
            #                                                                                            predictions,
            #                                                                                            setting.num_class)
            # reset_metrics_op = tf.variables_initializer([var for var in tf.local_variables()
            #                                              if var.name.split('/')[0] == 'metrics'])

            # _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train'])
            # _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train'])
            # _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train'])

            # _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
            # _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val'])
            # _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val'])

            lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,
                                                   setting.decay_rate, staircase=True)
            lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
            _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])
            reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
            if setting.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon)
            elif setting.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

        # Create a session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)

        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        saver = tf.train.Saver(max_to_keep=None)

        # backup all code
        # code_folder = os.path.abspath(os.path.dirname(__file__))
        # shutil.copytree(code_folder, os.path.join(root_folder)

        folder_ckpt = root_folder
        # if not os.path.exists(folder_ckpt):
        #     os.makedirs(folder_ckpt)

        folder_summary = os.path.join(root_folder, 'summary')
        if not os.path.exists(folder_summary):
            os.makedirs(folder_summary)          

        parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
        print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

        sess.run(init_op)

        # saver.restore(sess, os.path.join(folder_ckpt, "model.ckpt"))
        # log_string("Model restored.")        

        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(os.path.join(folder_summary, 'train'),
                                  sess.graph)
        test_writer = tf.summary.FileWriter(os.path.join(folder_summary, 'test'))

        ops = {'pointclouds_pl': pointclouds_pl,
               'labels_pl': labels_pl,
               'is_training_pl': is_training_pl,
               'pred': probs,
               'loss': loss_op,
               'train_op': train_op,
               'merged': merged,
               'step': global_step,
               'xforms': xforms,
               'rotations': rotations,
               'jitter_range': jitter_range}

        for epoch in range(num_epochs):
            log_string('**** EPOCH %03d ****' % (epoch))
            sys.stdout.flush()
             
            train_one_epoch(sess, ops, train_writer)
            eval_one_epoch(sess, ops, test_writer)
            
            # Save the variables to disk.
            # if epoch % 10 == 0:
            save_path = saver.save(sess, os.path.join(folder_ckpt, "model.ckpt"))
            log_string("Model saved in file: %s" % save_path)        
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--category',
                        '-c',
                        help='category name',
                        required=True)
    parser.add_argument('--level',
                        '-l',
                        type=int,
                        help='level id',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-k',
                        help='Path to a check point file for load')
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    args = parser.parse_args()

    args.save_folder = 'exps_results/'
    args.data_folder = '../../data/sem_seg_h5/'

    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, args.category, args.level))
    if os.path.exists(root_folder):
        print('ERROR: folder %s exist! Please check and delete!' % root_folder)
        exit(1)
    os.makedirs(root_folder)

    flog = open(os.path.join(root_folder, 'log.txt'), 'w')

    def printout(d):
        flog.write(str(d) + '\n')
        print(d)

    printout('PID: %s' % os.getpid())

    printout(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    args.filelist = os.path.join(args.data_folder,
                                 '%s-%d' % (args.category, args.level),
                                 'train_files.txt')
    args.filelist_val = os.path.join(args.data_folder,
                                     '%s-%d' % (args.category, args.level),
                                     'val_files.txt')

    # Load current category + level statistics
    with open(
            '../../stats/after_merging_label_ids/%s-level-%d.txt' %
        (args.category, args.level), 'r') as fin:
        setting.num_class = len(fin.readlines()) + 1  # with "other"
        printout('NUM CLASS: %d' % setting.num_class)

    label_weights_list = [1.0] * setting.num_class

    # Prepare inputs
    printout('{}-Preparing datasets...'.format(datetime.now()))
    data_train, data_num_train, label_train = data_utils.load_seg(
        args.filelist)
    data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    printout('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    train_batch = num_train // batch_size
    printout('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    printout('{}-{:d} testing batches per test.'.format(
        datetime.now(), batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    printout('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    printout('{}-Parameter number: {:d}.'.format(datetime.now(),
                                                 parameter_num))

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    sess = tf.Session(config=config)

    summaries_op = tf.summary.merge_all('train')
    summaries_val_op = tf.summary.merge_all('val')
    summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

    sess.run(init_op)

    # Load the model
    if args.load_ckpt is not None:
        saver.restore(sess, args.load_ckpt)
        printout('{}-Checkpoint loaded from {}!'.format(
            datetime.now(), args.load_ckpt))

    for batch_idx_train in range(batch_num):
        if (batch_idx_train % (10 * train_batch) == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                or batch_idx_train == batch_num - 1:
            ######################################################################
            # Validation
            filename_ckpt = os.path.join(folder_ckpt, 'iter')
            saver.save(sess, filename_ckpt, global_step=global_step)
            printout('{}-Checkpoint saved to {}!'.format(
                datetime.now(), filename_ckpt))

            sess.run(reset_metrics_op)
            for batch_val_idx in range(batch_num_val):
                start_idx = batch_size * batch_val_idx
                end_idx = min(start_idx + batch_size, num_val)
                batch_size_val = end_idx - start_idx
                points_batch = data_val[start_idx:end_idx, ...]
                points_num_batch = data_num_val[start_idx:end_idx, ...]
                labels_batch = label_val[start_idx:end_idx, ...]
                weights_batch = np.array(label_weights_list)[labels_batch]

                xforms_np, rotations_np = pf.get_xforms(
                    batch_size_val,
                    rotation_range=rotation_range_val,
                    scaling_range=scaling_range_val,
                    order=setting.rotation_order)
                sess.run(
                    [
                        loss_mean_update_op, t_1_acc_update_op,
                        t_1_per_class_acc_update_op
                    ],
                    feed_dict={
                        pts_fts:
                        points_batch,
                        indices:
                        pf.get_indices(batch_size_val, sample_num,
                                       points_num_batch),
                        xforms:
                        xforms_np,
                        rotations:
                        rotations_np,
                        jitter_range:
                        np.array([jitter_val]),
                        labels_seg:
                        labels_batch,
                        labels_weights:
                        weights_batch,
                        is_training:
                        False,
                    })

            loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val = sess.run(
                [
                    loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                    summaries_val_op
                ])
            summary_writer.add_summary(summaries_val, batch_idx_train)
            printout(
                '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                .format(datetime.now(), loss_val, t_1_acc_val,
                        t_1_per_class_acc_val))
            flog.flush()
            ######################################################################

        ######################################################################
        # Training
        start_idx = (batch_size * batch_idx_train) % num_train
        end_idx = min(start_idx + batch_size, num_train)
        batch_size_train = end_idx - start_idx
        points_batch = data_train[start_idx:end_idx, ...]
        points_num_batch = data_num_train[start_idx:end_idx, ...]
        labels_batch = label_train[start_idx:end_idx, ...]
        weights_batch = np.array(label_weights_list)[labels_batch]

        if start_idx + batch_size_train == num_train:
            data_train, data_num_train, label_train = \
                data_utils.grouped_shuffle([data_train, data_num_train, label_train])

        offset = int(random.gauss(0, sample_num * setting.sample_num_variance))
        offset = max(offset, -sample_num * setting.sample_num_clip)
        offset = min(offset, sample_num * setting.sample_num_clip)
        sample_num_train = sample_num + offset
        xforms_np, rotations_np = pf.get_xforms(batch_size_train,
                                                rotation_range=rotation_range,
                                                scaling_range=scaling_range,
                                                order=setting.rotation_order)
        sess.run(reset_metrics_op)
        sess.run(
            [
                train_op, loss_mean_update_op, t_1_acc_update_op,
                t_1_per_class_acc_update_op
            ],
            feed_dict={
                pts_fts:
                points_batch,
                indices:
                pf.get_indices(batch_size_train, sample_num_train,
                               points_num_batch),
                xforms:
                xforms_np,
                rotations:
                rotations_np,
                jitter_range:
                np.array([jitter]),
                labels_seg:
                labels_batch,
                labels_weights:
                weights_batch,
                is_training:
                True,
            })
        if batch_idx_train % 10 == 0:
            loss, t_1_acc, t_1_per_class_acc, summaries = sess.run(
                [loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_op])
            summary_writer.add_summary(summaries, batch_idx_train)
            printout(
                '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                .format(datetime.now(), batch_idx_train, loss, t_1_acc,
                        t_1_per_class_acc))
            flog.flush()
        ######################################################################
    printout('{}-Done!'.format(datetime.now()))
    flog.close()
Exemple #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load')
    parser.add_argument('--save_folder', '-s', default='log/seg', help='Path to folder for saving check points and summary')
    parser.add_argument('--model', '-m', default='shellconv', help='Model to use')
    parser.add_argument('--setting', '-x', default='seg_s3dis', help='Setting to use')
    parser.add_argument('--log', help='Log to FILE in save folder; use - for stdout (default is log.txt)', metavar='FILE', default='log.txt')
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(args.save_folder, '%s_%s_%s' % (args.model, args.setting, time_string))
    
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    global LOG_FOUT
    if args.log != '-':
        LOG_FOUT = open(os.path.join(root_folder, args.log), 'w')

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val    

    is_list_of_h5_list = data_utils.is_h5_list(setting.filelist)
    if is_list_of_h5_list:
        seg_list = [setting.filelist] # for train
    else:
        seg_list = data_utils.load_seg_list(setting.filelist)  # for train
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(setting.filelist_val)
    if data_val.shape[-1] > 3:    
        data_val = data_val[:,:,:3]  # only use the xyz coordinates
    point_num = data_val.shape[1]
    num_val = data_val.shape[0]
    batch_num_val = num_val // batch_size

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, sample_num, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts')
    labels_seg = tf.placeholder(tf.int64, shape=(None, point_num), name='labels_seg')
    labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights')

    ######################################################################
    points_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled')

    bn_decay_exp_op = tf.train.exponential_decay(0.5, global_step, setting.decay_steps,
                                           0.5, staircase=True)
    bn_decay_op = tf.minimum(0.99, 1 - bn_decay_exp_op)

    logits_op = model.get_model(points_augmented, is_training, setting.sconv_params, setting.sdconv_params, setting.fc_params, 
                            sampling=setting.sampling,
                            weight_decay=setting.weight_decay, 
                            bn_decay = bn_decay_op, 
                            part_num=setting.num_class)

    predictions = tf.argmax(logits_op, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_sampled, logits=logits_op,
                                                     weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([var for var in tf.local_variables()
                                                 if var.name.split('/')[0] == 'metrics'])


    _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,
                                           setting.decay_rate, staircase=True)
                                           
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

    saver = tf.train.Saver(max_to_keep=None)

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    with tf.Session(config=config) as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(tf.global_variables_initializer())

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))
        else:
            latest_ckpt = tf.train.latest_checkpoint(folder_ckpt)
            if latest_ckpt:
                print('{}-Found checkpoint {}'.format(datetime.now(), latest_ckpt))
                saver.restore(sess, latest_ckpt)
                print('{}-Checkpoint loaded from {} (Iter {})'.format(
                    datetime.now(), latest_ckpt, sess.run(global_step)))

        best_acc = 0
        best_epoch = 0
        for epoch in range(num_epochs):
            ############################### train #######################################
            # Shuffle train files
            np.random.shuffle(seg_list)
            for file_idx_train in range(len(seg_list)):
                print('----epoch:'+str(epoch) + '--train file:' + str(file_idx_train) + '-----')
                filelist_train = seg_list[file_idx_train]
                data_train, _, data_num_train, label_train, _ = data_utils.load_seg(filelist_train)
                num_train = data_train.shape[0]
                if data_train.shape[-1] > 3:    
                    data_train = data_train[:,:,:3]
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])
                # data_train, label_train, _ = provider.shuffle_data_seg(data_train, label_train) 

                batch_num = (num_train + batch_size - 1) // batch_size

                for batch_idx_train in range(batch_num):
                    # Training
                    start_idx = (batch_size * batch_idx_train) % num_train
                    end_idx = min(start_idx + batch_size, num_train)
                    batch_size_train = end_idx - start_idx
                    points_batch = data_train[start_idx:end_idx, ...]
                    points_num_batch = data_num_train[start_idx:end_idx, ...]
                    labels_batch = label_train[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[labels_batch]

                    offset = int(random.gauss(0, sample_num * setting.sample_num_variance))
                    offset = max(offset, -sample_num * setting.sample_num_clip)
                    offset = min(offset, sample_num * setting.sample_num_clip)
                    sample_num_train = sample_num + offset
                    xforms_np, rotations_np = pf.get_xforms(batch_size_train,
                                                            rotation_range=rotation_range,
                                                            scaling_range=scaling_range,
                                                            order=setting.rotation_order)
                    sess.run(reset_metrics_op)
                    sess.run([train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op],
                            feed_dict={
                                pts_fts: points_batch,
                                indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch),
                                xforms: xforms_np,
                                rotations: rotations_np,
                                jitter_range: np.array([jitter]),
                                labels_seg: labels_batch,
                                labels_weights: weights_batch,
                                is_training: True,
                            })
                
                loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([loss_mean_op,
                                                                        t_1_acc_op,
                                                                        t_1_per_class_acc_op,
                                                                        summaries_op,
                                                                        global_step])
                summary_writer.add_summary(summaries, step)
                log_string('{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), step, loss, t_1_acc, t_1_per_class_acc))
                sys.stdout.flush()
                ######################################################################
        
            filename_ckpt = os.path.join(folder_ckpt, 'epoch')
            saver.save(sess, filename_ckpt, global_step=epoch)
            print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))

            sess.run(reset_metrics_op)
            for batch_val_idx in range(batch_num_val):
                start_idx = batch_size * batch_val_idx
                end_idx = min(start_idx + batch_size, num_val)
                batch_size_val = end_idx - start_idx
                points_batch = data_val[start_idx:end_idx, ...]
                points_num_batch = data_num_val[start_idx:end_idx, ...]
                labels_batch = label_val[start_idx:end_idx, ...]
                weights_batch = np.array(label_weights_list)[labels_batch]

                xforms_np, rotations_np = pf.get_xforms(batch_size_val,
                                                            rotation_range=rotation_range_val,
                                                            scaling_range=scaling_range_val,
                                                            order=setting.rotation_order)
                sess.run([loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op],
                        feed_dict={
                            pts_fts: points_batch,
                            indices: pf.get_indices(batch_size_val, sample_num, points_num_batch),
                            xforms: xforms_np,
                            rotations: rotations_np,
                            jitter_range: np.array([jitter_val]),
                            labels_seg: labels_batch,
                            labels_weights: weights_batch,
                            is_training: False,
                        })
            loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run(
                [loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_val_op, global_step])
            summary_writer.add_summary(summaries_val, step)

            if t_1_per_class_acc_val > best_acc:
                best_acc = t_1_per_class_acc_val
                best_epoch = epoch

            log_string('{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f} best epoch: {} Current epoch: {}'
                  .format(datetime.now(), loss_val, t_1_acc_val, t_1_per_class_acc_val, best_epoch, epoch))
            sys.stdout.flush()
            ######################################################################
            
        print('{}-Done!'.format(datetime.now()))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist', '-t', help='Path to training set ground truth (.txt)', required=True)
    parser.add_argument('--filelist_val', '-v', help='Path to validation set ground truth (.txt)', required=True)
    parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load')
    parser.add_argument('--save_folder', '-s', help='Path to folder for saving check points and summary', required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting', '-x', help='Setting to use', required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(args.save_folder, '%s_%s_%d_%s' % (args.model, args.setting, os.getpid(), time_string))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    print(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = 1000
    num_parts = setting.num_parts
    label_weights_list = setting.label_weights
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    
    #data_train, data_num_train, label_train = data_utils.load_seg(args.filelist)
    #data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val)
    data_train, img_train, data_num_train, label_train, label_hwl_train, label_xyz_train,label_ry_train, label_ry_reg_train = data_utils.load_seg_kitti(args.filelist)
    data_val, img_val, data_num_val, label_val, label_hwl_val, label_xyz_val,label_ry_val, label_ry_reg_val = data_utils.load_seg_kitti(args.filelist_val)

    # shuffle
    #data_train, data_num_train, label_train = \
    #    data_utils.grouped_shuffle([data_train, data_num_train, label_train])
    data_train, img_train, data_num_train, label_train, label_hwl_train, label_xyz_train,\
        label_ry_train, label_ry_reg_train = data_utils.grouped_shuffle([data_train, img_train,\
        data_num_train, label_train, label_hwl_train,label_xyz_train, label_ry_train,label_ry_reg_train])
        
    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]

    print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts')
    labels_seg = tf.placeholder(tf.int32, shape=(None, point_num), name='labels_seg')
    labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights')

    labels_hwl = tf.placeholder(tf.float32, shape=(None, 3), name='labels_hwl')
    labels_xyz = tf.placeholder(tf.float32, shape=(None, 3), name='labels_xyz')
    labels_ry = tf.placeholder(tf.int32, shape=(None, 1), name='labels_ry')
    labels_ry_reg = tf.placeholder(tf.float32, shape=(None, 1), name='labels_ry_reg')
    ######################################################################

    # Set Inputs(points,features_sampled)
    features_sampled = None

    if setting.data_dim > 3:

        points, _, features = tf.split(pts_fts, [setting.data_format["pts_xyz"], setting.data_format["img_xy"], setting.data_format["extra_features"]], axis=-1, name='split_points_xy_features')

        if setting.use_extra_features:

            features_sampled = tf.gather_nd(features, indices=indices, name='features_sampled')

    else:
        points = pts_fts

    points_sampled = tf.gather_nd(points, indices=indices, name='points_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled')
    
    # Build net
    #net = model.Net(points_augmented, features_sampled, None, None, num_parts, is_training, setting)
    net = model.Net(points_augmented, features_sampled, None, None, num_parts, is_training, setting)
    #logits, probs= net.logits, net.probs
    logits, probs, logits_hwl, logits_xyz, logits_ry, probs_ry=net.logits, net.probs, net.logits_hwl, net.logits_xyz, net.logits_ry, net.probs_ry


    # Define Loss Func
    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_sampled, logits=logits, weights=labels_weights_sampled)
    _ = tf.summary.scalar('loss/train_seg', tensor=loss_op, collections=['train'])

    # HWL loss
    # 这里以MSE loss 为例
    loss_hwl_op = tf.losses.mean_squared_error(labels_hwl, logits_hwl)  # API提供的MSE loss
    # 计算RMSE来作为一种评估指标
    rmse_hwl_op = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(labels_hwl - logits_hwl), -1)))

    # XYZ loss
    # 这里以MSE loss 为例
    loss_xyz_op = tf.losses.mean_squared_error(labels_xyz, logits_xyz)
    # 计算RMSE来作为一种评估指标
    rmse_xyz_op = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(labels_xyz - logits_xyz), -1)))

    # RY loss
    # 分类以softmax_cross_entropy为例
    loss_ry_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_ry, logits=logits_ry)
    # 计算分类准确率作为评估指标
    t_1_acc_ry_op = pf.top_1_accuracy(probs_ry, labels_ry)

    # 计算平均角度偏差作为评估指标
    mdis_ry_angle_op = tf.reduce_mean(tf.sqrt(tf.square(pf.ry_getangle(labels_ry,labels_ry_reg)-
                                                        pf.ry_getangle(tf.nn.top_k(probs_ry, 1)[1], tf.zeros_like(labels_ry_reg)))))

    # 将结果可视化在Tensorboard上
    _ = tf.summary.scalar('loss/train_hwl', tensor=loss_hwl_op, collections=['train'])
    _ = tf.summary.scalar('loss/train_xyz', tensor=loss_xyz_op, collections=['train'])
    _ = tf.summary.scalar('loss/train_ry', tensor=loss_ry_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_rmse_hwl', tensor=rmse_hwl_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_rmse_xyz', tensor=rmse_xyz_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_ry', tensor=t_1_acc_ry_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/mdis_ry_angle', tensor=mdis_ry_angle_op, collections=['train'])

    #for vis t1 acc
    t_1_acc_op = pf.top_1_accuracy(probs, labels_sampled)
    _ = tf.summary.scalar('t_1_acc/train_seg', tensor=t_1_acc_op, collections=['train'])
    #for vis instance acc
    t_1_acc_instance_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6)
    _ = tf.summary.scalar('t_1_acc/train_seg_instance', tensor=t_1_acc_instance_op, collections=['train'])
    #for vis other acc
    t_1_acc_others_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6,"less")
    _ = tf.summary.scalar('t_1_acc/train_seg_others', tensor=t_1_acc_others_op, collections=['train'])

    loss_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('loss/val_seg', tensor=loss_val_avg, collections=['val'])

    t_1_acc_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg', tensor=t_1_acc_val_avg, collections=['val'])
    t_1_acc_val_instance_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg_instance', tensor=t_1_acc_val_instance_avg, collections=['val'])
    t_1_acc_val_others_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg_others', tensor=t_1_acc_val_others_avg, collections=['val'])

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()

    # 对于val,需要计算整个Val过程误差的平均值作为Val的评估指标
    loss_val_avg_hwl = tf.placeholder(tf.float32)
    loss_val_avg_xyz = tf.placeholder(tf.float32)
    loss_val_avg_ry = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('loss/val_hwl', tensor=loss_val_avg_hwl, collections=['val'])
    _ = tf.summary.scalar('loss/val_xyz', tensor=loss_val_avg_xyz, collections=['val'])
    _ = tf.summary.scalar('loss/val_ry', tensor=loss_val_avg_ry, collections=['val'])
    rmse_val_hwl_avg = tf.placeholder(tf.float32)
    rmse_val_xyz_avg = tf.placeholder(tf.float32)
    t_1_acc_val_ry_avg = tf.placeholder(tf.float32)
    mdis_val_ry_angle_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_rmse_hwl', tensor=rmse_val_hwl_avg, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_rmse_xyz', tensor=rmse_val_xyz_avg, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_ry', tensor=t_1_acc_val_ry_avg, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_mdis_ry_angle', tensor=mdis_val_ry_angle_avg, collections=['val'])

    # lr decay
    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,setting.decay_rate, staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])

    # Optimizer
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=0.9, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):

        #train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)
        train_op = optimizer.minimize(loss_op + reg_loss + loss_hwl_op + loss_xyz_op + loss_ry_op, global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup this file, model and setting
    if not os.path.exists(os.path.join(root_folder, args.model)):
        os.makedirs(os.path.join(root_folder, args.model))
    shutil.copy(__file__, os.path.join(root_folder, os.path.basename(__file__)))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'),
                os.path.join(root_folder, args.model + '.py'))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model.split("_")[0] + "_kitti" + '.py'),
                os.path.join(root_folder, args.model.split("_")[0] + "_kitti" + '.py'))
    shutil.copy(os.path.join(setting_path, args.setting + '.py'),
                os.path.join(root_folder, args.model, args.setting + '.py'))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    # Session Run
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))

        for batch_idx in range(batch_num):
            if (batch_idx != 0 and batch_idx % step_val == 0) or batch_idx == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))

                losses_val = []
                t_1_accs = []
                t_1_accs_instance = []
                t_1_accs_others = []

                # 声明用于计算平均值的list
                losses_val_hwl = []
                losses_val_xyz = []
                losses_val_ry = []
                rmses_hwl = []
                rmses_xyz = []
                t_1_accs_ry = []
                mdises_ry_angle = []


                for batch_val_idx in range(math.ceil(num_val / batch_size)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[label_val[start_idx:end_idx, ...]]

                    # 取出每个batch的label
                    labels_hwl_batch = label_hwl_val[start_idx:end_idx, ...]
                    labels_xyz_batch = label_xyz_val[start_idx:end_idx, ...]
                    labels_ry_batch = label_ry_val[start_idx:end_idx, ...]
                    labels_ry_reg_batch = label_ry_reg_val[start_idx:end_idx, ...]

                    xforms_np, rotations_np = pf.get_xforms(batch_size_val, scaling_range=scaling_range_val)

                    sess_op_list = [loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op]
                    sess_op_list = sess_op_list + [loss_hwl_op, loss_xyz_op, loss_ry_op, rmse_hwl_op, rmse_xyz_op,t_1_acc_ry_op, mdis_ry_angle_op]

                    sess_feed_dict = {pts_fts: points_batch,
                                     indices: pf.get_indices(batch_size_val, sample_num, points_num_batch),
                                     xforms: xforms_np,
                                     rotations: rotations_np,
                                     jitter_range: np.array([jitter_val]),
                                     labels_seg: labels_batch,
                                     labels_weights: weights_batch,
                                     is_training: False}
                    sess_feed_dict[labels_hwl] = labels_hwl_batch
                    sess_feed_dict[labels_xyz] = labels_xyz_batch
                    sess_feed_dict[labels_ry] = labels_ry_batch
                    sess_feed_dict[labels_ry_reg] = labels_ry_reg_batch

                    #loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others = sess.run(sess_op_list,feed_dict=sess_feed_dict)
                    loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others, loss_val_hwl, loss_val_xyz,loss_val_ry, rmse_hwl_val, rmse_xyz_val, t_1_acc_ry_val, mdis_ry_angle_val =sess.run(sess_op_list, feed_dict=sess_feed_dict)

                    #print('{}-[Val  ]-Iter: {:06d}  Loss: {:.4f} T-1 Acc: {:.4f}'.format(datetime.now(), batch_val_idx, loss_val, t_1_acc_val))
                    print('{}-[Val ]-Iter: {:06d} Loss: {:.4f} Loss_hwl: {:.4f} Loss_xyz: {:.4f} Loss_ry: {:.4f}Rmse_hwl: {:.4f} Rmse_xyz: {:.4f} T - 1Ry_Acc: {:.4f} Mdis_ry_angle: {:.4f} T - 1Acc: {:.4f}'.format(datetime.now(), batch_val_idx, loss_val, loss_val_hwl, loss_val_xyz, loss_val_ry,rmse_hwl_val, rmse_xyz_val, t_1_acc_ry_val, mdis_ry_angle_val, t_1_acc_val))

                    sys.stdout.flush()

                    losses_val.append(loss_val * batch_size_val)
                    t_1_accs.append(t_1_acc_val * batch_size_val)
                    t_1_accs_instance.append(t_1_acc_val_instance * batch_size_val)
                    t_1_accs_others.append(t_1_acc_val_others * batch_size_val)

                    # 记录每个iter的评估指标
                    losses_val_hwl.append(loss_val_hwl * batch_size_val)
                    losses_val_xyz.append(loss_val_xyz * batch_size_val)
                    losses_val_ry.append(loss_val_ry * batch_size_val)
                    rmses_hwl.append(rmse_hwl_val * batch_size_val)
                    rmses_xyz.append(rmse_xyz_val * batch_size_val)
                    t_1_accs_ry.append(t_1_acc_ry_val * batch_size_val)
                    mdises_ry_angle.append(mdis_ry_angle_val * batch_size_val)


                loss_avg = sum(losses_val)/num_val
                t_1_acc_avg = sum(t_1_accs) / num_val
                t_1_acc_instance_avg =  sum(t_1_accs_instance) / num_val
                t_1_acc_others_avg =  sum(t_1_accs_others) / num_val
                loss_avg_hwl = sum(losses_val_hwl) / num_val
                loss_avg_xyz = sum(losses_val_xyz) / num_val
                loss_avg_ry = sum(losses_val_ry) / num_val
                rmse_hwl_avg = sum(rmses_hwl) / num_val
                rmse_xyz_avg = sum(rmses_xyz) / num_val
                t_1_acc_ry_avg = sum(t_1_accs_ry) / num_val
                mdis_ry_angle_avg = sum(mdises_ry_angle) / num_val

                summaries_feed_dict = { loss_val_avg: loss_avg,
                                        t_1_acc_val_avg: t_1_acc_avg,
                                        t_1_acc_val_instance_avg: t_1_acc_instance_avg,
                                        t_1_acc_val_others_avg: t_1_acc_others_avg}
                summaries_feed_dict[loss_val_avg_hwl] = loss_avg_hwl
                summaries_feed_dict[loss_val_avg_xyz] = loss_avg_xyz
                summaries_feed_dict[loss_val_avg_ry] = loss_avg_ry
                summaries_feed_dict[rmse_val_hwl_avg] = rmse_hwl_avg
                summaries_feed_dict[rmse_val_xyz_avg] = rmse_xyz_avg
                summaries_feed_dict[t_1_acc_val_ry_avg] = t_1_acc_ry_avg
                summaries_feed_dict[mdis_val_ry_angle_avg] = mdis_ry_angle_avg

                summaries_val = sess.run(summaries_val_op,feed_dict=summaries_feed_dict)
                summary_writer.add_summary(summaries_val, batch_idx)

                #print('{}-[Val  ]-Average:      Loss: {:.4f} T-1 Acc: {:.4f}'
                #    .format(datetime.now(), loss_avg, t_1_acc_avg))
                print('{}-[Val ]-Average: Loss: {:.4f} Loss_hwl: {:.4f} Loss_xyz: {:.4f} Loss_ry: {:.4f}Rmse_hwl: {:.4f} Rmse_xyz: {:.4f} T - 1Ry_Acc: {:.4f} Mdis_ry_angle: {:.4f} T - 1Acc: {:.4f}'.format(datetime.now(), loss_avg, loss_avg_hwl, loss_avg_xyz, loss_avg_ry, rmse_hwl_avg,rmse_xyz_avg, t_1_acc_ry_avg, mdis_ry_angle_avg,t_1_acc_avg))

                sys.stdout.flush()
                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            labels_hwl_batch = label_hwl_train[start_idx:end_idx, ...]
            labels_xyz_batch = label_xyz_train[start_idx:end_idx, ...]
            labels_ry_batch = label_ry_train[start_idx:end_idx, ...]
            labels_ry_reg_batch = label_ry_reg_train[start_idx:end_idx, ...]
                
            if start_idx + batch_size_train == num_train:
                #data_train, data_num_train, label_train = \
                #    data_utils.grouped_shuffle([data_train, data_num_train, label_train])
                data_train, img_train, data_num_train, label_train, label_hwl_train, label_xyz_train, \
                    label_ry_train, label_ry_reg_train = data_utils.grouped_shuffle([data_train, img_train,data_num_train, label_train,label_hwl_train, label_xyz_train,label_ry_train, label_ry_reg_train])

            offset = int(random.gauss(0, sample_num // 8))
            offset = max(offset, -sample_num // 4)
            offset = min(offset, sample_num // 4)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(batch_size_train, scaling_range=scaling_range)

            sess_op_list = [train_op, loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op, summaries_op]
            sess_op_list = sess_op_list + [loss_hwl_op, loss_xyz_op, loss_ry_op, rmse_hwl_op, rmse_xyz_op,t_1_acc_ry_op, mdis_ry_angle_op]

            sess_feed_dict = {pts_fts: points_batch,
                             indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch),
                             xforms: xforms_np,
                             rotations: rotations_np,
                             jitter_range: np.array([jitter]),
                             labels_seg: labels_batch,
                             labels_weights: weights_batch,
                             is_training: True}

            # 为placeholder feed label
            sess_feed_dict[labels_hwl] = labels_hwl_batch
            sess_feed_dict[labels_xyz] = labels_xyz_batch
            sess_feed_dict[labels_ry] = labels_ry_batch
            sess_feed_dict[labels_ry_reg] = labels_ry_reg_batch

            #_, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries = sess.run(sess_op_list,feed_dict=sess_feed_dict)
            _, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries, loss_hwl, loss_xyz, loss_ry,rmse_hwl, rmse_xyz, t_1_acc_ry, mdis_ry_angle=sess.run(sess_op_list, feed_dict=sess_feed_dict)

            #print('{}-[Train]-Iter: {:06d}  Loss_seg: {:.4f} T-1 Acc: {:.4f}'
            #  .format(datetime.now(), batch_idx, loss, t_1_acc))
            print('{}-[Train]-Iter: {:06d} Loss_seg: {:.4f} Loss_hwl: {:.4f} Loss_xyz: {:.4f} Loss_ry:{:.4f} Rmse_hwl: {:.4f} Rmse_xyz: {:.4f} T - 1Ry_Acc: {:.4f} Mdis_ry_angle: {:.4f} T - 1Acc:{:.4f}'.format(datetime.now(), batch_idx, loss, loss_hwl, loss_xyz, loss_ry, rmse_hwl, rmse_xyz,t_1_acc_ry, mdis_ry_angle, t_1_acc))


            summary_writer.add_summary(summaries, batch_idx)
            sys.stdout.flush()
            
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    parser.add_argument(
        '--epochs',
        help='Number of training epochs (default defined in setting)',
        type=int)
    parser.add_argument('--batch_size',
                        help='Batch size (default defined in setting)',
                        type=int)
    parser.add_argument(
        '--log',
        help=
        'Log to FILE in save folder; use - for stdout (default is log.txt)',
        metavar='FILE',
        default='log.txt')
    parser.add_argument('--no_timestamp_folder',
                        help='Dont save to timestamp folder',
                        action='store_true')
    parser.add_argument('--no_code_backup',
                        help='Dont backup code',
                        action='store_true')
    args = parser.parse_args()

    if not args.no_timestamp_folder:
        time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
        root_folder = os.path.join(
            args.save_folder, '%s_%s_%s_%d' %
            (args.model, args.setting, time_string, os.getpid()))
    else:
        root_folder = args.save_folder
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    if args.log != '-':
        sys.stdout = open(os.path.join(root_folder, args.log), 'w')

    #print('PID:', os.getpid())

    #print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = args.epochs or setting.num_epochs
    batch_size = args.batch_size or setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    #print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    # shuffle
    # data_num_train is the number of points in each point cloud
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    #print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    #print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    #print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')
    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup all code
    if not args.no_code_backup:
        code_folder = os.path.abspath(os.path.dirname(__file__))
        shutil.copytree(
            code_folder,
            os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        else:
            latest_ckpt = tf.train.latest_checkpoint(folder_ckpt)
            if latest_ckpt:
                print('{}-Found checkpoint {}'.format(datetime.now(),
                                                      latest_ckpt))
                saver.restore(sess, latest_ckpt)
                print('{}-Checkpoint loaded from {} (Iter {})'.format(
                    datetime.now(), latest_ckpt, sess.run(global_step)))

        for batch_idx_train in tqdm(range(batch_num)):
            ######################################################################
            # Validation

            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                for batch_val_idx in range(batch_num_val):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)
                    sess.run(
                        [
                            loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            indices:
                            pf.get_indices(batch_size_val, sample_num,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            labels_seg:
                            labels_batch,
                            labels_weights:
                            weights_batch,
                            is_training:
                            False,
                        })
                loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        summaries_val_op, global_step
                    ])
                summary_writer.add_summary(summaries_val, step)
                print(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            t_1_per_class_acc_val))
                sys.stdout.flush()
            ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                if is_list_of_h5_list:
                    filelist_train_prev = seg_list[(seg_list_idx - 1) %
                                                   len(seg_list)]
                    filelist_train = seg_list[seg_list_idx % len(seg_list)]
                    if filelist_train != filelist_train_prev:
                        data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
                            filelist_train)
                        num_train = data_train.shape[0]
                    seg_list_idx = seg_list_idx + 1
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    train_op, loss_mean_update_op, t_1_acc_update_op,
                    t_1_per_class_acc_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    True,
                })
            if batch_idx_train % 10 == 0:
                loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([
                    loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                    summaries_op, global_step
                ])
                summary_writer.add_summary(summaries, step)
                print(
                    '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), step, loss, t_1_acc,
                            t_1_per_class_acc))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
Exemple #8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', '-t', help='Path to data', required=True)
    parser.add_argument('--path_val', '-v', help='Path to validation data')
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    parser.add_argument('--gpu', '-gpu', help='Setting to use', required='0')
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    point_num = 2048
    rotation_range = setting.rotation_range
    scaling_range = setting.scaling_range
    jitter = setting.jitter
    pool_setting_train = None if not hasattr(
        setting, 'pool_setting_train') else setting.pool_setting_train

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    sys.stdout.flush()
    data_train, label_train, weight_train, box_sizes, len = data_utils.load_file(
        args.path)

    num_train = len

    print('{}-{:d} training samples.'.format(datetime.now(), len))
    sys.stdout.flush()

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    weight_train_placeholder = tf.placeholder(tf.float32,
                                              shape=(batch_size),
                                              name="weight")
    ### add weight
    data_train_placeholder = tf.placeholder(tf.float32,
                                            shape=(batch_size, point_num, 6),
                                            name='data_train')
    label_train_placeholder = tf.placeholder(tf.int64,
                                             shape=(batch_size),
                                             name='label_train')
    size_train_placeholder = tf.placeholder(tf.float32,
                                            shape=(batch_size, 1, 3),
                                            name="weight")
    ########################################################################
    batch_num_per_epoch = math.floor(num_train / batch_size)

    print('{}-{:d} training batches per_epoch.'.format(datetime.now(),
                                                       batch_num_per_epoch))
    sys.stdout.flush()

    pts_fts_sampled = tf.gather_nd(data_train_placeholder,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    net = model.Net(points=points_augmented,
                    features=features_augmented,
                    is_training=is_training,
                    setting=setting)
    #logits = net.logits
    feature = net.fc_layers[-1]

    ####
    box_size = size_train_placeholder
    #box_size = tf.expand_dims(size_train_placeholder, axis=1, name='box_size')
    box_feature = tf.layers.dense(inputs=box_size, units=20)
    feature_concat = tf.concat((feature, box_feature), 2)
    output = tf.layers.dense(inputs=feature_concat, units=256)
    logits = tf.layers.dense(inputs=output, units=100)
    ####

    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs,
                            axis=-1,
                            name='predictions',
                            output_type=tf.int32)
    predictions = tf.squeeze(predictions)

    labels_2d = tf.expand_dims(label_train_placeholder,
                               axis=-1,
                               name='labels_2d')
    labels_tile = tf.tile(labels_2d, (1, tf.shape(logits)[1]),
                          name='labels_tile')
    # loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile, logits=logits)
    weights_2d = tf.expand_dims(weight_train_placeholder,
                                axis=-1,
                                name='weights_2d')
    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile,
                                                     logits=logits,
                                                     weights=weights_2d)

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)

    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)

    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op)

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()

    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)

    train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))
    sys.stdout.flush()

    with tf.Session() as sess:
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)
        sess.run(init_op)
        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        print('total-[Train]-Iter: ', num_epochs)
        sys.stdout.flush()

        num_epochs = 1  # test mode
        dataset = 'ScanNet'
        if dataset == 'S3DIS':
            categories = [6, 8, 9, 14, 99]  # chair,board,table,sofa
        elif dataset == 'Matterport':
            categories = [3, 5, 7, 8, 11, 15, 18, 22, 25, 28]
        elif dataset == 'ScanNet':
            categories = [
                3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39
            ]
        categories = np.array(categories)

        TP = np.zeros(categories.shape[0])
        FP = np.zeros(categories.shape[0])
        FN = np.zeros(categories.shape[0])
        TN = np.zeros(categories.shape[0])
        recall = np.zeros(categories.shape[0])
        precision = np.zeros(categories.shape[0])

        for epoch_idx_train in range(num_epochs):
            print('xxxx')
            total_correct = 0
            total_seen = 0
            loss_sum = 0
            if epoch_idx_train == num_epochs - 1:
                confidences = []
                cloud_features = []
                for batch_idx_train in range(batch_num_per_epoch):
                    print('batch_idx_train', batch_idx_train)
                    index_ch = np.arange(len)
                    # do not shuttle
                    label = []
                    weight = []
                    size = []
                    dataset_train = []
                    for i in range(batch_size):
                        #print('i',i)
                        k = batch_idx_train * batch_size + i
                        label.append(label_train[index_ch[k]])
                        weight.append(weight_train[index_ch[k]])
                        size.append(box_sizes[index_ch[k]])
                        data = []
                        count = 0
                        with open(data_train[index_ch[k]]) as fpts:
                            while 1:
                                line = fpts.readline()
                                if not line:
                                    break
                                L = line.split(' ')
                                L = [float(i) for i in L]
                                data.append(np.array(L))
                                count = count + 1
                            data = np.array(data)
                            data = data[:, :6]
                            trans_x = (min(data[:, 0]) + max(data[:, 0])) / 2
                            trans_y = (min(data[:, 1]) + max(data[:, 1])) / 2
                            trans_z = (min(data[:, 2]) + max(data[:, 2])) / 2
                            data = data - [
                                trans_x, trans_y, trans_z, 0.5, 0.5, 0.5
                            ]
                            if (count >= 2048):
                                index = np.random.choice(count,
                                                         size=2048,
                                                         replace=False)
                                # index = random.sample(range(0, count), 2048)
                                dataset = data[index, :]
                            else:
                                # k = random.sample(range(0, count), count)
                                index = np.random.choice(count,
                                                         size=2048,
                                                         replace=True)
                                dataset = data[index, :]
                            dataset_train.append(dataset)
                    data_batch = np.array(dataset_train)
                    label_batch = np.array(label)
                    weight_batch = np.array(weight)
                    size_batch = np.array(size)
                    ######################################################################
                    # TESting
                    offset = int(
                        random.gauss(0,
                                     sample_num * setting.sample_num_variance))
                    offset = max(offset, -sample_num * setting.sample_num_clip)
                    offset = min(offset, sample_num * setting.sample_num_clip)
                    sample_num_train = sample_num + offset
                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size,
                        rotation_range=rotation_range,
                        scaling_range=scaling_range,
                        order=setting.rotation_order)
                    loss, prediction, confidence, cloud_feature = sess.run(
                        [loss_op, predictions, probs, feature],
                        feed_dict={
                            data_train_placeholder:
                            data_batch,
                            label_train_placeholder:
                            label_batch,
                            indices:
                            pf.get_indices(batch_size, sample_num_train,
                                           point_num, pool_setting_train),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter]),
                            is_training:
                            True,
                            weight_train_placeholder:
                            weight_batch,
                            size_train_placeholder:
                            size_batch,
                        })
                    print('confidence.shape', confidence.shape)
                    confidences.append(confidence)
                    cloud_features.append(cloud_feature)
                    correct = np.sum(prediction == label_batch)
                    total_correct += correct
                    total_seen += batch_size
                    loss_sum += loss

                    for i in range(categories.shape[0]):
                        for j in range(label_batch.shape[0]):
                            pred = prediction[j]
                            label = label_batch[j]
                            cat = categories[i]

                            if label == cat and pred == cat:
                                TP[i] += 1
                            elif label == cat and pred != cat:
                                FN[i] += 1
                            elif label != cat and pred == cat:
                                FP[i] += 1
                            elif label != cat and pred != cat:
                                TN[i] += 1

                    for i in range(categories.shape[0]):
                        recall[i] = TP[i] / (TP[i] + FN[i])
                        precision[i] = TP[i] / (TP[i] + FP[i])
                    print('precision', precision)
                    print('recall', recall)

                for i in range(categories.shape[0]):
                    recall[i] = TP[i] / (TP[i] + FN[i])
                    precision[i] = TP[i] / (TP[i] + FP[i])
                print('precision', precision)
                print('recall', recall)

                confidences = np.array(confidences).reshape((-1, 101))
                cloud_features = np.array(cloud_features)
                cloud_features = cloud_features.reshape(
                    (-1, cloud_features.shape[-1]))

                # class num :101
                np.savetxt(os.path.join(folder_summary, 'confidence.txt'),
                           confidences)
                np.savetxt(os.path.join(folder_summary, 'feature.txt'),
                           cloud_features)
                print('confidences and features saved to {}!'.format(
                    folder_summary))
                print('confidences shape is {}!'.format(confidences.shape))
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))
                print(
                    '{}-[test]-done: {:06d}  Loss: {:.4f}   Acc: {:.4f}  lr:{:.4f}'
                    .format(datetime.now(), epoch_idx_train, loss_sum,
                            (total_correct / float(total_seen)), learningrate))
                sys.stdout.flush()
            else:
                for batch_idx_train in range(batch_num_per_epoch):
                    ########################################################################
                    #sample
                    index_ch = np.arange(len)
                    np.random.shuffle(index_ch)
                    label = []
                    weight = []
                    dataset_train = []
                    size = []
                    for i in range(batch_size):
                        k = batch_idx_train * batch_size + i
                        label.append(label_train[index_ch[k]])
                        #weight.append(pow(weight_train[index_ch[k]], 2))
                        weight.append(weight_train[index_ch[k]])
                        size.append(box_sizes[index_ch[k]])
                        data = []
                        count = 0
                        with open(data_train[index_ch[k]]) as fpts:
                            while 1:
                                line = fpts.readline()
                                if not line:
                                    break
                                L = line.split(' ')
                                L = [float(i) for i in L]
                                data.append(np.array(L))
                                count = count + 1
                            data = np.array(data)
                            data = data[:, :6]
                            trans_x = (min(data[:, 0]) + max(data[:, 0])) / 2
                            trans_y = (min(data[:, 1]) + max(data[:, 1])) / 2
                            trans_z = (min(data[:, 2]) + max(data[:, 2])) / 2
                            data = data - [
                                trans_x, trans_y, trans_z, 0.5, 0.5, 0.5
                            ]
                            ######################################

                            if (count >= 2048):
                                index = np.random.choice(count,
                                                         size=2048,
                                                         replace=False)
                                dataset = data[index, :]
                            else:
                                # k = random.sample(range(0, count), count)
                                index = np.random.choice(count,
                                                         size=2048,
                                                         replace=True)
                                dataset = data[index, :]
                            dataset_train.append(dataset)
                    data_batch = np.array(dataset_train)
                    label_batch = np.array(label)
                    weight_batch = np.array(weight)
                    size_batch = np.array(size)
                    ######################################################################
                    # Training
                    offset = int(
                        random.gauss(0,
                                     sample_num * setting.sample_num_variance))
                    offset = max(offset, -sample_num * setting.sample_num_clip)
                    offset = min(offset, sample_num * setting.sample_num_clip)
                    sample_num_train = sample_num + offset
                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size,
                        rotation_range=rotation_range,
                        scaling_range=scaling_range,
                        order=setting.rotation_order)
                    _, loss, prediction, learningrate, bs, bf = sess.run(
                        [
                            train_op, loss_op, predictions, lr_clip_op,
                            box_size, box_feature
                        ],
                        feed_dict={
                            data_train_placeholder:
                            data_batch,
                            label_train_placeholder:
                            label_batch,
                            indices:
                            pf.get_indices(batch_size, sample_num_train,
                                           point_num, pool_setting_train),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter]),
                            is_training:
                            True,
                            weight_train_placeholder:
                            weight_batch,
                            size_train_placeholder:
                            size_batch,
                        })
                    correct = np.sum(prediction == label_batch)
                    total_correct += correct
                    total_seen += batch_size
                    loss_sum += loss
                    if batch_idx_train % 50 == 0 or 1:
                        print(
                            '{}-[Train]-Iter:{:06d}   batch_idx:{:06d}  Loss: {:.4f}   Acc: {:.4f}  lr:{:.4f}'
                            .format(datetime.now(), epoch_idx_train,
                                    batch_idx_train, loss,
                                    (total_correct / float(total_seen)),
                                    learningrate))
                        sys.stdout.flush()
                print(
                    '{}-[Train]-Iter: {:06d}  Loss: {:.4f}   Acc: {:.4f}  lr:{:.4f}'
                    .format(datetime.now(), epoch_idx_train, loss_sum,
                            (total_correct / float(total_seen)), learningrate))
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))
                sys.stdout.flush()

                ####################################################################

        print('{}-Done!'.format(datetime.now()))
Exemple #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load')
    parser.add_argument('--save_folder', '-s',default='models', help='Path to folder for saving check points and summary')
    parser.add_argument('--model', '-m', default='pointcnn_cls',help='Model to use')
    parser.add_argument('--setting', '-x',default='ScanObjectNN_x3_l4', help='Setting to use')
    parser.add_argument('--epochs', help='Number of training epochs (default defined in setting)', type=int)
    parser.add_argument('--batch_size', help='Batch size (default defined in setting)', type=int)
    parser.add_argument('--log', help='Log to FILE in save folder; use - for stdout (default is log.txt)', metavar='FILE', default='log.txt')#default='log.txt'输出存入文档//default='-'显示输出
    parser.add_argument('--no_timestamp_folder', help='Dont save to timestamp folder', action='store_true')
    parser.add_argument('--no_code_backup', help='Dont backup code', action='store_true',default='True')
    parser.add_argument('--err_data', type=float, default=5696)
    parser.add_argument('--weight', type=float, default=10)
    args = parser.parse_args()
    
    if not args.no_timestamp_folder:
        time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
        root_folder = os.path.join(args.save_folder, '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    else:
        root_folder = args.save_folder
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    if args.log != '-':
        sys.stdout = open(os.path.join(root_folder, args.log), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = args.epochs or setting.num_epochs
    batch_size = args.batch_size or setting.batch_size
    sample_num = setting.sample_num#1024
    step_val = setting.step_val
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val
    pool_setting_val = None if not hasattr(setting, 'pool_setting_val') else setting.pool_setting_val
    pool_setting_train = None if not hasattr(setting, 'pool_setting_train') else setting.pool_setting_train
    ERR = args.err_data
    WEIGHT = args.weight

    
    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    if setting.data_set == 'ModelNet':
        prepare_data.make_ModelNet_data_A_B(ERR,batch_size)
    if setting.data_set == 'ScanObjectNN':
        prepare_data.make_ScanObjectNN_data_batch_A_B(ERR,batch_size)

    data_train, label_train, data_val, label_val = setting.load_fn(setting.path, setting.path_val)
    
    #data_train:[2*11416,2048,3],label_train:[11416,],data_val:[2882,2048,3],label_val:[2882,]
    if setting.balance_fn is not None:#None
        num_train_before_balance = data_train.shape[0]
        repeat_num = setting.balance_fn(label_train)
        data_train = np.repeat(data_train, repeat_num, axis=0)
        label_train = np.repeat(label_train, repeat_num, axis=0)
        data_train, label_train = data_utils.grouped_shuffle([data_train, label_train])
        num_epochs = math.floor(num_epochs * (num_train_before_balance / data_train.shape[0]))

    if setting.save_ply_fn is not None:#None
        folder = os.path.join(root_folder, 'pts')
        print('{}-Saving samples as .ply files to {}...'.format(datetime.now(), folder))
        sample_num_for_ply = min(512, data_train.shape[0])
        if setting.map_fn is None:
            data_sample = data_train[:sample_num_for_ply]
        else:
            data_sample_list = []
            for idx in range(sample_num_for_ply):
                data_sample_list.append(setting.map_fn(data_train[idx], 0)[0])
            data_sample = np.stack(data_sample_list)
        setting.save_ply_fn(data_sample, folder)

    num_train = data_train.shape[0]#11416
    point_num = data_train.shape[1]#2048
    num_val = data_val.shape[0]#2882

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    data_train_placeholder = tf.placeholder(data_train.dtype, data_train.shape, name='data_train')
    label_train_placeholder = tf.placeholder(tf.int64, label_train.shape, name='label_train')
    data_val_placeholder = tf.placeholder(data_val.dtype, data_val.shape, name='data_val')
    label_val_placeholder = tf.placeholder(tf.int64, label_val.shape, name='label_val')
    handle = tf.placeholder(tf.string, shape=[], name='handle')

    ######################################################################
    dataset_train = tf.data.Dataset.from_tensor_slices((data_train_placeholder, label_train_placeholder))
    #dataset_train = dataset_train.shuffle(buffer_size=batch_size * 4)

    if setting.map_fn is not None:#None
        dataset_train = dataset_train.map(lambda data, label:
                                          tuple(tf.py_func(setting.map_fn, [data, label], [tf.float32, label.dtype])),
                                          num_parallel_calls=setting.num_parallel_calls)

    if setting.keep_remainder:#Ture
        dataset_train = dataset_train.batch(batch_size)
        batch_num_per_epoch = math.ceil(num_train / batch_size)
    else:
        dataset_train = dataset_train.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
        batch_num_per_epoch = math.floor(num_train / batch_size)
    dataset_train = dataset_train.repeat(num_epochs)
    iterator_train = dataset_train.make_initializable_iterator()
    batch_num = batch_num_per_epoch * num_epochs
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))

    dataset_val = tf.data.Dataset.from_tensor_slices((data_val_placeholder, label_val_placeholder))
    if setting.map_fn is not None:#None
        dataset_val = dataset_val.map(lambda data, label: tuple(tf.py_func(
            setting.map_fn, [data, label], [tf.float32, label.dtype])), num_parallel_calls=setting.num_parallel_calls)
    if setting.keep_remainder:#Ture
        dataset_val = dataset_val.batch(batch_size)
        batch_num_val = math.ceil(num_val / batch_size)
    else:
        dataset_val = dataset_val.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
        batch_num_val = math.floor(num_val / batch_size)
    iterator_val = dataset_val.make_initializable_iterator()
    print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val))

    iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types)
    (pts_fts, labels) = iterator.get_next()

    pts_fts_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled')
    features_augmented = None
    
    points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)
 
    net = model.Net(points=points_augmented, features=features_augmented, is_training=is_training, setting=setting)
    logits = net.logits
    feature_A = net.feature_list_A
    feature_B = net.feature_list_B
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    labels_2d = tf.expand_dims(labels, axis=-1, name='labels_2d')
    labels_tile = tf.tile(labels_2d, (1, tf.shape(logits)[1]), name='labels_tile')
        
    
    # --------------------------------------------------------------------
    # compute the Loss of DINet
    different = tf.square(tf.subtract(feature_A,feature_B))
    different = tf.reduce_sum(different,1)
    
    label_A = tf.to_float(labels[0:tf.to_int32(batch_size/2)])
    label_B = tf.to_float(labels[tf.to_int32(batch_size/2):tf.to_int32(batch_size)])
    f_same = tf.multiply(tf.add(tf.sign(-tf.abs(tf.subtract(label_A,label_B))),1),different)
    f_diff = tf.divide(tf.sign(tf.abs(tf.subtract(label_A,label_B))),different)
    f_same = tf.reduce_sum(f_same,0)
    f_diff = tf.reduce_sum(f_diff,0)
    f_loss = tf.add(f_diff,f_same)
    # --------------------------------------------------------------------
    
    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile, logits=logits) + ((f_loss/(61440*batch_size/2)) * WEIGHT)#10

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(labels_tile, predictions)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = tf.metrics.mean_per_class_accuracy(labels_tile,
                                                                                               predictions,
                                                                                               setting.num_class)
       
    reset_metrics_op = tf.variables_initializer([var for var in tf.local_variables()
                                                 if var.name.split('/')[0] == 'metrics'])

    _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,
                                           setting.decay_rate, staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup all code
    if not args.no_code_backup:
        code_folder = os.path.abspath(os.path.dirname(__file__))
        shutil.copytree(code_folder, os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))
        else:
            latest_ckpt = tf.train.latest_checkpoint(folder_ckpt)
            if latest_ckpt:
                print('{}-Found checkpoint {}'.format(datetime.now(), latest_ckpt))
                saver.restore(sess, latest_ckpt)
                print('{}-Checkpoint loaded from {} (Iter {})'.format(
                    datetime.now(), latest_ckpt, sess.run(global_step)))

        handle_train = sess.run(iterator_train.string_handle())
        handle_val = sess.run(iterator_val.string_handle())

        sess.run(iterator_train.initializer, feed_dict={
            data_train_placeholder: data_train,
            label_train_placeholder: label_train,
        })

        for batch_idx_train in range(batch_num):
            ######################################################################
            # Validation
            
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                sess.run(iterator_val.initializer, feed_dict={
                    data_val_placeholder: data_val,
                    label_val_placeholder: label_val,
                })
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                total_seen_class = [0 for _ in range(setting.num_class)]
                total_correct_class = [0 for _ in range(setting.num_class)]
                for batch_idx_val in range(batch_num_val):
                    if not setting.keep_remainder \
                            or num_val % batch_size == 0 \
                            or batch_idx_val != batch_num_val - 1:
                        batch_size_val = batch_size
                    else:
                        batch_size_val = num_val % batch_size
                    xforms_np, rotations_np = pf.get_xforms(batch_size_val,
                                                            rotation_range=rotation_range_val,
                                                            scaling_range=scaling_range_val,
                                                            order=setting.rotation_order)
                    lll, ppp, _, _ = sess.run([labels, predictions, t_1_acc_update_op, t_1_per_class_acc_update_op],
                             feed_dict={
                                 handle: handle_val,
                                 indices: pf.get_indices(batch_size_val, sample_num, point_num,
                                                         ),
                                 xforms: xforms_np,
                                 rotations: rotations_np,
                                 jitter_range: np.array([jitter_val]),
                                 is_training: False,
                             })
                    ppp = np.reshape(ppp,[-1,])
                    for i in range(len(lll)):
                        l = lll[i]
                        total_seen_class[l] +=1
                        total_correct_class[l] +=(ppp[i] == l)
                    
                t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run(
                    [t_1_acc_op, t_1_per_class_acc_op, summaries_val_op, global_step])
                summary_writer.add_summary(summaries_val, step)
                print('{}-[Val  ]-Average:     T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                      .format(datetime.now(),  t_1_acc_val, t_1_per_class_acc_val))
                print('every class acc:',np.array(total_correct_class)/np.array(total_seen_class,dtype = np.float))
                sys.stdout.flush()
                
            ######################################################################

            ######################################################################
            # Training
            if not setting.keep_remainder \
                    or num_train % batch_size == 0 \
                    or (batch_idx_train % batch_num_per_epoch) != (batch_num_per_epoch - 1):
                batch_size_train = batch_size
            else:
                batch_size_train = num_train % batch_size

            offset = int(random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(batch_size_train,
                                                    rotation_range=rotation_range,
                                                    scaling_range=scaling_range,
                                                    order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run([train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op],
                     feed_dict={
                         handle: handle_train,
                         indices: pf.get_indices(batch_size_train, sample_num_train, point_num, pool_setting_train),
                         xforms: xforms_np,
                         rotations: rotations_np,
                         jitter_range: np.array([jitter]),
                         is_training: True,
                     })
            if batch_idx_train % 10 == 0:
                loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([loss_mean_op,
                                                                        t_1_acc_op,
                                                                        t_1_per_class_acc_op,
                                                                        summaries_op,
                                                                        global_step])
                summary_writer.add_summary(summaries, step)
                print('{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                      .format(datetime.now(), step, loss, t_1_acc, t_1_per_class_acc))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
Exemple #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir_bin',
                        '-i',
                        help='Path to binary files dir (*.npy)',
                        required=True)
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=False)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    # sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    dir_bin = args.dir_bin
    path_filelist_train = args.filelist
    path_filelist_val = args.filelist_val
    list_fru_train, max_point_num_train = data_utils.load_bin_all(
        dir_bin, path_filelist_train)
    if path_filelist_val is None:
        print("train with no val data")
        list_fru_val = list_fru_train[0:100]
        max_point_num_val = 0
    else:
        list_fru_val, max_point_num_val = data_utils.load_bin_all(
            dir_bin, path_filelist_val)
    max_point_num = max(max_point_num_train, max_point_num_val)

    # shuffle
    random.shuffle(list_fru_train)

    num_train = len(list_fru_train)
    num_val = len(list_fru_val)
    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = int(math.ceil(num_val / batch_size))
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    max_sample_num = sample_num + sample_num * setting.sample_num_clip
    pts_fts_sampled = tf.placeholder(tf.float32,
                                     shape=(None, max_sample_num,
                                            setting.data_dim),
                                     name='pts_fts')
    labels_sampled = tf.placeholder(tf.int64,
                                    shape=(None, max_sample_num),
                                    name='labels_seg')
    labels_weights_sampled = tf.placeholder(tf.float32,
                                            shape=(None, max_sample_num),
                                            name='labels_weights')

    ######################################################################
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)
    # points_augmented = points_sampled

    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
        t_1_mean_iou_op, t_1_mean_iou_update_op = \
            tf.metrics.mean_iou(labels_sampled, predictions, setting.num_class, labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_mean_iou/train',
                          tensor=t_1_mean_iou_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_mean_iou/val',
                          tensor=t_1_mean_iou_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    else:
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)  # adam
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup all code
    code_folder = os.path.abspath(os.path.dirname(__file__))
    shutil.copytree(code_folder,
                    os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))

        for batch_idx_train in range(batch_num):
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                for batch_val_idx in range(batch_num_val):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    fru_batch = list_fru_train[start_idx:end_idx]

                    points_batch_sampled, labels_batch_sampled, weights_batch_sampled = \
                        df_utils.group_sampling_fru(fru_batch, sample_num, label_weights_list)

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)
                    sess.run(
                        [
                            loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op, t_1_mean_iou_update_op
                        ],
                        feed_dict={
                            pts_fts_sampled: points_batch_sampled,
                            xforms: xforms_np,
                            rotations: rotations_np,
                            jitter_range: np.array([jitter_val]),
                            labels_sampled: labels_batch_sampled,
                            labels_weights_sampled: weights_batch_sampled,
                            is_training: False,
                        })

                loss_val, t_1_acc_val, t_1_per_class_acc_val, t_1_mean_iou_val, summaries_val = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        t_1_mean_iou_op, summaries_val_op
                    ])
                summary_writer.add_summary(summaries_val, batch_idx_train)
                print(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}  T-1 mIOU: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            t_1_per_class_acc_val, t_1_mean_iou_val))
                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            fru_batch = list_fru_train[start_idx:end_idx]

            if start_idx + batch_size_train == num_train:
                random.shuffle(list_fru_train)

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            points_batch_sampled, labels_batch_sampled, weights_batch_sampled = \
                df_utils.group_sampling_fru(fru_batch, sample_num_train, label_weights_list)

            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    train_op, loss_mean_update_op, t_1_acc_update_op,
                    t_1_per_class_acc_update_op, t_1_mean_iou_update_op
                ],
                feed_dict={
                    pts_fts_sampled: points_batch_sampled,
                    xforms: xforms_np,
                    rotations: rotations_np,
                    jitter_range: np.array([jitter]),
                    labels_sampled: labels_batch_sampled,
                    labels_weights_sampled: weights_batch_sampled,
                    is_training: True,
                })
            if batch_idx_train % 10 == 0:
                loss, t_1_acc, t_1_per_class_acc, t_1_mean_iou, summaries = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        t_1_mean_iou_op, summaries_op
                    ])
                summary_writer.add_summary(summaries, batch_idx_train)
                print(
                    '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}  T-1 mIOU: {:.4f}'
                    .format(datetime.now(), batch_idx_train, loss, t_1_acc,
                            t_1_per_class_acc, t_1_mean_iou))
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
Exemple #11
0
def main():
    modelname = "pointcnn_cls"
    settingname = "mynet_2048"
    load_ckpt = "../models/ckpts/iter-5007"
    temp_dir = "../haimai/tmp" # 存放临时文件的位置
    if(os.path.exists(temp_dir) == False):
        os.makedirs(temp_dir)
        
    temp_frame = os.path.join(temp_dir, "cur_frame.pcd") # 存放当前帧的位置
    temp_isend = os.path.join(temp_dir, "isend")
    temp_c_writing = os.path.join(temp_dir, "c_is_writing")

    
    setting_path = os.path.join(os.path.dirname(__file__), modelname)
    sys.path.append(setting_path)
    setting = importlib.import_module(settingname) ### 载入超参数
    
    batch_size_val = 1 # 一次只预测一帧点云
    data_frame = np.zeros((batch_size_val, setting.sample_num, 3), dtype=np.float32)
    
    ### Placeholders
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    is_training = tf.placeholder(tf.bool, name='is_training')
    data_frame_placeholder = tf.placeholder(data_frame.dtype, data_frame.shape, name="data_frame")    
    
    model = importlib.import_module(modelname) # 载入pointcnn
    points_augmented = pf.augment(data_frame_placeholder, xforms, jitter_range)
    net = model.Net(points=points_augmented, features=None, is_training=is_training, setting=setting) # 构建网络,暂时不使用其他特征
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predict = tf.argmax(probs, axis=-1, name='predictions') # 网络的输出
    
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    saver = tf.train.Saver(max_to_keep=None)
    with tf.Session() as sess:
        sess.run(init_op)
        saver.restore(sess, load_ckpt) # Load the model
        xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=setting.rotation_range_val,
                                                                scaling_range=setting.scaling_range_val,
                                                                order=setting.rotation_order)
        ### 循环进行单帧的标签预测
        while (os.path.exists(temp_isend) == False): # 非最后一帧
            if (os.path.exists(temp_frame) == True): # 点云帧
                while (os.path.exists(temp_c_writing) == True): # C++ is writing PCD
                    time.sleep(0.05) 
                data_frame[0, ...] = load_pcd(temp_frame) # 读取点云
                label = sess.run(predict, feed_dict={data_frame_placeholder: data_frame,
                                                            xforms: xforms_np,
                                                            rotations: rotations_np,
                                                            jitter_range: np.array([setting.jitter_val]),
                                                            is_training: False,
                                                            })
                ### 传递预测结果给mapping
                label_file = open(temp_dir + "/label_" + str(label[0][0]), 'w')
                label_file.close()
                
                if delete_file(temp_frame) == False:
                    print("Delete frame error.")
                    return     
    delete_file(temp_isend) # 删除通信文件
    return 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir_train', '-t', help='Path to dir of train set', required=True)
    parser.add_argument('--dir_val', '-v', help='Path to dir of val set', required=False)
    parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load')
    parser.add_argument('--save_folder', '-s', help='Path to folder for saving check points and summary', required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting', '-x', help='Setting to use', required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(args.save_folder, '%s_%s_%d_%s' % (args.model, args.setting, os.getpid(), time_string))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    print(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    num_parts = setting.num_parts
    label_weights_list = setting.label_weights
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))

    if args.dir_val is None:
        print("only load train")
        index_length_train, points_ele_train, intensities_train, labels_train, _ = data_utils.load_bin(args.dir_train)
        index_length_val = index_length_train
        points_ele_val = points_ele_train
        intensities_val = intensities_train
        labels_val = labels_train
    else:
        print("load train and val")
        index_length_train, points_ele_train, intensities_train, labels_train, _ = data_utils.load_bin(args.dir_train)
        index_length_val, points_ele_val, intensities_val, labels_val, _ = data_utils.load_bin(args.dir_val)

    # shuffle
    index_length_train = data_utils.index_shuffle(index_length_train)
    index_length_val = data_utils.index_shuffle(index_length_val)

    num_train = index_length_train.shape[0]
    point_num = max(np.max(index_length_train[:, 1]), np.max(index_length_val[:, 1]))
    num_val = index_length_val.shape[0]

    print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts = tf.placeholder(tf.float32, shape=(None, point_num, setting.point_dim), name='pts')
    fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.extra_dim), name='fts')
    labels_seg = tf.placeholder(tf.int32, shape=(None, point_num), name='labels_seg')
    labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights')

    ######################################################################

    # Set Inputs(points,features_sampled)
    features_sampled = None

    if setting.extra_dim == 1:
        points = pts
        features = fts

        if setting.use_extra_features:
            features_sampled = tf.gather_nd(features, indices=indices, name='features_sampled')

    elif setting.extra_dim == 0:
        points = pts

    points_sampled = tf.gather_nd(points, indices=indices, name='points_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled')

    # Build net
    net = model.Net(points_augmented, features_sampled, None, None, num_parts, is_training, setting)
    logits, probs = net.logits, net.probs

    # Define Loss Func
    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_sampled, logits=logits,
                                                     weights=labels_weights_sampled)
    _ = tf.summary.scalar('loss/train_seg', tensor=loss_op, collections=['train'])

    # for vis t1 acc
    t_1_acc_op = pf.top_1_accuracy(probs, labels_sampled)
    _ = tf.summary.scalar('t_1_acc/train_seg', tensor=t_1_acc_op, collections=['train'])
    # for vis instance acc
    t_1_acc_instance_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6)
    _ = tf.summary.scalar('t_1_acc/train_seg_instance', tensor=t_1_acc_instance_op, collections=['train'])
    # for vis other acc
    t_1_acc_others_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6, "less")
    _ = tf.summary.scalar('t_1_acc/train_seg_others', tensor=t_1_acc_others_op, collections=['train'])

    loss_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('loss/val_seg', tensor=loss_val_avg, collections=['val'])

    t_1_acc_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg', tensor=t_1_acc_val_avg, collections=['val'])
    t_1_acc_val_instance_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg_instance', tensor=t_1_acc_val_instance_avg, collections=['val'])
    t_1_acc_val_others_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg_others', tensor=t_1_acc_val_others_avg, collections=['val'])

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()

    # lr decay
    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,
                                           setting.decay_rate, staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])

    # Optimizer
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=0.9, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):

        train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup this file, model and setting
    if not os.path.exists(os.path.join(root_folder, args.model)):
        os.makedirs(os.path.join(root_folder, args.model))
    shutil.copy(__file__, os.path.join(root_folder, os.path.basename(__file__)))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'),
                os.path.join(root_folder, args.model + '.py'))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model.split("_")[0] + "_kitti" + '.py'),
                os.path.join(root_folder, args.model.split("_")[0] + "_kitti" + '.py'))
    shutil.copy(os.path.join(setting_path, args.setting + '.py'),
                os.path.join(root_folder, args.model, args.setting + '.py'))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    # Session Run
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))

        for batch_idx in range(batch_num):
            if (batch_idx != 0 and batch_idx % step_val == 0) or batch_idx == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))

                losses_val = []
                t_1_accs = []
                t_1_accs_instance = []
                t_1_accs_others = []

                for batch_val_idx in range(math.ceil(num_val / batch_size)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    index_length_val_batch = index_length_val[start_idx:end_idx]

                    points_batch = np.zeros((batch_size_val, point_num, 3), np.float32)
                    intensity_batch = np.zeros((batch_size_val, point_num, 1), np.float32)
                    points_num_batch = np.zeros(batch_size_val, np.int32)
                    labels_batch = np.zeros((batch_size_val, point_num), np.int32)

                    for i, index_length in enumerate(index_length_val_batch):
                        points_batch[i, 0:index_length[1], :] = \
                            points_ele_val[index_length[0] * 3:
                                           index_length[0] * 3 + index_length[1] * 3].reshape(index_length[1], 3)

                        intensity_batch[i, 0:index_length[1], :] = \
                            intensities_val[index_length[0]:
                                            index_length[0] + index_length[1]].reshape(index_length[1], 1)

                        points_num_batch[i] = index_length[1].astype(np.int32)

                        labels_batch[i, 0:index_length[1]] = \
                            labels_val[index_length[0]:index_length[0] + index_length[1]].astype(np.int32)

                    weights_batch = np.array(label_weights_list)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(batch_size_val, scaling_range=scaling_range_val)

                    sess_op_list = [loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op]

                    sess_feed_dict = {pts: points_batch,
                                      fts: intensity_batch,
                                      indices: pf.get_indices(batch_size_val, sample_num, points_num_batch),
                                      xforms: xforms_np,
                                      rotations: rotations_np,
                                      jitter_range: np.array([jitter_val]),
                                      labels_seg: labels_batch,
                                      labels_weights: weights_batch,
                                      is_training: False}

                    loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others = sess.run(sess_op_list,
                                                                                               feed_dict=sess_feed_dict)
                    print('{}-[Val  ]-Iter: {:06d}  Loss: {:.4f} T-1 Acc: {:.4f}'.format(datetime.now(), batch_val_idx,
                                                                                         loss_val, t_1_acc_val))

                    losses_val.append(loss_val * batch_size_val)
                    t_1_accs.append(t_1_acc_val * batch_size_val)
                    t_1_accs_instance.append(t_1_acc_val_instance * batch_size_val)
                    t_1_accs_others.append(t_1_acc_val_others * batch_size_val)

                loss_avg = sum(losses_val) / num_val
                t_1_acc_avg = sum(t_1_accs) / num_val
                t_1_acc_instance_avg = sum(t_1_accs_instance) / num_val
                t_1_acc_others_avg = sum(t_1_accs_others) / num_val

                summaries_feed_dict = {loss_val_avg: loss_avg,
                                       t_1_acc_val_avg: t_1_acc_avg,
                                       t_1_acc_val_instance_avg: t_1_acc_instance_avg,
                                       t_1_acc_val_others_avg: t_1_acc_others_avg}

                summaries_val = sess.run(summaries_val_op, feed_dict=summaries_feed_dict)
                summary_writer.add_summary(summaries_val, batch_idx)

                print('{}-[Val  ]-Average:      Loss: {:.4f} T-1 Acc: {:.4f}'
                      .format(datetime.now(), loss_avg, t_1_acc_avg))

                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx

            index_length_train_batch = index_length_train[start_idx:end_idx]
            points_batch = np.zeros((batch_size_train, point_num, 3), np.float32)
            intensity_batch = np.zeros((batch_size_train, point_num, 1), np.float32)
            points_num_batch = np.zeros(batch_size_train, np.int32)
            labels_batch = np.zeros((batch_size_train, point_num), np.int32)

            for i, index_length in enumerate(index_length_train_batch):
                points_batch[i, 0:index_length[1], :] = \
                    points_ele_train[index_length[0] * 3:
                                     index_length[0] * 3 + index_length[1] * 3].reshape(index_length[1], 3)

                intensity_batch[i, 0:index_length[1], :] = \
                    intensities_train[index_length[0]:
                                      index_length[0] + index_length[1]].reshape(index_length[1], 1)

                points_num_batch[i] = index_length[1].astype(np.int32)

                labels_batch[i, 0:index_length[1]] = \
                    labels_train[index_length[0]:index_length[0] + index_length[1]].astype(np.int32)

            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                index_length_train = data_utils.index_shuffle(index_length_train)

            offset = int(random.gauss(0, sample_num // 8))
            offset = max(offset, -sample_num // 4)
            offset = min(offset, sample_num // 4)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(batch_size_train, scaling_range=scaling_range)

            sess_op_list = [train_op, loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op, summaries_op]

            sess_feed_dict = {pts: points_batch,
                              fts: intensity_batch,
                              indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch),
                              xforms: xforms_np,
                              rotations: rotations_np,
                              jitter_range: np.array([jitter]),
                              labels_seg: labels_batch,
                              labels_weights: weights_batch,
                              is_training: True}

            _, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries = sess.run(sess_op_list,
                                                                                     feed_dict=sess_feed_dict)
            print('{}-[Train]-Iter: {:06d}  Loss_seg: {:.4f} T-1 Acc: {:.4f}'
                  .format(datetime.now(), batch_idx, loss, t_1_acc))

            summary_writer.add_summary(summaries, batch_idx)

            ######################################################################
        print('{}-Done!'.format(datetime.now()))
Exemple #13
0
    def load_model(self):
        with tf.Graph().as_default():
            points_pl = tf.compat.v1.placeholder(tf.float32, [
                self.train_config.batch_size, self.data_config.num_points,
                self.input_dim
            ])

            labels_pl = tf.compat.v1.placeholder(
                tf.int32, shape=(self.train_config.batch_size))
            is_training_pl = tf.compat.v1.placeholder(tf.bool)

            one_hot_labels = tf.one_hot(indices=labels_pl,
                                        depth=self.data_config.num_classes)

            global_step = tf.compat.v1.train.get_or_create_global_step()
            batch = tf.Variable(0)

            bn_decay = tf_util.get_bn_decay(
                global_step,  #batch
                self.model_config.bn_init_decay,
                self.train_config.batch_size,
                self.model_config.bn_decay_decay_step,
                self.model_config.bn_decay_decay_rate,
                self.model_config.bn_decay_clip)

            learning_rate = tf_util.get_learning_rate(
                global_step,  #batch
                self.model_config.base_learning_rate,
                self.train_config.batch_size,
                self.model_config.decay_step,
                self.model_config.decay_rate)

            print('**** Model selected  -> {} ****\n'.format(self.model))

            if self.model == "3DmFV":
                w_pl = tf.compat.v1.placeholder(
                    tf.float32, shape=(self.gmm.means_.shape[0]))
                mu_pl = tf.compat.v1.placeholder(
                    tf.float32,
                    shape=(self.gmm.means_.shape[0], self.gmm.means_.shape[1]))
                sigma_pl = tf.compat.v1.placeholder(
                    tf.float32,
                    shape=(self.gmm.means_.shape[0], self.gmm.means_.shape[1]))

                logits, end_points = fv_model.get_model(
                    points_pl,
                    w_pl,
                    mu_pl,
                    sigma_pl,
                    is_training_pl,
                    bn_decay=bn_decay,
                    weigth_decay=self.model_config.weight_decay,
                    add_noise=self.model_config.add_gaussian_noise,
                    num_classes=self.data_config.num_classes)
                total_loss = fv_model.get_loss(logits, labels_pl)

            elif self.model == "DGCNN" or self.model == "DGCNNC":
                logits, end_points = dgcnn.get_model(
                    points_pl,
                    self.data_config.num_classes,
                    is_training_pl,
                    bn_decay=bn_decay,
                    color=True if self.model == "DGCNNC" else False)
                total_loss = dgcnn.get_loss(
                    logits,
                    labels_pl,
                    num_classes=self.data_config.num_classes)

            elif self.model == "SpiderCNN":
                logits, end_points = spidercnn.get_model(
                    points_pl,
                    is_training_pl,
                    bn_decay=bn_decay,
                    num_class=self.data_config.num_classes)
                total_loss = spidercnn.get_loss(logits, labels_pl)

            elif self.model == "PointNet":
                logits, end_points = pointnet.get_model(
                    points_pl,
                    is_training_pl,
                    bn_decay=bn_decay,
                    num_class=self.data_config.num_classes)
                total_loss = pointnet.get_loss(logits, labels_pl, end_points)

            elif self.model == "PointNet2":
                logits, end_points = pointnet2_ssg.get_model(
                    points_pl,
                    is_training_pl,
                    bn_decay=bn_decay,
                    num_class=self.data_config.num_classes)
                total_loss = pointnet2_ssg.get_loss(logits, labels_pl,
                                                    end_points)

            elif self.model == "PointCNN":

                xforms = tf.compat.v1.placeholder(tf.float32,
                                                  shape=(None, 3, 3),
                                                  name="xforms")
                rotations = tf.compat.v1.placeholder(tf.float32,
                                                     shape=(None, 3, 3),
                                                     name="rotations")
                jitter_range = tf.compat.v1.placeholder(tf.float32,
                                                        shape=(1),
                                                        name="jitter_range")
                points_augmented = pf.augment(points_pl, xforms, jitter_range)
                net = Net(points=points_augmented,
                          features=None,
                          is_training=is_training_pl,
                          setting=setting)

            if self.is_train:
                variables_to_restore = slim.get_variables_to_restore(
                    exclude=self.exclude_blocks)
            else:
                variables_to_restore = slim.get_variables_to_restore()

            if self.model == "PointCNN":
                total_loss, end_points = net.get_loss(labels_pl)
                probabilities = end_points['Probabilities']
                predictions = tf.argmax(probabilities,
                                        axis=-1,
                                        name='predictions')
                correct = tf.equal(predictions, tf.to_int64(labels_pl))

                with tf.name_scope('metrics'):
                    loss_mean_op, loss_mean_update_op = tf.compat.v1.metrics.mean(
                        total_loss)
                    accuracy, update_op = tf.compat.v1.metrics.accuracy(
                        end_points['labels_tile'], predictions)
                    t_1_per_class_acc_op, t_1_per_class_acc_update_op = tf.compat.v1.metrics.mean_per_class_accuracy(
                        labels_pl, predictions, self.data_config.num_classes)
                reset_metrics_op = tf.variables_initializer([
                    var for var in tf.local_variables()
                    if var.name.split('/')[0] == 'metrics'
                ])
                metrics_op = tf.group(update_op, probabilities)

                bn_decay = tf.compat.v1.train.exponential_decay(
                    setting.learning_rate_base,
                    global_step,
                    setting.decay_steps,
                    setting.decay_rate,
                    staircase=True)

                learning_rate = tf.maximum(bn_decay, setting.learning_rate_min)
                tf.summary.scalar('learning_rate',
                                  tensor=learning_rate,
                                  collections=['train'])
                reg_loss = setting.weight_decay * tf.compat.v1.losses.get_regularization_loss(
                )
                optimizer = tf.compat.v1.train.AdamOptimizer(
                    learning_rate=learning_rate, epsilon=setting.epsilon)

                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                with tf.control_dependencies(update_ops):
                    train_op = optimizer.minimize(total_loss + reg_loss,
                                                  global_step=global_step)
            else:
                #predictions that is not one_hot_encoded.
                probabilities = end_points['Probabilities']
                predictions = tf.argmax(probabilities, 1)
                correct = tf.equal(predictions, tf.to_int64(labels_pl))
                accuracy, update_op = tf.compat.v1.metrics.accuracy(
                    labels_pl, predictions)
                metrics_op = tf.group(update_op, probabilities)
                optimizer = tf.compat.v1.train.AdamOptimizer(
                    learning_rate=learning_rate)
                train_op = optimizer.minimize(total_loss,
                                              global_step=global_step)
                tf.compat.v1.summary.scalar('learning_rate', learning_rate)

            tf.compat.v1.summary.scalar('bn_decay', bn_decay)
            tf.compat.v1.summary.scalar('loss', total_loss)
            tf.compat.v1.summary.scalar('accuracy', accuracy)
            summary_op = tf.compat.v1.summary.merge_all()

            # Add ops to save and restore all the variables.
            saver = tf.compat.v1.train.Saver(variables_to_restore)

            self.ops = {
                'labels_pl':
                labels_pl,
                'points_pl':
                points_pl,
                'w_pl':
                w_pl if self.model == "3DmFV" else tf.compat.v1.placeholder(
                    tf.float16, shape=(1)),
                'mu_pl':
                mu_pl if self.model == "3DmFV" else tf.compat.v1.placeholder(
                    tf.float16, shape=(1)),
                'sigma_pl':
                sigma_pl if self.model == "3DmFV" else
                tf.compat.v1.placeholder(tf.float16, shape=(1)),
                'is_training_pl':
                is_training_pl,
                'loss':
                total_loss,
                'train_op':
                train_op,
                'summary_op':
                summary_op,
                'metrics_op':
                metrics_op,
                'predictions':
                predictions,
                'probabilities':
                probabilities,
                'step':
                batch,
                'global_step':
                global_step,
                'accuracy':
                accuracy,
                'correct':
                correct,
                'xforms':
                xforms if self.model == "PointCNN" else
                tf.compat.v1.placeholder(tf.float16, shape=(1)),
                'rotations':
                rotations if self.model == "PointCNN" else
                tf.compat.v1.placeholder(tf.float16, shape=(1)),
                'jitter_range':
                jitter_range if self.model == "PointCNN" else
                tf.compat.v1.placeholder(tf.float16, shape=(1)),
            }

            def restore_fn(sess):
                if self.is_train:
                    if self.checkpoint_file is not None:
                        return saver.restore(sess, self.checkpoint_file)
                    else:
                        return None
                else:
                    return saver.restore(sess, self.checkpoint_file)

            if self.is_train:
                self.sv = tf.train.Supervisor(logdir=self.train_logdir,
                                              summary_op=None,
                                              init_fn=restore_fn)
            else:
                self.sv = tf.train.Supervisor(logdir=self.test_logdir,
                                              saver=None,
                                              summary_op=None,
                                              init_fn=restore_fn)

            sess_config = tf_util.get_sess_conf(
                self.train_config.gpu_selection,
                limit_gpu=self.train_config.limit_gpu)
            self.sv.PrepareSession(config=sess_config)
            with self.sv.managed_session() as self.sess:
                if self.is_train:
                    self.train()
                else:
                    self.evaluate(export=True)
Exemple #14
0
features_augmented = None

# In[8]:

if setting.data_dim > 3:
    points_sampled, features_sampled = tf.split(pts_fts_sampled,
                                                [3, setting.data_dim - 3],
                                                axis=-1,
                                                name='split_points_features')
    if setting.use_extra_features:
        if setting.with_normal_feature:
            if setting.data_dim < 6:
                print('Only 3D normals are supported!')
                exit()
            elif setting.data_dim == 6:
                features_augmented = pf.augment(features_sampled, rotations)
            else:
                normals, rest = tf.split(features_sampled,
                                         [3, setting.data_dim - 6])
                normals_augmented = pf.augment(normals, rotations)
                features_augmented = tf.concat([normals_augmented, rest],
                                               axis=-1)
        else:
            features_augmented = features_sampled
else:
    points_sampled = pts_fts_sampled
points_augmented = pf.augment(points_sampled, xforms, jitter_range)

# In[9]:

labels_sampled = tf.gather_nd(labels_seg,
Exemple #15
0
def main():

    # sys.stdout = open("./log_test_mynet.txt", 'w')
    modelname = "pointcnn_cls"
    # 载入pointcnn
    model = importlib.import_module(modelname)
    # 载入超参数
    setting_path = os.path.join(os.path.dirname(__file__), modelname)
    sys.path.append(setting_path)
    setting = importlib.import_module("mynet")
    sample_num = setting.sample_num 
    rotation_range_val = setting.rotation_range_val
    scaling_range_val = setting.scaling_range_val
    jitter_val = setting.jitter_val
    # 输入文件
    filepath = "../data/mynet/test/1/2018-05-12-12-52-11_Velodyne-HDL-32-Data(1955to2295)_1955.pcd"
    
    batch_size_val = 1 # 不知道干嘛用的参数
    data_frame = np.zeros((batch_size_val, sample_num, 3), dtype=np.float32)
    #frame_id = (filepath.split('_')[-1]).split('.')[0]

    #######################################################################
    # Loading PCD
    with open(filepath, 'r') as f:
            xyz = np.array([ [float(value) for value in line.split(' ')[0:3]]  
                                            for line in f.readlines()[11:len(f.readlines())-1]])
    #######################################################################

    xyz = xyz.astype(np.float32)
    np.random.shuffle(xyz)
    pt_num = xyz.shape[0]
    indices = np.random.choice(pt_num, sample_num, replace=(pt_num <sample_num))
    data_frame[0, ...] = xyz[indices]
    #data_frame[0, ...] = xyz

    ######################################################################
    # Placeholders
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    is_training = tf.placeholder(tf.bool, name='is_training')
    data_frame_placeholder = tf.placeholder(data_frame.dtype, data_frame.shape, name="data_frame")    
    ######################################################################

    
    #######################################################################
    # 网络的输入
    points_augmented = pf.augment(data_frame_placeholder, xforms, jitter_range)
    # 构建网络,暂时不使用其他特征
    net = model.Net(points=points_augmented, features=None, is_training=is_training, setting=setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    # 网络的输出
    predict = tf.argmax(probs, axis=-1, name='predictions')
    #######################################################################
    
    #######################################################################
    load_ckpt = "/home/elvin/models/mynet/pointcnn_cls_mynet_2019-07-05-20-00-17_31693/ckpts/iter-528"
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    saver = tf.train.Saver(max_to_keep=None)
    with tf.Session() as sess:
        sess.run(init_op)
        saver.restore(sess, load_ckpt) # Load the model
        xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=rotation_range_val,
                                                                scaling_range=scaling_range_val,
                                                                order=setting.rotation_order)
        res = sess.run(predict, feed_dict={data_frame_placeholder: data_frame,
                                                            xforms: xforms_np,
                                                            rotations: rotations_np,
                                                            jitter_range: np.array([jitter_val]),
                                                            is_training: False,
                                                            })
        print("res=", res[0][0])    
    ######################################################################    
    #sys.stdout.flush()
    return res[0][0]
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%d_%s' % (args.model, args.setting, os.getpid(), time_string))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    # 输出重定向
    # sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    print(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = 1000
    num_parts = setting.num_parts
    label_weights_list = setting.label_weights
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))

    data_train, data_num_train, label_train = data_utils.load_seg(
        args.filelist)
    data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]  # 2048
    point_num = data_train.shape[1]  # 6501
    num_val = data_val.shape[0]

    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3),
                            name="xforms")  # 缩放形变
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")  # 旋转形变
    jitter_range = tf.placeholder(tf.float32, shape=(1),
                                  name="jitter_range")  # 数据扰动
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int32,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')
    loss_val_avg = tf.placeholder(tf.float32)
    t_1_acc_val_avg = tf.placeholder(tf.float32)
    t_1_acc_val_instance_avg = tf.placeholder(tf.float32)
    t_1_acc_val_others_avg = tf.placeholder(tf.float32)

    ######################################################################

    # Set Inputs(points,features_sampled)
    features_sampled = None

    if setting.data_dim > 3:
        #   [3 pts_xyz , 2 img_xy , 4 extra_features]
        points, _, features = tf.split(pts_fts, [
            setting.data_format["pts_xyz"], setting.data_format["img_xy"],
            setting.data_format["extra_features"]
        ],
                                       axis=-1,
                                       name='split_points_xy_features')
        # r, g, b, i
        if setting.use_extra_features:
            # 选取extra特征
            features_sampled = tf.gather_nd(features,
                                            indices=indices,
                                            name='features_sampled')
    else:
        points = pts_fts

    points_sampled = tf.gather_nd(points,
                                  indices=indices,
                                  name='points_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)
    # N * 6501
    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    # Build net
    net = model.Net(points_augmented, features_sampled, None, None, num_parts,
                    is_training, setting)
    logits, probs = net.logits, net.probs

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)
    t_1_acc_op = pf.top_1_accuracy(probs, labels_sampled)
    t_1_acc_instance_op = pf.top_1_accuracy(probs, labels_sampled,
                                            labels_weights_sampled, 0.6)
    t_1_acc_others_op = pf.top_1_accuracy(probs, labels_sampled,
                                          labels_weights_sampled, 0.6, "less")

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()

    # lr decay
    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)

    # Optimizer
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=0.9,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    saver = tf.train.Saver(max_to_keep=1)

    # backup this file, model and setting
    if not os.path.exists(os.path.join(root_folder, args.model)):
        os.makedirs(os.path.join(root_folder, args.model))
    shutil.copy(__file__, os.path.join(root_folder,
                                       os.path.basename(__file__)))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'),
                os.path.join(root_folder, args.model + '.py'))
    shutil.copy(
        os.path.join(os.path.dirname(__file__),
                     args.model.split("_")[0] + "_kitti" + '.py'),
        os.path.join(root_folder,
                     args.model.split("_")[0] + "_kitti" + '.py'))
    shutil.copy(os.path.join(setting_path, args.setting + '.py'),
                os.path.join(root_folder, args.model, args.setting + '.py'))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    _ = tf.summary.scalar('loss/train_seg',
                          tensor=loss_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_seg',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_seg_instance',
                          tensor=t_1_acc_instance_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_seg_others',
                          tensor=t_1_acc_others_op,
                          collections=['train'])
    _ = tf.summary.scalar('loss/val_seg',
                          tensor=loss_val_avg,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_seg_instance',
                          tensor=t_1_acc_val_instance_avg,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_seg',
                          tensor=t_1_acc_val_avg,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_seg_others',
                          tensor=t_1_acc_val_others_avg,
                          collections=['val'])
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])

    # Session Run
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))

        for batch_idx in range(batch_num):
            if (batch_idx != 0 and batch_idx % step_val
                    == 0) or batch_idx == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                losses_val = []
                t_1_accs = []
                t_1_accs_instance = []
                t_1_accs_others = []

                for batch_val_idx in range(math.ceil(num_val / batch_size)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[label_val[
                        start_idx:end_idx, ...]]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val, scaling_range=scaling_range_val)
                    sess_op_list = [
                        loss_op, t_1_acc_op, t_1_acc_instance_op,
                        t_1_acc_others_op
                    ]
                    sess_feed_dict = {
                        pts_fts:
                        points_batch,
                        indices:
                        pf.get_indices(batch_size_val, sample_num,
                                       points_num_batch),
                        xforms:
                        xforms_np,
                        rotations:
                        rotations_np,
                        jitter_range:
                        np.array([jitter_val]),  # data jitter
                        labels_seg:
                        labels_batch,
                        labels_weights:
                        weights_batch,
                        is_training:
                        False
                    }

                    loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others = sess.run(
                        sess_op_list, feed_dict=sess_feed_dict)
                    print(
                        '{}-[Val  ]-Iter: {:06d}  Loss: {:.4f} T-1 Acc: {:.4f}'
                        .format(datetime.now(), batch_val_idx, loss_val,
                                t_1_acc_val))

                    sys.stdout.flush()
                    losses_val.append(loss_val * batch_size_val)
                    t_1_accs.append(t_1_acc_val * batch_size_val)
                    t_1_accs_instance.append(t_1_acc_val_instance *
                                             batch_size_val)
                    t_1_accs_others.append(t_1_acc_val_others * batch_size_val)

                loss_avg = sum(losses_val) / num_val
                t_1_acc_avg = sum(t_1_accs) / num_val
                t_1_acc_instance_avg = sum(t_1_accs_instance) / num_val
                t_1_acc_others_avg = sum(t_1_accs_others) / num_val

                summaries_feed_dict = {
                    loss_val_avg: loss_avg,
                    t_1_acc_val_avg: t_1_acc_avg,
                    t_1_acc_val_instance_avg: t_1_acc_instance_avg,
                    t_1_acc_val_others_avg: t_1_acc_others_avg
                }

                summaries_val = sess.run(summaries_val_op,
                                         feed_dict=summaries_feed_dict)
                summary_writer.add_summary(summaries_val, batch_idx)

                print('{}-[Val  ]-Average:      Loss: {:.4f} T-1 Acc: {:.4f}'.
                      format(datetime.now(), loss_avg, t_1_acc_avg))

                sys.stdout.flush()
                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]  # 2048 6501 9

            points_num_batch = data_num_train[start_idx:end_idx, ...]  #2048
            labels_batch = label_train[start_idx:end_idx, ...]  #2048 6501
            weights_batch = np.array(label_weights_list)[labels_batch]  # 值为零

            if start_idx + batch_size_train == num_train:
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(random.gauss(0, sample_num // 8))  # 均值 方差
            offset = max(offset, -sample_num // 4)
            offset = min(offset, sample_num // 4)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train, scaling_range=scaling_range)  #数据增强

            sess_op_list = [
                train_op, loss_op, t_1_acc_op, t_1_acc_instance_op,
                t_1_acc_others_op, summaries_op
            ]

            sess_feed_dict = {
                pts_fts:
                points_batch,
                indices:
                pf.get_indices(batch_size_train, sample_num_train,
                               points_num_batch),
                xforms:
                xforms_np,
                rotations:
                rotations_np,
                jitter_range:
                np.array([jitter]),
                labels_seg:
                labels_batch,
                labels_weights:
                weights_batch,
                is_training:
                True
            }

            _, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries = sess.run(
                sess_op_list, feed_dict=sess_feed_dict)
            print('{}-[Train]-Iter: {:06d}  Loss_seg: {:.4f} T-1 Acc: {:.4f}'.
                  format(datetime.now(), batch_idx, loss, t_1_acc))

            summary_writer.add_summary(summaries, batch_idx)
            sys.stdout.flush()

            ######################################################################
        print('{}-Done!'.format(datetime.now()))
Exemple #17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_unseen',
                        '-u',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    label_weights_val = [1.0] * 1 + [1.0] * (setting.num_class - 1)

    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    data_unseen, _, data_num_unseen, label_unseen, _ = data_utils.load_seg(
        args.filelist_unseen)

    # shuffle
    data_unseen, data_num_unseen, label_unseen = \
        data_utils.grouped_shuffle([data_unseen, data_num_unseen, label_unseen])

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    num_unseen = data_unseen.shape[0]

    batch_num_unseen = int(math.ceil(num_unseen / batch_size))

    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = int(math.ceil(num_val / batch_size))
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    is_labelled_data = tf.placeholder(tf.bool, name='is_labelled_data')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    net = model.Net(points_augmented, features_augmented, is_training, setting)

    logits = net.logits
    probs = tf.nn.softmax(logits, name='prob')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    with tf.variable_scope('xcrf_ker_weights'):

        _, point_indices = pf.knn_indices_general(points_augmented,
                                                  points_augmented, 1024, True)

        xcrf = learningBlock(num_points=sample_num,
                             num_classes=setting.num_class,
                             theta_alpha=float(5),
                             theta_beta=float(2),
                             theta_gamma=float(1),
                             num_iterations=5,
                             name='xcrf',
                             point_indices=point_indices)

        _logits1 = xcrf.call(net.logits, points_augmented, features_augmented,
                             setting.data_dim)
        _logits2 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=2)
        _logits3 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=3)
        _logits4 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=4)
        _logits5 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=8)
        _logits6 = xcrf.call(net.logits,
                             points_augmented,
                             features_augmented,
                             setting.data_dim,
                             D=16)

        _logits = _logits1 + _logits2 + _logits3 + _logits4 + _logits5 + _logits6
        _probs = tf.nn.softmax(_logits, name='probs_crf')
        _predictions = tf.argmax(_probs, axis=-1, name='predictions_crf')

    logits = tf.cond(
        is_training,
        lambda: tf.cond(is_labelled_data, lambda: _logits, lambda: logits),
        lambda: _logits)

    predictions = tf.cond(
        is_training, lambda: tf.cond(is_labelled_data, lambda: _predictions,
                                     lambda: predictions),
        lambda: _predictions)

    labels_sampled = tf.cond(
        is_training, lambda: tf.cond(is_labelled_data, lambda: labels_sampled,
                                     lambda: _predictions),
        lambda: labels_sampled)

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
        t_1_per_mean_iou_op, t_1_per_mean_iou_op_update_op = \
            tf.metrics.mean_iou(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)

    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_mean_iou/val',
                          tensor=t_1_per_mean_iou_op,
                          collections=['val'])

    all_variable = tf.global_variables()

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    xcrf_ker_weights = [
        var for var in tf.global_variables() if 'xcrf_ker_weights' in var.name
    ]
    no_xcrf_ker_weights = [
        var for var in tf.global_variables()
        if 'xcrf_ker_weights' not in var.name
    ]

    #print(restore_values)
    #train_op = optimizer.minimize( loss_op+reg_loss, global_step=global_step)

    with tf.control_dependencies(update_ops):
        train_op_xcrf = optimizer.minimize(loss_op + reg_loss,
                                           var_list=no_xcrf_ker_weights,
                                           global_step=global_step)
        train_op_all = optimizer.minimize(loss_op + reg_loss,
                                          var_list=all_variable,
                                          global_step=global_step)

    train_op = tf.cond(is_labelled_data, lambda: train_op_all,
                       lambda: train_op_xcrf)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(var_list=no_xcrf_ker_weights, max_to_keep=5)

    saver_best = tf.train.Saver(max_to_keep=5)
    # backup all code
    code_folder = os.path.abspath(os.path.dirname(__file__))
    shutil.copytree(code_folder,
                    os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_ckpt_best = os.path.join(root_folder, 'ckpt-best')
    if not os.path.exists(folder_ckpt_best):
        os.makedirs(folder_ckpt_best)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(),
                                              int(parameter_num)))

    _highest_val = 0.0
    max_val = 10
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_values_op = tf.summary.merge_all('summary_values')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        batch_num = 50000
        for batch_idx_train in range(batch_num):
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                ######################################################################

                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                summary_hist = None
                _idxVal = np.arange(num_val)
                np.random.shuffle(_idxVal)

                _pred = []
                _label = []
                for batch_val_idx in tqdm(range(batch_num_val)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx

                    points_batch = data_val[_idxVal[start_idx:end_idx], ...]
                    points_num_batch = data_num_val[_idxVal[start_idx:end_idx],
                                                    ...]
                    labels_batch = label_val[_idxVal[start_idx:end_idx], ...]

                    weights_batch = np.array(label_weights_val)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)

                    _labels_sampled, _predictions, _, _, _, _ = sess.run(
                        [
                            labels_sampled, predictions, loss_mean_update_op,
                            t_1_acc_update_op, t_1_per_class_acc_update_op,
                            t_1_per_mean_iou_op_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            indices:
                            pf.get_indices(batch_size_val, sample_num,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            labels_seg:
                            labels_batch,
                            is_labelled_data:
                            True,
                            labels_weights:
                            weights_batch,
                            is_training:
                            False,
                        })
                    _pred.append(_predictions.flatten())
                    _label.append(_labels_sampled.flatten())

                loss_val, t_1_acc_val, t_1_per_class_acc_val, t1__mean_iou, summaries_val = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        t_1_per_mean_iou_op, summaries_val_op
                    ])

                img_d_summary = pf.plot_confusion_matrix(
                    _label,
                    _pred, ["Environment", "Pedestrian", "Car", "Cyclist"],
                    tensor_name='confusion_matrix')

                max_val = max_val - 1

                if (t_1_per_class_acc_val > _highest_val):

                    max_val = 10

                    _highest_val = t_1_per_class_acc_val

                    filename_ckpt = os.path.join(folder_ckpt_best,
                                                 str(_highest_val) + "-iter-")
                    saver_best.save(sess,
                                    filename_ckpt,
                                    global_step=global_step)

                if (max_val < 0):
                    sys.exit(0)

                summary_writer.add_summary(summaries_val, batch_idx_train)

                summary_writer.add_summary(img_d_summary, batch_idx_train)

                #summary_writer.add_summary(summary_hist, batch_idx_train)
                print(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  Diff-Best: {:.4f} T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            _highest_val - t_1_per_class_acc_val,
                            t_1_per_class_acc_val, t1__mean_iou))
                sys.stdout.flush()
                ######################################################################

                # Unseen Data
                sess.run(reset_metrics_op)
                summary_hist = None
                _idxunseen = np.arange(num_unseen)
                np.random.shuffle(_idxunseen)

                _pred = []
                _label = []

                unseenIndices = np.arange(batch_num_unseen)
                np.random.shuffle(unseenIndices)
                for batch_val_idx in tqdm(unseenIndices[:500]):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_unseen)
                    batch_size_train = end_idx - start_idx

                    points_batch = data_unseen[_idxunseen[start_idx:end_idx],
                                               ...]
                    points_num_batch = data_num_unseen[
                        _idxunseen[start_idx:end_idx], ...]
                    labels_batch = np.zeros(points_batch.shape[0:2],
                                            dtype=np.int32)

                    weights_batch = np.array(label_weights_list)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_train,
                        rotation_range=rotation_range,
                        scaling_range=scaling_range,
                        order=setting.rotation_order)

                    offset = int(
                        random.gauss(0,
                                     sample_num * setting.sample_num_variance))
                    offset = max(offset, -sample_num * setting.sample_num_clip)
                    offset = min(offset, sample_num * setting.sample_num_clip)
                    sample_num_train = sample_num + offset

                    sess.run(
                        [
                            train_op, loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op,
                            t_1_per_mean_iou_op_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            is_labelled_data:
                            False,
                            indices:
                            pf.get_indices(batch_size_train, sample_num_train,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter]),
                            labels_seg:
                            labels_batch,
                            labels_weights:
                            weights_batch,
                            is_training:
                            True,
                        })
                    if batch_val_idx % 100 == 0:
                        loss, t_1_acc, t_1_per_class_acc, t_1__mean_iou = sess.run(
                            [
                                loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                                t_1_per_mean_iou_op
                            ],
                            feed_dict={
                                pts_fts:
                                points_batch,
                                indices:
                                pf.get_indices(batch_size_train,
                                               sample_num_train,
                                               points_num_batch),
                                xforms:
                                xforms_np,
                                is_labelled_data:
                                False,
                                rotations:
                                rotations_np,
                                jitter_range:
                                np.array([jitter]),
                                labels_seg:
                                labels_batch,
                                labels_weights:
                                weights_batch,
                                is_training:
                                True,
                            })
                        print(
                            '{}-[Train]-Unseen: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
                            .format(datetime.now(), batch_val_idx, loss,
                                    t_1_acc, t_1_per_class_acc, t_1__mean_iou))
                        sys.stdout.flush()

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                if is_list_of_h5_list:
                    filelist_train_prev = seg_list[(seg_list_idx - 1) %
                                                   len(seg_list)]
                    filelist_train = seg_list[seg_list_idx % len(seg_list)]
                    if filelist_train != filelist_train_prev:
                        data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
                            filelist_train)
                        num_train = data_train.shape[0]
                    seg_list_idx = seg_list_idx + 1
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    train_op, loss_mean_update_op, t_1_acc_update_op,
                    t_1_per_class_acc_update_op, t_1_per_mean_iou_op_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    is_labelled_data:
                    True,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    True,
                })
            if batch_idx_train % 100 == 0:
                loss, t_1_acc, t_1_per_class_acc, t_1__mean_iou, summaries = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        t_1_per_mean_iou_op, summaries_op
                    ],
                    feed_dict={
                        pts_fts:
                        points_batch,
                        indices:
                        pf.get_indices(batch_size_train, sample_num_train,
                                       points_num_batch),
                        xforms:
                        xforms_np,
                        is_labelled_data:
                        True,
                        rotations:
                        rotations_np,
                        jitter_range:
                        np.array([jitter]),
                        labels_seg:
                        labels_batch,
                        labels_weights:
                        weights_batch,
                        is_training:
                        True,
                    })
                summary_writer.add_summary(summaries, batch_idx_train)
                print(
                    '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
                    .format(datetime.now(), batch_idx_train, loss, t_1_acc,
                            t_1_per_class_acc, t_1__mean_iou))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
Exemple #18
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', '-t', help='Path to data', required=True)
    parser.add_argument('--path_val', '-v', help='Path to validation data')
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    parser.add_argument(
        '--epochs',
        help='Number of training epochs (default defined in setting)',
        type=int)
    parser.add_argument('--batch_size',
                        help='Batch size (default defined in setting)',
                        type=int)
    parser.add_argument(
        '--log',
        help=
        'Log to FILE in save folder; use - for stdout (default is log.txt)',
        metavar='FILE',
        default='log.txt')
    parser.add_argument('--no_timestamp_folder',
                        help='Dont save to timestamp folder',
                        action='store_true')
    parser.add_argument('--no_code_backup',
                        help='Dont backup code',
                        action='store_true')
    args = parser.parse_args()

    if not args.no_timestamp_folder:
        time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
        root_folder = os.path.join(
            args.save_folder, '%s_%s_%s_%d' %
            (args.model, args.setting, time_string, os.getpid()))
    else:
        root_folder = args.save_folder
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    if args.log != '-':
        sys.stdout = open(os.path.join(root_folder, args.log), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = args.epochs or setting.num_epochs
    batch_size = args.batch_size or setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val
    pool_setting_val = None if not hasattr(
        setting, 'pool_setting_val') else setting.pool_setting_val
    pool_setting_train = None if not hasattr(
        setting, 'pool_setting_train') else setting.pool_setting_train

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    data_train, label_train, data_val, label_val = setting.load_fn(
        args.path, args.path_val)
    if setting.balance_fn is not None:
        num_train_before_balance = data_train.shape[0]
        repeat_num = setting.balance_fn(label_train)
        data_train = np.repeat(data_train, repeat_num, axis=0)
        label_train = np.repeat(label_train, repeat_num, axis=0)
        data_train, label_train = data_utils.grouped_shuffle(
            [data_train, label_train])
        num_epochs = math.floor(
            num_epochs * (num_train_before_balance / data_train.shape[0]))

    if setting.save_ply_fn is not None:
        folder = os.path.join(root_folder, 'pts')
        print('{}-Saving samples as .ply files to {}...'.format(
            datetime.now(), folder))
        sample_num_for_ply = min(512, data_train.shape[0])
        if setting.map_fn is None:
            data_sample = data_train[:sample_num_for_ply]
        else:
            data_sample_list = []
            for idx in range(sample_num_for_ply):
                data_sample_list.append(setting.map_fn(data_train[idx], 0)[0])
            data_sample = np.stack(data_sample_list)
        setting.save_ply_fn(data_sample, folder)

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    data_train_placeholder = tf.placeholder(data_train.dtype,
                                            data_train.shape,
                                            name='data_train')
    label_train_placeholder = tf.placeholder(tf.int64,
                                             label_train.shape,
                                             name='label_train')
    data_val_placeholder = tf.placeholder(data_val.dtype,
                                          data_val.shape,
                                          name='data_val')
    label_val_placeholder = tf.placeholder(tf.int64,
                                           label_val.shape,
                                           name='label_val')
    handle = tf.placeholder(tf.string, shape=[], name='handle')

    ######################################################################
    dataset_train = tf.data.Dataset.from_tensor_slices(
        (data_train_placeholder, label_train_placeholder))
    dataset_train = dataset_train.shuffle(buffer_size=batch_size * 4)

    if setting.map_fn is not None:
        dataset_train = dataset_train.map(
            lambda data, label: tuple(
                tf.py_func(setting.map_fn, [data, label],
                           [tf.float32, label.dtype])),
            num_parallel_calls=setting.num_parallel_calls)
    # quick draw 会用到map
    if setting.keep_remainder:
        dataset_train = dataset_train.batch(batch_size)
        batch_num_per_epoch = math.ceil(num_train / batch_size)
    else:
        dataset_train = dataset_train.apply(
            tf.contrib.data.batch_and_drop_remainder(batch_size))
        batch_num_per_epoch = math.floor(num_train / batch_size)
    dataset_train = dataset_train.repeat(num_epochs)
    iterator_train = dataset_train.make_initializable_iterator()
    batch_num = batch_num_per_epoch * num_epochs
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))

    dataset_val = tf.data.Dataset.from_tensor_slices(
        (data_val_placeholder, label_val_placeholder))
    if setting.map_fn is not None:
        dataset_val = dataset_val.map(
            lambda data, label: tuple(
                tf.py_func(setting.map_fn, [data, label],
                           [tf.float32, label.dtype])),
            num_parallel_calls=setting.num_parallel_calls)
    if setting.keep_remainder:
        dataset_val = dataset_val.batch(batch_size)
        batch_num_val = math.ceil(num_val /
                                  batch_size)  # batch_num_val = count batch
    else:
        dataset_val = dataset_val.apply(
            tf.contrib.data.batch_and_drop_remainder(batch_size))
        batch_num_val = math.floor(num_val / batch_size)
    iterator_val = dataset_val.make_initializable_iterator()
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    iterator = tf.data.Iterator.from_string_handle(
        handle, dataset_train.output_types)  # feedable iterator,
    (pts_fts, labels) = iterator.get_next()

    pts_fts_sampled = tf.gather_nd(
        pts_fts, indices=indices,
        name='pts_fts_sampled')  # gather_nd is differentiable
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(
        points_sampled, xforms,
        jitter_range)  # input含feature吗;要不要用这些Feature, rotate/jitter/scale加强

    net = model.Net(points=points_augmented,
                    features=features_augmented,
                    is_training=is_training,
                    setting=setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    labels_2d = tf.expand_dims(labels, axis=-1, name='labels_2d')
    labels_tile = tf.tile(labels_2d, (1, tf.shape(logits)[1]),
                          name='labels_tile')
    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile,
                                                     logits=logits)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_tile, predictions)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = tf.metrics.mean_per_class_accuracy(
            labels_tile, predictions, setting.num_class)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup all code
    if not args.no_code_backup:
        code_folder = os.path.abspath(os.path.dirname(__file__))
        shutil.copytree(
            code_folder,
            os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        else:
            latest_ckpt = tf.train.latest_checkpoint(folder_ckpt)
            if latest_ckpt:
                print('{}-Found checkpoint {}'.format(datetime.now(),
                                                      latest_ckpt))
                saver.restore(sess, latest_ckpt)
                print('{}-Checkpoint loaded from {} (Iter {})'.format(
                    datetime.now(), latest_ckpt, sess.run(global_step)))

        handle_train = sess.run(iterator_train.string_handle())
        handle_val = sess.run(iterator_val.string_handle())

        sess.run(iterator_train.initializer,
                 feed_dict={
                     data_train_placeholder: data_train,
                     label_train_placeholder: label_train,
                 })

        for batch_idx_train in range(batch_num):
            ######################################################################
            # Validation
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                sess.run(iterator_val.initializer,
                         feed_dict={
                             data_val_placeholder: data_val,
                             label_val_placeholder: label_val,
                         })
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                for batch_idx_val in range(
                        batch_num_val):  # batch_idx_val = record which batch
                    if not setting.keep_remainder \
                            or num_val % batch_size == 0 \
                            or batch_idx_val != batch_num_val - 1:
                        batch_size_val = batch_size
                    else:
                        batch_size_val = num_val % batch_size  # batch_size_val = ex in this batch
                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,  # xforms rely on batch_size_val
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)
                    sess.run(
                        [
                            loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op
                        ],
                        feed_dict={
                            handle:
                            handle_val,
                            indices:
                            pf.get_indices(
                                batch_size_val,
                                sample_num,
                                point_num,
                            ),  # randomly pick sample_num points
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            is_training:
                            False,
                        })
                loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        summaries_val_op, global_step
                    ])
                summary_writer.add_summary(summaries_val, step)
                print(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            t_1_per_class_acc_val))
                sys.stdout.flush()
            ######################################################################

            ######################################################################
            # Training
            if not setting.keep_remainder \
                    or num_train % batch_size == 0 \
                    or (batch_idx_train % batch_num_per_epoch) != (batch_num_per_epoch - 1):
                batch_size_train = batch_size
            else:
                batch_size_train = num_train % batch_size

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset  # 'To train a model that takes N points as input, N (N,(N/8)2) points are used for training', 作者尝试发现这样取样训练效果好
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    train_op, loss_mean_update_op, t_1_acc_update_op,
                    t_1_per_class_acc_update_op
                ],
                feed_dict={
                    handle:
                    handle_train,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   point_num, pool_setting_train),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    is_training:
                    True,
                })
            if batch_idx_train % 10 == 0:
                loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([
                    loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                    summaries_op, global_step
                ])
                summary_writer.add_summary(summaries, step)
                print(
                    '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), step, loss, t_1_acc,
                            t_1_per_class_acc))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
Exemple #19
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)

    parser.add_argument('--startpoint',
                        '-b',
                        help='Setting to use',
                        required=True)

    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    #sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    label_weights_val = [1.0] * 1 + [1.0] * (setting.num_class - 1)

    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    #data_train, _, data_num_train, label_train, _ = data_utils.load_seg(filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    data_train = data_val
    data_num_train = data_num_val
    label_train = label_val

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = int(math.ceil(num_val / batch_size))
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
        t_1_per_mean_iou_op, t_1_per_mean_iou_op_update_op = \
            tf.metrics.mean_iou(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)

    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_mean_iou/val',
                          tensor=t_1_per_mean_iou_op,
                          collections=['val'])

    #_ = tf.summary.histogram('summary/Add_F2', Add_F2, collections=['summary_values'])
    #_ = tf.summary.histogram('summary/Add_F3', Add_F3, collections=['summary_values'])
    #_ = tf.summary.histogram('summary/Z', Z, collections=['summary_values'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=5)

    saver_best = tf.train.Saver(max_to_keep=5)
    # backup all code
    code_folder = os.path.abspath(os.path.dirname(__file__))
    shutil.copytree(code_folder,
                    os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_ckpt_best = os.path.join(root_folder, 'ckpt-best')
    if not os.path.exists(folder_ckpt_best):
        os.makedirs(folder_ckpt_best)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    folder_cm_matrix = os.path.join(root_folder, 'cm-matrix')
    if not os.path.exists(folder_cm_matrix):
        os.makedirs(folder_cm_matrix)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    _highest_val = 0.0
    max_val = 10
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_values_op = tf.summary.merge_all('summary_values')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        #img_d_summary_writer = tf.summary.FileWriter(folder_cm_matrix, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        batch_num = 1
        batch_idx_train = 1

        ######################################################################
        # Validation
        filename_ckpt = os.path.join(folder_ckpt, 'iter')
        saver.save(sess, filename_ckpt, global_step=global_step)
        print('{}-Checkpoint saved to {}!'.format(datetime.now(),
                                                  filename_ckpt))

        sess.run(reset_metrics_op)
        summary_hist = None
        _idxVal = np.arange(num_val)
        np.random.shuffle(_idxVal)

        _dataX = []
        _dataY = []
        _dataZ = []
        _dataD = []
        _pred = []
        _label = []
        for batch_val_idx in tqdm(range(batch_num_val)):
            start_idx = batch_size * batch_val_idx
            end_idx = min(start_idx + batch_size, num_val)
            batch_size_val = end_idx - start_idx

            points_batch = data_val[_idxVal[start_idx:end_idx], ...]
            points_num_batch = data_num_val[_idxVal[start_idx:end_idx], ...]
            labels_batch = label_val[_idxVal[start_idx:end_idx], ...]

            weights_batch = np.array(label_weights_val)[labels_batch]

            xforms_np, rotations_np = pf.get_xforms(
                batch_size_val,
                rotation_range=rotation_range_val,
                scaling_range=scaling_range_val,
                order=setting.rotation_order)

            _labels_sampled, _predictions, _, _, _, _ = sess.run(
                [
                    labels_sampled, predictions, loss_mean_update_op,
                    t_1_acc_update_op, t_1_per_class_acc_update_op,
                    t_1_per_mean_iou_op_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    indices:
                    pf.get_indices(batch_size_val, sample_num,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter_val]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    False,
                })
            _dataX.append(points_batch[:, :, 0].flatten())
            _dataY.append(points_batch[:, :, 1].flatten())
            _dataZ.append(points_batch[:, :, 2].flatten())
            _pred.append(_predictions.flatten())
            _label.append(_labels_sampled.flatten())

        loss_val, t_1_acc_val, t_1_per_class_acc_val, t1__mean_iou, summaries_val = sess.run(
            [
                loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                t_1_per_mean_iou_op, summaries_val_op
            ])

        img_d_summary = pf.plot_confusion_matrix(
            _label,
            _pred, ["Environment", "Pedestrian", "Car", "Cyclist"],
            tensor_name='confusion_matrix',
            normalize=False)

        _dataX = np.concatenate(_dataX, axis=0).flatten()
        _dataY = np.concatenate(_dataY, axis=0).flatten()
        _dataZ = np.concatenate(_dataZ, axis=0).flatten()
        correct_labels = np.concatenate(_label, axis=0).flatten()
        predict_labels = np.concatenate(_pred, axis=0).flatten()

        filename_pred = args.filelist_val + '_cm.h5'
        print('{}-Saving {}...'.format(datetime.now(), filename_pred))
        file = h5py.File(filename_pred, 'w')
        file.create_dataset('dataX', data=_dataX)
        file.create_dataset('dataY', data=_dataY)
        file.create_dataset('dataZ', data=_dataZ)
        file.create_dataset('correct_labels', data=correct_labels)
        file.create_dataset('predict_labels', data=predict_labels)
        file.close()

        summary_writer.add_summary(img_d_summary, batch_idx_train)
        summary_writer.add_summary(summaries_val, batch_idx_train)

        y_test = correct_labels
        prediction = predict_labels

        print('Accuracy:', accuracy_score(y_test, prediction))
        print('F1 score:', f1_score(y_test, prediction, average=None))
        print('Recall:', recall_score(y_test, prediction, average=None))
        print('Precision:', precision_score(y_test, prediction, average=None))
        print('\n clasification report:\n',
              classification_report(y_test, prediction))
        print('\n confussion matrix:\n', confusion_matrix(y_test, prediction))

        print(
            '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  Diff-Best: {:.4f} T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
            .format(datetime.now(), loss_val, t_1_acc_val,
                    _highest_val - t_1_per_class_acc_val,
                    t_1_per_class_acc_val, t1__mean_iou))
        sys.stdout.flush()
Exemple #20
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', '-t', help='Path to data', required=True)
    parser.add_argument('--path_val', '-v', help='Path to validation data')
    parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load')
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting', '-x', help='Setting to use', required=True)
    parser.add_argument('--train_name', '-n', help='train name')
    parser.add_argument('--save_folder_chenzhixing_original', '-s', help='Path to folder for saving check points and summary', required=True)
    args = parser.parse_args()
    save_folder_chenzhixing = args.save_folder_chenzhixing_original

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(save_folder_chenzhixing, '%s_%s_%s' % (args.model, args.setting, args.train_name))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    num_class = setting.num_class
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    data_train, label_train, data_val, label_val = setting.load_fn(args.path, args.path_val)

    if setting.save_ply_fn is not None:
        folder = os.path.join(root_folder, 'pts')
        print('{}-Saving samples as .ply files to {}...'.format(datetime.now(), folder))
        sample_num_for_ply = min(512, data_train.shape[0])
        if setting.map_fn is None:
            data_sample = data_train[:sample_num_for_ply]
        else:
            data_sample_list = []
            for idx in range(sample_num_for_ply):
                data_sample_list.append(setting.map_fn(data_train[idx], 0)[0])
            data_sample = np.stack(data_sample_list)
        setting.save_ply_fn(data_sample, folder)

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    data_train_placeholder = tf.placeholder(data_train.dtype, data_train.shape)
    label_train_placeholder = tf.placeholder(label_train.dtype, label_train.shape)
    data_val_placeholder = tf.placeholder(data_val.dtype, data_val.shape)
    label_val_placeholder = tf.placeholder(label_val.dtype, label_val.shape)
    handle = tf.placeholder(tf.string, shape=[])

    ######################################################################
    dataset_train = tf.data.Dataset.from_tensor_slices((data_train_placeholder, label_train_placeholder))
    if setting.map_fn is not None:
        dataset_train = dataset_train.map(lambda data, label: tuple(tf.py_func(
            setting.map_fn, [data, label], [tf.float32, label.dtype])), num_parallel_calls=setting.num_parallel_calls)
    dataset_train = dataset_train.shuffle(buffer_size=batch_size * 4)

    if setting.keep_remainder:
        dataset_train = dataset_train.batch(batch_size)
        batch_num_per_epoch = math.ceil(num_train / batch_size)
    else:
        dataset_train = dataset_train.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
        batch_num_per_epoch = math.floor(num_train / batch_size)
    batch_num = batch_num_per_epoch * num_epochs
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))

    dataset_train = dataset_train.repeat()
    iterator_train = dataset_train.make_initializable_iterator()

    dataset_val = tf.data.Dataset.from_tensor_slices((data_val_placeholder, label_val_placeholder))
    if setting.map_fn is not None:
        dataset_val = dataset_val.map(lambda data, label: tuple(tf.py_func(
            setting.map_fn, [data, label], [tf.float32, label.dtype])), num_parallel_calls=setting.num_parallel_calls)
    if setting.keep_remainder:
        dataset_val = dataset_val.batch(batch_size)
        batch_num_val = math.ceil(num_val / batch_size)
    else:
        dataset_val = dataset_val.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
        batch_num_val = math.floor(num_val / batch_size)
    iterator_val = dataset_val.make_initializable_iterator()

    iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types, dataset_train.output_shapes)
    (pts_fts, labels) = iterator.get_next()

    features_augmented = None
    if setting.data_dim > 3:
        points, features = tf.split(pts_fts, [3, setting.data_dim - 3], axis=-1, name='split_points_features')
        if setting.use_extra_features:
            features_sampled = tf.gather_nd(features, indices=indices, name='features_sampled')
            if setting.with_normal_feature:
                features_augmented = pf.augment(features_sampled, rotations)
            else:
                features_augmented = features_sampled
    else:
        points = pts_fts
    points_sampled = tf.gather_nd(points, indices=indices, name='points_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    net = model.Net(points=points_augmented, features=features_augmented, num_class=num_class,
                    is_training=is_training, setting=setting)
    logits, probs = net.logits, net.probs
    labels_2d = tf.expand_dims(labels, axis=-1, name='labels_2d')
    labels_tile = tf.tile(labels_2d, (1, tf.shape(probs)[1]), name='labels_tile')

    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile, logits=logits)
    t_1_acc_op = pf.top_1_accuracy(probs, labels_tile)

    _ = tf.summary.scalar('loss/train', tensor=loss_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train'])

    loss_val_avg = tf.placeholder(tf.float32)
    t_1_acc_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('loss/val', tensor=loss_val_avg, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_val_avg, collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,
                                           setting.decay_rate, staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=0.9, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup this file, model and setting
    shutil.copy(__file__, os.path.join(root_folder, os.path.basename(__file__)))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'),
                os.path.join(root_folder, args.model + '.py'))
    if not os.path.exists(os.path.join(root_folder, args.model)):
        os.makedirs(os.path.join(root_folder, args.model))
    shutil.copy(os.path.join(setting_path, args.setting + '.py'),
                os.path.join(root_folder, args.model, args.setting + '.py'))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))

        handle_train = sess.run(iterator_train.string_handle())
        handle_val = sess.run(iterator_val.string_handle())

        sess.run(iterator_train.initializer, feed_dict={
            data_train_placeholder: data_train,
            label_train_placeholder: label_train,
        })

        ######################################################################
        # Predict
        outer_predict_acc = 0
        outer_predict_num = 10
        for opi in range(outer_predict_num):
            predict_acc = 0
            predict_num = 20
            predict_ave_probs = []
            labels_list = []
            total_time = 0
            for pre_i in range(predict_num):
                sess.run(iterator_val.initializer, feed_dict={
                    data_val_placeholder: data_val,
                    label_val_placeholder: label_val,
                })
    
                losses = []
                t_1_accs = []
                pre_i_probs = []
                for batch_idx_val in range(batch_num_val):
                    if not setting.keep_remainder or num_val % batch_size == 0 or batch_idx_val != batch_num_val - 1:
                        batch_size_val = batch_size
                    else:
                        batch_size_val = num_val % batch_size
                    xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=rotation_range_val,
                                                                order=setting.order)
                    time1 = time.time()
                    _, loss_val, t_1_acc_val, predict_probs, true_labels = \
                        sess.run([update_ops, loss_op, t_1_acc_op, probs, labels],
                                 feed_dict={
                                     handle: handle_val,
                                     indices: pf.get_indices(batch_size_val, sample_num, point_num),#, False),
                                     xforms: xforms_np,
                                     rotations: rotations_np,
                                     jitter_range: np.array([jitter_val]),
                                     is_training: False,
                                 })
                    time2 = time.time()
                    total_time += time2 - time1
                    losses.append(loss_val * batch_size_val)
                    t_1_accs.append(t_1_acc_val * batch_size_val)
                    pre_i_probs.append(predict_probs)
                    if (pre_i == 0):
                        labels_list.append(true_labels)
                    #print('{}-[Val  ]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}'.format
                    #      (datetime.now(), batch_idx_val, loss_val, t_1_acc_val))
                    #sys.stdout.flush()
                predict_ave_probs.append(pre_i_probs)
    
                #loss_avg = sum(losses) / num_val
                #t_1_acc_avg = sum(t_1_accs) / num_val
                #print('{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}'
                #      .format(datetime.now(), loss_avg, t_1_acc_avg))
                #sys.stdout.flush()
                #predict_acc += t_1_acc_avg
            #predict_acc /= predict_num
            #print('{}-[Mean ]-Average:      T-1 Acc: {:.4f}'
            #      .format(datetime.now(), predict_acc))
            predict_ave_probs = np.argmax(np.mean(np.squeeze(np.array(predict_ave_probs)), axis=0), axis=1)
            labels_list = np.squeeze(np.array(labels_list))
            print('predict:')
            print(predict_ave_probs)
            print('label:')
            print(labels_list)
            true_num = np.count_nonzero(predict_ave_probs == labels_list)
            total_num = labels_list.shape[0]
            outer_predict_acc += true_num/total_num*100
            print('Acc: %.2f (%d/%d)' % (true_num/total_num*100, true_num, total_num))
            print('average time: %f' % (total_time/(predict_num*batch_num_val)))
            sys.stdout.flush()
        outer_predict_acc /= outer_predict_num
        print('\n Average Acc: %.2f' % outer_predict_acc)
        sys.stdout.flush()
        exit()
        ######################################################################

        best_acc = -np.inf
        for batch_idx_train in range(batch_num):
            ######################################################################
            # Validation
            if (batch_idx_train != 0 and batch_idx_train % step_val == 0) or batch_idx_train == batch_num - 1:

                filename_ckpt = os.path.join(folder_ckpt, 'last_model')
                saver.save(sess, filename_ckpt, global_step=None)
                print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))

                predict_num = 20
                predict_ave_probs = []
                labels_list = []
                losses = []
                for pre_i in range(predict_num):
                    sess.run(iterator_val.initializer, feed_dict={
                        data_val_placeholder: data_val,
                        label_val_placeholder: label_val,
                    })
        
                    pre_i_probs = []
                    for batch_idx_val in range(batch_num_val):
                        if not setting.keep_remainder or num_val % batch_size == 0 or batch_idx_val != batch_num_val - 1:
                            batch_size_val = batch_size
                        else:
                            batch_size_val = num_val % batch_size
                        xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=rotation_range_val,
                                                                    order=setting.order)
                        _, loss_val, t_1_acc_val, predict_probs, true_labels = \
                            sess.run([update_ops, loss_op, t_1_acc_op, probs, labels],
                                     feed_dict={
                                         handle: handle_val,
                                         indices: pf.get_indices(batch_size_val, sample_num, point_num),#, False),
                                         xforms: xforms_np,
                                         rotations: rotations_np,
                                         jitter_range: np.array([jitter_val]),
                                         is_training: False,
                                     })
                        losses.append(loss_val * batch_size_val)
                        pre_i_probs.append(predict_probs)
                        if (pre_i == 0):
                            labels_list.append(true_labels)
                    predict_ave_probs.append(pre_i_probs)
        
                loss_avg = sum(losses) / (predict_num * num_val)
                predict_ave_probs = np.argmax(np.mean(np.reshape(np.array(predict_ave_probs), (predict_num, num_val, num_class)), axis=0), axis=1)
                labels_list = np.reshape(np.array(labels_list), (num_val,))
                true_num = np.count_nonzero(predict_ave_probs == labels_list)
                t_1_acc_avg = true_num / num_val
                print('{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}'
                      .format(datetime.now(), loss_avg, t_1_acc_avg))
                if t_1_acc_avg > (best_acc-1e-6):
                  best_acc = t_1_acc_avg
                  print('{}-[Val  ]-best:         Loss: {:.4f}  T-1 Acc: {:.4f}'
                      .format(datetime.now(), loss_avg, t_1_acc_avg))
                  filename_ckpt = os.path.join(folder_ckpt, 'best_model')
                  saver.save(sess, filename_ckpt, global_step=None)
                  print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))
                sys.stdout.flush()

                #############################################################################
                # Original Validation
#                sess.run(iterator_val.initializer, feed_dict={
#                    data_val_placeholder: data_val,
#                    label_val_placeholder: label_val,
#                })
#                filename_ckpt = os.path.join(folder_ckpt, 'last_model')
#                saver.save(sess, filename_ckpt, global_step=None)
#                print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))
#
#                losses = []
#                t_1_accs = []
#                for batch_idx_val in range(batch_num_val):
#                    if not setting.keep_remainder or num_val % batch_size == 0 or batch_idx_val != batch_num_val - 1:
#                        batch_size_val = batch_size
#                    else:
#                        batch_size_val = num_val % batch_size
#                    xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=rotation_range_val,
#                                                                order=setting.order)
#                    _, loss_val, t_1_acc_val = \
#                        sess.run([update_ops, loss_op, t_1_acc_op],
#                                 feed_dict={
#                                     handle: handle_val,
#                                     indices: pf.get_indices(batch_size_val, sample_num, point_num),#, False),
#                                     xforms: xforms_np,
#                                     rotations: rotations_np,
#                                     jitter_range: np.array([jitter_val]),
#                                     is_training: False,
#                                 })
#                    losses.append(loss_val * batch_size_val)
#                    t_1_accs.append(t_1_acc_val * batch_size_val)
#                    print('{}-[Val  ]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}'.format
#                          (datetime.now(), batch_idx_val, loss_val, t_1_acc_val))
#                    sys.stdout.flush()
#
#                loss_avg = sum(losses) / num_val
#                t_1_acc_avg = sum(t_1_accs) / num_val
#                summaries_val = sess.run(summaries_val_op,
#                                         feed_dict={
#                                             loss_val_avg: loss_avg,
#                                             t_1_acc_val_avg: t_1_acc_avg,
#                                         })
#                summary_writer.add_summary(summaries_val, batch_idx_train)
#                print('{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}'
#                      .format(datetime.now(), loss_avg, t_1_acc_avg))
#                if t_1_acc_avg > (best_acc-1e-6):
#                  best_acc = t_1_acc_avg
#                  print('{}-[Val  ]-best:         Loss: {:.4f}  T-1 Acc: {:.4f}'
#                      .format(datetime.now(), loss_avg, t_1_acc_avg))
#                  filename_ckpt = os.path.join(folder_ckpt, 'best_model')
#                  saver.save(sess, filename_ckpt, global_step=None)
#                  print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))
#                sys.stdout.flush()
            ######################################################################

            ######################################################################
            # Training
            if not setting.keep_remainder or num_train % batch_size == 0 or (batch_idx_train % batch_num_per_epoch) != (batch_num_per_epoch - 1):
                batch_size_train = batch_size
            else:
                batch_size_train = num_train % batch_size
            offset = int(random.gauss(0, sample_num // 8))
            offset = max(offset, -sample_num // 4)
            offset = min(offset, sample_num // 4)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(batch_size_train, rotation_range=rotation_range,
                                                        order=setting.order)
            _, loss, t_1_acc, summaries = \
                sess.run([train_op, loss_op, t_1_acc_op, summaries_op],
                         feed_dict={
                             handle: handle_train,
                             indices: pf.get_indices(batch_size_train, sample_num_train, point_num),
                             xforms: xforms_np,
                             rotations: rotations_np,
                             jitter_range: np.array([jitter]),
                             is_training: True,
                         })
            summary_writer.add_summary(summaries, batch_idx_train)
            print('{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}'
                  .format(datetime.now(), batch_idx_train, loss, t_1_acc))
            sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
def main():
    args = AttrDict()
    # Path of training set ground truth file list (.txt)
    args.filelist = os.path.join(SCENENN_DIR, 'train_files.txt')
    # Path of validation set ground truth file list (.txt)
    args.filelist_val = os.path.join(SCENENN_DIR, 'test_files.txt')
    # Path of a check point file to load
    args.load_ckpt = os.path.join(ROOT_DIR, '..', 'models',
                                  'pretrained_scannet', 'ckpts', 'iter-354000')
    # Base directory where model checkpoint and summary files get saved in separate subdirectories
    args.save_folder = os.path.join(ROOT_DIR, '..', 'models')
    # PointCNN model to use
    args.model = 'pointcnn_seg'
    # Model setting to use
    args.setting = 'scenenn_x8_2048_fps'

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    model_save_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(model_save_folder):
        os.makedirs(model_save_folder)

    # sys.stdout = open(os.path.join(model_save_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())
    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(ROOT_DIR, args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)
    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    print('{}-Initializing TF-placeholders...'.format(datetime.now()))
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    print('{}-Initializing net...'.format(datetime.now()))
    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    print('{}-Setting up optimizer...'.format(datetime.now()))
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        last_layer_train_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'logits')
        last_layer_train_op = optimizer.minimize(
            loss_op + reg_loss,
            global_step=global_step,
            var_list=last_layer_train_vars)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    variables_to_restore = get_variables_to_restore(exclude=['logits'])
    restorer = tf.train.Saver(var_list=variables_to_restore, max_to_keep=None)

    # backup all code
    # code_folder = os.path.abspath(os.path.dirname(__file__))
    # shutil.copytree(code_folder, os.path.join(model_save_folder, os.path.basename(code_folder)), symlinks=True)

    folder_ckpt = os.path.join(model_save_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(model_save_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Number of model parameters: {:d}.'.format(
        datetime.now(), parameter_num))

    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        print('{}-Initializing variables...'.format(datetime.now()))
        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            print('{}-Loading checkpoint from {}...'.format(
                datetime.now(), args.load_ckpt))
            restorer.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded.'.format(datetime.now()))

        for batch_idx_train in tqdm(range(batch_num), ncols=60):
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                tqdm.write('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                for batch_val_idx in range(batch_num_val):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)
                    sess.run(
                        [
                            loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            indices:
                            pf.get_indices(batch_size_val, sample_num,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            labels_seg:
                            labels_batch,
                            labels_weights:
                            weights_batch,
                            is_training:
                            False,
                        })

                loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        summaries_val_op
                    ])
                summary_writer.add_summary(summaries_val, batch_idx_train)
                tqdm.write(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            t_1_per_class_acc_val))
                sys.stdout.flush()
                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                if is_list_of_h5_list:
                    filelist_train_prev = seg_list[(seg_list_idx - 1) %
                                                   len(seg_list)]
                    filelist_train = seg_list[seg_list_idx % len(seg_list)]
                    if filelist_train != filelist_train_prev:
                        data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
                            filelist_train)
                        num_train = data_train.shape[0]
                    seg_list_idx = seg_list_idx + 1
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    last_layer_train_op, loss_mean_update_op,
                    t_1_acc_update_op, t_1_per_class_acc_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    True,
                })
            if batch_idx_train % 10 == 0:
                loss, t_1_acc, t_1_per_class_acc, summaries = sess.run([
                    loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                    summaries_op
                ])
                summary_writer.add_summary(summaries, batch_idx_train)
                # tqdm.write('{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                #            .format(datetime.now(), batch_idx_train, loss, t_1_acc, t_1_per_class_acc))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))