Exemple #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', '-t', help='Path to data', required=True)
    parser.add_argument('--path_val', '-v', help='Path to validation data')
    parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load')
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting', '-x', help='Setting to use', required=True)
    parser.add_argument('--train_name', '-n', help='train name')
    parser.add_argument('--save_folder_chenzhixing_original', '-s', help='Path to folder for saving check points and summary', required=True)
    args = parser.parse_args()
    save_folder_chenzhixing = args.save_folder_chenzhixing_original

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(save_folder_chenzhixing, '%s_%s_%s' % (args.model, args.setting, args.train_name))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    num_class = setting.num_class
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    data_train, label_train, data_val, label_val = setting.load_fn(args.path, args.path_val)

    if setting.save_ply_fn is not None:
        folder = os.path.join(root_folder, 'pts')
        print('{}-Saving samples as .ply files to {}...'.format(datetime.now(), folder))
        sample_num_for_ply = min(512, data_train.shape[0])
        if setting.map_fn is None:
            data_sample = data_train[:sample_num_for_ply]
        else:
            data_sample_list = []
            for idx in range(sample_num_for_ply):
                data_sample_list.append(setting.map_fn(data_train[idx], 0)[0])
            data_sample = np.stack(data_sample_list)
        setting.save_ply_fn(data_sample, folder)

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    data_train_placeholder = tf.placeholder(data_train.dtype, data_train.shape)
    label_train_placeholder = tf.placeholder(label_train.dtype, label_train.shape)
    data_val_placeholder = tf.placeholder(data_val.dtype, data_val.shape)
    label_val_placeholder = tf.placeholder(label_val.dtype, label_val.shape)
    handle = tf.placeholder(tf.string, shape=[])

    ######################################################################
    dataset_train = tf.data.Dataset.from_tensor_slices((data_train_placeholder, label_train_placeholder))
    if setting.map_fn is not None:
        dataset_train = dataset_train.map(lambda data, label: tuple(tf.py_func(
            setting.map_fn, [data, label], [tf.float32, label.dtype])), num_parallel_calls=setting.num_parallel_calls)
    dataset_train = dataset_train.shuffle(buffer_size=batch_size * 4)

    if setting.keep_remainder:
        dataset_train = dataset_train.batch(batch_size)
        batch_num_per_epoch = math.ceil(num_train / batch_size)
    else:
        dataset_train = dataset_train.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
        batch_num_per_epoch = math.floor(num_train / batch_size)
    batch_num = batch_num_per_epoch * num_epochs
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))

    dataset_train = dataset_train.repeat()
    iterator_train = dataset_train.make_initializable_iterator()

    dataset_val = tf.data.Dataset.from_tensor_slices((data_val_placeholder, label_val_placeholder))
    if setting.map_fn is not None:
        dataset_val = dataset_val.map(lambda data, label: tuple(tf.py_func(
            setting.map_fn, [data, label], [tf.float32, label.dtype])), num_parallel_calls=setting.num_parallel_calls)
    if setting.keep_remainder:
        dataset_val = dataset_val.batch(batch_size)
        batch_num_val = math.ceil(num_val / batch_size)
    else:
        dataset_val = dataset_val.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
        batch_num_val = math.floor(num_val / batch_size)
    iterator_val = dataset_val.make_initializable_iterator()

    iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types, dataset_train.output_shapes)
    (pts_fts, labels) = iterator.get_next()

    features_augmented = None
    if setting.data_dim > 3:
        points, features = tf.split(pts_fts, [3, setting.data_dim - 3], axis=-1, name='split_points_features')
        if setting.use_extra_features:
            features_sampled = tf.gather_nd(features, indices=indices, name='features_sampled')
            if setting.with_normal_feature:
                features_augmented = pf.augment(features_sampled, rotations)
            else:
                features_augmented = features_sampled
    else:
        points = pts_fts
    points_sampled = tf.gather_nd(points, indices=indices, name='points_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    net = model.Net(points=points_augmented, features=features_augmented, num_class=num_class,
                    is_training=is_training, setting=setting)
    logits, probs = net.logits, net.probs
    labels_2d = tf.expand_dims(labels, axis=-1, name='labels_2d')
    labels_tile = tf.tile(labels_2d, (1, tf.shape(probs)[1]), name='labels_tile')

    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile, logits=logits)
    t_1_acc_op = pf.top_1_accuracy(probs, labels_tile)

    _ = tf.summary.scalar('loss/train', tensor=loss_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train'])

    loss_val_avg = tf.placeholder(tf.float32)
    t_1_acc_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('loss/val', tensor=loss_val_avg, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_val_avg, collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,
                                           setting.decay_rate, staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=0.9, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup this file, model and setting
    shutil.copy(__file__, os.path.join(root_folder, os.path.basename(__file__)))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'),
                os.path.join(root_folder, args.model + '.py'))
    if not os.path.exists(os.path.join(root_folder, args.model)):
        os.makedirs(os.path.join(root_folder, args.model))
    shutil.copy(os.path.join(setting_path, args.setting + '.py'),
                os.path.join(root_folder, args.model, args.setting + '.py'))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))

        handle_train = sess.run(iterator_train.string_handle())
        handle_val = sess.run(iterator_val.string_handle())

        sess.run(iterator_train.initializer, feed_dict={
            data_train_placeholder: data_train,
            label_train_placeholder: label_train,
        })

        ######################################################################
        # Predict
        outer_predict_acc = 0
        outer_predict_num = 10
        for opi in range(outer_predict_num):
            predict_acc = 0
            predict_num = 20
            predict_ave_probs = []
            labels_list = []
            total_time = 0
            for pre_i in range(predict_num):
                sess.run(iterator_val.initializer, feed_dict={
                    data_val_placeholder: data_val,
                    label_val_placeholder: label_val,
                })
    
                losses = []
                t_1_accs = []
                pre_i_probs = []
                for batch_idx_val in range(batch_num_val):
                    if not setting.keep_remainder or num_val % batch_size == 0 or batch_idx_val != batch_num_val - 1:
                        batch_size_val = batch_size
                    else:
                        batch_size_val = num_val % batch_size
                    xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=rotation_range_val,
                                                                order=setting.order)
                    time1 = time.time()
                    _, loss_val, t_1_acc_val, predict_probs, true_labels = \
                        sess.run([update_ops, loss_op, t_1_acc_op, probs, labels],
                                 feed_dict={
                                     handle: handle_val,
                                     indices: pf.get_indices(batch_size_val, sample_num, point_num),#, False),
                                     xforms: xforms_np,
                                     rotations: rotations_np,
                                     jitter_range: np.array([jitter_val]),
                                     is_training: False,
                                 })
                    time2 = time.time()
                    total_time += time2 - time1
                    losses.append(loss_val * batch_size_val)
                    t_1_accs.append(t_1_acc_val * batch_size_val)
                    pre_i_probs.append(predict_probs)
                    if (pre_i == 0):
                        labels_list.append(true_labels)
                    #print('{}-[Val  ]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}'.format
                    #      (datetime.now(), batch_idx_val, loss_val, t_1_acc_val))
                    #sys.stdout.flush()
                predict_ave_probs.append(pre_i_probs)
    
                #loss_avg = sum(losses) / num_val
                #t_1_acc_avg = sum(t_1_accs) / num_val
                #print('{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}'
                #      .format(datetime.now(), loss_avg, t_1_acc_avg))
                #sys.stdout.flush()
                #predict_acc += t_1_acc_avg
            #predict_acc /= predict_num
            #print('{}-[Mean ]-Average:      T-1 Acc: {:.4f}'
            #      .format(datetime.now(), predict_acc))
            predict_ave_probs = np.argmax(np.mean(np.squeeze(np.array(predict_ave_probs)), axis=0), axis=1)
            labels_list = np.squeeze(np.array(labels_list))
            print('predict:')
            print(predict_ave_probs)
            print('label:')
            print(labels_list)
            true_num = np.count_nonzero(predict_ave_probs == labels_list)
            total_num = labels_list.shape[0]
            outer_predict_acc += true_num/total_num*100
            print('Acc: %.2f (%d/%d)' % (true_num/total_num*100, true_num, total_num))
            print('average time: %f' % (total_time/(predict_num*batch_num_val)))
            sys.stdout.flush()
        outer_predict_acc /= outer_predict_num
        print('\n Average Acc: %.2f' % outer_predict_acc)
        sys.stdout.flush()
        exit()
        ######################################################################

        best_acc = -np.inf
        for batch_idx_train in range(batch_num):
            ######################################################################
            # Validation
            if (batch_idx_train != 0 and batch_idx_train % step_val == 0) or batch_idx_train == batch_num - 1:

                filename_ckpt = os.path.join(folder_ckpt, 'last_model')
                saver.save(sess, filename_ckpt, global_step=None)
                print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))

                predict_num = 20
                predict_ave_probs = []
                labels_list = []
                losses = []
                for pre_i in range(predict_num):
                    sess.run(iterator_val.initializer, feed_dict={
                        data_val_placeholder: data_val,
                        label_val_placeholder: label_val,
                    })
        
                    pre_i_probs = []
                    for batch_idx_val in range(batch_num_val):
                        if not setting.keep_remainder or num_val % batch_size == 0 or batch_idx_val != batch_num_val - 1:
                            batch_size_val = batch_size
                        else:
                            batch_size_val = num_val % batch_size
                        xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=rotation_range_val,
                                                                    order=setting.order)
                        _, loss_val, t_1_acc_val, predict_probs, true_labels = \
                            sess.run([update_ops, loss_op, t_1_acc_op, probs, labels],
                                     feed_dict={
                                         handle: handle_val,
                                         indices: pf.get_indices(batch_size_val, sample_num, point_num),#, False),
                                         xforms: xforms_np,
                                         rotations: rotations_np,
                                         jitter_range: np.array([jitter_val]),
                                         is_training: False,
                                     })
                        losses.append(loss_val * batch_size_val)
                        pre_i_probs.append(predict_probs)
                        if (pre_i == 0):
                            labels_list.append(true_labels)
                    predict_ave_probs.append(pre_i_probs)
        
                loss_avg = sum(losses) / (predict_num * num_val)
                predict_ave_probs = np.argmax(np.mean(np.reshape(np.array(predict_ave_probs), (predict_num, num_val, num_class)), axis=0), axis=1)
                labels_list = np.reshape(np.array(labels_list), (num_val,))
                true_num = np.count_nonzero(predict_ave_probs == labels_list)
                t_1_acc_avg = true_num / num_val
                print('{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}'
                      .format(datetime.now(), loss_avg, t_1_acc_avg))
                if t_1_acc_avg > (best_acc-1e-6):
                  best_acc = t_1_acc_avg
                  print('{}-[Val  ]-best:         Loss: {:.4f}  T-1 Acc: {:.4f}'
                      .format(datetime.now(), loss_avg, t_1_acc_avg))
                  filename_ckpt = os.path.join(folder_ckpt, 'best_model')
                  saver.save(sess, filename_ckpt, global_step=None)
                  print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))
                sys.stdout.flush()

                #############################################################################
                # Original Validation
