Example #1
0
 def post_build(self, sess, rng):
     print "get_variables_to_restore():"
     print get_variables_to_restore()
     print "get_model_variables():"
     print get_model_variables()
     restore = assign_from_checkpoint_fn(ck_name,
                                         get_variables_to_restore()[2:])
     restore(sess)
Example #2
0
 def post_build(self, sess, rng):
     """
     this function is highly critical, several ways of restoring variables from
     checkpoint did not work
     note that get_model_variables() returrns an empty list, so cannot be used,
     using get_variables_to_restore() the list has few spurious variables that are not
     found in the checkpoint (and are of no use)
     current solution is to import from flat_cifar a list of the scopes used when defining
     the network
     """
     scopes = flat_cifar.scopes
     print "get_variables_to_restore():"
     print get_variables_to_restore()
     var_lst = get_variables_to_restore(include=scopes)
     print "get_variables_to_restore( include=scopes ):"
     print var_lst
     restore = assign_from_checkpoint_fn(ck_name, var_lst)
     restore(sess)
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_last_layer,
                      last_layers,
                      ignore_missing_vars=True):
  """Gets the function initializing model variables from a checkpoint.

  Args:
    train_logdir: Log directory for training.
    tf_initial_checkpoint: TensorFlow checkpoint for initialization.
    initialize_last_layer: Initialize last layer or not.
    last_layers: Last layers of the model.
    ignore_missing_vars: Ignore missing variables in the checkpoint.

  Returns:
    Initialization function.
  """
  if tf_initial_checkpoint is None:
    tf.logging.info('Not initializing the model from a checkpoint.')
    return None

  if tf.train.latest_checkpoint(train_logdir):
    tf.logging.info('Ignoring initialization; other checkpoint exists')
    return None

  tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

  # Variables that will not be restored.

  # MobilenetV3_Large
  #exclude_list = ['global_step']

  # MobilenetV3_Small
  #exclude_list = ['global_step', 'image_pooling/', 'aspp0/', 'decoder/feature_projection0', 'MobilenetV3/expanded_conv/squeeze_excite/']
  exclude_list = ['global_step', 'decoder/feature_projection0', 'decoder/decoder_conv0_depthwise', 'decoder/decoder_conv0_pointwise', 'decoder/decoder_conv0_pointwise/']

  #exclude_list = ['global_step']
  if not initialize_last_layer:
    exclude_list.extend(last_layers)

  variables_to_restore = contrib_framework.get_variables_to_restore(
      exclude=exclude_list)

  if variables_to_restore:
    init_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
        tf_initial_checkpoint,
        variables_to_restore,
        ignore_missing_vars=ignore_missing_vars)
    global_step = tf.train.get_or_create_global_step()

    def restore_fn(sess):
      sess.run(init_op, init_feed_dict)
      sess.run([global_step])

    return restore_fn

  return None
    def LoadPart(self,
                 loadpath,
                 alpha,
                 flag,
                 graphRevealFlag=False,
                 graphPath='logs/'):
        restorePart = framework.get_variables_to_restore(
            include=['Attention_Value/kernel:0', 'Attention_Value/bias:0'])
        saver = tensorflow.train.Saver(restorePart)
        saver.restore(self.session, loadpath)

        with tensorflow.variable_scope('Attention_Value',
                                       reuse=tensorflow.AUTO_REUSE):
            self.parameters['Extract_W'] = tensorflow.get_variable('kernel')
            self.parameters['Extract_b'] = tensorflow.get_variable('bias')

        with tensorflow.variable_scope('Punishment'):
            self.parameters['Attention_Origin_W'] = tensorflow.Variable(
                initial_value=self.session.run('Attention_Value/kernel:0'),
                trainable=False,
                name='Attention_Origin_W')
            self.parameters['Attention_Origin_b'] = tensorflow.Variable(
                initial_value=self.session.run('Attention_Value/bias:0'),
                trainable=False,
                name='Attention_Origin_b')

            self.parameters['Distance_W'] = tensorflow.abs(
                self.parameters['Extract_W'] -
                self.parameters['Attention_Origin_W'])
            self.parameters['Distance_b'] = tensorflow.abs(
                self.parameters['Extract_b'] -
                self.parameters['Attention_Origin_b'])

            if flag == 'L1':
                self.parameters['PunishmentLoss'] = alpha * (
                    tensorflow.reduce_mean(self.parameters['Distance_W']) +
                    self.parameters['Distance_b'])

        self.parameters['Cost'] = self.parameters['Loss'] + self.parameters[
            'PunishmentLoss']
        # with tensorflow.variable_scope('Optimizer'):
        self.train = tensorflow.train.RMSPropOptimizer(
            learning_rate=self.learningRate).minimize(self.parameters['Cost'])
        self.decode, self.logProbability = tensorflow.nn.ctc_beam_search_decoder(
            inputs=self.parameters['Logits_TimeMajor'],
            sequence_length=self.seqLenInput,
            merge_repeated=False)
        self.decodeDense = tensorflow.sparse_tensor_to_dense(
            sp_input=self.decode[0])

        self.session.run(tensorflow.global_variables_initializer())

        saver.restore(self.session, loadpath)

        if graphRevealFlag:
            tensorflow.summary.FileWriter(graphPath, self.session.graph)
