def main(): parser = argparse.ArgumentParser() parser.add_argument('--path', '-t', help='Path to data', required=True) parser.add_argument('--path_val', '-v', help='Path to validation data') parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load') parser.add_argument( '--save_folder', '-s', help='Path to folder for saving check points and summary', required=True) parser.add_argument('--model', '-m', help='Model to use', required=True) parser.add_argument('--setting', '-x', help='Setting to use', required=True) parser.add_argument( '--epochs', help='Number of training epochs (default defined in setting)', type=int) parser.add_argument('--batch_size', help='Batch size (default defined in setting)', type=int) parser.add_argument( '--log', help= 'Log to FILE in save folder; use - for stdout (default is log.txt)', metavar='FILE', default='log.txt') parser.add_argument('--no_timestamp_folder', help='Dont save to timestamp folder', action='store_true') parser.add_argument('--no_code_backup', help='Dont backup code', action='store_true') args = parser.parse_args() if not args.no_timestamp_folder: time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') root_folder = os.path.join( args.save_folder, '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid())) else: root_folder = args.save_folder if not os.path.exists(root_folder): os.makedirs(root_folder) if args.log != '-': sys.stdout = open(os.path.join(root_folder, args.log), 'w') print('PID:', os.getpid()) print(args) model = importlib.import_module(args.model) setting_path = os.path.join(os.path.dirname(__file__), args.model) sys.path.append(setting_path) setting = importlib.import_module(args.setting) num_epochs = args.epochs or setting.num_epochs batch_size = args.batch_size or setting.batch_size sample_num = setting.sample_num step_val = setting.step_val rotation_range = setting.rotation_range rotation_range_val = setting.rotation_range_val scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val pool_setting_val = None if not hasattr( setting, 'pool_setting_val') else setting.pool_setting_val pool_setting_train = None if not hasattr( setting, 'pool_setting_train') else setting.pool_setting_train # Prepare inputs print('{}-Preparing datasets...'.format(datetime.now())) data_train, label_train, data_val, label_val = setting.load_fn( args.path, args.path_val) if setting.balance_fn is not None: num_train_before_balance = data_train.shape[0] repeat_num = setting.balance_fn(label_train) data_train = np.repeat(data_train, repeat_num, axis=0) label_train = np.repeat(label_train, repeat_num, axis=0) data_train, label_train = data_utils.grouped_shuffle( [data_train, label_train]) num_epochs = math.floor( num_epochs * (num_train_before_balance / data_train.shape[0])) if setting.save_ply_fn is not None: folder = os.path.join(root_folder, 'pts') print('{}-Saving samples as .ply files to {}...'.format( datetime.now(), folder)) sample_num_for_ply = min(512, data_train.shape[0]) if setting.map_fn is None: data_sample = data_train[:sample_num_for_ply] else: data_sample_list = [] for idx in range(sample_num_for_ply): data_sample_list.append(setting.map_fn(data_train[idx], 0)[0]) data_sample = np.stack(data_sample_list) setting.save_ply_fn(data_sample, folder) num_train = data_train.shape[0] point_num = data_train.shape[1] num_val = data_val.shape[0] print('{}-{:d}/{:d} training/validation samples.'.format( datetime.now(), num_train, num_val)) ###################################################################### # Placeholders indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') data_train_placeholder = tf.placeholder(data_train.dtype, data_train.shape, name='data_train') label_train_placeholder = tf.placeholder(tf.int64, label_train.shape, name='label_train') data_val_placeholder = tf.placeholder(data_val.dtype, data_val.shape, name='data_val') label_val_placeholder = tf.placeholder(tf.int64, label_val.shape, name='label_val') handle = tf.placeholder(tf.string, shape=[], name='handle') ###################################################################### dataset_train = tf.data.Dataset.from_tensor_slices( (data_train_placeholder, label_train_placeholder)) dataset_train = dataset_train.shuffle(buffer_size=batch_size * 4) if setting.map_fn is not None: dataset_train = dataset_train.map( lambda data, label: tuple( tf.py_func(setting.map_fn, [data, label], [tf.float32, label.dtype])), num_parallel_calls=setting.num_parallel_calls) # quick draw 会用到map if setting.keep_remainder: dataset_train = dataset_train.batch(batch_size) batch_num_per_epoch = math.ceil(num_train / batch_size) else: dataset_train = dataset_train.apply( tf.contrib.data.batch_and_drop_remainder(batch_size)) batch_num_per_epoch = math.floor(num_train / batch_size) dataset_train = dataset_train.repeat(num_epochs) iterator_train = dataset_train.make_initializable_iterator() batch_num = batch_num_per_epoch * num_epochs print('{}-{:d} training batches.'.format(datetime.now(), batch_num)) dataset_val = tf.data.Dataset.from_tensor_slices( (data_val_placeholder, label_val_placeholder)) if setting.map_fn is not None: dataset_val = dataset_val.map( lambda data, label: tuple( tf.py_func(setting.map_fn, [data, label], [tf.float32, label.dtype])), num_parallel_calls=setting.num_parallel_calls) if setting.keep_remainder: dataset_val = dataset_val.batch(batch_size) batch_num_val = math.ceil(num_val / batch_size) # batch_num_val = count batch else: dataset_val = dataset_val.apply( tf.contrib.data.batch_and_drop_remainder(batch_size)) batch_num_val = math.floor(num_val / batch_size) iterator_val = dataset_val.make_initializable_iterator() print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val)) iterator = tf.data.Iterator.from_string_handle( handle, dataset_train.output_types) # feedable iterator, (pts_fts, labels) = iterator.get_next() pts_fts_sampled = tf.gather_nd( pts_fts, indices=indices, name='pts_fts_sampled') # gather_nd is differentiable features_augmented = None if setting.data_dim > 3: points_sampled, features_sampled = tf.split( pts_fts_sampled, [3, setting.data_dim - 3], axis=-1, name='split_points_features') if setting.use_extra_features: if setting.with_normal_feature: if setting.data_dim < 6: print('Only 3D normals are supported!') exit() elif setting.data_dim == 6: features_augmented = pf.augment(features_sampled, rotations) else: normals, rest = tf.split(features_sampled, [3, setting.data_dim - 6]) normals_augmented = pf.augment(normals, rotations) features_augmented = tf.concat([normals_augmented, rest], axis=-1) else: features_augmented = features_sampled else: points_sampled = pts_fts_sampled points_augmented = pf.augment( points_sampled, xforms, jitter_range) # input含feature吗;要不要用这些Feature, rotate/jitter/scale加强 net = model.Net(points=points_augmented, features=features_augmented, is_training=is_training, setting=setting) logits = net.logits probs = tf.nn.softmax(logits, name='probs') predictions = tf.argmax(probs, axis=-1, name='predictions') labels_2d = tf.expand_dims(labels, axis=-1, name='labels_2d') labels_tile = tf.tile(labels_2d, (1, tf.shape(logits)[1]), name='labels_tile') loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile, logits=logits) with tf.name_scope('metrics'): loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op) t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy( labels_tile, predictions) t_1_per_class_acc_op, t_1_per_class_acc_update_op = tf.metrics.mean_per_class_accuracy( labels_tile, predictions, setting.num_class) reset_metrics_op = tf.variables_initializer([ var for var in tf.local_variables() if var.name.split('/')[0] == 'metrics' ]) _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train']) _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val']) _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val']) lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=None) # backup all code if not args.no_code_backup: code_folder = os.path.abspath(os.path.dirname(__file__)) shutil.copytree( code_folder, os.path.join(root_folder, os.path.basename(code_folder))) folder_ckpt = os.path.join(root_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum( [np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num)) with tf.Session() as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) sess.run(init_op) # Load the model if args.load_ckpt is not None: saver.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded from {}!'.format( datetime.now(), args.load_ckpt)) else: latest_ckpt = tf.train.latest_checkpoint(folder_ckpt) if latest_ckpt: print('{}-Found checkpoint {}'.format(datetime.now(), latest_ckpt)) saver.restore(sess, latest_ckpt) print('{}-Checkpoint loaded from {} (Iter {})'.format( datetime.now(), latest_ckpt, sess.run(global_step))) handle_train = sess.run(iterator_train.string_handle()) handle_val = sess.run(iterator_val.string_handle()) sess.run(iterator_train.initializer, feed_dict={ data_train_placeholder: data_train, label_train_placeholder: label_train, }) for batch_idx_train in range(batch_num): ###################################################################### # Validation if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \ or batch_idx_train == batch_num - 1: sess.run(iterator_val.initializer, feed_dict={ data_val_placeholder: data_val, label_val_placeholder: label_val, }) filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) print('{}-Checkpoint saved to {}!'.format( datetime.now(), filename_ckpt)) sess.run(reset_metrics_op) for batch_idx_val in range( batch_num_val): # batch_idx_val = record which batch if not setting.keep_remainder \ or num_val % batch_size == 0 \ or batch_idx_val != batch_num_val - 1: batch_size_val = batch_size else: batch_size_val = num_val % batch_size # batch_size_val = ex in this batch xforms_np, rotations_np = pf.get_xforms( batch_size_val, # xforms rely on batch_size_val rotation_range=rotation_range_val, scaling_range=scaling_range_val, order=setting.rotation_order) sess.run( [ loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ handle: handle_val, indices: pf.get_indices( batch_size_val, sample_num, point_num, ), # randomly pick sample_num points xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), is_training: False, }) loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run( [ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_val_op, global_step ]) summary_writer.add_summary(summaries_val, step) print( '{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), loss_val, t_1_acc_val, t_1_per_class_acc_val)) sys.stdout.flush() ###################################################################### ###################################################################### # Training if not setting.keep_remainder \ or num_train % batch_size == 0 \ or (batch_idx_train % batch_num_per_epoch) != (batch_num_per_epoch - 1): batch_size_train = batch_size else: batch_size_train = num_train % batch_size offset = int( random.gauss(0, sample_num * setting.sample_num_variance)) offset = max(offset, -sample_num * setting.sample_num_clip) offset = min(offset, sample_num * setting.sample_num_clip) sample_num_train = sample_num + offset # 'To train a model that takes N points as input, N (N,(N/8)2) points are used for training', 作者尝试发现这样取样训练效果好 xforms_np, rotations_np = pf.get_xforms( batch_size_train, rotation_range=rotation_range, scaling_range=scaling_range, order=setting.rotation_order) sess.run(reset_metrics_op) sess.run( [ train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ handle: handle_train, indices: pf.get_indices(batch_size_train, sample_num_train, point_num, pool_setting_train), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), is_training: True, }) if batch_idx_train % 10 == 0: loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_op, global_step ]) summary_writer.add_summary(summaries, step) print( '{}-[Train]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), step, loss, t_1_acc, t_1_per_class_acc)) sys.stdout.flush() ###################################################################### print('{}-Done!'.format(datetime.now()))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--filelist', '-t', help='Path to training set ground truth (.txt)', required=True) parser.add_argument('--filelist_val', '-v', help='Path to validation set ground truth (.txt)', required=True) parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load') parser.add_argument( '--save_folder', '-s', help='Path to folder for saving check points and summary', required=True) parser.add_argument('--model', '-m', help='Model to use', required=True) parser.add_argument('--setting', '-x', help='Setting to use', required=True) parser.add_argument( '--epochs', help='Number of training epochs (default defined in setting)', type=int) parser.add_argument('--batch_size', help='Batch size (default defined in setting)', type=int) parser.add_argument( '--log', help= 'Log to FILE in save folder; use - for stdout (default is log.txt)', metavar='FILE', default='log.txt') parser.add_argument('--no_timestamp_folder', help='Dont save to timestamp folder', action='store_true') parser.add_argument('--no_code_backup', help='Dont backup code', action='store_true') args = parser.parse_args() if not args.no_timestamp_folder: time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') root_folder = os.path.join( args.save_folder, '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid())) else: root_folder = args.save_folder if not os.path.exists(root_folder): os.makedirs(root_folder) if args.log != '-': sys.stdout = open(os.path.join(root_folder, args.log), 'w') #print('PID:', os.getpid()) #print(args) model = importlib.import_module(args.model) setting_path = os.path.join(os.path.dirname(__file__), args.model) sys.path.append(setting_path) setting = importlib.import_module(args.setting) num_epochs = args.epochs or setting.num_epochs batch_size = args.batch_size or setting.batch_size sample_num = setting.sample_num step_val = setting.step_val label_weights_list = setting.label_weights rotation_range = setting.rotation_range rotation_range_val = setting.rotation_range_val scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val # Prepare inputs #print('{}-Preparing datasets...'.format(datetime.now())) is_list_of_h5_list = not data_utils.is_h5_list(args.filelist) if is_list_of_h5_list: seg_list = data_utils.load_seg_list(args.filelist) seg_list_idx = 0 filelist_train = seg_list[seg_list_idx] seg_list_idx = seg_list_idx + 1 else: filelist_train = args.filelist data_train, _, data_num_train, label_train, _ = data_utils.load_seg( filelist_train) data_val, _, data_num_val, label_val, _ = data_utils.load_seg( args.filelist_val) # shuffle # data_num_train is the number of points in each point cloud data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) num_train = data_train.shape[0] point_num = data_train.shape[1] num_val = data_val.shape[0] #print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val)) batch_num = (num_train * num_epochs + batch_size - 1) // batch_size #print('{}-{:d} training batches.'.format(datetime.now(), batch_num)) batch_num_val = math.ceil(num_val / batch_size) #print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val)) ###################################################################### # Placeholders indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts') labels_seg = tf.placeholder(tf.int64, shape=(None, point_num), name='labels_seg') labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights') ###################################################################### pts_fts_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled') features_augmented = None if setting.data_dim > 3: points_sampled, features_sampled = tf.split( pts_fts_sampled, [3, setting.data_dim - 3], axis=-1, name='split_points_features') if setting.use_extra_features: if setting.with_normal_feature: if setting.data_dim < 6: print('Only 3D normals are supported!') exit() elif setting.data_dim == 6: features_augmented = pf.augment(features_sampled, rotations) else: normals, rest = tf.split(features_sampled, [3, setting.data_dim - 6]) normals_augmented = pf.augment(normals, rotations) features_augmented = tf.concat([normals_augmented, rest], axis=-1) else: features_augmented = features_sampled else: points_sampled = pts_fts_sampled points_augmented = pf.augment(points_sampled, xforms, jitter_range) labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled') labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled') net = model.Net(points_augmented, features_augmented, is_training, setting) logits = net.logits probs = tf.nn.softmax(logits, name='probs') predictions = tf.argmax(probs, axis=-1, name='predictions') loss_op = tf.losses.sparse_softmax_cross_entropy( labels=labels_sampled, logits=logits, weights=labels_weights_sampled) with tf.name_scope('metrics'): loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op) t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy( labels_sampled, predictions, weights=labels_weights_sampled) t_1_per_class_acc_op, t_1_per_class_acc_update_op = \ tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class, weights=labels_weights_sampled) reset_metrics_op = tf.variables_initializer([ var for var in tf.local_variables() if var.name.split('/')[0] == 'metrics' ]) _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train']) _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val']) _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val']) lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=None) # backup all code if not args.no_code_backup: code_folder = os.path.abspath(os.path.dirname(__file__)) shutil.copytree( code_folder, os.path.join(root_folder, os.path.basename(code_folder))) folder_ckpt = os.path.join(root_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum( [np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num)) with tf.Session() as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) sess.run(init_op) # Load the model if args.load_ckpt is not None: saver.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded from {}!'.format( datetime.now(), args.load_ckpt)) else: latest_ckpt = tf.train.latest_checkpoint(folder_ckpt) if latest_ckpt: print('{}-Found checkpoint {}'.format(datetime.now(), latest_ckpt)) saver.restore(sess, latest_ckpt) print('{}-Checkpoint loaded from {} (Iter {})'.format( datetime.now(), latest_ckpt, sess.run(global_step))) for batch_idx_train in tqdm(range(batch_num)): ###################################################################### # Validation if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \ or batch_idx_train == batch_num - 1: filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) print('{}-Checkpoint saved to {}!'.format( datetime.now(), filename_ckpt)) sess.run(reset_metrics_op) for batch_val_idx in range(batch_num_val): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_val) batch_size_val = end_idx - start_idx points_batch = data_val[start_idx:end_idx, ...] points_num_batch = data_num_val[start_idx:end_idx, ...] labels_batch = label_val[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] xforms_np, rotations_np = pf.get_xforms( batch_size_val, rotation_range=rotation_range_val, scaling_range=scaling_range_val, order=setting.rotation_order) sess.run( [ loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_val, sample_num, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: False, }) loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run( [ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_val_op, global_step ]) summary_writer.add_summary(summaries_val, step) print( '{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), loss_val, t_1_acc_val, t_1_per_class_acc_val)) sys.stdout.flush() ###################################################################### ###################################################################### # Training start_idx = (batch_size * batch_idx_train) % num_train end_idx = min(start_idx + batch_size, num_train) batch_size_train = end_idx - start_idx points_batch = data_train[start_idx:end_idx, ...] points_num_batch = data_num_train[start_idx:end_idx, ...] labels_batch = label_train[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] if start_idx + batch_size_train == num_train: if is_list_of_h5_list: filelist_train_prev = seg_list[(seg_list_idx - 1) % len(seg_list)] filelist_train = seg_list[seg_list_idx % len(seg_list)] if filelist_train != filelist_train_prev: data_train, _, data_num_train, label_train, _ = data_utils.load_seg( filelist_train) num_train = data_train.shape[0] seg_list_idx = seg_list_idx + 1 data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) offset = int( random.gauss(0, sample_num * setting.sample_num_variance)) offset = max(offset, -sample_num * setting.sample_num_clip) offset = min(offset, sample_num * setting.sample_num_clip) sample_num_train = sample_num + offset xforms_np, rotations_np = pf.get_xforms( batch_size_train, rotation_range=rotation_range, scaling_range=scaling_range, order=setting.rotation_order) sess.run(reset_metrics_op) sess.run( [ train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True, }) if batch_idx_train % 10 == 0: loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_op, global_step ]) summary_writer.add_summary(summaries, step) print( '{}-[Train]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), step, loss, t_1_acc, t_1_per_class_acc)) sys.stdout.flush() ###################################################################### print('{}-Done!'.format(datetime.now()))
def load_fn(folder_npz, ratio, categories=None): lift_pen_padding = 2.0 categories = [ line.strip() for line in open(os.path.join(folder_npz, 'categories.txt'), 'r') ] if categories is None else categories stoke_len_max = 0 stoke_len_sum = 0 stoke_num = 0 load_data_list = [] for idx_category, category in enumerate(categories): print('{}-Loading category {} ({} of {})...'.format( datetime.now(), category, idx_category + 1, len(categories))) sys.stdout.flush() filename_category = os.path.join(folder_npz, category + '.npz') load_data = np.load(filename_category, encoding='bytes') load_data_list.append(load_data) for tag in load_data: for stoke in load_data[tag]: stoke_len_max = max(stoke_len_max, stoke.shape[0]) stoke_len_sum += stoke.shape[0] stoke_num += len(load_data[tag]) print('{}-Max stoke length: {}, average stoke length: {}.'.format( datetime.now(), stoke_len_max, stoke_len_sum / stoke_num)) sys.stdout.flush() stoke_placeholder = np.array([(0.0, 0.0, lift_pen_padding)] * stoke_len_max).astype(np.float32) raw_train_list = [] label_train_list = [] raw_val_list = [] label_val_list = [] for idx_category, category in enumerate(categories): print('{}-Extracting category {} ({} of {})...'.format( datetime.now(), category, idx_category + 1, len(categories))) sys.stdout.flush() load_data = load_data_list[idx_category] raw_train_list.append( _extract_padded_stokes(load_data['train'], stoke_len_max, stoke_placeholder, ratio)) label_train_list += [idx_category] * len(raw_train_list[-1]) raw_val_list.append( _extract_padded_stokes(load_data['valid'], stoke_len_max, stoke_placeholder, ratio)) label_val_list += [idx_category] * len(raw_val_list[-1]) raw_train = np.concatenate(raw_train_list, axis=0) label_train = np.array(label_train_list) raw_val = np.concatenate(raw_val_list, axis=0) label_val = np.array(label_val_list) print('{}-Shuffling data...'.format(datetime.now())) sys.stdout.flush() raw_train, label_train = data_utils.grouped_shuffle( [raw_train, label_train]) raw_val, label_val = data_utils.grouped_shuffle([raw_val, label_val]) print('{}-Quick Draw data loaded!'.format(datetime.now())) sys.stdout.flush() return raw_train, label_train, raw_val, label_val
def main(): parser = argparse.ArgumentParser() parser.add_argument('--filelist', '-t', help='Path to training set ground truth (.txt)', required=True) parser.add_argument('--filelist_val', '-v', help='Path to validation set ground truth (.txt)', required=True) parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load') parser.add_argument( '--save_folder', '-s', help='Path to folder for saving check points and summary', required=True) parser.add_argument('--model', '-m', help='Model to use', required=True) parser.add_argument('--setting', '-x', help='Setting to use', required=True) args = parser.parse_args() time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') root_folder = os.path.join( args.save_folder, '%s_%s_%d_%s' % (args.model, args.setting, os.getpid(), time_string)) if not os.path.exists(root_folder): os.makedirs(root_folder) # 输出重定向 # sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w') print('PID:', os.getpid()) print(args) model = importlib.import_module(args.model) setting_path = os.path.join(os.path.dirname(__file__), args.model) sys.path.append(setting_path) print(setting_path) setting = importlib.import_module(args.setting) num_epochs = setting.num_epochs batch_size = setting.batch_size sample_num = setting.sample_num step_val = 1000 num_parts = setting.num_parts label_weights_list = setting.label_weights scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val # Prepare inputs print('{}-Preparing datasets...'.format(datetime.now())) data_train, data_num_train, label_train = data_utils.load_seg( args.filelist) data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val) # shuffle data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) num_train = data_train.shape[0] # 2048 point_num = data_train.shape[1] # 6501 num_val = data_val.shape[0] print('{}-{:d}/{:d} training/validation samples.'.format( datetime.now(), num_train, num_val)) batch_num = (num_train * num_epochs + batch_size - 1) // batch_size print('{}-{:d} training batches.'.format(datetime.now(), batch_num)) batch_num_val = math.ceil(num_val / batch_size) print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val)) ###################################################################### # Placeholders indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") # 缩放形变 rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") # 旋转形变 jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") # 数据扰动 global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts') labels_seg = tf.placeholder(tf.int32, shape=(None, point_num), name='labels_seg') labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights') loss_val_avg = tf.placeholder(tf.float32) t_1_acc_val_avg = tf.placeholder(tf.float32) t_1_acc_val_instance_avg = tf.placeholder(tf.float32) t_1_acc_val_others_avg = tf.placeholder(tf.float32) ###################################################################### # Set Inputs(points,features_sampled) features_sampled = None if setting.data_dim > 3: # [3 pts_xyz , 2 img_xy , 4 extra_features] points, _, features = tf.split(pts_fts, [ setting.data_format["pts_xyz"], setting.data_format["img_xy"], setting.data_format["extra_features"] ], axis=-1, name='split_points_xy_features') # r, g, b, i if setting.use_extra_features: # 选取extra特征 features_sampled = tf.gather_nd(features, indices=indices, name='features_sampled') else: points = pts_fts points_sampled = tf.gather_nd(points, indices=indices, name='points_sampled') points_augmented = pf.augment(points_sampled, xforms, jitter_range) # N * 6501 labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled') labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled') # Build net net = model.Net(points_augmented, features_sampled, None, None, num_parts, is_training, setting) logits, probs = net.logits, net.probs loss_op = tf.losses.sparse_softmax_cross_entropy( labels=labels_sampled, logits=logits, weights=labels_weights_sampled) t_1_acc_op = pf.top_1_accuracy(probs, labels_sampled) t_1_acc_instance_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6) t_1_acc_others_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6, "less") reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() # lr decay lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) # Optimizer if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=0.9, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=1) # backup this file, model and setting if not os.path.exists(os.path.join(root_folder, args.model)): os.makedirs(os.path.join(root_folder, args.model)) shutil.copy(__file__, os.path.join(root_folder, os.path.basename(__file__))) shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'), os.path.join(root_folder, args.model + '.py')) shutil.copy( os.path.join(os.path.dirname(__file__), args.model.split("_")[0] + "_kitti" + '.py'), os.path.join(root_folder, args.model.split("_")[0] + "_kitti" + '.py')) shutil.copy(os.path.join(setting_path, args.setting + '.py'), os.path.join(root_folder, args.model, args.setting + '.py')) folder_ckpt = os.path.join(root_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum( [np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num)) _ = tf.summary.scalar('loss/train_seg', tensor=loss_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train_seg', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train_seg_instance', tensor=t_1_acc_instance_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train_seg_others', tensor=t_1_acc_others_op, collections=['train']) _ = tf.summary.scalar('loss/val_seg', tensor=loss_val_avg, collections=['val']) _ = tf.summary.scalar('t_1_acc/val_seg_instance', tensor=t_1_acc_val_instance_avg, collections=['val']) _ = tf.summary.scalar('t_1_acc/val_seg', tensor=t_1_acc_val_avg, collections=['val']) _ = tf.summary.scalar('t_1_acc/val_seg_others', tensor=t_1_acc_val_others_avg, collections=['val']) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) # Session Run with tf.Session() as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) sess.run(init_op) # Load the model if args.load_ckpt is not None: saver.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded from {}!'.format( datetime.now(), args.load_ckpt)) for batch_idx in range(batch_num): if (batch_idx != 0 and batch_idx % step_val == 0) or batch_idx == batch_num - 1: ###################################################################### # Validation filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) print('{}-Checkpoint saved to {}!'.format( datetime.now(), filename_ckpt)) losses_val = [] t_1_accs = [] t_1_accs_instance = [] t_1_accs_others = [] for batch_val_idx in range(math.ceil(num_val / batch_size)): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_val) batch_size_val = end_idx - start_idx points_batch = data_val[start_idx:end_idx, ...] points_num_batch = data_num_val[start_idx:end_idx, ...] labels_batch = label_val[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[label_val[ start_idx:end_idx, ...]] xforms_np, rotations_np = pf.get_xforms( batch_size_val, scaling_range=scaling_range_val) sess_op_list = [ loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op ] sess_feed_dict = { pts_fts: points_batch, indices: pf.get_indices(batch_size_val, sample_num, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), # data jitter labels_seg: labels_batch, labels_weights: weights_batch, is_training: False } loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others = sess.run( sess_op_list, feed_dict=sess_feed_dict) print( '{}-[Val ]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f}' .format(datetime.now(), batch_val_idx, loss_val, t_1_acc_val)) sys.stdout.flush() losses_val.append(loss_val * batch_size_val) t_1_accs.append(t_1_acc_val * batch_size_val) t_1_accs_instance.append(t_1_acc_val_instance * batch_size_val) t_1_accs_others.append(t_1_acc_val_others * batch_size_val) loss_avg = sum(losses_val) / num_val t_1_acc_avg = sum(t_1_accs) / num_val t_1_acc_instance_avg = sum(t_1_accs_instance) / num_val t_1_acc_others_avg = sum(t_1_accs_others) / num_val summaries_feed_dict = { loss_val_avg: loss_avg, t_1_acc_val_avg: t_1_acc_avg, t_1_acc_val_instance_avg: t_1_acc_instance_avg, t_1_acc_val_others_avg: t_1_acc_others_avg } summaries_val = sess.run(summaries_val_op, feed_dict=summaries_feed_dict) summary_writer.add_summary(summaries_val, batch_idx) print('{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f}'. format(datetime.now(), loss_avg, t_1_acc_avg)) sys.stdout.flush() ###################################################################### ###################################################################### # Training start_idx = (batch_size * batch_idx) % num_train end_idx = min(start_idx + batch_size, num_train) batch_size_train = end_idx - start_idx points_batch = data_train[start_idx:end_idx, ...] # 2048 6501 9 points_num_batch = data_num_train[start_idx:end_idx, ...] #2048 labels_batch = label_train[start_idx:end_idx, ...] #2048 6501 weights_batch = np.array(label_weights_list)[labels_batch] # 值为零 if start_idx + batch_size_train == num_train: data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) offset = int(random.gauss(0, sample_num // 8)) # 均值 方差 offset = max(offset, -sample_num // 4) offset = min(offset, sample_num // 4) sample_num_train = sample_num + offset xforms_np, rotations_np = pf.get_xforms( batch_size_train, scaling_range=scaling_range) #数据增强 sess_op_list = [ train_op, loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op, summaries_op ] sess_feed_dict = { pts_fts: points_batch, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True } _, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries = sess.run( sess_op_list, feed_dict=sess_feed_dict) print('{}-[Train]-Iter: {:06d} Loss_seg: {:.4f} T-1 Acc: {:.4f}'. format(datetime.now(), batch_idx, loss, t_1_acc)) summary_writer.add_summary(summaries, batch_idx) sys.stdout.flush() ###################################################################### print('{}-Done!'.format(datetime.now()))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load') parser.add_argument('--save_folder', '-s', default='log/seg', help='Path to folder for saving check points and summary') parser.add_argument('--model', '-m', default='shellconv', help='Model to use') parser.add_argument('--setting', '-x', default='seg_s3dis', help='Setting to use') parser.add_argument('--log', help='Log to FILE in save folder; use - for stdout (default is log.txt)', metavar='FILE', default='log.txt') args = parser.parse_args() time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') root_folder = os.path.join(args.save_folder, '%s_%s_%s' % (args.model, args.setting, time_string)) if not os.path.exists(root_folder): os.makedirs(root_folder) global LOG_FOUT if args.log != '-': LOG_FOUT = open(os.path.join(root_folder, args.log), 'w') model = importlib.import_module(args.model) setting_path = os.path.join(os.path.dirname(__file__), args.model) sys.path.append(setting_path) setting = importlib.import_module(args.setting) num_epochs = setting.num_epochs batch_size = setting.batch_size sample_num = setting.sample_num step_val = setting.step_val label_weights_list = setting.label_weights rotation_range = setting.rotation_range rotation_range_val = setting.rotation_range_val scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val is_list_of_h5_list = data_utils.is_h5_list(setting.filelist) if is_list_of_h5_list: seg_list = [setting.filelist] # for train else: seg_list = data_utils.load_seg_list(setting.filelist) # for train data_val, _, data_num_val, label_val, _ = data_utils.load_seg(setting.filelist_val) if data_val.shape[-1] > 3: data_val = data_val[:,:,:3] # only use the xyz coordinates point_num = data_val.shape[1] num_val = data_val.shape[0] batch_num_val = num_val // batch_size ###################################################################### # Placeholders indices = tf.placeholder(tf.int32, shape=(None, sample_num, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts') labels_seg = tf.placeholder(tf.int64, shape=(None, point_num), name='labels_seg') labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights') ###################################################################### points_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled') points_augmented = pf.augment(points_sampled, xforms, jitter_range) labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled') labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled') bn_decay_exp_op = tf.train.exponential_decay(0.5, global_step, setting.decay_steps, 0.5, staircase=True) bn_decay_op = tf.minimum(0.99, 1 - bn_decay_exp_op) logits_op = model.get_model(points_augmented, is_training, setting.sconv_params, setting.sdconv_params, setting.fc_params, sampling=setting.sampling, weight_decay=setting.weight_decay, bn_decay = bn_decay_op, part_num=setting.num_class) predictions = tf.argmax(logits_op, axis=-1, name='predictions') loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_sampled, logits=logits_op, weights=labels_weights_sampled) with tf.name_scope('metrics'): loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op) t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(labels_sampled, predictions, weights=labels_weights_sampled) t_1_per_class_acc_op, t_1_per_class_acc_update_op = \ tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class, weights=labels_weights_sampled) reset_metrics_op = tf.variables_initializer([var for var in tf.local_variables() if var.name.split('/')[0] == 'metrics']) _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train']) _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val']) _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val']) lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step) saver = tf.train.Saver(max_to_keep=None) folder_ckpt = os.path.join(root_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num)) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False with tf.Session(config=config) as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) sess.run(tf.global_variables_initializer()) # Load the model if args.load_ckpt is not None: saver.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt)) else: latest_ckpt = tf.train.latest_checkpoint(folder_ckpt) if latest_ckpt: print('{}-Found checkpoint {}'.format(datetime.now(), latest_ckpt)) saver.restore(sess, latest_ckpt) print('{}-Checkpoint loaded from {} (Iter {})'.format( datetime.now(), latest_ckpt, sess.run(global_step))) best_acc = 0 best_epoch = 0 for epoch in range(num_epochs): ############################### train ####################################### # Shuffle train files np.random.shuffle(seg_list) for file_idx_train in range(len(seg_list)): print('----epoch:'+str(epoch) + '--train file:' + str(file_idx_train) + '-----') filelist_train = seg_list[file_idx_train] data_train, _, data_num_train, label_train, _ = data_utils.load_seg(filelist_train) num_train = data_train.shape[0] if data_train.shape[-1] > 3: data_train = data_train[:,:,:3] data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) # data_train, label_train, _ = provider.shuffle_data_seg(data_train, label_train) batch_num = (num_train + batch_size - 1) // batch_size for batch_idx_train in range(batch_num): # Training start_idx = (batch_size * batch_idx_train) % num_train end_idx = min(start_idx + batch_size, num_train) batch_size_train = end_idx - start_idx points_batch = data_train[start_idx:end_idx, ...] points_num_batch = data_num_train[start_idx:end_idx, ...] labels_batch = label_train[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] offset = int(random.gauss(0, sample_num * setting.sample_num_variance)) offset = max(offset, -sample_num * setting.sample_num_clip) offset = min(offset, sample_num * setting.sample_num_clip) sample_num_train = sample_num + offset xforms_np, rotations_np = pf.get_xforms(batch_size_train, rotation_range=rotation_range, scaling_range=scaling_range, order=setting.rotation_order) sess.run(reset_metrics_op) sess.run([train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True, }) loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_op, global_step]) summary_writer.add_summary(summaries, step) log_string('{}-[Train]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), step, loss, t_1_acc, t_1_per_class_acc)) sys.stdout.flush() ###################################################################### filename_ckpt = os.path.join(folder_ckpt, 'epoch') saver.save(sess, filename_ckpt, global_step=epoch) print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt)) sess.run(reset_metrics_op) for batch_val_idx in range(batch_num_val): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_val) batch_size_val = end_idx - start_idx points_batch = data_val[start_idx:end_idx, ...] points_num_batch = data_num_val[start_idx:end_idx, ...] labels_batch = label_val[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=rotation_range_val, scaling_range=scaling_range_val, order=setting.rotation_order) sess.run([loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_val, sample_num, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: False, }) loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run( [loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_val_op, global_step]) summary_writer.add_summary(summaries_val, step) if t_1_per_class_acc_val > best_acc: best_acc = t_1_per_class_acc_val best_epoch = epoch log_string('{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f} best epoch: {} Current epoch: {}' .format(datetime.now(), loss_val, t_1_acc_val, t_1_per_class_acc_val, best_epoch, epoch)) sys.stdout.flush() ###################################################################### print('{}-Done!'.format(datetime.now()))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--category', '-c', help='category name', required=True) parser.add_argument('--level', '-l', type=int, help='level id', required=True) parser.add_argument('--load_ckpt', '-k', help='Path to a check point file for load') parser.add_argument('--model', '-m', help='Model to use', required=True) parser.add_argument('--setting', '-x', help='Setting to use', required=True) args = parser.parse_args() args.save_folder = 'exps_results/' args.data_folder = '../../data/sem_seg_h5/' root_folder = os.path.join( args.save_folder, '%s_%s_%s_%d' % (args.model, args.setting, args.category, args.level)) if os.path.exists(root_folder): print('ERROR: folder %s exist! Please check and delete!' % root_folder) exit(1) os.makedirs(root_folder) flog = open(os.path.join(root_folder, 'log.txt'), 'w') def printout(d): flog.write(str(d) + '\n') print(d) printout('PID: %s' % os.getpid()) printout(args) model = importlib.import_module(args.model) setting_path = os.path.join(os.path.dirname(__file__), args.model) sys.path.append(setting_path) setting = importlib.import_module(args.setting) num_epochs = setting.num_epochs batch_size = setting.batch_size sample_num = setting.sample_num step_val = setting.step_val rotation_range = setting.rotation_range rotation_range_val = setting.rotation_range_val scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val args.filelist = os.path.join(args.data_folder, '%s-%d' % (args.category, args.level), 'train_files.txt') args.filelist_val = os.path.join(args.data_folder, '%s-%d' % (args.category, args.level), 'val_files.txt') # Load current category + level statistics with open( '../../stats/after_merging_label_ids/%s-level-%d.txt' % (args.category, args.level), 'r') as fin: setting.num_class = len(fin.readlines()) + 1 # with "other" printout('NUM CLASS: %d' % setting.num_class) label_weights_list = [1.0] * setting.num_class # Prepare inputs printout('{}-Preparing datasets...'.format(datetime.now())) data_train, data_num_train, label_train = data_utils.load_seg( args.filelist) data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val) # shuffle data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) num_train = data_train.shape[0] point_num = data_train.shape[1] num_val = data_val.shape[0] printout('{}-{:d}/{:d} training/validation samples.'.format( datetime.now(), num_train, num_val)) batch_num = (num_train * num_epochs + batch_size - 1) // batch_size train_batch = num_train // batch_size printout('{}-{:d} training batches.'.format(datetime.now(), batch_num)) batch_num_val = math.ceil(num_val / batch_size) printout('{}-{:d} testing batches per test.'.format( datetime.now(), batch_num_val)) ###################################################################### # Placeholders indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts') labels_seg = tf.placeholder(tf.int64, shape=(None, point_num), name='labels_seg') labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights') ###################################################################### pts_fts_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled') features_augmented = None if setting.data_dim > 3: points_sampled, features_sampled = tf.split( pts_fts_sampled, [3, setting.data_dim - 3], axis=-1, name='split_points_features') if setting.use_extra_features: if setting.with_normal_feature: if setting.data_dim < 6: printout('Only 3D normals are supported!') exit() elif setting.data_dim == 6: features_augmented = pf.augment(features_sampled, rotations) else: normals, rest = tf.split(features_sampled, [3, setting.data_dim - 6]) normals_augmented = pf.augment(normals, rotations) features_augmented = tf.concat([normals_augmented, rest], axis=-1) else: features_augmented = features_sampled else: points_sampled = pts_fts_sampled points_augmented = pf.augment(points_sampled, xforms, jitter_range) labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled') labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled') net = model.Net(points_augmented, features_augmented, is_training, setting) logits = net.logits probs = tf.nn.softmax(logits, name='probs') predictions = tf.argmax(probs, axis=-1, name='predictions') loss_op = tf.losses.sparse_softmax_cross_entropy( labels=labels_sampled, logits=logits, weights=labels_weights_sampled) with tf.name_scope('metrics'): loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op) t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy( labels_sampled, predictions, weights=labels_weights_sampled) t_1_per_class_acc_op, t_1_per_class_acc_update_op = \ tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class, weights=labels_weights_sampled) reset_metrics_op = tf.variables_initializer([ var for var in tf.local_variables() if var.name.split('/')[0] == 'metrics' ]) _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train']) _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val']) _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val']) lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=None) folder_ckpt = os.path.join(root_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum( [np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) printout('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num)) # Create a session config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False sess = tf.Session(config=config) summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) sess.run(init_op) # Load the model if args.load_ckpt is not None: saver.restore(sess, args.load_ckpt) printout('{}-Checkpoint loaded from {}!'.format( datetime.now(), args.load_ckpt)) for batch_idx_train in range(batch_num): if (batch_idx_train % (10 * train_batch) == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \ or batch_idx_train == batch_num - 1: ###################################################################### # Validation filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) printout('{}-Checkpoint saved to {}!'.format( datetime.now(), filename_ckpt)) sess.run(reset_metrics_op) for batch_val_idx in range(batch_num_val): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_val) batch_size_val = end_idx - start_idx points_batch = data_val[start_idx:end_idx, ...] points_num_batch = data_num_val[start_idx:end_idx, ...] labels_batch = label_val[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] xforms_np, rotations_np = pf.get_xforms( batch_size_val, rotation_range=rotation_range_val, scaling_range=scaling_range_val, order=setting.rotation_order) sess.run( [ loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_val, sample_num, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: False, }) loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val = sess.run( [ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_val_op ]) summary_writer.add_summary(summaries_val, batch_idx_train) printout( '{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), loss_val, t_1_acc_val, t_1_per_class_acc_val)) flog.flush() ###################################################################### ###################################################################### # Training start_idx = (batch_size * batch_idx_train) % num_train end_idx = min(start_idx + batch_size, num_train) batch_size_train = end_idx - start_idx points_batch = data_train[start_idx:end_idx, ...] points_num_batch = data_num_train[start_idx:end_idx, ...] labels_batch = label_train[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] if start_idx + batch_size_train == num_train: data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) offset = int(random.gauss(0, sample_num * setting.sample_num_variance)) offset = max(offset, -sample_num * setting.sample_num_clip) offset = min(offset, sample_num * setting.sample_num_clip) sample_num_train = sample_num + offset xforms_np, rotations_np = pf.get_xforms(batch_size_train, rotation_range=rotation_range, scaling_range=scaling_range, order=setting.rotation_order) sess.run(reset_metrics_op) sess.run( [ train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True, }) if batch_idx_train % 10 == 0: loss, t_1_acc, t_1_per_class_acc, summaries = sess.run( [loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_op]) summary_writer.add_summary(summaries, batch_idx_train) printout( '{}-[Train]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), batch_idx_train, loss, t_1_acc, t_1_per_class_acc)) flog.flush() ###################################################################### printout('{}-Done!'.format(datetime.now())) flog.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--filelist', '-t', help='Path to training set ground truth (.txt)', required=True) parser.add_argument('--filelist_val', '-v', help='Path to validation set ground truth (.txt)', required=True) parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load') parser.add_argument( '--save_folder', '-s', help='Path to folder for saving check points and summary', required=True) parser.add_argument('--model', '-m', help='Model to use', required=True) parser.add_argument('--setting', '-x', help='Setting to use', required=True) parser.add_argument('--startpoint', '-b', help='Setting to use', required=True) args = parser.parse_args() time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') root_folder = os.path.join( args.save_folder, '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid())) if not os.path.exists(root_folder): os.makedirs(root_folder) #sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w') print('PID:', os.getpid()) print(args) model = importlib.import_module(args.model) setting_path = os.path.join(os.path.dirname(__file__), args.model) sys.path.append(setting_path) setting = importlib.import_module(args.setting) num_epochs = setting.num_epochs batch_size = setting.batch_size sample_num = setting.sample_num step_val = setting.step_val label_weights_list = setting.label_weights label_weights_val = [1.0] * 1 + [1.0] * (setting.num_class - 1) rotation_range = setting.rotation_range rotation_range_val = setting.rotation_range_val scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val # Prepare inputs print('{}-Preparing datasets...'.format(datetime.now())) is_list_of_h5_list = not data_utils.is_h5_list(args.filelist) if is_list_of_h5_list: seg_list = data_utils.load_seg_list(args.filelist) seg_list_idx = 0 filelist_train = seg_list[seg_list_idx] seg_list_idx = seg_list_idx + 1 else: filelist_train = args.filelist #data_train, _, data_num_train, label_train, _ = data_utils.load_seg(filelist_train) data_val, _, data_num_val, label_val, _ = data_utils.load_seg( args.filelist_val) data_train = data_val data_num_train = data_num_val label_train = label_val # shuffle data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) num_train = data_train.shape[0] point_num = data_train.shape[1] num_val = data_val.shape[0] print('{}-{:d}/{:d} training/validation samples.'.format( datetime.now(), num_train, num_val)) batch_num = (num_train * num_epochs + batch_size - 1) // batch_size print('{}-{:d} training batches.'.format(datetime.now(), batch_num)) batch_num_val = int(math.ceil(num_val / batch_size)) print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val)) ###################### folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) ###################################################################### # Placeholders indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts') labels_seg = tf.placeholder(tf.int64, shape=(None, point_num), name='labels_seg') labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights') ###################################################################### pts_fts_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled') features_augmented = None if setting.data_dim > 3: points_sampled, features_sampled = tf.split( pts_fts_sampled, [3, setting.data_dim - 3], axis=-1, name='split_points_features') if setting.use_extra_features: if setting.with_normal_feature: if setting.data_dim < 6: print('Only 3D normals are supported!') exit() elif setting.data_dim == 6: features_augmented = pf.augment(features_sampled, rotations) else: normals, rest = tf.split(features_sampled, [3, setting.data_dim - 6]) normals_augmented = pf.augment(normals, rotations) features_augmented = tf.concat([normals_augmented, rest], axis=-1) else: features_augmented = features_sampled else: points_sampled = pts_fts_sampled points_augmented = pf.augment(points_sampled, xforms, jitter_range) labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled') labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled') net = model.Net(points_augmented, features_augmented, is_training, setting) logits = net.logits probs = tf.nn.softmax(logits, name='probs') predictions = tf.argmax(probs, axis=-1, name='predictions') loss_op = tf.losses.sparse_softmax_cross_entropy( labels=labels_sampled, logits=logits, weights=labels_weights_sampled) with tf.name_scope('metrics'): loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op) t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy( labels_sampled, predictions, weights=labels_weights_sampled) t_1_per_class_acc_op, t_1_per_class_acc_update_op = \ tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class, weights=labels_weights_sampled) t_1_per_mean_iou_op, t_1_per_mean_iou_op_update_op = \ tf.metrics.mean_iou(labels_sampled, predictions, setting.num_class, weights=labels_weights_sampled) reset_metrics_op = tf.variables_initializer([ var for var in tf.local_variables() if var.name.split('/')[0] == 'metrics' ]) _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train']) _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val']) _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_mean_iou/val', tensor=t_1_per_mean_iou_op, collections=['val']) #_ = tf.summary.histogram('summary/Add_F2', Add_F2, collections=['summary_values']) #_ = tf.summary.histogram('summary/Add_F3', Add_F3, collections=['summary_values']) #_ = tf.summary.histogram('summary/Z', Z, collections=['summary_values']) lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=5) saver_best = tf.train.Saver(max_to_keep=5) # backup all code code_folder = os.path.abspath(os.path.dirname(__file__)) shutil.copytree(code_folder, os.path.join(root_folder, os.path.basename(code_folder))) folder_ckpt = os.path.join(root_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_ckpt_best = os.path.join(root_folder, 'ckpt-best') if not os.path.exists(folder_ckpt_best): os.makedirs(folder_ckpt_best) folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) folder_cm_matrix = os.path.join(root_folder, 'cm-matrix') if not os.path.exists(folder_cm_matrix): os.makedirs(folder_cm_matrix) parameter_num = np.sum( [np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num)) _highest_val = 0.0 max_val = 10 with tf.Session() as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_values_op = tf.summary.merge_all('summary_values') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) #img_d_summary_writer = tf.summary.FileWriter(folder_cm_matrix, sess.graph) sess.run(init_op) # Load the model if args.load_ckpt is not None: saver.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded from {}!'.format( datetime.now(), args.load_ckpt)) batch_num = 1 batch_idx_train = 1 ###################################################################### # Validation filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt)) sess.run(reset_metrics_op) summary_hist = None _idxVal = np.arange(num_val) np.random.shuffle(_idxVal) _dataX = [] _dataY = [] _dataZ = [] _dataD = [] _pred = [] _label = [] for batch_val_idx in tqdm(range(batch_num_val)): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_val) batch_size_val = end_idx - start_idx points_batch = data_val[_idxVal[start_idx:end_idx], ...] points_num_batch = data_num_val[_idxVal[start_idx:end_idx], ...] labels_batch = label_val[_idxVal[start_idx:end_idx], ...] weights_batch = np.array(label_weights_val)[labels_batch] xforms_np, rotations_np = pf.get_xforms( batch_size_val, rotation_range=rotation_range_val, scaling_range=scaling_range_val, order=setting.rotation_order) _labels_sampled, _predictions, _, _, _, _ = sess.run( [ labels_sampled, predictions, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op, t_1_per_mean_iou_op_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_val, sample_num, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: False, }) _dataX.append(points_batch[:, :, 0].flatten()) _dataY.append(points_batch[:, :, 1].flatten()) _dataZ.append(points_batch[:, :, 2].flatten()) _pred.append(_predictions.flatten()) _label.append(_labels_sampled.flatten()) loss_val, t_1_acc_val, t_1_per_class_acc_val, t1__mean_iou, summaries_val = sess.run( [ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, t_1_per_mean_iou_op, summaries_val_op ]) img_d_summary = pf.plot_confusion_matrix( _label, _pred, ["Environment", "Pedestrian", "Car", "Cyclist"], tensor_name='confusion_matrix', normalize=False) _dataX = np.concatenate(_dataX, axis=0).flatten() _dataY = np.concatenate(_dataY, axis=0).flatten() _dataZ = np.concatenate(_dataZ, axis=0).flatten() correct_labels = np.concatenate(_label, axis=0).flatten() predict_labels = np.concatenate(_pred, axis=0).flatten() filename_pred = args.filelist_val + '_cm.h5' print('{}-Saving {}...'.format(datetime.now(), filename_pred)) file = h5py.File(filename_pred, 'w') file.create_dataset('dataX', data=_dataX) file.create_dataset('dataY', data=_dataY) file.create_dataset('dataZ', data=_dataZ) file.create_dataset('correct_labels', data=correct_labels) file.create_dataset('predict_labels', data=predict_labels) file.close() summary_writer.add_summary(img_d_summary, batch_idx_train) summary_writer.add_summary(summaries_val, batch_idx_train) y_test = correct_labels prediction = predict_labels print('Accuracy:', accuracy_score(y_test, prediction)) print('F1 score:', f1_score(y_test, prediction, average=None)) print('Recall:', recall_score(y_test, prediction, average=None)) print('Precision:', precision_score(y_test, prediction, average=None)) print('\n clasification report:\n', classification_report(y_test, prediction)) print('\n confussion matrix:\n', confusion_matrix(y_test, prediction)) print( '{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f} Diff-Best: {:.4f} T-1 mAcc: {:.4f} T-1 mIoU: {:.4f}' .format(datetime.now(), loss_val, t_1_acc_val, _highest_val - t_1_per_class_acc_val, t_1_per_class_acc_val, t1__mean_iou)) sys.stdout.flush()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--filelist', '-t', help='Path to training set ground truth (.txt)', required=True) parser.add_argument('--filelist_val', '-v', help='Path to validation set ground truth (.txt)', required=True) parser.add_argument('--filelist_unseen', '-u', help='Path to validation set ground truth (.txt)', required=True) parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load') parser.add_argument( '--save_folder', '-s', help='Path to folder for saving check points and summary', required=True) parser.add_argument('--model', '-m', help='Model to use', required=True) parser.add_argument('--setting', '-x', help='Setting to use', required=True) args = parser.parse_args() time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') root_folder = os.path.join( args.save_folder, '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid())) if not os.path.exists(root_folder): os.makedirs(root_folder) print('PID:', os.getpid()) print(args) model = importlib.import_module(args.model) setting_path = os.path.join(os.path.dirname(__file__), args.model) sys.path.append(setting_path) setting = importlib.import_module(args.setting) num_epochs = setting.num_epochs batch_size = setting.batch_size sample_num = setting.sample_num step_val = setting.step_val label_weights_list = setting.label_weights label_weights_val = [1.0] * 1 + [1.0] * (setting.num_class - 1) rotation_range = setting.rotation_range rotation_range_val = setting.rotation_range_val scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val # Prepare inputs print('{}-Preparing datasets...'.format(datetime.now())) is_list_of_h5_list = not data_utils.is_h5_list(args.filelist) if is_list_of_h5_list: seg_list = data_utils.load_seg_list(args.filelist) seg_list_idx = 0 filelist_train = seg_list[seg_list_idx] seg_list_idx = seg_list_idx + 1 else: filelist_train = args.filelist data_train, _, data_num_train, label_train, _ = data_utils.load_seg( filelist_train) data_val, _, data_num_val, label_val, _ = data_utils.load_seg( args.filelist_val) data_unseen, _, data_num_unseen, label_unseen, _ = data_utils.load_seg( args.filelist_unseen) # shuffle data_unseen, data_num_unseen, label_unseen = \ data_utils.grouped_shuffle([data_unseen, data_num_unseen, label_unseen]) # shuffle data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) num_train = data_train.shape[0] point_num = data_train.shape[1] num_val = data_val.shape[0] num_unseen = data_unseen.shape[0] batch_num_unseen = int(math.ceil(num_unseen / batch_size)) print('{}-{:d}/{:d} training/validation samples.'.format( datetime.now(), num_train, num_val)) batch_num = (num_train * num_epochs + batch_size - 1) // batch_size print('{}-{:d} training batches.'.format(datetime.now(), batch_num)) batch_num_val = int(math.ceil(num_val / batch_size)) print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val)) ###################################################################### # Placeholders indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') is_labelled_data = tf.placeholder(tf.bool, name='is_labelled_data') pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts') labels_seg = tf.placeholder(tf.int64, shape=(None, point_num), name='labels_seg') labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights') ###################################################################### pts_fts_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled') features_augmented = None if setting.data_dim > 3: points_sampled, features_sampled = tf.split( pts_fts_sampled, [3, setting.data_dim - 3], axis=-1, name='split_points_features') if setting.use_extra_features: if setting.with_normal_feature: if setting.data_dim < 6: print('Only 3D normals are supported!') exit() elif setting.data_dim == 6: features_augmented = pf.augment(features_sampled, rotations) else: normals, rest = tf.split(features_sampled, [3, setting.data_dim - 6]) normals_augmented = pf.augment(normals, rotations) features_augmented = tf.concat([normals_augmented, rest], axis=-1) else: features_augmented = features_sampled else: points_sampled = pts_fts_sampled points_augmented = pf.augment(points_sampled, xforms, jitter_range) labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled') labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled') net = model.Net(points_augmented, features_augmented, is_training, setting) logits = net.logits probs = tf.nn.softmax(logits, name='prob') predictions = tf.argmax(probs, axis=-1, name='predictions') with tf.variable_scope('xcrf_ker_weights'): _, point_indices = pf.knn_indices_general(points_augmented, points_augmented, 1024, True) xcrf = learningBlock(num_points=sample_num, num_classes=setting.num_class, theta_alpha=float(5), theta_beta=float(2), theta_gamma=float(1), num_iterations=5, name='xcrf', point_indices=point_indices) _logits1 = xcrf.call(net.logits, points_augmented, features_augmented, setting.data_dim) _logits2 = xcrf.call(net.logits, points_augmented, features_augmented, setting.data_dim, D=2) _logits3 = xcrf.call(net.logits, points_augmented, features_augmented, setting.data_dim, D=3) _logits4 = xcrf.call(net.logits, points_augmented, features_augmented, setting.data_dim, D=4) _logits5 = xcrf.call(net.logits, points_augmented, features_augmented, setting.data_dim, D=8) _logits6 = xcrf.call(net.logits, points_augmented, features_augmented, setting.data_dim, D=16) _logits = _logits1 + _logits2 + _logits3 + _logits4 + _logits5 + _logits6 _probs = tf.nn.softmax(_logits, name='probs_crf') _predictions = tf.argmax(_probs, axis=-1, name='predictions_crf') logits = tf.cond( is_training, lambda: tf.cond(is_labelled_data, lambda: _logits, lambda: logits), lambda: _logits) predictions = tf.cond( is_training, lambda: tf.cond(is_labelled_data, lambda: _predictions, lambda: predictions), lambda: _predictions) labels_sampled = tf.cond( is_training, lambda: tf.cond(is_labelled_data, lambda: labels_sampled, lambda: _predictions), lambda: labels_sampled) loss_op = tf.losses.sparse_softmax_cross_entropy( labels=labels_sampled, logits=logits, weights=labels_weights_sampled) with tf.name_scope('metrics'): loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op) t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy( labels_sampled, predictions, weights=labels_weights_sampled) t_1_per_class_acc_op, t_1_per_class_acc_update_op = \ tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class, weights=labels_weights_sampled) t_1_per_mean_iou_op, t_1_per_mean_iou_op_update_op = \ tf.metrics.mean_iou(labels_sampled, predictions, setting.num_class, weights=labels_weights_sampled) reset_metrics_op = tf.variables_initializer([ var for var in tf.local_variables() if var.name.split('/')[0] == 'metrics' ]) _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train']) _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val']) _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_mean_iou/val', tensor=t_1_per_mean_iou_op, collections=['val']) all_variable = tf.global_variables() lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) xcrf_ker_weights = [ var for var in tf.global_variables() if 'xcrf_ker_weights' in var.name ] no_xcrf_ker_weights = [ var for var in tf.global_variables() if 'xcrf_ker_weights' not in var.name ] #print(restore_values) #train_op = optimizer.minimize( loss_op+reg_loss, global_step=global_step) with tf.control_dependencies(update_ops): train_op_xcrf = optimizer.minimize(loss_op + reg_loss, var_list=no_xcrf_ker_weights, global_step=global_step) train_op_all = optimizer.minimize(loss_op + reg_loss, var_list=all_variable, global_step=global_step) train_op = tf.cond(is_labelled_data, lambda: train_op_all, lambda: train_op_xcrf) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(var_list=no_xcrf_ker_weights, max_to_keep=5) saver_best = tf.train.Saver(max_to_keep=5) # backup all code code_folder = os.path.abspath(os.path.dirname(__file__)) shutil.copytree(code_folder, os.path.join(root_folder, os.path.basename(code_folder))) folder_ckpt = os.path.join(root_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_ckpt_best = os.path.join(root_folder, 'ckpt-best') if not os.path.exists(folder_ckpt_best): os.makedirs(folder_ckpt_best) folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum( [np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Parameter number: {:d}.'.format(datetime.now(), int(parameter_num))) _highest_val = 0.0 max_val = 10 with tf.Session() as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_values_op = tf.summary.merge_all('summary_values') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) sess.run(init_op) # Load the model if args.load_ckpt is not None: saver.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded from {}!'.format( datetime.now(), args.load_ckpt)) batch_num = 50000 for batch_idx_train in range(batch_num): if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \ or batch_idx_train == batch_num - 1: ###################################################################### # Validation filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) print('{}-Checkpoint saved to {}!'.format( datetime.now(), filename_ckpt)) sess.run(reset_metrics_op) summary_hist = None _idxVal = np.arange(num_val) np.random.shuffle(_idxVal) _pred = [] _label = [] for batch_val_idx in tqdm(range(batch_num_val)): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_val) batch_size_val = end_idx - start_idx points_batch = data_val[_idxVal[start_idx:end_idx], ...] points_num_batch = data_num_val[_idxVal[start_idx:end_idx], ...] labels_batch = label_val[_idxVal[start_idx:end_idx], ...] weights_batch = np.array(label_weights_val)[labels_batch] xforms_np, rotations_np = pf.get_xforms( batch_size_val, rotation_range=rotation_range_val, scaling_range=scaling_range_val, order=setting.rotation_order) _labels_sampled, _predictions, _, _, _, _ = sess.run( [ labels_sampled, predictions, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op, t_1_per_mean_iou_op_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_val, sample_num, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), labels_seg: labels_batch, is_labelled_data: True, labels_weights: weights_batch, is_training: False, }) _pred.append(_predictions.flatten()) _label.append(_labels_sampled.flatten()) loss_val, t_1_acc_val, t_1_per_class_acc_val, t1__mean_iou, summaries_val = sess.run( [ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, t_1_per_mean_iou_op, summaries_val_op ]) img_d_summary = pf.plot_confusion_matrix( _label, _pred, ["Environment", "Pedestrian", "Car", "Cyclist"], tensor_name='confusion_matrix') max_val = max_val - 1 if (t_1_per_class_acc_val > _highest_val): max_val = 10 _highest_val = t_1_per_class_acc_val filename_ckpt = os.path.join(folder_ckpt_best, str(_highest_val) + "-iter-") saver_best.save(sess, filename_ckpt, global_step=global_step) if (max_val < 0): sys.exit(0) summary_writer.add_summary(summaries_val, batch_idx_train) summary_writer.add_summary(img_d_summary, batch_idx_train) #summary_writer.add_summary(summary_hist, batch_idx_train) print( '{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f} Diff-Best: {:.4f} T-1 mAcc: {:.4f} T-1 mIoU: {:.4f}' .format(datetime.now(), loss_val, t_1_acc_val, _highest_val - t_1_per_class_acc_val, t_1_per_class_acc_val, t1__mean_iou)) sys.stdout.flush() ###################################################################### # Unseen Data sess.run(reset_metrics_op) summary_hist = None _idxunseen = np.arange(num_unseen) np.random.shuffle(_idxunseen) _pred = [] _label = [] unseenIndices = np.arange(batch_num_unseen) np.random.shuffle(unseenIndices) for batch_val_idx in tqdm(unseenIndices[:500]): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_unseen) batch_size_train = end_idx - start_idx points_batch = data_unseen[_idxunseen[start_idx:end_idx], ...] points_num_batch = data_num_unseen[ _idxunseen[start_idx:end_idx], ...] labels_batch = np.zeros(points_batch.shape[0:2], dtype=np.int32) weights_batch = np.array(label_weights_list)[labels_batch] xforms_np, rotations_np = pf.get_xforms( batch_size_train, rotation_range=rotation_range, scaling_range=scaling_range, order=setting.rotation_order) offset = int( random.gauss(0, sample_num * setting.sample_num_variance)) offset = max(offset, -sample_num * setting.sample_num_clip) offset = min(offset, sample_num * setting.sample_num_clip) sample_num_train = sample_num + offset sess.run( [ train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op, t_1_per_mean_iou_op_update_op ], feed_dict={ pts_fts: points_batch, is_labelled_data: False, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True, }) if batch_val_idx % 100 == 0: loss, t_1_acc, t_1_per_class_acc, t_1__mean_iou = sess.run( [ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, t_1_per_mean_iou_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, is_labelled_data: False, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True, }) print( '{}-[Train]-Unseen: {:06d} Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f} T-1 mIoU: {:.4f}' .format(datetime.now(), batch_val_idx, loss, t_1_acc, t_1_per_class_acc, t_1__mean_iou)) sys.stdout.flush() ###################################################################### # Training start_idx = (batch_size * batch_idx_train) % num_train end_idx = min(start_idx + batch_size, num_train) batch_size_train = end_idx - start_idx points_batch = data_train[start_idx:end_idx, ...] points_num_batch = data_num_train[start_idx:end_idx, ...] labels_batch = label_train[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] if start_idx + batch_size_train == num_train: if is_list_of_h5_list: filelist_train_prev = seg_list[(seg_list_idx - 1) % len(seg_list)] filelist_train = seg_list[seg_list_idx % len(seg_list)] if filelist_train != filelist_train_prev: data_train, _, data_num_train, label_train, _ = data_utils.load_seg( filelist_train) num_train = data_train.shape[0] seg_list_idx = seg_list_idx + 1 data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) offset = int( random.gauss(0, sample_num * setting.sample_num_variance)) offset = max(offset, -sample_num * setting.sample_num_clip) offset = min(offset, sample_num * setting.sample_num_clip) sample_num_train = sample_num + offset xforms_np, rotations_np = pf.get_xforms( batch_size_train, rotation_range=rotation_range, scaling_range=scaling_range, order=setting.rotation_order) sess.run(reset_metrics_op) sess.run( [ train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op, t_1_per_mean_iou_op_update_op ], feed_dict={ pts_fts: points_batch, is_labelled_data: True, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True, }) if batch_idx_train % 100 == 0: loss, t_1_acc, t_1_per_class_acc, t_1__mean_iou, summaries = sess.run( [ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, t_1_per_mean_iou_op, summaries_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, is_labelled_data: True, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True, }) summary_writer.add_summary(summaries, batch_idx_train) print( '{}-[Train]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f} T-1 mIoU: {:.4f}' .format(datetime.now(), batch_idx_train, loss, t_1_acc, t_1_per_class_acc, t_1__mean_iou)) sys.stdout.flush() ###################################################################### print('{}-Done!'.format(datetime.now()))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load') parser.add_argument('--save_folder', '-s',default='models', help='Path to folder for saving check points and summary') parser.add_argument('--model', '-m', default='pointcnn_cls',help='Model to use') parser.add_argument('--setting', '-x',default='ScanObjectNN_x3_l4', help='Setting to use') parser.add_argument('--epochs', help='Number of training epochs (default defined in setting)', type=int) parser.add_argument('--batch_size', help='Batch size (default defined in setting)', type=int) parser.add_argument('--log', help='Log to FILE in save folder; use - for stdout (default is log.txt)', metavar='FILE', default='log.txt')#default='log.txt'输出存入文档//default='-'显示输出 parser.add_argument('--no_timestamp_folder', help='Dont save to timestamp folder', action='store_true') parser.add_argument('--no_code_backup', help='Dont backup code', action='store_true',default='True') parser.add_argument('--err_data', type=float, default=5696) parser.add_argument('--weight', type=float, default=10) args = parser.parse_args() if not args.no_timestamp_folder: time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') root_folder = os.path.join(args.save_folder, '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid())) else: root_folder = args.save_folder if not os.path.exists(root_folder): os.makedirs(root_folder) if args.log != '-': sys.stdout = open(os.path.join(root_folder, args.log), 'w') print('PID:', os.getpid()) print(args) model = importlib.import_module(args.model) setting_path = os.path.join(os.path.dirname(__file__), args.model) sys.path.append(setting_path) setting = importlib.import_module(args.setting) num_epochs = args.epochs or setting.num_epochs batch_size = args.batch_size or setting.batch_size sample_num = setting.sample_num#1024 step_val = setting.step_val rotation_range = setting.rotation_range rotation_range_val = setting.rotation_range_val scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val pool_setting_val = None if not hasattr(setting, 'pool_setting_val') else setting.pool_setting_val pool_setting_train = None if not hasattr(setting, 'pool_setting_train') else setting.pool_setting_train ERR = args.err_data WEIGHT = args.weight # Prepare inputs print('{}-Preparing datasets...'.format(datetime.now())) if setting.data_set == 'ModelNet': prepare_data.make_ModelNet_data_A_B(ERR,batch_size) if setting.data_set == 'ScanObjectNN': prepare_data.make_ScanObjectNN_data_batch_A_B(ERR,batch_size) data_train, label_train, data_val, label_val = setting.load_fn(setting.path, setting.path_val) #data_train:[2*11416,2048,3],label_train:[11416,],data_val:[2882,2048,3],label_val:[2882,] if setting.balance_fn is not None:#None num_train_before_balance = data_train.shape[0] repeat_num = setting.balance_fn(label_train) data_train = np.repeat(data_train, repeat_num, axis=0) label_train = np.repeat(label_train, repeat_num, axis=0) data_train, label_train = data_utils.grouped_shuffle([data_train, label_train]) num_epochs = math.floor(num_epochs * (num_train_before_balance / data_train.shape[0])) if setting.save_ply_fn is not None:#None folder = os.path.join(root_folder, 'pts') print('{}-Saving samples as .ply files to {}...'.format(datetime.now(), folder)) sample_num_for_ply = min(512, data_train.shape[0]) if setting.map_fn is None: data_sample = data_train[:sample_num_for_ply] else: data_sample_list = [] for idx in range(sample_num_for_ply): data_sample_list.append(setting.map_fn(data_train[idx], 0)[0]) data_sample = np.stack(data_sample_list) setting.save_ply_fn(data_sample, folder) num_train = data_train.shape[0]#11416 point_num = data_train.shape[1]#2048 num_val = data_val.shape[0]#2882 ###################################################################### # Placeholders indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') data_train_placeholder = tf.placeholder(data_train.dtype, data_train.shape, name='data_train') label_train_placeholder = tf.placeholder(tf.int64, label_train.shape, name='label_train') data_val_placeholder = tf.placeholder(data_val.dtype, data_val.shape, name='data_val') label_val_placeholder = tf.placeholder(tf.int64, label_val.shape, name='label_val') handle = tf.placeholder(tf.string, shape=[], name='handle') ###################################################################### dataset_train = tf.data.Dataset.from_tensor_slices((data_train_placeholder, label_train_placeholder)) #dataset_train = dataset_train.shuffle(buffer_size=batch_size * 4) if setting.map_fn is not None:#None dataset_train = dataset_train.map(lambda data, label: tuple(tf.py_func(setting.map_fn, [data, label], [tf.float32, label.dtype])), num_parallel_calls=setting.num_parallel_calls) if setting.keep_remainder:#Ture dataset_train = dataset_train.batch(batch_size) batch_num_per_epoch = math.ceil(num_train / batch_size) else: dataset_train = dataset_train.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) batch_num_per_epoch = math.floor(num_train / batch_size) dataset_train = dataset_train.repeat(num_epochs) iterator_train = dataset_train.make_initializable_iterator() batch_num = batch_num_per_epoch * num_epochs print('{}-{:d} training batches.'.format(datetime.now(), batch_num)) dataset_val = tf.data.Dataset.from_tensor_slices((data_val_placeholder, label_val_placeholder)) if setting.map_fn is not None:#None dataset_val = dataset_val.map(lambda data, label: tuple(tf.py_func( setting.map_fn, [data, label], [tf.float32, label.dtype])), num_parallel_calls=setting.num_parallel_calls) if setting.keep_remainder:#Ture dataset_val = dataset_val.batch(batch_size) batch_num_val = math.ceil(num_val / batch_size) else: dataset_val = dataset_val.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) batch_num_val = math.floor(num_val / batch_size) iterator_val = dataset_val.make_initializable_iterator() print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val)) iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types) (pts_fts, labels) = iterator.get_next() pts_fts_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled') features_augmented = None points_sampled = pts_fts_sampled points_augmented = pf.augment(points_sampled, xforms, jitter_range) net = model.Net(points=points_augmented, features=features_augmented, is_training=is_training, setting=setting) logits = net.logits feature_A = net.feature_list_A feature_B = net.feature_list_B probs = tf.nn.softmax(logits, name='probs') predictions = tf.argmax(probs, axis=-1, name='predictions') labels_2d = tf.expand_dims(labels, axis=-1, name='labels_2d') labels_tile = tf.tile(labels_2d, (1, tf.shape(logits)[1]), name='labels_tile') # -------------------------------------------------------------------- # compute the Loss of DINet different = tf.square(tf.subtract(feature_A,feature_B)) different = tf.reduce_sum(different,1) label_A = tf.to_float(labels[0:tf.to_int32(batch_size/2)]) label_B = tf.to_float(labels[tf.to_int32(batch_size/2):tf.to_int32(batch_size)]) f_same = tf.multiply(tf.add(tf.sign(-tf.abs(tf.subtract(label_A,label_B))),1),different) f_diff = tf.divide(tf.sign(tf.abs(tf.subtract(label_A,label_B))),different) f_same = tf.reduce_sum(f_same,0) f_diff = tf.reduce_sum(f_diff,0) f_loss = tf.add(f_diff,f_same) # -------------------------------------------------------------------- loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_tile, logits=logits) + ((f_loss/(61440*batch_size/2)) * WEIGHT)#10 with tf.name_scope('metrics'): loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op) t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(labels_tile, predictions) t_1_per_class_acc_op, t_1_per_class_acc_update_op = tf.metrics.mean_per_class_accuracy(labels_tile, predictions, setting.num_class) reset_metrics_op = tf.variables_initializer([var for var in tf.local_variables() if var.name.split('/')[0] == 'metrics']) _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train']) _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val']) _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val']) lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=None) # backup all code if not args.no_code_backup: code_folder = os.path.abspath(os.path.dirname(__file__)) shutil.copytree(code_folder, os.path.join(root_folder, os.path.basename(code_folder))) folder_ckpt = os.path.join(root_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num)) with tf.Session() as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) sess.run(init_op) # Load the model if args.load_ckpt is not None: saver.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt)) else: latest_ckpt = tf.train.latest_checkpoint(folder_ckpt) if latest_ckpt: print('{}-Found checkpoint {}'.format(datetime.now(), latest_ckpt)) saver.restore(sess, latest_ckpt) print('{}-Checkpoint loaded from {} (Iter {})'.format( datetime.now(), latest_ckpt, sess.run(global_step))) handle_train = sess.run(iterator_train.string_handle()) handle_val = sess.run(iterator_val.string_handle()) sess.run(iterator_train.initializer, feed_dict={ data_train_placeholder: data_train, label_train_placeholder: label_train, }) for batch_idx_train in range(batch_num): ###################################################################### # Validation if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \ or batch_idx_train == batch_num - 1: sess.run(iterator_val.initializer, feed_dict={ data_val_placeholder: data_val, label_val_placeholder: label_val, }) filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt)) sess.run(reset_metrics_op) total_seen_class = [0 for _ in range(setting.num_class)] total_correct_class = [0 for _ in range(setting.num_class)] for batch_idx_val in range(batch_num_val): if not setting.keep_remainder \ or num_val % batch_size == 0 \ or batch_idx_val != batch_num_val - 1: batch_size_val = batch_size else: batch_size_val = num_val % batch_size xforms_np, rotations_np = pf.get_xforms(batch_size_val, rotation_range=rotation_range_val, scaling_range=scaling_range_val, order=setting.rotation_order) lll, ppp, _, _ = sess.run([labels, predictions, t_1_acc_update_op, t_1_per_class_acc_update_op], feed_dict={ handle: handle_val, indices: pf.get_indices(batch_size_val, sample_num, point_num, ), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), is_training: False, }) ppp = np.reshape(ppp,[-1,]) for i in range(len(lll)): l = lll[i] total_seen_class[l] +=1 total_correct_class[l] +=(ppp[i] == l) t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run( [t_1_acc_op, t_1_per_class_acc_op, summaries_val_op, global_step]) summary_writer.add_summary(summaries_val, step) print('{}-[Val ]-Average: T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), t_1_acc_val, t_1_per_class_acc_val)) print('every class acc:',np.array(total_correct_class)/np.array(total_seen_class,dtype = np.float)) sys.stdout.flush() ###################################################################### ###################################################################### # Training if not setting.keep_remainder \ or num_train % batch_size == 0 \ or (batch_idx_train % batch_num_per_epoch) != (batch_num_per_epoch - 1): batch_size_train = batch_size else: batch_size_train = num_train % batch_size offset = int(random.gauss(0, sample_num * setting.sample_num_variance)) offset = max(offset, -sample_num * setting.sample_num_clip) offset = min(offset, sample_num * setting.sample_num_clip) sample_num_train = sample_num + offset xforms_np, rotations_np = pf.get_xforms(batch_size_train, rotation_range=rotation_range, scaling_range=scaling_range, order=setting.rotation_order) sess.run(reset_metrics_op) sess.run([train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op], feed_dict={ handle: handle_train, indices: pf.get_indices(batch_size_train, sample_num_train, point_num, pool_setting_train), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), is_training: True, }) if batch_idx_train % 10 == 0: loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_op, global_step]) summary_writer.add_summary(summaries, step) print('{}-[Train]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), step, loss, t_1_acc, t_1_per_class_acc)) sys.stdout.flush() ###################################################################### print('{}-Done!'.format(datetime.now()))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--filelist', '-t', help='Path to training set ground truth (.txt)', required=True) parser.add_argument('--filelist_val', '-v', help='Path to validation set ground truth (.txt)', required=True) parser.add_argument('--load_ckpt', '-l', help='Path to a check point file for load') parser.add_argument('--save_folder', '-s', help='Path to folder for saving check points and summary', required=True) parser.add_argument('--model', '-m', help='Model to use', required=True) parser.add_argument('--setting', '-x', help='Setting to use', required=True) args = parser.parse_args() time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') root_folder = os.path.join(args.save_folder, '%s_%s_%d_%s' % (args.model, args.setting, os.getpid(), time_string)) if not os.path.exists(root_folder): os.makedirs(root_folder) sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w') print('PID:', os.getpid()) print(args) model = importlib.import_module(args.model) setting_path = os.path.join(os.path.dirname(__file__), args.model) sys.path.append(setting_path) print(setting_path) setting = importlib.import_module(args.setting) num_epochs = setting.num_epochs batch_size = setting.batch_size sample_num = setting.sample_num step_val = 1000 num_parts = setting.num_parts label_weights_list = setting.label_weights scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val # Prepare inputs print('{}-Preparing datasets...'.format(datetime.now())) #data_train, data_num_train, label_train = data_utils.load_seg(args.filelist) #data_val, data_num_val, label_val = data_utils.load_seg(args.filelist_val) data_train, img_train, data_num_train, label_train, label_hwl_train, label_xyz_train,label_ry_train, label_ry_reg_train = data_utils.load_seg_kitti(args.filelist) data_val, img_val, data_num_val, label_val, label_hwl_val, label_xyz_val,label_ry_val, label_ry_reg_val = data_utils.load_seg_kitti(args.filelist_val) # shuffle #data_train, data_num_train, label_train = \ # data_utils.grouped_shuffle([data_train, data_num_train, label_train]) data_train, img_train, data_num_train, label_train, label_hwl_train, label_xyz_train,\ label_ry_train, label_ry_reg_train = data_utils.grouped_shuffle([data_train, img_train,\ data_num_train, label_train, label_hwl_train,label_xyz_train, label_ry_train,label_ry_reg_train]) num_train = data_train.shape[0] point_num = data_train.shape[1] num_val = data_val.shape[0] print('{}-{:d}/{:d} training/validation samples.'.format(datetime.now(), num_train, num_val)) batch_num = (num_train * num_epochs + batch_size - 1) // batch_size print('{}-{:d} training batches.'.format(datetime.now(), batch_num)) batch_num_val = math.ceil(num_val / batch_size) print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val)) ###################################################################### # Placeholders indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts') labels_seg = tf.placeholder(tf.int32, shape=(None, point_num), name='labels_seg') labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights') labels_hwl = tf.placeholder(tf.float32, shape=(None, 3), name='labels_hwl') labels_xyz = tf.placeholder(tf.float32, shape=(None, 3), name='labels_xyz') labels_ry = tf.placeholder(tf.int32, shape=(None, 1), name='labels_ry') labels_ry_reg = tf.placeholder(tf.float32, shape=(None, 1), name='labels_ry_reg') ###################################################################### # Set Inputs(points,features_sampled) features_sampled = None if setting.data_dim > 3: points, _, features = tf.split(pts_fts, [setting.data_format["pts_xyz"], setting.data_format["img_xy"], setting.data_format["extra_features"]], axis=-1, name='split_points_xy_features') if setting.use_extra_features: features_sampled = tf.gather_nd(features, indices=indices, name='features_sampled') else: points = pts_fts points_sampled = tf.gather_nd(points, indices=indices, name='points_sampled') points_augmented = pf.augment(points_sampled, xforms, jitter_range) labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled') labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled') # Build net #net = model.Net(points_augmented, features_sampled, None, None, num_parts, is_training, setting) net = model.Net(points_augmented, features_sampled, None, None, num_parts, is_training, setting) #logits, probs= net.logits, net.probs logits, probs, logits_hwl, logits_xyz, logits_ry, probs_ry=net.logits, net.probs, net.logits_hwl, net.logits_xyz, net.logits_ry, net.probs_ry # Define Loss Func loss_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_sampled, logits=logits, weights=labels_weights_sampled) _ = tf.summary.scalar('loss/train_seg', tensor=loss_op, collections=['train']) # HWL loss # 这里以MSE loss 为例 loss_hwl_op = tf.losses.mean_squared_error(labels_hwl, logits_hwl) # API提供的MSE loss # 计算RMSE来作为一种评估指标 rmse_hwl_op = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(labels_hwl - logits_hwl), -1))) # XYZ loss # 这里以MSE loss 为例 loss_xyz_op = tf.losses.mean_squared_error(labels_xyz, logits_xyz) # 计算RMSE来作为一种评估指标 rmse_xyz_op = tf.sqrt(tf.reduce_mean(tf.reduce_sum(tf.square(labels_xyz - logits_xyz), -1))) # RY loss # 分类以softmax_cross_entropy为例 loss_ry_op = tf.losses.sparse_softmax_cross_entropy(labels=labels_ry, logits=logits_ry) # 计算分类准确率作为评估指标 t_1_acc_ry_op = pf.top_1_accuracy(probs_ry, labels_ry) # 计算平均角度偏差作为评估指标 mdis_ry_angle_op = tf.reduce_mean(tf.sqrt(tf.square(pf.ry_getangle(labels_ry,labels_ry_reg)- pf.ry_getangle(tf.nn.top_k(probs_ry, 1)[1], tf.zeros_like(labels_ry_reg))))) # 将结果可视化在Tensorboard上 _ = tf.summary.scalar('loss/train_hwl', tensor=loss_hwl_op, collections=['train']) _ = tf.summary.scalar('loss/train_xyz', tensor=loss_xyz_op, collections=['train']) _ = tf.summary.scalar('loss/train_ry', tensor=loss_ry_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train_rmse_hwl', tensor=rmse_hwl_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train_rmse_xyz', tensor=rmse_xyz_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train_ry', tensor=t_1_acc_ry_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/mdis_ry_angle', tensor=mdis_ry_angle_op, collections=['train']) #for vis t1 acc t_1_acc_op = pf.top_1_accuracy(probs, labels_sampled) _ = tf.summary.scalar('t_1_acc/train_seg', tensor=t_1_acc_op, collections=['train']) #for vis instance acc t_1_acc_instance_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6) _ = tf.summary.scalar('t_1_acc/train_seg_instance', tensor=t_1_acc_instance_op, collections=['train']) #for vis other acc t_1_acc_others_op = pf.top_1_accuracy(probs, labels_sampled, labels_weights_sampled, 0.6,"less") _ = tf.summary.scalar('t_1_acc/train_seg_others', tensor=t_1_acc_others_op, collections=['train']) loss_val_avg = tf.placeholder(tf.float32) _ = tf.summary.scalar('loss/val_seg', tensor=loss_val_avg, collections=['val']) t_1_acc_val_avg = tf.placeholder(tf.float32) _ = tf.summary.scalar('t_1_acc/val_seg', tensor=t_1_acc_val_avg, collections=['val']) t_1_acc_val_instance_avg = tf.placeholder(tf.float32) _ = tf.summary.scalar('t_1_acc/val_seg_instance', tensor=t_1_acc_val_instance_avg, collections=['val']) t_1_acc_val_others_avg = tf.placeholder(tf.float32) _ = tf.summary.scalar('t_1_acc/val_seg_others', tensor=t_1_acc_val_others_avg, collections=['val']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() # 对于val,需要计算整个Val过程误差的平均值作为Val的评估指标 loss_val_avg_hwl = tf.placeholder(tf.float32) loss_val_avg_xyz = tf.placeholder(tf.float32) loss_val_avg_ry = tf.placeholder(tf.float32) _ = tf.summary.scalar('loss/val_hwl', tensor=loss_val_avg_hwl, collections=['val']) _ = tf.summary.scalar('loss/val_xyz', tensor=loss_val_avg_xyz, collections=['val']) _ = tf.summary.scalar('loss/val_ry', tensor=loss_val_avg_ry, collections=['val']) rmse_val_hwl_avg = tf.placeholder(tf.float32) rmse_val_xyz_avg = tf.placeholder(tf.float32) t_1_acc_val_ry_avg = tf.placeholder(tf.float32) mdis_val_ry_angle_avg = tf.placeholder(tf.float32) _ = tf.summary.scalar('t_1_acc/val_rmse_hwl', tensor=rmse_val_hwl_avg, collections=['val']) _ = tf.summary.scalar('t_1_acc/val_rmse_xyz', tensor=rmse_val_xyz_avg, collections=['val']) _ = tf.summary.scalar('t_1_acc/val_ry', tensor=t_1_acc_val_ry_avg, collections=['val']) _ = tf.summary.scalar('t_1_acc/val_mdis_ry_angle', tensor=mdis_val_ry_angle_avg, collections=['val']) # lr decay lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps,setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) # Optimizer if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=0.9, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): #train_op = optimizer.minimize(loss_op + reg_loss, global_step=global_step) train_op = optimizer.minimize(loss_op + reg_loss + loss_hwl_op + loss_xyz_op + loss_ry_op, global_step=global_step) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=None) # backup this file, model and setting if not os.path.exists(os.path.join(root_folder, args.model)): os.makedirs(os.path.join(root_folder, args.model)) shutil.copy(__file__, os.path.join(root_folder, os.path.basename(__file__))) shutil.copy(os.path.join(os.path.dirname(__file__), args.model + '.py'), os.path.join(root_folder, args.model + '.py')) shutil.copy(os.path.join(os.path.dirname(__file__), args.model.split("_")[0] + "_kitti" + '.py'), os.path.join(root_folder, args.model.split("_")[0] + "_kitti" + '.py')) shutil.copy(os.path.join(setting_path, args.setting + '.py'), os.path.join(root_folder, args.model, args.setting + '.py')) folder_ckpt = os.path.join(root_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_summary = os.path.join(root_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num)) # Session Run with tf.Session() as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) sess.run(init_op) # Load the model if args.load_ckpt is not None: saver.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded from {}!'.format(datetime.now(), args.load_ckpt)) for batch_idx in range(batch_num): if (batch_idx != 0 and batch_idx % step_val == 0) or batch_idx == batch_num - 1: ###################################################################### # Validation filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) print('{}-Checkpoint saved to {}!'.format(datetime.now(), filename_ckpt)) losses_val = [] t_1_accs = [] t_1_accs_instance = [] t_1_accs_others = [] # 声明用于计算平均值的list losses_val_hwl = [] losses_val_xyz = [] losses_val_ry = [] rmses_hwl = [] rmses_xyz = [] t_1_accs_ry = [] mdises_ry_angle = [] for batch_val_idx in range(math.ceil(num_val / batch_size)): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_val) batch_size_val = end_idx - start_idx points_batch = data_val[start_idx:end_idx, ...] points_num_batch = data_num_val[start_idx:end_idx, ...] labels_batch = label_val[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[label_val[start_idx:end_idx, ...]] # 取出每个batch的label labels_hwl_batch = label_hwl_val[start_idx:end_idx, ...] labels_xyz_batch = label_xyz_val[start_idx:end_idx, ...] labels_ry_batch = label_ry_val[start_idx:end_idx, ...] labels_ry_reg_batch = label_ry_reg_val[start_idx:end_idx, ...] xforms_np, rotations_np = pf.get_xforms(batch_size_val, scaling_range=scaling_range_val) sess_op_list = [loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op] sess_op_list = sess_op_list + [loss_hwl_op, loss_xyz_op, loss_ry_op, rmse_hwl_op, rmse_xyz_op,t_1_acc_ry_op, mdis_ry_angle_op] sess_feed_dict = {pts_fts: points_batch, indices: pf.get_indices(batch_size_val, sample_num, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: False} sess_feed_dict[labels_hwl] = labels_hwl_batch sess_feed_dict[labels_xyz] = labels_xyz_batch sess_feed_dict[labels_ry] = labels_ry_batch sess_feed_dict[labels_ry_reg] = labels_ry_reg_batch #loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others = sess.run(sess_op_list,feed_dict=sess_feed_dict) loss_val, t_1_acc_val, t_1_acc_val_instance, t_1_acc_val_others, loss_val_hwl, loss_val_xyz,loss_val_ry, rmse_hwl_val, rmse_xyz_val, t_1_acc_ry_val, mdis_ry_angle_val =sess.run(sess_op_list, feed_dict=sess_feed_dict) #print('{}-[Val ]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f}'.format(datetime.now(), batch_val_idx, loss_val, t_1_acc_val)) print('{}-[Val ]-Iter: {:06d} Loss: {:.4f} Loss_hwl: {:.4f} Loss_xyz: {:.4f} Loss_ry: {:.4f}Rmse_hwl: {:.4f} Rmse_xyz: {:.4f} T - 1Ry_Acc: {:.4f} Mdis_ry_angle: {:.4f} T - 1Acc: {:.4f}'.format(datetime.now(), batch_val_idx, loss_val, loss_val_hwl, loss_val_xyz, loss_val_ry,rmse_hwl_val, rmse_xyz_val, t_1_acc_ry_val, mdis_ry_angle_val, t_1_acc_val)) sys.stdout.flush() losses_val.append(loss_val * batch_size_val) t_1_accs.append(t_1_acc_val * batch_size_val) t_1_accs_instance.append(t_1_acc_val_instance * batch_size_val) t_1_accs_others.append(t_1_acc_val_others * batch_size_val) # 记录每个iter的评估指标 losses_val_hwl.append(loss_val_hwl * batch_size_val) losses_val_xyz.append(loss_val_xyz * batch_size_val) losses_val_ry.append(loss_val_ry * batch_size_val) rmses_hwl.append(rmse_hwl_val * batch_size_val) rmses_xyz.append(rmse_xyz_val * batch_size_val) t_1_accs_ry.append(t_1_acc_ry_val * batch_size_val) mdises_ry_angle.append(mdis_ry_angle_val * batch_size_val) loss_avg = sum(losses_val)/num_val t_1_acc_avg = sum(t_1_accs) / num_val t_1_acc_instance_avg = sum(t_1_accs_instance) / num_val t_1_acc_others_avg = sum(t_1_accs_others) / num_val loss_avg_hwl = sum(losses_val_hwl) / num_val loss_avg_xyz = sum(losses_val_xyz) / num_val loss_avg_ry = sum(losses_val_ry) / num_val rmse_hwl_avg = sum(rmses_hwl) / num_val rmse_xyz_avg = sum(rmses_xyz) / num_val t_1_acc_ry_avg = sum(t_1_accs_ry) / num_val mdis_ry_angle_avg = sum(mdises_ry_angle) / num_val summaries_feed_dict = { loss_val_avg: loss_avg, t_1_acc_val_avg: t_1_acc_avg, t_1_acc_val_instance_avg: t_1_acc_instance_avg, t_1_acc_val_others_avg: t_1_acc_others_avg} summaries_feed_dict[loss_val_avg_hwl] = loss_avg_hwl summaries_feed_dict[loss_val_avg_xyz] = loss_avg_xyz summaries_feed_dict[loss_val_avg_ry] = loss_avg_ry summaries_feed_dict[rmse_val_hwl_avg] = rmse_hwl_avg summaries_feed_dict[rmse_val_xyz_avg] = rmse_xyz_avg summaries_feed_dict[t_1_acc_val_ry_avg] = t_1_acc_ry_avg summaries_feed_dict[mdis_val_ry_angle_avg] = mdis_ry_angle_avg summaries_val = sess.run(summaries_val_op,feed_dict=summaries_feed_dict) summary_writer.add_summary(summaries_val, batch_idx) #print('{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f}' # .format(datetime.now(), loss_avg, t_1_acc_avg)) print('{}-[Val ]-Average: Loss: {:.4f} Loss_hwl: {:.4f} Loss_xyz: {:.4f} Loss_ry: {:.4f}Rmse_hwl: {:.4f} Rmse_xyz: {:.4f} T - 1Ry_Acc: {:.4f} Mdis_ry_angle: {:.4f} T - 1Acc: {:.4f}'.format(datetime.now(), loss_avg, loss_avg_hwl, loss_avg_xyz, loss_avg_ry, rmse_hwl_avg,rmse_xyz_avg, t_1_acc_ry_avg, mdis_ry_angle_avg,t_1_acc_avg)) sys.stdout.flush() ###################################################################### ###################################################################### # Training start_idx = (batch_size * batch_idx) % num_train end_idx = min(start_idx + batch_size, num_train) batch_size_train = end_idx - start_idx points_batch = data_train[start_idx:end_idx, ...] points_num_batch = data_num_train[start_idx:end_idx, ...] labels_batch = label_train[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] labels_hwl_batch = label_hwl_train[start_idx:end_idx, ...] labels_xyz_batch = label_xyz_train[start_idx:end_idx, ...] labels_ry_batch = label_ry_train[start_idx:end_idx, ...] labels_ry_reg_batch = label_ry_reg_train[start_idx:end_idx, ...] if start_idx + batch_size_train == num_train: #data_train, data_num_train, label_train = \ # data_utils.grouped_shuffle([data_train, data_num_train, label_train]) data_train, img_train, data_num_train, label_train, label_hwl_train, label_xyz_train, \ label_ry_train, label_ry_reg_train = data_utils.grouped_shuffle([data_train, img_train,data_num_train, label_train,label_hwl_train, label_xyz_train,label_ry_train, label_ry_reg_train]) offset = int(random.gauss(0, sample_num // 8)) offset = max(offset, -sample_num // 4) offset = min(offset, sample_num // 4) sample_num_train = sample_num + offset xforms_np, rotations_np = pf.get_xforms(batch_size_train, scaling_range=scaling_range) sess_op_list = [train_op, loss_op, t_1_acc_op, t_1_acc_instance_op, t_1_acc_others_op, summaries_op] sess_op_list = sess_op_list + [loss_hwl_op, loss_xyz_op, loss_ry_op, rmse_hwl_op, rmse_xyz_op,t_1_acc_ry_op, mdis_ry_angle_op] sess_feed_dict = {pts_fts: points_batch, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True} # 为placeholder feed label sess_feed_dict[labels_hwl] = labels_hwl_batch sess_feed_dict[labels_xyz] = labels_xyz_batch sess_feed_dict[labels_ry] = labels_ry_batch sess_feed_dict[labels_ry_reg] = labels_ry_reg_batch #_, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries = sess.run(sess_op_list,feed_dict=sess_feed_dict) _, loss, t_1_acc, t_1_acc_instance, t_1_acc_others, summaries, loss_hwl, loss_xyz, loss_ry,rmse_hwl, rmse_xyz, t_1_acc_ry, mdis_ry_angle=sess.run(sess_op_list, feed_dict=sess_feed_dict) #print('{}-[Train]-Iter: {:06d} Loss_seg: {:.4f} T-1 Acc: {:.4f}' # .format(datetime.now(), batch_idx, loss, t_1_acc)) print('{}-[Train]-Iter: {:06d} Loss_seg: {:.4f} Loss_hwl: {:.4f} Loss_xyz: {:.4f} Loss_ry:{:.4f} Rmse_hwl: {:.4f} Rmse_xyz: {:.4f} T - 1Ry_Acc: {:.4f} Mdis_ry_angle: {:.4f} T - 1Acc:{:.4f}'.format(datetime.now(), batch_idx, loss, loss_hwl, loss_xyz, loss_ry, rmse_hwl, rmse_xyz,t_1_acc_ry, mdis_ry_angle, t_1_acc)) summary_writer.add_summary(summaries, batch_idx) sys.stdout.flush() ###################################################################### print('{}-Done!'.format(datetime.now()))
def main(): args = AttrDict() # Path of training set ground truth file list (.txt) args.filelist = os.path.join(SCENENN_DIR, 'train_files.txt') # Path of validation set ground truth file list (.txt) args.filelist_val = os.path.join(SCENENN_DIR, 'test_files.txt') # Path of a check point file to load args.load_ckpt = os.path.join(ROOT_DIR, '..', 'models', 'pretrained_scannet', 'ckpts', 'iter-354000') # Base directory where model checkpoint and summary files get saved in separate subdirectories args.save_folder = os.path.join(ROOT_DIR, '..', 'models') # PointCNN model to use args.model = 'pointcnn_seg' # Model setting to use args.setting = 'scenenn_x8_2048_fps' time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') model_save_folder = os.path.join( args.save_folder, '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid())) if not os.path.exists(model_save_folder): os.makedirs(model_save_folder) # sys.stdout = open(os.path.join(model_save_folder, 'log.txt'), 'w') print('PID:', os.getpid()) print(args) model = importlib.import_module(args.model) setting_path = os.path.join(ROOT_DIR, args.model) sys.path.append(setting_path) setting = importlib.import_module(args.setting) num_epochs = setting.num_epochs batch_size = setting.batch_size sample_num = setting.sample_num step_val = setting.step_val label_weights_list = setting.label_weights rotation_range = setting.rotation_range rotation_range_val = setting.rotation_range_val scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val # Prepare inputs print('{}-Preparing datasets...'.format(datetime.now())) is_list_of_h5_list = not data_utils.is_h5_list(args.filelist) if is_list_of_h5_list: seg_list = data_utils.load_seg_list(args.filelist) seg_list_idx = 0 filelist_train = seg_list[seg_list_idx] seg_list_idx = seg_list_idx + 1 else: filelist_train = args.filelist data_train, _, data_num_train, label_train, _ = data_utils.load_seg( filelist_train) data_val, _, data_num_val, label_val, _ = data_utils.load_seg( args.filelist_val) # shuffle data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) num_train = data_train.shape[0] point_num = data_train.shape[1] num_val = data_val.shape[0] print('{}-{:d}/{:d} training/validation samples.'.format( datetime.now(), num_train, num_val)) batch_num = (num_train * num_epochs + batch_size - 1) // batch_size print('{}-{:d} training batches.'.format(datetime.now(), batch_num)) batch_num_val = math.ceil(num_val / batch_size) print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val)) ###################################################################### # Placeholders print('{}-Initializing TF-placeholders...'.format(datetime.now())) indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts') labels_seg = tf.placeholder(tf.int64, shape=(None, point_num), name='labels_seg') labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights') ###################################################################### pts_fts_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled') features_augmented = None if setting.data_dim > 3: points_sampled, features_sampled = tf.split( pts_fts_sampled, [3, setting.data_dim - 3], axis=-1, name='split_points_features') if setting.use_extra_features: if setting.with_normal_feature: if setting.data_dim < 6: print('Only 3D normals are supported!') exit() elif setting.data_dim == 6: features_augmented = pf.augment(features_sampled, rotations) else: normals, rest = tf.split(features_sampled, [3, setting.data_dim - 6]) normals_augmented = pf.augment(normals, rotations) features_augmented = tf.concat([normals_augmented, rest], axis=-1) else: features_augmented = features_sampled else: points_sampled = pts_fts_sampled points_augmented = pf.augment(points_sampled, xforms, jitter_range) labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled') labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled') print('{}-Initializing net...'.format(datetime.now())) net = model.Net(points_augmented, features_augmented, is_training, setting) logits = net.logits probs = tf.nn.softmax(logits, name='probs') predictions = tf.argmax(probs, axis=-1, name='predictions') loss_op = tf.losses.sparse_softmax_cross_entropy( labels=labels_sampled, logits=logits, weights=labels_weights_sampled) with tf.name_scope('metrics'): loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op) t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy( labels_sampled, predictions, weights=labels_weights_sampled) t_1_per_class_acc_op, t_1_per_class_acc_update_op = \ tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class, weights=labels_weights_sampled) reset_metrics_op = tf.variables_initializer([ var for var in tf.local_variables() if var.name.split('/')[0] == 'metrics' ]) _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train']) _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val']) _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val']) lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() print('{}-Setting up optimizer...'.format(datetime.now())) if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): last_layer_train_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, 'logits') last_layer_train_op = optimizer.minimize( loss_op + reg_loss, global_step=global_step, var_list=last_layer_train_vars) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=None) variables_to_restore = get_variables_to_restore(exclude=['logits']) restorer = tf.train.Saver(var_list=variables_to_restore, max_to_keep=None) # backup all code # code_folder = os.path.abspath(os.path.dirname(__file__)) # shutil.copytree(code_folder, os.path.join(model_save_folder, os.path.basename(code_folder)), symlinks=True) folder_ckpt = os.path.join(model_save_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_summary = os.path.join(model_save_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum( [np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Number of model parameters: {:d}.'.format( datetime.now(), parameter_num)) with tf.Session() as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) print('{}-Initializing variables...'.format(datetime.now())) sess.run(init_op) # Load the model if args.load_ckpt is not None: print('{}-Loading checkpoint from {}...'.format( datetime.now(), args.load_ckpt)) restorer.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded.'.format(datetime.now())) for batch_idx_train in tqdm(range(batch_num), ncols=60): if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \ or batch_idx_train == batch_num - 1: ###################################################################### # Validation filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) tqdm.write('{}-Checkpoint saved to {}!'.format( datetime.now(), filename_ckpt)) sess.run(reset_metrics_op) for batch_val_idx in range(batch_num_val): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_val) batch_size_val = end_idx - start_idx points_batch = data_val[start_idx:end_idx, ...] points_num_batch = data_num_val[start_idx:end_idx, ...] labels_batch = label_val[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] xforms_np, rotations_np = pf.get_xforms( batch_size_val, rotation_range=rotation_range_val, scaling_range=scaling_range_val, order=setting.rotation_order) sess.run( [ loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_val, sample_num, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: False, }) loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val = sess.run( [ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_val_op ]) summary_writer.add_summary(summaries_val, batch_idx_train) tqdm.write( '{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), loss_val, t_1_acc_val, t_1_per_class_acc_val)) sys.stdout.flush() ###################################################################### ###################################################################### # Training start_idx = (batch_size * batch_idx_train) % num_train end_idx = min(start_idx + batch_size, num_train) batch_size_train = end_idx - start_idx points_batch = data_train[start_idx:end_idx, ...] points_num_batch = data_num_train[start_idx:end_idx, ...] labels_batch = label_train[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] if start_idx + batch_size_train == num_train: if is_list_of_h5_list: filelist_train_prev = seg_list[(seg_list_idx - 1) % len(seg_list)] filelist_train = seg_list[seg_list_idx % len(seg_list)] if filelist_train != filelist_train_prev: data_train, _, data_num_train, label_train, _ = data_utils.load_seg( filelist_train) num_train = data_train.shape[0] seg_list_idx = seg_list_idx + 1 data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) offset = int( random.gauss(0, sample_num * setting.sample_num_variance)) offset = max(offset, -sample_num * setting.sample_num_clip) offset = min(offset, sample_num * setting.sample_num_clip) sample_num_train = sample_num + offset xforms_np, rotations_np = pf.get_xforms( batch_size_train, rotation_range=rotation_range, scaling_range=scaling_range, order=setting.rotation_order) sess.run(reset_metrics_op) sess.run( [ last_layer_train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True, }) if batch_idx_train % 10 == 0: loss, t_1_acc, t_1_per_class_acc, summaries = sess.run([ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_op ]) summary_writer.add_summary(summaries, batch_idx_train) # tqdm.write('{}-[Train]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' # .format(datetime.now(), batch_idx_train, loss, t_1_acc, t_1_per_class_acc)) sys.stdout.flush() ###################################################################### print('{}-Done!'.format(datetime.now()))