#                sess.run(iterator_val.initializer, feed_dict={
#                    data_val_placeholder: data_val,
#                    label_val_placeholder: label_val,
#                })
#                filename_ckpt = os.path.join(folder_ckpt, 'last_model')
#                saver.save(sess, filename_ckpt, global_step=None)
#                print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))
#
#                losses = []
#                t_1_accs = []
#                for batch_idx_val in range(batch_num_val):
#                    if not setting.keep_remainder or num_val % batch_size == 0 or batch_idx_val != batch_num_val - 1:
#                        batch_size_val = batch_size
#                    else:
#                        batch_size_val = num_val % batch_size
#                    xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=rotation_range_val,
#                                                                order=setting.order)
#                    _, loss_val, t_1_acc_val = \
#                        sess.run([update_ops, loss_op, t_1_acc_op],
#                                 feed_dict={
#                                     handle: handle_val,
#                                     indices: pf.get_indices(batch_size_val, sample_num, point_num),#, False),
#                                     xforms: xforms_np,
#                                     rotations: rotations_np,
#                                     jitter_range: np.array([jitter_val]),
#                                     is_training: False,
#                                 })
#                    losses.append(loss_val * batch_size_val)
#                    t_1_accs.append(t_1_acc_val * batch_size_val)
#                    print('{}-[Val  ]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}'.format
#                          (datetime.now(), batch_idx_val, loss_val, t_1_acc_val))
#                    sys.stdout.flush()
#
#                loss_avg = sum(losses) / num_val
#                t_1_acc_avg = sum(t_1_accs) / num_val
#                summaries_val = sess.run(summaries_val_op,
#                                         feed_dict={
#                                             loss_val_avg: loss_avg,
#                                             t_1_acc_val_avg: t_1_acc_avg,
#                                         })
#                summary_writer.add_summary(summaries_val, batch_idx_train)
#                print('{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}'
#                      .format(datetime.now(), loss_avg, t_1_acc_avg))
#                if t_1_acc_avg > (best_acc-1e-6):
#                  best_acc = t_1_acc_avg
#                  print('{}-[Val  ]-best:         Loss: {:.4f}  T-1 Acc: {:.4f}'
#                      .format(datetime.now(), loss_avg, t_1_acc_avg))
#                  filename_ckpt = os.path.join(folder_ckpt, 'best_model')
#                  saver.save(sess, filename_ckpt, global_step=None)
#                  print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))
#                sys.stdout.flush()
            ######################################################################

            ######################################################################
            # Training
            if not setting.keep_remainder or num_train % batch_size == 0 or (batch_idx_train % batch_num_per_epoch) != (batch_num_per_epoch - 1):
                batch_size_train = batch_size
            else:
                batch_size_train = num_train % batch_size
            offset = int(random.gauss(0, sample_num // 8))
            offset = max(offset, -sample_num // 4)
            offset = min(offset, sample_num // 4)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(batch_size_train, rotation_range=rotation_range,
                                                        order=setting.order)
            _, loss, t_1_acc, summaries = \
                sess.run([train_op, loss_op, t_1_acc_op, summaries_op],
                         feed_dict={
                             handle: handle_train,
                             indices: pf.get_indices(batch_size_train, sample_num_train, point_num),
                             xforms: xforms_np,
                             rotations: rotations_np,
                             jitter_range: np.array([jitter]),
                             is_training: True,
                         })
            summary_writer.add_summary(summaries, batch_idx_train)
            print('{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}'
                  .format(datetime.now(), batch_idx_train, loss, t_1_acc))
            sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir_train', '-t', help='Path to dir of train set', required=True)
    parser.add_argument('--dir_val', '-v', help='Path to dir of val set', required=False)
    parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load')
    parser.add_argument('--save_folder', '-s', help='Path to folder for saving check points and summary', required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting', '-x', help='Setting to use', required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(args.save_folder, '%s_%s_%d_%s' % (args.model, args.setting, os.getpid(), time_string))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    print(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    num_parts = setting.num_parts
    label_weights_list = setting.label_weights
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))

    if args.dir_val is None:
        print("only load train")
        index_length_train, points_ele_train, intensities_train, labels_train, _ = data_utils.load_bin(args.dir_train)
        index_length_val = index_length_train
        points_ele_val = points_ele_train
        intensities_val = intensities_train
        labels_val = labels_train
    else:
        print("load train and val")
        index_length_train, points_ele_train, intensities_train, labels_train, _ = data_utils.load_bin(args.dir_train)
        index_length_val, points_ele_val, intensities_val, labels_val, _ = data_utils.load_bin(args.dir_val)

    # shuffle
    index_length_train = data_utils.index_shuffle(index_length_train)
    index_length_val = data_utils.index_shuffle(index_length_val)

    num_train = index_length_train.shape[0]
    point_num = max(np.max(index_length_train[:, 1]), np.max(index_length_val[:, 1]))
    num_val = index_length_val.shape[0]

    print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts = tf.placeholder(tf.float32, shape=(None, point_num, setting.point_dim), name='pts')
    fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.extra_dim), name='fts')
    labels_seg = tf.placeholder(tf.int32, shape=(None, point_num), name='labels_seg')
    labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights')

    ######################################################################

    # Set Inputs(points,features_sampled)
    features_sampled = None

    if setting.extra_dim == 1:
        points = pts
        features = fts

        if setting.use_extra_features:
            features_sampled = tf.gather_nd(features, indices=indices, name='features_sampled')

    elif setting.extra_dim == 0:
        points = pts

    points_sampled = tf.gather_nd(points, indices=indices, name='points_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled')

    # Build net
    net = model.Net(points_augmented, features_sampled, None, None, num_parts, is_training, setting)
    logits, probs = net.logits, net.probs

    # Define Loss Func
    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_sampled, logits=logits,
                                                     weights=labels_weights_sampled)
    _ = tf.summary.scalar('loss/train_seg', tensor=loss_op, collections=['train'])

    # for vis t1 acc
    t_1_acc_op = pf.top_1_accuracy(probs, labels_sampled)
    _ = tf.summary.scalar('t_1_acc/train_seg', tensor=t_1_acc_op, collections=['train'])
    # for vis instance acc
    t_1_acc_instance_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6)
    _ = tf.summary.scalar('t_1_acc/train_seg_instance', tensor=t_1_acc_instance_op, collections=['train'])
    # for vis other acc
    t_1_acc_others_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6, "less")
    _ = tf.summary.scalar('t_1_acc/train_seg_others', tensor=t_1_acc_others_op, collections=['train'])

    loss_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('loss/val_seg', tensor=loss_val_avg, collections=['val'])

    t_1_acc_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg', tensor=t_1_acc_val_avg, collections=['val'])
    t_1_acc_val_instance_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg_instance', tensor=t_1_acc_val_instance_avg, collections=['val'])
    t_1_acc_val_others_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg_others', tensor=t_1_acc_val_others_avg, collections=['val'])

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()

    # lr decay
    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,
                                           setting.decay_rate, staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])

    # Optimizer
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=0.9, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):

        train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup this file, model and setting
    if not os.path.exists(os.path.join(root_folder, args.model)):
        os.makedirs(os.path.join(root_folder, args.model))
    shutil.copy(__file__, os.path.join(root_folder, os.path.basename(__file__)))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'),
                os.path.join(root_folder, args.model + '.py'))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model.split("_")[0] + "_kitti" + '.py'),
                os.path.join(root_folder, args.model.split("_")[0] + "_kitti" + '.py'))
    shutil.copy(os.path.join(setting_path, args.setting + '.py'),
                os.path.join(root_folder, args.model, args.setting + '.py'))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    # Session Run
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))

        for batch_idx in range(batch_num):
            if (batch_idx != 0 and batch_idx % step_val == 0) or batch_idx == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))

                losses_val = []
                t_1_accs = []
                t_1_accs_instance = []
                t_1_accs_others = []

                for batch_val_idx in range(math.ceil(num_val / batch_size)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    index_length_val_batch = index_length_val[start_idx:end_idx]

                    points_batch = np.zeros((batch_size_val, point_num, 3), np.float32)
                    intensity_batch = np.zeros((batch_size_val, point_num, 1), np.float32)
                    points_num_batch = np.zeros(batch_size_val, np.int32)
                    labels_batch = np.zeros((batch_size_val, point_num), np.int32)

                    for i, index_length in enumerate(index_length_val_batch):
                        points_batch[i, 0:index_length[1], :] = \
                            points_ele_val[index_length[0] * 3:
                                           index_length[0] * 3 + index_length[1] * 3].reshape(index_length[1], 3)

                        intensity_batch[i, 0:index_length[1], :] = \
                            intensities_val[index_length[0]:
                                            index_length[0] + index_length[1]].reshape(index_length[1], 1)

                        points_num_batch[i] = index_length[1].astype(np.int32)

                        labels_batch[i, 0:index_length[1]] = \
                            labels_val[index_length[0]:index_length[0] + index_length[1]].astype(np.int32)

                    weights_batch = np.array(label_weights_list)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(batch_size_val, scaling_range=scaling_range_val)

                    sess_op_list = [loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op]

                    sess_feed_dict = {pts: points_batch,
                                      fts: intensity_batch,
                                      indices: pf.get_indices(batch_size_val, sample_num, points_num_batch),
                                      xforms: xforms_np,
                                      rotations: rotations_np,
                                      jitter_range: np.array([jitter_val]),
                                      labels_seg: labels_batch,
                                      labels_weights: weights_batch,
                                      is_training: False}

                    loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others = sess.run(sess_op_list,
                                                                                               feed_dict=sess_feed_dict)
                    print('{}-[Val  ]-Iter: {:06d}  Loss: {:.4f} T-1 Acc: {:.4f}'.format(datetime.now(), batch_val_idx,
                                                                                         loss_val, t_1_acc_val))

                    losses_val.append(loss_val * batch_size_val)
                    t_1_accs.append(t_1_acc_val * batch_size_val)
                    t_1_accs_instance.append(t_1_acc_val_instance * batch_size_val)
                    t_1_accs_others.append(t_1_acc_val_others * batch_size_val)

                loss_avg = sum(losses_val) / num_val
                t_1_acc_avg = sum(t_1_accs) / num_val
                t_1_acc_instance_avg = sum(t_1_accs_instance) / num_val
                t_1_acc_others_avg = sum(t_1_accs_others) / num_val

                summaries_feed_dict = {loss_val_avg: loss_avg,
                                       t_1_acc_val_avg: t_1_acc_avg,
                                       t_1_acc_val_instance_avg: t_1_acc_instance_avg,
                                       t_1_acc_val_others_avg: t_1_acc_others_avg}

                summaries_val = sess.run(summaries_val_op, feed_dict=summaries_feed_dict)
                summary_writer.add_summary(summaries_val, batch_idx)

                print('{}-[Val  ]-Average:      Loss: {:.4f} T-1 Acc: {:.4f}'
                      .format(datetime.now(), loss_avg, t_1_acc_avg))

                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx

            index_length_train_batch = index_length_train[start_idx:end_idx]
            points_batch = np.zeros((batch_size_train, point_num, 3), np.float32)
            intensity_batch = np.zeros((batch_size_train, point_num, 1), np.float32)
            points_num_batch = np.zeros(batch_size_train, np.int32)
            labels_batch = np.zeros((batch_size_train, point_num), np.int32)

            for i, index_length in enumerate(index_length_train_batch):
                points_batch[i, 0:index_length[1], :] = \
                    points_ele_train[index_length[0] * 3:
                                     index_length[0] * 3 + index_length[1] * 3].reshape(index_length[1], 3)

                intensity_batch[i, 0:index_length[1], :] = \
                    intensities_train[index_length[0]:
                                      index_length[0] + index_length[1]].reshape(index_length[1], 1)

                points_num_batch[i] = index_length[1].astype(np.int32)

                labels_batch[i, 0:index_length[1]] = \
                    labels_train[index_length[0]:index_length[0] + index_length[1]].astype(np.int32)

            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                index_length_train = data_utils.index_shuffle(index_length_train)

            offset = int(random.gauss(0, sample_num // 8))
            offset = max(offset, -sample_num // 4)
            offset = min(offset, sample_num // 4)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(batch_size_train, scaling_range=scaling_range)

            sess_op_list = [train_op, loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op, summaries_op]

            sess_feed_dict = {pts: points_batch,
                              fts: intensity_batch,
                              indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch),
                              xforms: xforms_np,
                              rotations: rotations_np,
                              jitter_range: np.array([jitter]),
                              labels_seg: labels_batch,
                              labels_weights: weights_batch,
                              is_training: True}

            _, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries = sess.run(sess_op_list,
                                                                                     feed_dict=sess_feed_dict)
            print('{}-[Train]-Iter: {:06d}  Loss_seg: {:.4f} T-1 Acc: {:.4f}'
                  .format(datetime.now(), batch_idx, loss, t_1_acc))

            summary_writer.add_summary(summaries, batch_idx)

            ######################################################################
        print('{}-Done!'.format(datetime.now()))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%d_%s' % (args.model, args.setting, os.getpid(), time_string))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    # 输出重定向
    # sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    print(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = 1000
    num_parts = setting.num_parts
    label_weights_list = setting.label_weights
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))

    data_train, data_num_train, label_train = data_utils.load_seg(
        args.filelist)
    data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]  # 2048
    point_num = data_train.shape[1]  # 6501
    num_val = data_val.shape[0]

    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3),
                            name="xforms")  # 缩放形变
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")  # 旋转形变
    jitter_range = tf.placeholder(tf.float32, shape=(1),
                                  name="jitter_range")  # 数据扰动
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int32,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')
    loss_val_avg = tf.placeholder(tf.float32)
    t_1_acc_val_avg = tf.placeholder(tf.float32)
    t_1_acc_val_instance_avg = tf.placeholder(tf.float32)
    t_1_acc_val_others_avg = tf.placeholder(tf.float32)

    ######################################################################

    # Set Inputs(points,features_sampled)
    features_sampled = None

    if setting.data_dim > 3:
        #   [3 pts_xyz , 2 img_xy , 4 extra_features]
        points, _, features = tf.split(pts_fts, [
            setting.data_format["pts_xyz"], setting.data_format["img_xy"],
            setting.data_format["extra_features"]
        ],
                                       axis=-1,
                                       name='split_points_xy_features')
        # r, g, b, i
        if setting.use_extra_features:
            # 选取extra特征
            features_sampled = tf.gather_nd(features,
                                            indices=indices,
                                            name='features_sampled')
    else:
        points = pts_fts

    points_sampled = tf.gather_nd(points,
                                  indices=indices,
                                  name='points_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)
    # N * 6501
    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    # Build net
    net = model.Net(points_augmented, features_sampled, None, None, num_parts,
                    is_training, setting)
    logits, probs = net.logits, net.probs

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)
    t_1_acc_op = pf.top_1_accuracy(probs, labels_sampled)
    t_1_acc_instance_op = pf.top_1_accuracy(probs, labels_sampled,
                                            labels_weights_sampled, 0.6)
    t_1_acc_others_op = pf.top_1_accuracy(probs, labels_sampled,
                                          labels_weights_sampled, 0.6, "less")

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()

    # lr decay
    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)

    # Optimizer
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=0.9,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    saver = tf.train.Saver(max_to_keep=1)

    # backup this file, model and setting
    if not os.path.exists(os.path.join(root_folder, args.model)):
        os.makedirs(os.path.join(root_folder, args.model))
    shutil.copy(__file__, os.path.join(root_folder,
                                       os.path.basename(__file__)))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'),
                os.path.join(root_folder, args.model + '.py'))
    shutil.copy(
        os.path.join(os.path.dirname(__file__),
                     args.model.split("_")[0] + "_kitti" + '.py'),
        os.path.join(root_folder,
                     args.model.split("_")[0] + "_kitti" + '.py'))
    shutil.copy(os.path.join(setting_path, args.setting + '.py'),
                os.path.join(root_folder, args.model, args.setting + '.py'))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    _ = tf.summary.scalar('loss/train_seg',
                          tensor=loss_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_seg',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_seg_instance',
                          tensor=t_1_acc_instance_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_seg_others',
                          tensor=t_1_acc_others_op,
                          collections=['train'])
    _ = tf.summary.scalar('loss/val_seg',
                          tensor=loss_val_avg,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_seg_instance',
                          tensor=t_1_acc_val_instance_avg,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_seg',
                          tensor=t_1_acc_val_avg,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_seg_others',
                          tensor=t_1_acc_val_others_avg,
                          collections=['val'])
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])

    # Session Run
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))

        for batch_idx in range(batch_num):
            if (batch_idx != 0 and batch_idx % step_val
                    == 0) or batch_idx == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                losses_val = []
                t_1_accs = []
                t_1_accs_instance = []
                t_1_accs_others = []

                for batch_val_idx in range(math.ceil(num_val / batch_size)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[label_val[
                        start_idx:end_idx, ...]]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val, scaling_range=scaling_range_val)
                    sess_op_list = [
                        loss_op, t_1_acc_op, t_1_acc_instance_op,
                        t_1_acc_others_op
                    ]
                    sess_feed_dict = {
                        pts_fts:
                        points_batch,
                        indices:
                        pf.get_indices(batch_size_val, sample_num,
                                       points_num_batch),
                        xforms:
                        xforms_np,
                        rotations:
                        rotations_np,
                        jitter_range:
                        np.array([jitter_val]),  # data jitter
                        labels_seg:
                        labels_batch,
                        labels_weights:
                        weights_batch,
                        is_training:
                        False
                    }

                    loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others = sess.run(
                        sess_op_list, feed_dict=sess_feed_dict)
                    print(
                        '{}-[Val  ]-Iter: {:06d}  Loss: {:.4f} T-1 Acc: {:.4f}'
                        .format(datetime.now(), batch_val_idx, loss_val,
                                t_1_acc_val))

                    sys.stdout.flush()
                    losses_val.append(loss_val * batch_size_val)
                    t_1_accs.append(t_1_acc_val * batch_size_val)
                    t_1_accs_instance.append(t_1_acc_val_instance *
                                             batch_size_val)
                    t_1_accs_others.append(t_1_acc_val_others * batch_size_val)

                loss_avg = sum(losses_val) / num_val
                t_1_acc_avg = sum(t_1_accs) / num_val
                t_1_acc_instance_avg = sum(t_1_accs_instance) / num_val
                t_1_acc_others_avg = sum(t_1_accs_others) / num_val

                summaries_feed_dict = {
                    loss_val_avg: loss_avg,
                    t_1_acc_val_avg: t_1_acc_avg,
                    t_1_acc_val_instance_avg: t_1_acc_instance_avg,
                    t_1_acc_val_others_avg: t_1_acc_others_avg
                }

                summaries_val = sess.run(summaries_val_op,
                                         feed_dict=summaries_feed_dict)
                summary_writer.add_summary(summaries_val, batch_idx)

                print('{}-[Val  ]-Average:      Loss: {:.4f} T-1 Acc: {:.4f}'.
                      format(datetime.now(), loss_avg, t_1_acc_avg))

                sys.stdout.flush()
                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]  # 2048 6501 9

            points_num_batch = data_num_train[start_idx:end_idx, ...]  #2048
            labels_batch = label_train[start_idx:end_idx, ...]  #2048 6501
            weights_batch = np.array(label_weights_list)[labels_batch]  # 值为零

            if start_idx + batch_size_train == num_train:
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(random.gauss(0, sample_num // 8))  # 均值 方差
            offset = max(offset, -sample_num // 4)
            offset = min(offset, sample_num // 4)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train, scaling_range=scaling_range)  #数据增强

            sess_op_list = [
                train_op, loss_op, t_1_acc_op, t_1_acc_instance_op,
                t_1_acc_others_op, summaries_op
            ]

            sess_feed_dict = {
                pts_fts:
                points_batch,
                indices:
                pf.get_indices(batch_size_train, sample_num_train,
                               points_num_batch),
                xforms:
                xforms_np,
                rotations:
                rotations_np,
                jitter_range:
                np.array([jitter]),
                labels_seg:
                labels_batch,
                labels_weights:
                weights_batch,
                is_training:
                True
            }

            _, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries = sess.run(
                sess_op_list, feed_dict=sess_feed_dict)
            print('{}-[Train]-Iter: {:06d}  Loss_seg: {:.4f} T-1 Acc: {:.4f}'.
                  format(datetime.now(), batch_idx, loss, t_1_acc))

            summary_writer.add_summary(summaries, batch_idx)
            sys.stdout.flush()

            ######################################################################
        print('{}-Done!'.format(datetime.now()))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist', '-t', help='Path to training set ground truth (.txt)', required=True)
    parser.add_argument('--filelist_val', '-v', help='Path to validation set ground truth (.txt)', required=True)
    parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load')
    parser.add_argument('--save_folder', '-s', help='Path to folder for saving check points and summary', required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting', '-x', help='Setting to use', required=True)
    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(args.save_folder, '%s_%s_%d_%s' % (args.model, args.setting, os.getpid(), time_string))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    print(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = 1000
    num_parts = setting.num_parts
    label_weights_list = setting.label_weights
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    
    #data_train, data_num_train, label_train = data_utils.load_seg(args.filelist)
    #data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val)
    data_train, img_train, data_num_train, label_train, label_hwl_train, label_xyz_train,label_ry_train, label_ry_reg_train = data_utils.load_seg_kitti(args.filelist)
    data_val, img_val, data_num_val, label_val, label_hwl_val, label_xyz_val,label_ry_val, label_ry_reg_val = data_utils.load_seg_kitti(args.filelist_val)

    # shuffle
    #data_train, data_num_train, label_train = \
    #    data_utils.grouped_shuffle([data_train, data_num_train, label_train])
    data_train, img_train, data_num_train, label_train, label_hwl_train, label_xyz_train,\
        label_ry_train, label_ry_reg_train = data_utils.grouped_shuffle([data_train, img_train,\
        data_num_train, label_train, label_hwl_train,label_xyz_train, label_ry_train,label_ry_reg_train])
        
    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]

    print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts')
    labels_seg = tf.placeholder(tf.int32, shape=(None, point_num), name='labels_seg')
    labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights')

    labels_hwl = tf.placeholder(tf.float32, shape=(None, 3), name='labels_hwl')
    labels_xyz = tf.placeholder(tf.float32, shape=(None, 3), name='labels_xyz')
    labels_ry = tf.placeholder(tf.int32, shape=(None, 1), name='labels_ry')
    labels_ry_reg = tf.placeholder(tf.float32, shape=(None, 1), name='labels_ry_reg')
    ######################################################################

    # Set Inputs(points,features_sampled)
    features_sampled = None

    if setting.data_dim > 3:

        points, _, features = tf.split(pts_fts, [setting.data_format["pts_xyz"], setting.data_format["img_xy"], setting.data_format["extra_features"]], axis=-1, name='split_points_xy_features')

        if setting.use_extra_features:

            features_sampled = tf.gather_nd(features, indices=indices, name='features_sampled')

    else:
        points = pts_fts

    points_sampled = tf.gather_nd(points, indices=indices, name='points_sampled')
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled')
    
    # Build net
    #net = model.Net(points_augmented, features_sampled, None, None, num_parts, is_training, setting)
    net = model.Net(points_augmented, features_sampled, None, None, num_parts, is_training, setting)
    #logits, probs= net.logits, net.probs
    logits, probs, logits_hwl, logits_xyz, logits_ry, probs_ry=net.logits, net.probs, net.logits_hwl, net.logits_xyz, net.logits_ry, net.probs_ry


    # Define Loss Func
    loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_sampled, logits=logits, weights=labels_weights_sampled)
    _ = tf.summary.scalar('loss/train_seg', tensor=loss_op, collections=['train'])

    # HWL loss
    # 这里以MSE loss 为例
    loss_hwl_op = tf.losses.mean_squared_error(labels_hwl, logits_hwl)  # API提供的MSE loss
    # 计算RMSE来作为一种评估指标
    rmse_hwl_op = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(labels_hwl - logits_hwl), -1)))

    # XYZ loss
    # 这里以MSE loss 为例
    loss_xyz_op = tf.losses.mean_squared_error(labels_xyz, logits_xyz)
    # 计算RMSE来作为一种评估指标
    rmse_xyz_op = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(labels_xyz - logits_xyz), -1)))

    # RY loss
    # 分类以softmax_cross_entropy为例
    loss_ry_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_ry, logits=logits_ry)
    # 计算分类准确率作为评估指标
    t_1_acc_ry_op = pf.top_1_accuracy(probs_ry, labels_ry)

    # 计算平均角度偏差作为评估指标
    mdis_ry_angle_op = tf.reduce_mean(tf.sqrt(tf.square(pf.ry_getangle(labels_ry,labels_ry_reg)-
                                                        pf.ry_getangle(tf.nn.top_k(probs_ry, 1)[1], tf.zeros_like(labels_ry_reg)))))

    # 将结果可视化在Tensorboard上
    _ = tf.summary.scalar('loss/train_hwl', tensor=loss_hwl_op, collections=['train'])
    _ = tf.summary.scalar('loss/train_xyz', tensor=loss_xyz_op, collections=['train'])
    _ = tf.summary.scalar('loss/train_ry', tensor=loss_ry_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_rmse_hwl', tensor=rmse_hwl_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_rmse_xyz', tensor=rmse_xyz_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train_ry', tensor=t_1_acc_ry_op, collections=['train'])
    _ = tf.summary.scalar('t_1_acc/mdis_ry_angle', tensor=mdis_ry_angle_op, collections=['train'])

    #for vis t1 acc
    t_1_acc_op = pf.top_1_accuracy(probs, labels_sampled)
    _ = tf.summary.scalar('t_1_acc/train_seg', tensor=t_1_acc_op, collections=['train'])
    #for vis instance acc
    t_1_acc_instance_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6)
    _ = tf.summary.scalar('t_1_acc/train_seg_instance', tensor=t_1_acc_instance_op, collections=['train'])
    #for vis other acc
    t_1_acc_others_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6,"less")
    _ = tf.summary.scalar('t_1_acc/train_seg_others', tensor=t_1_acc_others_op, collections=['train'])

    loss_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('loss/val_seg', tensor=loss_val_avg, collections=['val'])

    t_1_acc_val_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg', tensor=t_1_acc_val_avg, collections=['val'])
    t_1_acc_val_instance_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg_instance', tensor=t_1_acc_val_instance_avg, collections=['val'])
    t_1_acc_val_others_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_seg_others', tensor=t_1_acc_val_others_avg, collections=['val'])

    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()

    # 对于val,需要计算整个Val过程误差的平均值作为Val的评估指标
    loss_val_avg_hwl = tf.placeholder(tf.float32)
    loss_val_avg_xyz = tf.placeholder(tf.float32)
    loss_val_avg_ry = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('loss/val_hwl', tensor=loss_val_avg_hwl, collections=['val'])
    _ = tf.summary.scalar('loss/val_xyz', tensor=loss_val_avg_xyz, collections=['val'])
    _ = tf.summary.scalar('loss/val_ry', tensor=loss_val_avg_ry, collections=['val'])
    rmse_val_hwl_avg = tf.placeholder(tf.float32)
    rmse_val_xyz_avg = tf.placeholder(tf.float32)
    t_1_acc_val_ry_avg = tf.placeholder(tf.float32)
    mdis_val_ry_angle_avg = tf.placeholder(tf.float32)
    _ = tf.summary.scalar('t_1_acc/val_rmse_hwl', tensor=rmse_val_hwl_avg, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_rmse_xyz', tensor=rmse_val_xyz_avg, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_ry', tensor=t_1_acc_val_ry_avg, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val_mdis_ry_angle', tensor=mdis_val_ry_angle_avg, collections=['val'])

    # lr decay
    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,setting.decay_rate, staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train'])

    # Optimizer
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=0.9, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):

        #train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step)
        train_op = optimizer.minimize(loss_op + reg_loss + loss_hwl_op + loss_xyz_op + loss_ry_op, global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    # backup this file, model and setting
    if not os.path.exists(os.path.join(root_folder, args.model)):
        os.makedirs(os.path.join(root_folder, args.model))
    shutil.copy(__file__, os.path.join(root_folder, os.path.basename(__file__)))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'),
                os.path.join(root_folder, args.model + '.py'))
    shutil.copy(os.path.join(os.path.dirname(__file__), args.model.split("_")[0] + "_kitti" + '.py'),
                os.path.join(root_folder, args.model.split("_")[0] + "_kitti" + '.py'))
    shutil.copy(os.path.join(setting_path, args.setting + '.py'),
                os.path.join(root_folder, args.model, args.setting + '.py'))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    # Session Run
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt))

        for batch_idx in range(batch_num):
            if (batch_idx != 0 and batch_idx % step_val == 0) or batch_idx == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt))

                losses_val = []
                t_1_accs = []
                t_1_accs_instance = []
                t_1_accs_others = []

                # 声明用于计算平均值的list
                losses_val_hwl = []
                losses_val_xyz = []
                losses_val_ry = []
                rmses_hwl = []
                rmses_xyz = []
                t_1_accs_ry = []
                mdises_ry_angle = []


                for batch_val_idx in range(math.ceil(num_val / batch_size)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[label_val[start_idx:end_idx, ...]]

                    # 取出每个batch的label
                    labels_hwl_batch = label_hwl_val[start_idx:end_idx, ...]
                    labels_xyz_batch = label_xyz_val[start_idx:end_idx, ...]
                    labels_ry_batch = label_ry_val[start_idx:end_idx, ...]
                    labels_ry_reg_batch = label_ry_reg_val[start_idx:end_idx, ...]

                    xforms_np, rotations_np = pf.get_xforms(batch_size_val, scaling_range=scaling_range_val)

                    sess_op_list = [loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op]
                    sess_op_list = sess_op_list + [loss_hwl_op, loss_xyz_op, loss_ry_op, rmse_hwl_op, rmse_xyz_op,t_1_acc_ry_op, mdis_ry_angle_op]

                    sess_feed_dict = {pts_fts: points_batch,
                                     indices: pf.get_indices(batch_size_val, sample_num, points_num_batch),
                                     xforms: xforms_np,
                                     rotations: rotations_np,
                                     jitter_range: np.array([jitter_val]),
                                     labels_seg: labels_batch,
                                     labels_weights: weights_batch,
                                     is_training: False}
                    sess_feed_dict[labels_hwl] = labels_hwl_batch
                    sess_feed_dict[labels_xyz] = labels_xyz_batch
                    sess_feed_dict[labels_ry] = labels_ry_batch
                    sess_feed_dict[labels_ry_reg] = labels_ry_reg_batch

                    #loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others = sess.run(sess_op_list,feed_dict=sess_feed_dict)
                    loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others, loss_val_hwl, loss_val_xyz,loss_val_ry, rmse_hwl_val, rmse_xyz_val, t_1_acc_ry_val, mdis_ry_angle_val =sess.run(sess_op_list, feed_dict=sess_feed_dict)

                    #print('{}-[Val  ]-Iter: {:06d}  Loss: {:.4f} T-1 Acc: {:.4f}'.format(datetime.now(), batch_val_idx, loss_val, t_1_acc_val))
                    print('{}-[Val ]-Iter: {:06d} Loss: {:.4f} Loss_hwl: {:.4f} Loss_xyz: {:.4f} Loss_ry: {:.4f}Rmse_hwl: {:.4f} Rmse_xyz: {:.4f} T - 1Ry_Acc: {:.4f} Mdis_ry_angle: {:.4f} T - 1Acc: {:.4f}'.format(datetime.now(), batch_val_idx, loss_val, loss_val_hwl, loss_val_xyz, loss_val_ry,rmse_hwl_val, rmse_xyz_val, t_1_acc_ry_val, mdis_ry_angle_val, t_1_acc_val))

                    sys.stdout.flush()

                    losses_val.append(loss_val * batch_size_val)
                    t_1_accs.append(t_1_acc_val * batch_size_val)
                    t_1_accs_instance.append(t_1_acc_val_instance * batch_size_val)
                    t_1_accs_others.append(t_1_acc_val_others * batch_size_val)

                    # 记录每个iter的评估指标
                    losses_val_hwl.append(loss_val_hwl * batch_size_val)
                    losses_val_xyz.append(loss_val_xyz * batch_size_val)
                    losses_val_ry.append(loss_val_ry * batch_size_val)
                    rmses_hwl.append(rmse_hwl_val * batch_size_val)
                    rmses_xyz.append(rmse_xyz_val * batch_size_val)
                    t_1_accs_ry.append(t_1_acc_ry_val * batch_size_val)
                    mdises_ry_angle.append(mdis_ry_angle_val * batch_size_val)


                loss_avg = sum(losses_val)/num_val
                t_1_acc_avg = sum(t_1_accs) / num_val
                t_1_acc_instance_avg =  sum(t_1_accs_instance) / num_val
                t_1_acc_others_avg =  sum(t_1_accs_others) / num_val
                loss_avg_hwl = sum(losses_val_hwl) / num_val
                loss_avg_xyz = sum(losses_val_xyz) / num_val
                loss_avg_ry = sum(losses_val_ry) / num_val
                rmse_hwl_avg = sum(rmses_hwl) / num_val
                rmse_xyz_avg = sum(rmses_xyz) / num_val
                t_1_acc_ry_avg = sum(t_1_accs_ry) / num_val
                mdis_ry_angle_avg = sum(mdises_ry_angle) / num_val

                summaries_feed_dict = { loss_val_avg: loss_avg,
                                        t_1_acc_val_avg: t_1_acc_avg,
                                        t_1_acc_val_instance_avg: t_1_acc_instance_avg,
                                        t_1_acc_val_others_avg: t_1_acc_others_avg}
                summaries_feed_dict[loss_val_avg_hwl] = loss_avg_hwl
                summaries_feed_dict[loss_val_avg_xyz] = loss_avg_xyz
                summaries_feed_dict[loss_val_avg_ry] = loss_avg_ry
                summaries_feed_dict[rmse_val_hwl_avg] = rmse_hwl_avg
                summaries_feed_dict[rmse_val_xyz_avg] = rmse_xyz_avg
                summaries_feed_dict[t_1_acc_val_ry_avg] = t_1_acc_ry_avg
                summaries_feed_dict[mdis_val_ry_angle_avg] = mdis_ry_angle_avg

                summaries_val = sess.run(summaries_val_op,feed_dict=summaries_feed_dict)
                summary_writer.add_summary(summaries_val, batch_idx)

                #print('{}-[Val  ]-Average:      Loss: {:.4f} T-1 Acc: {:.4f}'
                #    .format(datetime.now(), loss_avg, t_1_acc_avg))
                print('{}-[Val ]-Average: Loss: {:.4f} Loss_hwl: {:.4f} Loss_xyz: {:.4f} Loss_ry: {:.4f}Rmse_hwl: {:.4f} Rmse_xyz: {:.4f} T - 1Ry_Acc: {:.4f} Mdis_ry_angle: {:.4f} T - 1Acc: {:.4f}'.format(datetime.now(), loss_avg, loss_avg_hwl, loss_avg_xyz, loss_avg_ry, rmse_hwl_avg,rmse_xyz_avg, t_1_acc_ry_avg, mdis_ry_angle_avg,t_1_acc_avg))

                sys.stdout.flush()
                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            labels_hwl_batch = label_hwl_train[start_idx:end_idx, ...]
            labels_xyz_batch = label_xyz_train[start_idx:end_idx, ...]
            labels_ry_batch = label_ry_train[start_idx:end_idx, ...]
            labels_ry_reg_batch = label_ry_reg_train[start_idx:end_idx, ...]
                
            if start_idx + batch_size_train == num_train:
                #data_train, data_num_train, label_train = \
                #    data_utils.grouped_shuffle([data_train, data_num_train, label_train])
                data_train, img_train, data_num_train, label_train, label_hwl_train, label_xyz_train, \
                    label_ry_train, label_ry_reg_train = data_utils.grouped_shuffle([data_train, img_train,data_num_train, label_train,label_hwl_train, label_xyz_train,label_ry_train, label_ry_reg_train])

            offset = int(random.gauss(0, sample_num // 8))
            offset = max(offset, -sample_num // 4)
            offset = min(offset, sample_num // 4)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(batch_size_train, scaling_range=scaling_range)

            sess_op_list = [train_op, loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op, summaries_op]
            sess_op_list = sess_op_list + [loss_hwl_op, loss_xyz_op, loss_ry_op, rmse_hwl_op, rmse_xyz_op,t_1_acc_ry_op, mdis_ry_angle_op]

            sess_feed_dict = {pts_fts: points_batch,
                             indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch),
                             xforms: xforms_np,
                             rotations: rotations_np,
                             jitter_range: np.array([jitter]),
                             labels_seg: labels_batch,
                             labels_weights: weights_batch,
                             is_training: True}

            # 为placeholder feed label
            sess_feed_dict[labels_hwl] = labels_hwl_batch
            sess_feed_dict[labels_xyz] = labels_xyz_batch
            sess_feed_dict[labels_ry] = labels_ry_batch
            sess_feed_dict[labels_ry_reg] = labels_ry_reg_batch

            #_, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries = sess.run(sess_op_list,feed_dict=sess_feed_dict)
            _, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries, loss_hwl, loss_xyz, loss_ry,rmse_hwl, rmse_xyz, t_1_acc_ry, mdis_ry_angle=sess.run(sess_op_list, feed_dict=sess_feed_dict)

            #print('{}-[Train]-Iter: {:06d}  Loss_seg: {:.4f} T-1 Acc: {:.4f}'
            #  .format(datetime.now(), batch_idx, loss, t_1_acc))
            print('{}-[Train]-Iter: {:06d} Loss_seg: {:.4f} Loss_hwl: {:.4f} Loss_xyz: {:.4f} Loss_ry:{:.4f} Rmse_hwl: {:.4f} Rmse_xyz: {:.4f} T - 1Ry_Acc: {:.4f} Mdis_ry_angle: {:.4f} T - 1Acc:{:.4f}'.format(datetime.now(), batch_idx, loss, loss_hwl, loss_xyz, loss_ry, rmse_hwl, rmse_xyz,t_1_acc_ry, mdis_ry_angle, t_1_acc))


            summary_writer.add_summary(summaries, batch_idx)
            sys.stdout.flush()
            
            ######################################################################
        print('{}-Done!'.format(datetime.now()))