Example #5
0
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_last_layer,
                      last_layers,
                      ignore_missing_vars=False):
    """Gets the function initializing model variables from a checkpoint.

  Args:
    train_logdir: Log directory for training.
    tf_initial_checkpoint: TensorFlow checkpoint for initialization.
    initialize_last_layer: Initialize last layer or not.
    last_layers: Last layers of the model.
    ignore_missing_vars: Ignore missing variables in the checkpoint.

  Returns:
    Initialization function.
  """
    if tf_initial_checkpoint is None:
        tf.compat.v1.logging.info(
            'Not initializing the model from a checkpoint.')
        return None

    if tf.train.latest_checkpoint(train_logdir):
        tf.compat.v1.logging.info(
            'Ignoring initialization; other checkpoint exists')
        return None

    tf.compat.v1.logging.info('Initializing model from path: %s',
                              tf_initial_checkpoint)

    # Variables that will not be restored.
    exclude_list = ['global_step', 'logits']
    if not initialize_last_layer:
        exclude_list.extend(last_layers)

    variables_to_restore = contrib_framework.get_variables_to_restore(
        exclude=exclude_list)

    if variables_to_restore:
        init_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
            tf_initial_checkpoint,
            variables_to_restore,
            ignore_missing_vars=ignore_missing_vars)
        global_step = tf.compat.v1.train.get_or_create_global_step()

        def restore_fn(sess):
            sess.run(init_op, init_feed_dict)
            sess.run([global_step])

        return restore_fn

    return None
