def post_build(self, sess, rng): print "get_variables_to_restore():" print get_variables_to_restore() print "get_model_variables():" print get_model_variables() restore = assign_from_checkpoint_fn(ck_name, get_variables_to_restore()[2:]) restore(sess)
def post_build(self, sess, rng): """ this function is highly critical, several ways of restoring variables from checkpoint did not work note that get_model_variables() returrns an empty list, so cannot be used, using get_variables_to_restore() the list has few spurious variables that are not found in the checkpoint (and are of no use) current solution is to import from flat_cifar a list of the scopes used when defining the network """ scopes = flat_cifar.scopes print "get_variables_to_restore():" print get_variables_to_restore() var_lst = get_variables_to_restore(include=scopes) print "get_variables_to_restore( include=scopes ):" print var_lst restore = assign_from_checkpoint_fn(ck_name, var_lst) restore(sess)
def get_model_init_fn(train_logdir, tf_initial_checkpoint, initialize_last_layer, last_layers, ignore_missing_vars=True): """Gets the function initializing model variables from a checkpoint. Args: train_logdir: Log directory for training. tf_initial_checkpoint: TensorFlow checkpoint for initialization. initialize_last_layer: Initialize last layer or not. last_layers: Last layers of the model. ignore_missing_vars: Ignore missing variables in the checkpoint. Returns: Initialization function. """ if tf_initial_checkpoint is None: tf.logging.info('Not initializing the model from a checkpoint.') return None if tf.train.latest_checkpoint(train_logdir): tf.logging.info('Ignoring initialization; other checkpoint exists') return None tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint) # Variables that will not be restored. # MobilenetV3_Large #exclude_list = ['global_step'] # MobilenetV3_Small #exclude_list = ['global_step', 'image_pooling/', 'aspp0/', 'decoder/feature_projection0', 'MobilenetV3/expanded_conv/squeeze_excite/'] exclude_list = ['global_step', 'decoder/feature_projection0', 'decoder/decoder_conv0_depthwise', 'decoder/decoder_conv0_pointwise', 'decoder/decoder_conv0_pointwise/'] #exclude_list = ['global_step'] if not initialize_last_layer: exclude_list.extend(last_layers) variables_to_restore = contrib_framework.get_variables_to_restore( exclude=exclude_list) if variables_to_restore: init_op, init_feed_dict = contrib_framework.assign_from_checkpoint( tf_initial_checkpoint, variables_to_restore, ignore_missing_vars=ignore_missing_vars) global_step = tf.train.get_or_create_global_step() def restore_fn(sess): sess.run(init_op, init_feed_dict) sess.run([global_step]) return restore_fn return None
def LoadPart(self, loadpath, alpha, flag, graphRevealFlag=False, graphPath='logs/'): restorePart = framework.get_variables_to_restore( include=['Attention_Value/kernel:0', 'Attention_Value/bias:0']) saver = tensorflow.train.Saver(restorePart) saver.restore(self.session, loadpath) with tensorflow.variable_scope('Attention_Value', reuse=tensorflow.AUTO_REUSE): self.parameters['Extract_W'] = tensorflow.get_variable('kernel') self.parameters['Extract_b'] = tensorflow.get_variable('bias') with tensorflow.variable_scope('Punishment'): self.parameters['Attention_Origin_W'] = tensorflow.Variable( initial_value=self.session.run('Attention_Value/kernel:0'), trainable=False, name='Attention_Origin_W') self.parameters['Attention_Origin_b'] = tensorflow.Variable( initial_value=self.session.run('Attention_Value/bias:0'), trainable=False, name='Attention_Origin_b') self.parameters['Distance_W'] = tensorflow.abs( self.parameters['Extract_W'] - self.parameters['Attention_Origin_W']) self.parameters['Distance_b'] = tensorflow.abs( self.parameters['Extract_b'] - self.parameters['Attention_Origin_b']) if flag == 'L1': self.parameters['PunishmentLoss'] = alpha * ( tensorflow.reduce_mean(self.parameters['Distance_W']) + self.parameters['Distance_b']) self.parameters['Cost'] = self.parameters['Loss'] + self.parameters[ 'PunishmentLoss'] # with tensorflow.variable_scope('Optimizer'): self.train = tensorflow.train.RMSPropOptimizer( learning_rate=self.learningRate).minimize(self.parameters['Cost']) self.decode, self.logProbability = tensorflow.nn.ctc_beam_search_decoder( inputs=self.parameters['Logits_TimeMajor'], sequence_length=self.seqLenInput, merge_repeated=False) self.decodeDense = tensorflow.sparse_tensor_to_dense( sp_input=self.decode[0]) self.session.run(tensorflow.global_variables_initializer()) saver.restore(self.session, loadpath) if graphRevealFlag: tensorflow.summary.FileWriter(graphPath, self.session.graph)
def get_model_init_fn(train_logdir, tf_initial_checkpoint, initialize_last_layer, last_layers, ignore_missing_vars=False): """Gets the function initializing model variables from a checkpoint. Args: train_logdir: Log directory for training. tf_initial_checkpoint: TensorFlow checkpoint for initialization. initialize_last_layer: Initialize last layer or not. last_layers: Last layers of the model. ignore_missing_vars: Ignore missing variables in the checkpoint. Returns: Initialization function. """ if tf_initial_checkpoint is None: tf.compat.v1.logging.info( 'Not initializing the model from a checkpoint.') return None if tf.train.latest_checkpoint(train_logdir): tf.compat.v1.logging.info( 'Ignoring initialization; other checkpoint exists') return None tf.compat.v1.logging.info('Initializing model from path: %s', tf_initial_checkpoint) # Variables that will not be restored. exclude_list = ['global_step', 'logits'] if not initialize_last_layer: exclude_list.extend(last_layers) variables_to_restore = contrib_framework.get_variables_to_restore( exclude=exclude_list) if variables_to_restore: init_op, init_feed_dict = contrib_framework.assign_from_checkpoint( tf_initial_checkpoint, variables_to_restore, ignore_missing_vars=ignore_missing_vars) global_step = tf.compat.v1.train.get_or_create_global_step() def restore_fn(sess): sess.run(init_op, init_feed_dict) sess.run([global_step]) return restore_fn return None
def main(_): """Main function""" # Make outputdir if not os.path.exists(FLAGS.outputdir): os.makedirs(FLAGS.outputdir) # Load video list vid_lst = os.listdir(FLAGS.segmented_dir) vid_lst.sort() # Load label dictionary lbl_list = open(FLAGS.lbl_dict_pth).read().splitlines() n_classes = len(lbl_list) if FLAGS.has_bg_lbl: n_classes += 1 # Use the load_snippet_pths_test in data writer to get frames and labels dataset_writer = dataset_factory.get_writer(FLAGS.datasetname) writer = dataset_writer() # set default graph with tf.Graph().as_default(): # build network net = networks_factory.build_net(FLAGS.netname, n_classes, FLAGS.snippet_len, FLAGS.target_height, FLAGS.target_width, max_time_gap=FLAGS.max_time_gap, trainable=False) # extract features feat = net.get_output(FLAGS.featname) # load pretrained weights if '.pkl' in FLAGS.pretrained_model: assign_ops = networks_utils.load_pretrained( FLAGS.pretrained_model, ignore_missing=True, extension='pkl', initoffset=FLAGS.usemotionloss) else: variables_to_restore = get_variables_to_restore() saver = tf.train.Saver(variables_to_restore) def init_fn(sess): tf.logging.info('Restoring checkpoint...') return saver.restore(sess, FLAGS.pretrained_model) # create session with tf.Session() as sess: # initialization sess.run([ tf.global_variables_initializer(), tf.local_variables_initializer() ]) if '.pkl' in FLAGS.pretrained_model: sess.run(assign_ops) else: init_fn(sess) # for each video in video list n_vids = len(vid_lst) snippet_len = FLAGS.snippet_len + 1 # original snippet_len without flow for vid_id in range(n_vids): # skip existing feature files output_fname = '{}.avi.mat'.format(vid_lst[vid_id]) if os.path.exists(os.path.join(FLAGS.outputdir, output_fname)): print('{} already exists'.format(output_fname)) continue # load all file names and labels vid = vid_lst[vid_id] print('\nExtracting features for ' + vid) fname_lst, lbl_lst = writer.load_snippet_pths_test( FLAGS.segmented_dir, [vid], FLAGS.lbl_dict_pth, FLAGS.bg_lbl, FLAGS.ext, FLAGS.frameskip) fname_lst = [x[0] for x in fname_lst] # prefetch all frames of a video frames_all = read_n_compute_flow(fname_lst) # prepare indices n_frames = len(lbl_lst) left = snippet_len // 2 # correct, because snippet_len was increased by 1 right = snippet_len - left # go through the video frames in acausal fashion frame_id = left feats_per_vid = [] groundtruths_per_vid = [] pbar = ProgressBar(max_value=n_frames) while frame_id < n_frames - right + 1: # produce inputs snippet_batch = [] lbl_batch = [] for _ in range(FLAGS.batch_size): if frame_id + right > n_frames: break # ignore 1 last frame because this is flow snippet = frames_all[frame_id - left:frame_id + right - 1] # this is in sync with label from appearance stream lbl = lbl_lst[frame_id] snippet_batch.append(snippet) lbl_batch.append(lbl) frame_id += FLAGS.stride feed_dict = { net.data_raw: snippet_batch, net.labels_raw: lbl_batch } # extract features feat_ = sess.run(feat, feed_dict=feed_dict) # append data for i in range(feat_.shape[0]): feats_per_vid.append(feat_[i]) groundtruths_per_vid.append(lbl_batch[i]) pbar.update(frame_id) # produce mat file for a video feats_per_vid = np.array(feats_per_vid, dtype=np.float32) groundtruths_per_vid = np.array(groundtruths_per_vid) make_mat_file(output_fname, feats_per_vid, groundtruths_per_vid, expected_length=n_frames // FLAGS.stride) pass pass pass
def fit(dataset_constants, lmbda, pretrained_model, checkpoint_dir, msssim_loss=False, preprocess_threads=8, patchsize=256, batchsize=8, last_step=200000, validate=False): """Trains the model.""" train_glob = dataset_constants["train_images_glob"] x_train_files = sorted(glob.glob(train_glob)) train_dataset = tf.data.Dataset.from_tensor_slices(x_train_files) train_dataset = train_dataset.shuffle(buffer_size=len(x_train_files)).repeat() train_dataset = train_dataset.map(read_png, num_parallel_calls=preprocess_threads) train_dataset = train_dataset.map(lambda x: tf.random_crop(x, (patchsize, patchsize, 3))) train_dataset = train_dataset.batch(batchsize) train_dataset = train_dataset.prefetch(batchsize) x_train_unscaled = train_dataset.make_one_shot_iterator().get_next() x_train = x_train_unscaled / 255. train_loss, train_bpp, train_mse, train_msssim, x_tilde, \ _, _, _, _, _, _, layers = build_model(x_train, lmbda, mode = 'training', msssim_loss=msssim_loss) train_summary = log_all_summaries(x_train_unscaled, x_tilde, train_loss, train_bpp, train_mse, train_msssim, "train") if validate: def set_shape(x): x.set_shape(dataset_constants["image_shape"] +[3]) return x val_batchsize = batchsize // 4 val_preprocess_threads = min(preprocess_threads, val_batchsize) val_glob = dataset_constants["val_images_glob"] x_val_files = glob.glob(val_glob) val_dataset = tf.data.Dataset.from_tensor_slices(x_val_files) val_dataset = val_dataset.shuffle(buffer_size=len(x_val_files)).repeat() val_dataset = val_dataset.map(read_png, num_parallel_calls=val_preprocess_threads) val_dataset = val_dataset.map(set_shape, num_parallel_calls=val_preprocess_threads) val_dataset = val_dataset.batch(val_batchsize) val_dataset = val_dataset.prefetch(val_batchsize) x_val_unscaled = val_dataset.make_one_shot_iterator().get_next() x_val = x_val_unscaled / 255. val_loss, val_bpp, val_mse, val_msssim, x_hat, _, _, _, _,_, _, _ = \ build_model(x_val, lmbda, 'testing', layers, msssim_loss) val_summary = log_all_summaries(x_val_unscaled, x_hat, val_loss, val_bpp, val_mse, val_msssim, "val") step = tf.train.get_or_create_global_step() main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) main_step = main_optimizer.minimize(train_loss, global_step=step) _, _, entropy_bottleneck, _, _ = layers aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) aux_step = aux_optimizer.minimize(sum(entropy_bottleneck.losses)) train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates) hooks = [ tf.train.StopAtStepHook(last_step=last_step), tf.train.NanTensorHook(train_loss), tf.train.SummarySaverHook(save_secs=120, output_dir=checkpoint_dir, summary_op=train_summary) ] if validate: hooks += [ tf.train.SummarySaverHook(save_secs=120, output_dir=checkpoint_dir, summary_op=val_summary) ] exclude_list = ['global_step'] variables_to_restore = get_variables_to_restore(exclude=exclude_list) pre_train_saver = tf.train.Saver(variables_to_restore) def load_pretrain(scaffold, sess): pre_train_saver.restore(sess, save_path=pretrained_model) with tf.train.MonitoredTrainingSession( hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=25000, save_summaries_secs=None, save_summaries_steps=None, scaffold=tf.train.Scaffold(init_fn=load_pretrain, saver=tf.train.Saver(max_to_keep=11))) as sess: while not sess.should_stop(): sess.run(train_op)
def eval_one(ckpt_fname, frames, labels, eval_fn): """ Run evaluation on a specific checkpoint Args: ckpt_fname: file name of the checkpoint to restore from frames: a list, each item is a sub-list of all frames per video labels: a list, each item is a sub-list of corresponding labels Returns: final_acc_frame: final accuracy on frame level final_acc_vid: final accuracy on video level """ # retrieve label dictionary lbl_list = open(FLAGS.labels_fname).read().splitlines() n_classes = len(lbl_list) if FLAGS.has_bg_lbl: n_classes += 1 # set verbosity for info level only tf.logging.set_verbosity(tf.logging.INFO) # contruct graph and build models with tf.Graph().as_default(): tf.logging.info('Evaluating checkpoint %s', ckpt_fname) # build network net = networks_factory.build_net(FLAGS.netname, n_classes, FLAGS.snippet_len, FLAGS.target_height, FLAGS.target_width, max_time_gap=FLAGS.max_time_gap, trainable=False) # restore checkpoint function global_step = get_or_create_global_step() variables_to_restore = get_variables_to_restore() saver = tf.train.Saver(variables_to_restore) def init_fn(sess): tf.logging.info('Restoring checkpoint...') return saver.restore(sess, ckpt_fname) # metrics to predict, has to be defined after restoring checkpoint # unlike in training, otherwise it will reload the old values (accuracy_vid, confusion_vid, metrics_op) = net.create_metrics() if isinstance(confusion_vid, list): confusion_vid_img = [ colorize_tensor(x, extend=True) for x in confusion_vid ] else: confusion_vid_img = colorize_tensor(confusion_vid, extend=True) # summaries if isinstance(accuracy_vid, list): for item in accuracy_vid: tf.summary.scalar('accuracy/' + item.op.name, item) else: tf.summary.scalar('accuracy', accuracy_vid) if isinstance(confusion_vid_img, list): for item in confusion_vid_img: tf.summary.image('confusion/' + item.op.name, item) else: tf.summary.image('confusion', confusion_vid_img) summary_op = tf.summary.merge_all() # evaluation phase results_dict = eval_fn(net, init_fn, frames, labels, metrics_op, accuracy_vid, summary_op, global_step, confusion_vid, os.path.basename(ckpt_fname)) return results_dict
num_noisy_tags]) # noisy tags ex.) [1 0 0 1 0] y = tf.placeholder(tf.float32, shape=[batch_size, num_classes]) q = tf.reduce_sum(y, 1) # quantity keep_prob = tf.placeholder(tf.float32) is_training = tf.placeholder(tf.bool) # model # resnet_v1 101 with slim.arg_scope(resnet_v1.resnet_arg_scope()): net, end_points = resnet_v1.resnet_v1_101(img, num_classes, is_training=False) net_logit = tf.squeeze(net) # tensorflow operation for load pretrained weights variables_to_restore = get_variables_to_restore( exclude=['resnet_v1_101/logits', 'resnet_v1_101/AuxLogits']) init_fn = assign_from_checkpoint_fn('resnet_v1_101.ckpt', variables_to_restore) # multiscale resnet_v1 101 visual_features, fusion_logit = multiscale_resnet101(end_points, num_classes, is_training) textual_features, textual_logit = mlp(tag, num_classes, is_training) refined_features = tf.concat([visual_features, textual_features], 1) # score is prediction score, and k is label quantity score = multi_class_classification_model(refined_features, num_classes) k = label_quantity_prediction_model(refined_features, keep_prob) k = tf.reshape(k, shape=[batch_size]) # make trainable variable list var_list0 = [
def train(train_dir, config, dataset_fn, checkpoints_to_keep=5, keep_checkpoint_every_n_hours=1, num_steps=None, master='', num_sync_workers=0, num_ps_tasks=0, task=0): """Train loop.""" tf.gfile.MakeDirs(train_dir) is_chief = (task == 0) if is_chief: _trial_summary(config.hparams, config.train_examples_path or config.tfds_name, train_dir) with tf.Graph().as_default(): with tf.device( tf.train.replica_device_setter(num_ps_tasks, merge_devices=True)): model = config.model model.build(config.hparams, config.data_converter.output_depth, encoder_train=config.encoder_train, decoder_train=config.decoder_train) optimizer = model.train(**_get_input_tensors(dataset_fn(), config)) restored_vars = _get_restore_vars(config.var_train_pattern) _set_trainable_vars(config.var_train_pattern) hooks = [] if num_sync_workers: optimizer = tf.train.SyncReplicasOptimizer( optimizer, num_sync_workers) hooks.append(optimizer.make_session_run_hook(is_chief)) grads, var_list = zip(*optimizer.compute_gradients(model.loss)) global_norm = tf.global_norm(grads) tf.summary.scalar('global_norm', global_norm) if config.hparams.clip_mode == 'value': g = config.hparams.grad_clip clipped_grads = [ tf.clip_by_value(grad, -g, g) for grad in grads ] elif config.hparams.clip_mode == 'global_norm': clipped_grads = tf.cond( global_norm < config.hparams.grad_norm_clip_to_zero, lambda: tf.clip_by_global_norm(grads, config.hparams.grad_clip, use_norm=global_norm)[0], lambda: [tf.zeros(tf.shape(g)) for g in grads]) else: raise ValueError('Unknown clip_mode: {}'.format( config.hparams.clip_mode)) train_op = optimizer.apply_gradients(zip(clipped_grads, var_list), global_step=model.global_step, name='train_step') logging_dict = { 'global_step': model.global_step, 'loss': model.loss } hooks.append( tf.train.LoggingTensorHook(logging_dict, every_n_iter=5)) if num_steps: hooks.append(tf.train.StopAtStepHook(last_step=num_steps)) variables_to_restore = contrib_framework.get_variables_to_restore( include=[v.name for v in restored_vars]) init_assign_op, init_feed_dict = contrib_framework.assign_from_checkpoint( config.pretrained_path, variables_to_restore) def InitAssignFn(scaffold, sess): sess.run(init_assign_op, init_feed_dict) scaffold = tf.train.Scaffold( init_fn=InitAssignFn, saver=tf.train.Saver( max_to_keep=checkpoints_to_keep, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, )) contrib_training.train(train_op=train_op, logdir=train_dir, scaffold=scaffold, hooks=hooks, save_checkpoint_secs=60, master=master, is_chief=is_chief)
def get_model_init_fn(train_logdir, tf_initial_checkpoint, initialize_first_layer, initialize_last_layer, last_layers=None, restore_adam=False, ignore_missing_vars=False): """Gets the function initializing model variables from a checkpoint. Args: train_logdir: Log directory for training. tf_initial_checkpoint: TensorFlow checkpoint for initialization. initialize_last_layer: Initialize first layer or not. initialize_last_layer: Initialize last layer or not. last_layers: Last layers of the model. restore_adam: Restore Adam optimization parameters or not. ignore_missing_vars: Ignore missing variables in the checkpoint. Returns: Initialization function. """ if tf_initial_checkpoint is None: tf.logging.info('Not initializing the model from a checkpoint.') return None if tf.train.latest_checkpoint(train_logdir): tf.logging.info('Ignoring initialization; other checkpoint exists') return None tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint) # Variables that will not be restored. exclude_list = ['global_step'] if not initialize_last_layer: exclude_list.extend(last_layers) if not initialize_first_layer: exclude_list.append('resnet_v1_50/conv1_1/weights:0') variables_to_restore = contrib_framework.get_variables_to_restore( exclude=exclude_list) # Restore without Adam parameters if not restore_adam: new_v = [] for v in variables_to_restore: if "Adam" not in v.name: new_v.append(v) variables_to_restore = new_v if variables_to_restore: init_op, init_feed_dict = contrib_framework.assign_from_checkpoint( tf_initial_checkpoint, variables_to_restore, ignore_missing_vars=ignore_missing_vars) global_step = tf.train.get_or_create_global_step() def restore_fn(scaffold, sess): sess.run(init_op, init_feed_dict) sess.run([global_step]) return restore_fn return None
def main(): args = AttrDict() # Path of training set ground truth file list (.txt) args.filelist = os.path.join(SCENENN_DIR, 'train_files.txt') # Path of validation set ground truth file list (.txt) args.filelist_val = os.path.join(SCENENN_DIR, 'test_files.txt') # Path of a check point file to load args.load_ckpt = os.path.join(ROOT_DIR, '..', 'models', 'pretrained_scannet', 'ckpts', 'iter-354000') # Base directory where model checkpoint and summary files get saved in separate subdirectories args.save_folder = os.path.join(ROOT_DIR, '..', 'models') # PointCNN model to use args.model = 'pointcnn_seg' # Model setting to use args.setting = 'scenenn_x8_2048_fps' time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S') model_save_folder = os.path.join( args.save_folder, '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid())) if not os.path.exists(model_save_folder): os.makedirs(model_save_folder) # sys.stdout = open(os.path.join(model_save_folder, 'log.txt'), 'w') print('PID:', os.getpid()) print(args) model = importlib.import_module(args.model) setting_path = os.path.join(ROOT_DIR, args.model) sys.path.append(setting_path) setting = importlib.import_module(args.setting) num_epochs = setting.num_epochs batch_size = setting.batch_size sample_num = setting.sample_num step_val = setting.step_val label_weights_list = setting.label_weights rotation_range = setting.rotation_range rotation_range_val = setting.rotation_range_val scaling_range = setting.scaling_range scaling_range_val = setting.scaling_range_val jitter = setting.jitter jitter_val = setting.jitter_val # Prepare inputs print('{}-Preparing datasets...'.format(datetime.now())) is_list_of_h5_list = not data_utils.is_h5_list(args.filelist) if is_list_of_h5_list: seg_list = data_utils.load_seg_list(args.filelist) seg_list_idx = 0 filelist_train = seg_list[seg_list_idx] seg_list_idx = seg_list_idx + 1 else: filelist_train = args.filelist data_train, _, data_num_train, label_train, _ = data_utils.load_seg( filelist_train) data_val, _, data_num_val, label_val, _ = data_utils.load_seg( args.filelist_val) # shuffle data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) num_train = data_train.shape[0] point_num = data_train.shape[1] num_val = data_val.shape[0] print('{}-{:d}/{:d} training/validation samples.'.format( datetime.now(), num_train, num_val)) batch_num = (num_train * num_epochs + batch_size - 1) // batch_size print('{}-{:d} training batches.'.format(datetime.now(), batch_num)) batch_num_val = math.ceil(num_val / batch_size) print('{}-{:d} testing batches per test.'.format(datetime.now(), batch_num_val)) ###################################################################### # Placeholders print('{}-Initializing TF-placeholders...'.format(datetime.now())) indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices") xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms") rotations = tf.placeholder(tf.float32, shape=(None, 3, 3), name="rotations") jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range") global_step = tf.Variable(0, trainable=False, name='global_step') is_training = tf.placeholder(tf.bool, name='is_training') pts_fts = tf.placeholder(tf.float32, shape=(None, point_num, setting.data_dim), name='pts_fts') labels_seg = tf.placeholder(tf.int64, shape=(None, point_num), name='labels_seg') labels_weights = tf.placeholder(tf.float32, shape=(None, point_num), name='labels_weights') ###################################################################### pts_fts_sampled = tf.gather_nd(pts_fts, indices=indices, name='pts_fts_sampled') features_augmented = None if setting.data_dim > 3: points_sampled, features_sampled = tf.split( pts_fts_sampled, [3, setting.data_dim - 3], axis=-1, name='split_points_features') if setting.use_extra_features: if setting.with_normal_feature: if setting.data_dim < 6: print('Only 3D normals are supported!') exit() elif setting.data_dim == 6: features_augmented = pf.augment(features_sampled, rotations) else: normals, rest = tf.split(features_sampled, [3, setting.data_dim - 6]) normals_augmented = pf.augment(normals, rotations) features_augmented = tf.concat([normals_augmented, rest], axis=-1) else: features_augmented = features_sampled else: points_sampled = pts_fts_sampled points_augmented = pf.augment(points_sampled, xforms, jitter_range) labels_sampled = tf.gather_nd(labels_seg, indices=indices, name='labels_sampled') labels_weights_sampled = tf.gather_nd(labels_weights, indices=indices, name='labels_weight_sampled') print('{}-Initializing net...'.format(datetime.now())) net = model.Net(points_augmented, features_augmented, is_training, setting) logits = net.logits probs = tf.nn.softmax(logits, name='probs') predictions = tf.argmax(probs, axis=-1, name='predictions') loss_op = tf.losses.sparse_softmax_cross_entropy( labels=labels_sampled, logits=logits, weights=labels_weights_sampled) with tf.name_scope('metrics'): loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op) t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy( labels_sampled, predictions, weights=labels_weights_sampled) t_1_per_class_acc_op, t_1_per_class_acc_update_op = \ tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class, weights=labels_weights_sampled) reset_metrics_op = tf.variables_initializer([ var for var in tf.local_variables() if var.name.split('/')[0] == 'metrics' ]) _ = tf.summary.scalar('loss/train', tensor=loss_mean_op, collections=['train']) _ = tf.summary.scalar('t_1_acc/train', tensor=t_1_acc_op, collections=['train']) _ = tf.summary.scalar('t_1_per_class_acc/train', tensor=t_1_per_class_acc_op, collections=['train']) _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val']) _ = tf.summary.scalar('t_1_acc/val', tensor=t_1_acc_op, collections=['val']) _ = tf.summary.scalar('t_1_per_class_acc/val', tensor=t_1_per_class_acc_op, collections=['val']) lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base, global_step, setting.decay_steps, setting.decay_rate, staircase=True) lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min) _ = tf.summary.scalar('learning_rate', tensor=lr_clip_op, collections=['train']) reg_loss = setting.weight_decay * tf.losses.get_regularization_loss() print('{}-Setting up optimizer...'.format(datetime.now())) if setting.optimizer == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op, epsilon=setting.epsilon) elif setting.optimizer == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op, momentum=setting.momentum, use_nesterov=True) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): last_layer_train_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, 'logits') last_layer_train_op = optimizer.minimize( loss_op + reg_loss, global_step=global_step, var_list=last_layer_train_vars) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver(max_to_keep=None) variables_to_restore = get_variables_to_restore(exclude=['logits']) restorer = tf.train.Saver(var_list=variables_to_restore, max_to_keep=None) # backup all code # code_folder = os.path.abspath(os.path.dirname(__file__)) # shutil.copytree(code_folder, os.path.join(model_save_folder, os.path.basename(code_folder)), symlinks=True) folder_ckpt = os.path.join(model_save_folder, 'ckpts') if not os.path.exists(folder_ckpt): os.makedirs(folder_ckpt) folder_summary = os.path.join(model_save_folder, 'summary') if not os.path.exists(folder_summary): os.makedirs(folder_summary) parameter_num = np.sum( [np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) print('{}-Number of model parameters: {:d}.'.format( datetime.now(), parameter_num)) with tf.Session() as sess: summaries_op = tf.summary.merge_all('train') summaries_val_op = tf.summary.merge_all('val') summary_writer = tf.summary.FileWriter(folder_summary, sess.graph) print('{}-Initializing variables...'.format(datetime.now())) sess.run(init_op) # Load the model if args.load_ckpt is not None: print('{}-Loading checkpoint from {}...'.format( datetime.now(), args.load_ckpt)) restorer.restore(sess, args.load_ckpt) print('{}-Checkpoint loaded.'.format(datetime.now())) for batch_idx_train in tqdm(range(batch_num), ncols=60): if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \ or batch_idx_train == batch_num - 1: ###################################################################### # Validation filename_ckpt = os.path.join(folder_ckpt, 'iter') saver.save(sess, filename_ckpt, global_step=global_step) tqdm.write('{}-Checkpoint saved to {}!'.format( datetime.now(), filename_ckpt)) sess.run(reset_metrics_op) for batch_val_idx in range(batch_num_val): start_idx = batch_size * batch_val_idx end_idx = min(start_idx + batch_size, num_val) batch_size_val = end_idx - start_idx points_batch = data_val[start_idx:end_idx, ...] points_num_batch = data_num_val[start_idx:end_idx, ...] labels_batch = label_val[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] xforms_np, rotations_np = pf.get_xforms( batch_size_val, rotation_range=rotation_range_val, scaling_range=scaling_range_val, order=setting.rotation_order) sess.run( [ loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_val, sample_num, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter_val]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: False, }) loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val = sess.run( [ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_val_op ]) summary_writer.add_summary(summaries_val, batch_idx_train) tqdm.write( '{}-[Val ]-Average: Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' .format(datetime.now(), loss_val, t_1_acc_val, t_1_per_class_acc_val)) sys.stdout.flush() ###################################################################### ###################################################################### # Training start_idx = (batch_size * batch_idx_train) % num_train end_idx = min(start_idx + batch_size, num_train) batch_size_train = end_idx - start_idx points_batch = data_train[start_idx:end_idx, ...] points_num_batch = data_num_train[start_idx:end_idx, ...] labels_batch = label_train[start_idx:end_idx, ...] weights_batch = np.array(label_weights_list)[labels_batch] if start_idx + batch_size_train == num_train: if is_list_of_h5_list: filelist_train_prev = seg_list[(seg_list_idx - 1) % len(seg_list)] filelist_train = seg_list[seg_list_idx % len(seg_list)] if filelist_train != filelist_train_prev: data_train, _, data_num_train, label_train, _ = data_utils.load_seg( filelist_train) num_train = data_train.shape[0] seg_list_idx = seg_list_idx + 1 data_train, data_num_train, label_train = \ data_utils.grouped_shuffle([data_train, data_num_train, label_train]) offset = int( random.gauss(0, sample_num * setting.sample_num_variance)) offset = max(offset, -sample_num * setting.sample_num_clip) offset = min(offset, sample_num * setting.sample_num_clip) sample_num_train = sample_num + offset xforms_np, rotations_np = pf.get_xforms( batch_size_train, rotation_range=rotation_range, scaling_range=scaling_range, order=setting.rotation_order) sess.run(reset_metrics_op) sess.run( [ last_layer_train_op, loss_mean_update_op, t_1_acc_update_op, t_1_per_class_acc_update_op ], feed_dict={ pts_fts: points_batch, indices: pf.get_indices(batch_size_train, sample_num_train, points_num_batch), xforms: xforms_np, rotations: rotations_np, jitter_range: np.array([jitter]), labels_seg: labels_batch, labels_weights: weights_batch, is_training: True, }) if batch_idx_train % 10 == 0: loss, t_1_acc, t_1_per_class_acc, summaries = sess.run([ loss_mean_op, t_1_acc_op, t_1_per_class_acc_op, summaries_op ]) summary_writer.add_summary(summaries, batch_idx_train) # tqdm.write('{}-[Train]-Iter: {:06d} Loss: {:.4f} T-1 Acc: {:.4f} T-1 mAcc: {:.4f}' # .format(datetime.now(), batch_idx_train, loss, t_1_acc, t_1_per_class_acc)) sys.stdout.flush() ###################################################################### print('{}-Done!'.format(datetime.now()))