Example #6
0
def main(_):
    """Main function"""
    # Make outputdir
    if not os.path.exists(FLAGS.outputdir):
        os.makedirs(FLAGS.outputdir)

    # Load video list
    vid_lst = os.listdir(FLAGS.segmented_dir)
    vid_lst.sort()

    # Load label dictionary
    lbl_list = open(FLAGS.lbl_dict_pth).read().splitlines()
    n_classes = len(lbl_list)
    if FLAGS.has_bg_lbl:
        n_classes += 1

    # Use the load_snippet_pths_test in data writer to get frames and labels
    dataset_writer = dataset_factory.get_writer(FLAGS.datasetname)
    writer = dataset_writer()

    # set default graph
    with tf.Graph().as_default():
        # build network
        net = networks_factory.build_net(FLAGS.netname,
                                         n_classes,
                                         FLAGS.snippet_len,
                                         FLAGS.target_height,
                                         FLAGS.target_width,
                                         max_time_gap=FLAGS.max_time_gap,
                                         trainable=False)

        # extract features
        feat = net.get_output(FLAGS.featname)

        # load pretrained weights
        if '.pkl' in FLAGS.pretrained_model:
            assign_ops = networks_utils.load_pretrained(
                FLAGS.pretrained_model,
                ignore_missing=True,
                extension='pkl',
                initoffset=FLAGS.usemotionloss)
        else:
            variables_to_restore = get_variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)

            def init_fn(sess):
                tf.logging.info('Restoring checkpoint...')
                return saver.restore(sess, FLAGS.pretrained_model)

        # create session
        with tf.Session() as sess:
            # initialization
            sess.run([
                tf.global_variables_initializer(),
                tf.local_variables_initializer()
            ])
            if '.pkl' in FLAGS.pretrained_model:
                sess.run(assign_ops)
            else:
                init_fn(sess)

            # for each video in video list
            n_vids = len(vid_lst)
            snippet_len = FLAGS.snippet_len + 1  # original snippet_len without flow

            for vid_id in range(n_vids):
                # skip existing feature files
                output_fname = '{}.avi.mat'.format(vid_lst[vid_id])
                if os.path.exists(os.path.join(FLAGS.outputdir, output_fname)):
                    print('{} already exists'.format(output_fname))
                    continue

                # load all file names and labels
                vid = vid_lst[vid_id]
                print('\nExtracting features for ' + vid)
                fname_lst, lbl_lst = writer.load_snippet_pths_test(
                    FLAGS.segmented_dir, [vid], FLAGS.lbl_dict_pth,
                    FLAGS.bg_lbl, FLAGS.ext, FLAGS.frameskip)
                fname_lst = [x[0] for x in fname_lst]

                # prefetch all frames of a video
                frames_all = read_n_compute_flow(fname_lst)

                # prepare indices
                n_frames = len(lbl_lst)
                left = snippet_len // 2  # correct, because snippet_len was increased by 1
                right = snippet_len - left

                # go through the video frames in acausal fashion
                frame_id = left
                feats_per_vid = []
                groundtruths_per_vid = []
                pbar = ProgressBar(max_value=n_frames)
                while frame_id < n_frames - right + 1:
                    # produce inputs
                    snippet_batch = []
                    lbl_batch = []
                    for _ in range(FLAGS.batch_size):
                        if frame_id + right > n_frames:
                            break

                        # ignore 1 last frame because this is flow
                        snippet = frames_all[frame_id - left:frame_id + right -
                                             1]
                        # this is in sync with label from appearance stream
                        lbl = lbl_lst[frame_id]

                        snippet_batch.append(snippet)
                        lbl_batch.append(lbl)
                        frame_id += FLAGS.stride

                    feed_dict = {
                        net.data_raw: snippet_batch,
                        net.labels_raw: lbl_batch
                    }

                    # extract features
                    feat_ = sess.run(feat, feed_dict=feed_dict)

                    # append data
                    for i in range(feat_.shape[0]):
                        feats_per_vid.append(feat_[i])
                        groundtruths_per_vid.append(lbl_batch[i])
                    pbar.update(frame_id)

                # produce mat file for a video
                feats_per_vid = np.array(feats_per_vid, dtype=np.float32)
                groundtruths_per_vid = np.array(groundtruths_per_vid)
                make_mat_file(output_fname,
                              feats_per_vid,
                              groundtruths_per_vid,
                              expected_length=n_frames // FLAGS.stride)
            pass
        pass
    pass
Example #7
0
def fit(dataset_constants,
        lmbda,
        pretrained_model,
        checkpoint_dir,
        msssim_loss=False,
        preprocess_threads=8,
        patchsize=256,
        batchsize=8,
        last_step=200000,
        validate=False):
  """Trains the model."""
  
  train_glob = dataset_constants["train_images_glob"]
  x_train_files = sorted(glob.glob(train_glob))
  
  train_dataset = tf.data.Dataset.from_tensor_slices(x_train_files)

  train_dataset = train_dataset.shuffle(buffer_size=len(x_train_files)).repeat()
  train_dataset = train_dataset.map(read_png,
                                    num_parallel_calls=preprocess_threads)
  train_dataset = train_dataset.map(lambda x: tf.random_crop(x,
                                                (patchsize, patchsize, 3)))
  train_dataset = train_dataset.batch(batchsize)
  train_dataset = train_dataset.prefetch(batchsize)

  x_train_unscaled = train_dataset.make_one_shot_iterator().get_next()
  x_train = x_train_unscaled / 255.  
  
  train_loss, train_bpp, train_mse, train_msssim, x_tilde, \
    _, _, _, _, _, _, layers = build_model(x_train, lmbda,
                                     mode = 'training', msssim_loss=msssim_loss)  
  train_summary = log_all_summaries(x_train_unscaled, x_tilde, train_loss,
                                    train_bpp, train_mse, train_msssim, "train")
  
  if validate:
    def set_shape(x):
      x.set_shape(dataset_constants["image_shape"] +[3])
      return x

    val_batchsize = batchsize // 4
    val_preprocess_threads = min(preprocess_threads, val_batchsize)
    val_glob = dataset_constants["val_images_glob"]
    x_val_files = glob.glob(val_glob)
    val_dataset = tf.data.Dataset.from_tensor_slices(x_val_files)
    val_dataset = val_dataset.shuffle(buffer_size=len(x_val_files)).repeat()
    val_dataset = val_dataset.map(read_png, 
                                  num_parallel_calls=val_preprocess_threads)
    val_dataset = val_dataset.map(set_shape,
                                  num_parallel_calls=val_preprocess_threads)
    val_dataset = val_dataset.batch(val_batchsize)
    val_dataset = val_dataset.prefetch(val_batchsize)
    x_val_unscaled = val_dataset.make_one_shot_iterator().get_next()    
    x_val = x_val_unscaled / 255.

    val_loss, val_bpp, val_mse, val_msssim, x_hat, _, _, _, _,_, _, _  = \
      build_model(x_val, lmbda, 'testing', layers, msssim_loss)
    val_summary = log_all_summaries(x_val_unscaled, x_hat, val_loss, val_bpp,
                                    val_mse, val_msssim, "val") 

  step = tf.train.get_or_create_global_step()
  main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
  main_step = main_optimizer.minimize(train_loss, global_step=step)

  _, _, entropy_bottleneck, _, _ = layers  

  aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
  aux_step = aux_optimizer.minimize(sum(entropy_bottleneck.losses))
  train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates)

  hooks = [
    tf.train.StopAtStepHook(last_step=last_step),
    tf.train.NanTensorHook(train_loss),
    tf.train.SummarySaverHook(save_secs=120, output_dir=checkpoint_dir,
                                summary_op=train_summary)
  ] 
  if validate:
    hooks += [
      tf.train.SummarySaverHook(save_secs=120, output_dir=checkpoint_dir,
                                 summary_op=val_summary)
    ]
  exclude_list = ['global_step']
  variables_to_restore = get_variables_to_restore(exclude=exclude_list)
  pre_train_saver = tf.train.Saver(variables_to_restore)

  def load_pretrain(scaffold, sess):
    pre_train_saver.restore(sess, save_path=pretrained_model)

  with tf.train.MonitoredTrainingSession(
      hooks=hooks, checkpoint_dir=checkpoint_dir,
      save_checkpoint_steps=25000, save_summaries_secs=None,
      save_summaries_steps=None, 
      scaffold=tf.train.Scaffold(init_fn=load_pretrain,
                                saver=tf.train.Saver(max_to_keep=11))) as sess:
    while not sess.should_stop():
      sess.run(train_op)
Example #8
0
def eval_one(ckpt_fname, frames, labels, eval_fn):
    """ Run evaluation on a specific checkpoint

    Args:
        ckpt_fname: file name of the checkpoint to restore from
        frames: a list, each item is a sub-list of all frames per video
        labels: a list, each item is a sub-list of corresponding labels

    Returns:
        final_acc_frame: final accuracy on frame level
        final_acc_vid: final accuracy on video level
    """
    # retrieve label dictionary
    lbl_list = open(FLAGS.labels_fname).read().splitlines()
    n_classes = len(lbl_list)
    if FLAGS.has_bg_lbl:
        n_classes += 1

    # set verbosity for info level only
    tf.logging.set_verbosity(tf.logging.INFO)

    # contruct graph and build models
    with tf.Graph().as_default():
        tf.logging.info('Evaluating checkpoint %s', ckpt_fname)

        # build network
        net = networks_factory.build_net(FLAGS.netname,
                                         n_classes,
                                         FLAGS.snippet_len,
                                         FLAGS.target_height,
                                         FLAGS.target_width,
                                         max_time_gap=FLAGS.max_time_gap,
                                         trainable=False)

        # restore checkpoint function
        global_step = get_or_create_global_step()
        variables_to_restore = get_variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        def init_fn(sess):
            tf.logging.info('Restoring checkpoint...')
            return saver.restore(sess, ckpt_fname)

        # metrics to predict, has to be defined after restoring checkpoint
        # unlike in training, otherwise it will reload the old values
        (accuracy_vid, confusion_vid, metrics_op) = net.create_metrics()
        if isinstance(confusion_vid, list):
            confusion_vid_img = [
                colorize_tensor(x, extend=True) for x in confusion_vid
            ]
        else:
            confusion_vid_img = colorize_tensor(confusion_vid, extend=True)

        # summaries
        if isinstance(accuracy_vid, list):
            for item in accuracy_vid:
                tf.summary.scalar('accuracy/' + item.op.name, item)
        else:
            tf.summary.scalar('accuracy', accuracy_vid)

        if isinstance(confusion_vid_img, list):
            for item in confusion_vid_img:
                tf.summary.image('confusion/' + item.op.name, item)
        else:
            tf.summary.image('confusion', confusion_vid_img)
        summary_op = tf.summary.merge_all()

        # evaluation phase
        results_dict = eval_fn(net, init_fn, frames, labels, metrics_op,
                               accuracy_vid, summary_op, global_step,
                               confusion_vid, os.path.basename(ckpt_fname))
    return results_dict
Example #9
0
                            num_noisy_tags])  # noisy tags ex.) [1 0 0 1 0]
y = tf.placeholder(tf.float32, shape=[batch_size, num_classes])
q = tf.reduce_sum(y, 1)  # quantity
keep_prob = tf.placeholder(tf.float32)
is_training = tf.placeholder(tf.bool)

# model
# resnet_v1 101
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
    net, end_points = resnet_v1.resnet_v1_101(img,
                                              num_classes,
                                              is_training=False)
net_logit = tf.squeeze(net)

# tensorflow operation for load pretrained weights
variables_to_restore = get_variables_to_restore(
    exclude=['resnet_v1_101/logits', 'resnet_v1_101/AuxLogits'])
init_fn = assign_from_checkpoint_fn('resnet_v1_101.ckpt', variables_to_restore)

# multiscale resnet_v1 101
visual_features, fusion_logit = multiscale_resnet101(end_points, num_classes,
                                                     is_training)
textual_features, textual_logit = mlp(tag, num_classes, is_training)
refined_features = tf.concat([visual_features, textual_features], 1)

# score is prediction score, and k is label quantity
score = multi_class_classification_model(refined_features, num_classes)
k = label_quantity_prediction_model(refined_features, keep_prob)
k = tf.reshape(k, shape=[batch_size])

# make trainable variable list
var_list0 = [
Example #10
0
def train(train_dir,
          config,
          dataset_fn,
          checkpoints_to_keep=5,
          keep_checkpoint_every_n_hours=1,
          num_steps=None,
          master='',
          num_sync_workers=0,
          num_ps_tasks=0,
          task=0):
    """Train loop."""
    tf.gfile.MakeDirs(train_dir)
    is_chief = (task == 0)
    if is_chief:
        _trial_summary(config.hparams, config.train_examples_path
                       or config.tfds_name, train_dir)

    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(num_ps_tasks,
                                               merge_devices=True)):
            model = config.model
            model.build(config.hparams,
                        config.data_converter.output_depth,
                        encoder_train=config.encoder_train,
                        decoder_train=config.decoder_train)
            optimizer = model.train(**_get_input_tensors(dataset_fn(), config))
            restored_vars = _get_restore_vars(config.var_train_pattern)
            _set_trainable_vars(config.var_train_pattern)

            hooks = []
            if num_sync_workers:
                optimizer = tf.train.SyncReplicasOptimizer(
                    optimizer, num_sync_workers)
                hooks.append(optimizer.make_session_run_hook(is_chief))

            grads, var_list = zip(*optimizer.compute_gradients(model.loss))
            global_norm = tf.global_norm(grads)
            tf.summary.scalar('global_norm', global_norm)

            if config.hparams.clip_mode == 'value':
                g = config.hparams.grad_clip
                clipped_grads = [
                    tf.clip_by_value(grad, -g, g) for grad in grads
                ]
            elif config.hparams.clip_mode == 'global_norm':
                clipped_grads = tf.cond(
                    global_norm < config.hparams.grad_norm_clip_to_zero,
                    lambda: tf.clip_by_global_norm(grads,
                                                   config.hparams.grad_clip,
                                                   use_norm=global_norm)[0],
                    lambda: [tf.zeros(tf.shape(g)) for g in grads])
            else:
                raise ValueError('Unknown clip_mode: {}'.format(
                    config.hparams.clip_mode))
            train_op = optimizer.apply_gradients(zip(clipped_grads, var_list),
                                                 global_step=model.global_step,
                                                 name='train_step')

            logging_dict = {
                'global_step': model.global_step,
                'loss': model.loss
            }

            hooks.append(
                tf.train.LoggingTensorHook(logging_dict, every_n_iter=5))
            if num_steps:
                hooks.append(tf.train.StopAtStepHook(last_step=num_steps))

            variables_to_restore = contrib_framework.get_variables_to_restore(
                include=[v.name for v in restored_vars])
            init_assign_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
                config.pretrained_path, variables_to_restore)

            def InitAssignFn(scaffold, sess):
                sess.run(init_assign_op, init_feed_dict)

            scaffold = tf.train.Scaffold(
                init_fn=InitAssignFn,
                saver=tf.train.Saver(
                    max_to_keep=checkpoints_to_keep,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                ))
            contrib_training.train(train_op=train_op,
                                   logdir=train_dir,
                                   scaffold=scaffold,
                                   hooks=hooks,
                                   save_checkpoint_secs=60,
                                   master=master,
                                   is_chief=is_chief)
Example #11
0
def get_model_init_fn(train_logdir,
                      tf_initial_checkpoint,
                      initialize_first_layer,
                      initialize_last_layer,
                      last_layers=None,
                      restore_adam=False,
                      ignore_missing_vars=False):
    """Gets the function initializing model variables from a checkpoint.
  Args:
    train_logdir: Log directory for training.
    tf_initial_checkpoint: TensorFlow checkpoint for initialization.
    initialize_last_layer: Initialize first layer or not.
    initialize_last_layer: Initialize last layer or not.
    last_layers: Last layers of the model.
    restore_adam: Restore Adam optimization parameters or not.
    ignore_missing_vars: Ignore missing variables in the checkpoint.
  Returns:
    Initialization function.
  """
    if tf_initial_checkpoint is None:
        tf.logging.info('Not initializing the model from a checkpoint.')
        return None

    if tf.train.latest_checkpoint(train_logdir):
        tf.logging.info('Ignoring initialization; other checkpoint exists')
        return None

    tf.logging.info('Initializing model from path: %s', tf_initial_checkpoint)

    # Variables that will not be restored.
    exclude_list = ['global_step']
    if not initialize_last_layer:
        exclude_list.extend(last_layers)

    if not initialize_first_layer:
        exclude_list.append('resnet_v1_50/conv1_1/weights:0')

    variables_to_restore = contrib_framework.get_variables_to_restore(
        exclude=exclude_list)

    # Restore without Adam parameters
    if not restore_adam:
        new_v = []
        for v in variables_to_restore:
            if "Adam" not in v.name:
                new_v.append(v)
        variables_to_restore = new_v

    if variables_to_restore:
        init_op, init_feed_dict = contrib_framework.assign_from_checkpoint(
            tf_initial_checkpoint,
            variables_to_restore,
            ignore_missing_vars=ignore_missing_vars)

        global_step = tf.train.get_or_create_global_step()

        def restore_fn(scaffold, sess):
            sess.run(init_op, init_feed_dict)
            sess.run([global_step])

        return restore_fn

    return None
def main():
    args = AttrDict()
    # Path of training set ground truth file list (.txt)
    args.filelist = os.path.join(SCENENN_DIR, 'train_files.txt')
    # Path of validation set ground truth file list (.txt)
    args.filelist_val = os.path.join(SCENENN_DIR, 'test_files.txt')
    # Path of a check point file to load
    args.load_ckpt = os.path.join(ROOT_DIR, '..', 'models',
                                  'pretrained_scannet', 'ckpts', 'iter-354000')
    # Base directory where model checkpoint and summary files get saved in separate subdirectories
    args.save_folder = os.path.join(ROOT_DIR, '..', 'models')
    # PointCNN model to use
    args.model = 'pointcnn_seg'
    # Model setting to use
    args.setting = 'scenenn_x8_2048_fps'

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    model_save_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(model_save_folder):
        os.makedirs(model_save_folder)

    # sys.stdout = open(os.path.join(model_save_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())
    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(ROOT_DIR, args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)
    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    print('{}-Initializing TF-placeholders...'.format(datetime.now()))
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    print('{}-Initializing net...'.format(datetime.now()))
    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    print('{}-Setting up optimizer...'.format(datetime.now()))
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        last_layer_train_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'logits')
        last_layer_train_op = optimizer.minimize(
            loss_op + reg_loss,
            global_step=global_step,
            var_list=last_layer_train_vars)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    variables_to_restore = get_variables_to_restore(exclude=['logits'])
    restorer = tf.train.Saver(var_list=variables_to_restore, max_to_keep=None)

    # backup all code
    # code_folder = os.path.abspath(os.path.dirname(__file__))
    # shutil.copytree(code_folder, os.path.join(model_save_folder, os.path.basename(code_folder)), symlinks=True)

    folder_ckpt = os.path.join(model_save_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(model_save_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Number of model parameters: {:d}.'.format(
        datetime.now(), parameter_num))

    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        print('{}-Initializing variables...'.format(datetime.now()))
        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            print('{}-Loading checkpoint from {}...'.format(
                datetime.now(), args.load_ckpt))
            restorer.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded.'.format(datetime.now()))

        for batch_idx_train in tqdm(range(batch_num), ncols=60):
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                tqdm.write('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                for batch_val_idx in range(batch_num_val):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)
                    sess.run(
                        [
                            loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            indices:
                            pf.get_indices(batch_size_val, sample_num,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            labels_seg:
                            labels_batch,
                            labels_weights:
                            weights_batch,
                            is_training:
                            False,
                        })

                loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        summaries_val_op
                    ])
                summary_writer.add_summary(summaries_val, batch_idx_train)
                tqdm.write(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            t_1_per_class_acc_val))
                sys.stdout.flush()
                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                if is_list_of_h5_list:
                    filelist_train_prev = seg_list[(seg_list_idx - 1) %
                                                   len(seg_list)]
                    filelist_train = seg_list[seg_list_idx % len(seg_list)]
                    if filelist_train != filelist_train_prev:
                        data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
                            filelist_train)
                        num_train = data_train.shape[0]
                    seg_list_idx = seg_list_idx + 1
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    last_layer_train_op, loss_mean_update_op,
                    t_1_acc_update_op, t_1_per_class_acc_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    True,
                })
            if batch_idx_train % 10 == 0:
                loss, t_1_acc, t_1_per_class_acc, summaries = sess.run([
                    loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                    summaries_op
                ])
                summary_writer.add_summary(summaries, batch_idx_train)
                # tqdm.write('{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                #            .format(datetime.now(), batch_idx_train, loss, t_1_acc, t_1_per_class_acc))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))