Esempio n. 1
0
 def test_local_variable(self):
   with self.test_session() as sess:
     self.assertEquals([], tf.local_variables())
     value0 = 42
     tf.contrib.framework.local_variable(value0)
     value1 = 43
     tf.contrib.framework.local_variable(value1)
     variables = tf.local_variables()
     self.assertEquals(2, len(variables))
     self.assertRaises(tf.OpError, sess.run, variables)
     tf.variables_initializer(variables).run()
     self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
Esempio n. 2
0
def count_variables_by_type(variables=None):
  """Returns a dict mapping dtypes to number of variables and scalars.

  Args:
    variables: iterable of `tf.Variable`s, or None. If None is passed, then all
      global and local variables in the current graph are used.

  Returns:
    A dict mapping tf.dtype keys to a dict containing the keys 'num_scalars' and
      'num_variables'.
  """
  if variables is None:
    variables = tf.global_variables() + tf.local_variables()
  unique_types = set(v.dtype.base_dtype for v in variables)
  results_dict = {}
  for dtype in unique_types:
    if dtype == tf.string:
      tf.logging.warning(
          "NB: string Variables present. The memory usage for these  Variables "
          "will not be accurately computed as it depends on the exact strings "
          "stored in a particular session.")
    vars_of_type = [v for v in variables if v.dtype.base_dtype == dtype]
    num_scalars = sum(v.shape.num_elements() for v in vars_of_type)
    results_dict[dtype] = {
        "num_variables": len(vars_of_type),
        "num_scalars": num_scalars
    }
  return results_dict
Esempio n. 3
0
 def get_post_init_ops():
     """
     Copy values of variables on GPU 0 to other GPUs.
     """
     # literally all variables, because it's better to sync optimizer-internal variables as well
     all_vars = tf.global_variables() + tf.local_variables()
     var_by_name = dict([(v.name, v) for v in all_vars])
     post_init_ops = []
     for v in all_vars:
         if not v.name.startswith('tower'):
             continue
         if v.name.startswith('tower0'):
             logger.warn("[SyncMultiGPUReplicatedBuilder] variable "
                         "{} has prefix 'tower0', this is unexpected.".format(v.name))
             continue        # TODO some vars (EMA) may still startswith tower0
         # in this trainer, the master name doesn't have the towerx/ prefix
         split_name = v.name.split('/')
         prefix = split_name[0]
         realname = '/'.join(split_name[1:])
         if prefix in realname:
             logger.error("[SyncMultiGPUReplicatedBuilder] variable "
                          "{} has its prefix {} appears multiple times in its name!".format(v.name, prefix))
         copy_from = var_by_name.get(realname)
         assert copy_from is not None, var_by_name.keys()
         post_init_ops.append(v.assign(copy_from.read_value()))
     logger.info(
         "'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops)))
     return tf.group(*post_init_ops, name='sync_variables_from_main_tower')
Esempio n. 4
0
 def testNotInLocalVariables(self):
   with self.test_session():
     with tf.variable_scope('A'):
       a = tf.contrib.framework.model_variable('a', [5])
       self.assertTrue(a in tf.global_variables())
       self.assertTrue(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
       self.assertFalse(a in tf.local_variables())
Esempio n. 5
0
    def testVariables(self):
        with self.test_session() as s:

            # make some variables
            _ = [tf.Variable([1, 2, 3], dtype=tf.float32), tf.Variable([1, 2, 3], dtype=tf.int32)]
            s.run(tf.initialize_all_variables())
            _ = [v.name for v in tf.all_variables()]
            _ = [v.name for v in tf.local_variables()]
def cnn_train(config, data_len, embed, pf_r1, pf_r2):
    config.data_len = data_len
    tf.reset_default_graph()
    with tf.Session() as session:
        # build model
        with tf.variable_scope("cnn_ch", reuse=None):
            m_train = ch_model(config)
        with tf.variable_scope("cnn_ch", reuse=True):
            m_valid = ch_model(config)

        doc_datas, pf_r1s, pf_r2s, labels = read_batch(config.csv_file, config, True)
        doc_datas_v, pf_r1s_V, pf_r2s_v, labels_v = read_batch(config.csv_file, config, False)


        for item in tf.all_variables():
            print "var: ", item
        for item in tf.local_variables():
            print "local:", item

        loss, _ = m_train.inference(doc_datas, pf_r1s, pf_r2s, labels)
        loss_v, acc_v = m_valid.inference(doc_datas_v, pf_r1s_V, pf_r2s_v, labels_v)
        train_op = m_train.train(loss)

        tf.initialize_all_variables().run()
        tf.initialize_local_variables().run()
        m_train.assign_word_embed(session, embed)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=session)
        
        epoch = 0
        step = 0
        min_cost = sys.maxint
        try:
            while not coord.should_stop():
                _, f_l = session.run([train_op, loss])
                step += 1
                if step == config.data_len // config.batch_size:
                    cost = 0.0
                    acc = 0.0
                    for i in range(step):
                        v_l, acc_l = session.run([loss_v, acc_v])
                        cost += v_l
                        acc += acc_l
                    cost /= step
                    acc /= step
                    if cost < min_cost:
                        min_cost = cost
                        print "save model as cost:", cost
                        m_train.saver.save(session, config.model_path)
                    print "epoch: ", epoch, "loss: ", cost, "acc: ", acc, "step:", step
                    step = 0
                    epoch += 1
        except tf.errors.OutOfRangeError:
            print("Done training")
        finally:
            coord.request_stop()
        coord.join(threads)
Esempio n. 7
0
 def setup_graph(self):
     """ Will setup the assign operator for that variable. """
     all_vars = tf.global_variables() + tf.local_variables()
     for v in all_vars:
         if v.name == self.var_name:
             self.var = v
             break
     else:
         raise ValueError("{} is not a variable in the graph!".format(self.var_name))
Esempio n. 8
0
 def testCreateVariable(self):
   with self.test_session():
     with tf.variable_scope('A'):
       a = tf.contrib.framework.variable('a', [5])
       self.assertEquals(a.op.name, 'A/a')
       self.assertListEqual(a.get_shape().as_list(), [5])
       self.assertTrue(a in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
       self.assertFalse(a in tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
       self.assertFalse(a in tf.local_variables())
Esempio n. 9
0
  def testUsage(self, custom_getter_fn):
    # Create a module with no custom getters.
    linear = snt.Linear(10)

    # Create a module within the scope of an 'override args' custom getter.
    local_custom_getter = custom_getter_fn(
        collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.variable_scope("", custom_getter=local_custom_getter):
      local_linear = snt.Linear(10)

    # Connect both modules to the graph, creating their variables.
    inputs = tf.placeholder(dtype=tf.float32, shape=(7, 11))
    linear(inputs)
    local_linear(inputs)

    self.assertIn(linear.w, tf.global_variables())
    self.assertNotIn(linear.w, tf.local_variables())
    self.assertIn(local_linear.w, tf.local_variables())
    self.assertNotIn(local_linear.w, tf.global_variables())
Esempio n. 10
0
  def testExplicitArgOverridden(self):
    # Create a variable within the scope of an 'override args' custom getter.
    local_custom_getter = snt.custom_getters.override_args(
        collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.variable_scope("", custom_getter=local_custom_getter):
      # Explicitly specify an arg that disagrees with the custom getter.
      v = tf.get_variable("v", (), collections=[tf.GraphKeys.GLOBAL_VARIABLES])

    # The custom getter should win.
    self.assertIn(v, tf.local_variables())
    self.assertNotIn(v, tf.global_variables())
Esempio n. 11
0
def _initialize_variables():
    """Utility to initialize uninitialized variables on the fly.
    """
    variables = tf.local_variables()
    uninitialized_variables = []
    for v in variables:
        if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
            uninitialized_variables.append(v)
            v._keras_initialized = True
    if uninitialized_variables:
        sess = K.get_session()
        sess.run(tf.variables_initializer(uninitialized_variables))
Esempio n. 12
0
def log_variables(variables=None):
  """Logs variable information.

  This function logs the name, shape, type, collections, and device for either
  all variables or a given iterable of variables.

  Args:
    variables: iterable of variables; if not provided, then all variables
        (in the default graph) are logged.
  """
  if variables is None:
    variables = tf.global_variables() + tf.local_variables()
  for row in format_variables(variables, join_lines=False):
    tf.logging.info(row)
Esempio n. 13
0
  def testWithNested(self, custom_getter_fn):
    # Create a module with a custom getter, within the scope of an
    # 'override args' custom getter.
    local_custom_getter = custom_getter_fn(
        collections=[tf.GraphKeys.LOCAL_VARIABLES])
    with tf.variable_scope("", custom_getter=local_custom_getter):
      local_linear = snt.Linear(10, custom_getter=_suffix_custom_getter)

    # Connect the module to the graph, creating its variables.
    inputs = tf.placeholder(dtype=tf.float32, shape=(7, 11))
    local_linear(inputs)

    # Both custom getters should be effective.
    self.assertIn(local_linear.w, tf.local_variables())
    self.assertNotIn(local_linear.w, tf.global_variables())
    self.assertEqual("linear/w_test", local_linear.w.op.name)
Esempio n. 14
0
def log_variables(variables=None):
  """Logs variable information.

  This function logs the name, shape, type, collections, and device for either
  all variables or a given iterable of variables. In the "Device" columns,
  the nature of the variable (legacy or resource (for ResourceVariables)) is
  also specified in parenthesis.

  Args:
    variables: iterable of variables; if not provided, then all variables
        (in the default graph) are logged.
  """
  if variables is None:
    variables = tf.global_variables() + tf.local_variables()
  for row in format_variables(variables, join_lines=False):
    tf.logging.info(row)
Esempio n. 15
0
def guarantee_initialized_variables(session, variables=None):
    """Guarantee that all the specified variables are initialized.

    If a variable is already initialized, leave it alone. Otherwise, initialize it.

    If no variables are specified, checks all variables in the default graph.

    Args:
        variables (list[tf.Variable])
    """
    name_to_var = {v.op.name: v for v in tf.global_variables() + tf.local_variables()}
    uninitialized_variables = list(name_to_var[name] for name in
                                   session.run(tf.report_uninitialized_variables(variables)))
    init_op = tf.variables_initializer(uninitialized_variables)
    session.run(init_op)
    return uninitialized_variables
Esempio n. 16
0
  def get_post_init_ops(self):
    # Copy initialized variables for variables on the parameter server
    # to the local copy of the variable.

    local_vars = tf.local_variables()
    local_var_by_name = dict(
        [(self._strip_port(v.name), v) for v in local_vars])
    post_init_ops = []
    for v in tf.global_variables():
      if v.name.startswith(PS_SHADOW_VAR_PREFIX + '/v0/'):
        prefix = self._strip_port(
            v.name[len(PS_SHADOW_VAR_PREFIX + '/v0'):])
        for i in range(self.benchmark_cnn.num_gpus):
          name = 'v%s%s' % (i, prefix)
          if name in local_var_by_name:
            copy_to = local_var_by_name[name]
            post_init_ops.append(copy_to.assign(v.read_value()))
    return post_init_ops
Esempio n. 17
0
 def savable_variables(self):
   """Returns a list/dict of savable variables to pass to tf.train.Saver."""
   params = {}
   for v in tf.global_variables():
     assert (v.name.startswith(PS_SHADOW_VAR_PREFIX + '/v0/') or
             v.name == 'global_step:0')
     # We store variables in the checkpoint with the shadow variable prefix
     # removed so we can evaluate checkpoints in non-distributed replicated
     # mode. The checkpoints can also be loaded for training in
     # distributed_replicated mode.
     name = self._strip_port(self._remove_shadow_var_prefix_if_present(v.name))
     params[name] = v
   for v in tf.local_variables():
     # Non-trainable variables, such as batch norm moving averages, do not have
     # corresponding global shadow variables, so we add them here. Trainable
     # local variables have corresponding global shadow variables, which were
     # added in the global variable loop above.
     if v.name.startswith('v0/') and v not in tf.trainable_variables():
       params[self._strip_port(v.name)] = v
   return params
Esempio n. 18
0
  def testGetAllLocalVariables(self):
    def local_custom_getter(getter, *args, **kwargs):
      kwargs["trainable"] = False
      if "collections" in kwargs and kwargs["collections"] is not None:
        kwargs["collections"] += [tf.GraphKeys.LOCAL_VARIABLES]
      else:
        kwargs["collections"] = [tf.GraphKeys.LOCAL_VARIABLES]
      return getter(*args, **kwargs)

    inputs = tf.placeholder(tf.float32, [10, 10])
    # Create a new ModuleWithSubmodules that uses all local variables
    with tf.variable_scope("", custom_getter=local_custom_getter):
      submodule_a = SimpleModule(name="simple_submodule")
      submodule_b = ComplexModule(name="complex_submodule")
      local_module = ModuleWithSubmodules(
          submodule_a=submodule_a, submodule_b=submodule_b)
    local_module(inputs)  # pylint: disable=not-callable

    self.assertEqual(
        0,
        len(local_module.get_all_variables()))
    self.assertEqual(0, len(tf.all_variables()))
    self.assertEqual(12, len(tf.local_variables()))

    all_variables = local_module.get_all_variables(
        collection=tf.GraphKeys.LOCAL_VARIABLES)
    all_variable_names = sorted([str(v.name) for v in all_variables])
    self.assertEqual([
        "complex_submodule/linear_1/b:0",
        "complex_submodule/linear_1/w:0",
        "complex_submodule/linear_2/b:0",
        "complex_submodule/linear_2/w:0",
        "module_with_submodules/complex_build/linear_1/b:0",
        "module_with_submodules/complex_build/linear_1/w:0",
        "module_with_submodules/complex_build/linear_2/b:0",
        "module_with_submodules/complex_build/linear_2/w:0",
        "module_with_submodules/simple_build/b:0",
        "module_with_submodules/simple_build/w:0",
        "simple_submodule/b:0",
        "simple_submodule/w:0",
    ], all_variable_names)
def cnn_test(config, data_len):
    config.data_len = data_len
    tf.reset_default_graph()
    config.max_epoch = 1
    with tf.Session() as session:
        # build model
        with tf.variable_scope("cnn_ch", reuse=None):
            m_valid = ch_model(config)

        doc_datas_v, pf_r1s_V, pf_r2s_v, labels_v = read_batch(config.csv_file, config, False)


        loss_v, acc_v = m_valid.inference(doc_datas_v, pf_r1s_V, pf_r2s_v, labels_v)
        m_valid.saver.restore(session, config.model_path)

        for item in tf.all_variables():
            print "var:", item
        for item in tf.local_variables():
            print "local:", item
        tf.initialize_local_variables().run()


        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=session)
        step = 0
        cost = 0.0
        acc = 0.0
        try:
            while not coord.should_stop():
                v_l, acc_l = session.run([loss_v, acc_v])
                cost += v_l
                acc += acc_l
                step += 1
        except tf.errors.OutOfRangeError:
            cost /= step
            acc /= step
            print "loss: ", cost, "acc: ", acc
            print("Done testing")
        finally:
            coord.request_stop()
        coord.join(threads)
Esempio n. 20
0
    def get_post_init_ops():
        """
        Copy values of variables on GPU 0 to other GPUs.
        """
        # literally all variables, because it's better to sync optimizer-internal variables as well
        all_vars = tf.global_variables() + tf.local_variables()
        var_by_name = dict([(v.name, v) for v in all_vars])
        trainable_names = set([x.name for x in tf.trainable_variables()])
        post_init_ops = []

        def log_failure(name, reason):
            if name in trainable_names:
                msg = "This variable is trainable, so this is probably a fatal error."
            else:
                msg = "This variable is non-trainable. Ignore this warning if you know it's OK to leave it out-of-sync."
            logger.warn("[ReplicatedTrainer] Do not know how to sync variable '{}' across GPUs. "
                        "Reason: {} ".format(name, reason) + msg)

        for v in all_vars:
            if not v.name.startswith('tower'):
                continue
            if v.name.startswith('tower0'):
                # in this trainer, the master name doesn't have the towerx/ prefix
                log_failure(v.name, "Name should not have prefix 'tower0' in this trainer!")
                continue        # TODO some vars (EMA) may still startswith tower0

            split_name = v.name.split('/')
            prefix = split_name[0]
            realname = '/'.join(split_name[1:])
            if prefix in realname:
                log_failure(v.name, "Prefix {} appears multiple times in its name!".format(prefix))
                continue
            copy_from = var_by_name.get(realname)
            if copy_from is not None:
                post_init_ops.append(v.assign(copy_from.read_value()))
            else:
                log_failure(v.name, "Cannot find {} in the graph!".format(realname))
        logger.info(
            "'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops)))
        return tf.group(*post_init_ops, name='sync_variables_from_main_tower')
def create_global_variables(local_optimizer_vars = []):
	"""Creates global variables for local variables on the graph.
	Skips variables local variables that are created for
	local optimization.

	Returns dictionarys for local-to-global and global-to-local
	variable mappings.
	"""
	local_to_global = {}
	global_to_local = {}
	with tf.device('/job:ps/task:0'):
		for v in tf.local_variables():
			if v not in local_optimizer_vars:
				v_g = tf.get_variable('g/'+v.op.name,
					shape = v.shape,
					dtype = v.dtype,
					trainable=True,
					collections=[tf.GraphKeys.GLOBAL_VARIABLES,
								tf.GraphKeys.TRAINABLE_VARIABLES])
				local_to_global[v] = v_g
				global_to_local[v_g] = v
	return local_to_global,global_to_local
Esempio n. 22
0
 def _get_initial_sync_op(self):
     """
     Get the op to copy-initialized all local variables from PS.
     """
     def strip_port(s):
         if s.endswith(':0'):
             return s[:-2]
         return s
     local_vars = tf.local_variables()
     local_var_by_name = dict([(strip_port(v.name), v) for v in local_vars])
     ops = []
     nr_shadow_vars = len(self._shadow_vars)
     for v in self._shadow_vars:
         vname = strip_port(v.name)
         for i in range(self.nr_gpu):
             name = 'tower%s/%s' % (i, vname)
             assert name in local_var_by_name, \
                 "Shadow variable {} doesn't match a corresponding local variable!".format(v.name)
             copy_to = local_var_by_name[name]
             # logger.info("{} -> {}".format(v.name, copy_to.name))
             ops.append(copy_to.assign(v.read_value()))
     return tf.group(*ops, name='sync_{}_variables_from_ps'.format(nr_shadow_vars))
    def train(self, mnist, expected_steps=1000):
        if self.train_mode is False:
            raise Exception("Sorry I can't train it...")
        if self.finetune is True:
            to_do_var_list = list()
            var_list = tf.global_variables()+tf.local_variables()
            for one_var in var_list:
                if self.sess.run(tf.is_variable_initialized(one_var)):
                    pass
                else:
                    to_do_var_list.append(one_var)
            self.sess.run(tf.variables_initializer(to_do_var_list))
        else:
            self.sess.run(tf.global_variables_initializer())

        for i in range(expected_steps):
            batch = mnist.train.next_batch(50)
            if i % 100 == 0:
                train_accuracy = self.accuracy.eval(session=self.sess, feed_dict={self.x: batch[0], self.y_: batch[1], self.keep_prob: 1.0})
                print("step %d, training accuracy %g" % (i, train_accuracy))
            self.train_step.run(session=self.sess, feed_dict={self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5})

        print("test accuracy %g" % self.accuracy.eval(session=self.sess, feed_dict={self.x: mnist.test.images, self.y_: mnist.test.labels, self.keep_prob: 1.0}))
Esempio n. 24
0
def main():
    MODEL_DICT = {}
    MODEL_DICT['config'] = Config()
    for name in MODEL_DICT['config'].model_names:
        MODEL_DICT[name] = Model(config=MODEL_DICT['config'], name=name)

    aug_info_pl = tf.placeholder(dtype=tf.string, name='aug_info_pl')
    aug_info_summary = tf.summary.text('aug_info_summary', aug_info_pl)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(MODEL_DICT['config'].gpu_id)
    with tf.Session(config=MODEL_DICT['config'].gpu_config) as sess:
        # summary writer
        summary_writer_dict = {}
        for model_name in MODEL_DICT['config'].model_names:
            summary_writer_dict[model_name] = tf.summary.FileWriter(
                os.path.join(MODEL_DICT['config'].tb_dir, model_name))

        aug_info = []
        aug_info.append('stft -> learnable filter - > hvqt + dense')

        if MODEL_DICT['config'].train_or_inference.inference is not None:
            aug_info.append('inference with {}'.format(
                MODEL_DICT['config'].train_or_inference.inference))
        elif MODEL_DICT['config'].train_or_inference.from_saved is not None:
            aug_info.append('continue training from {}'.format(
                MODEL_DICT['config'].train_or_inference.from_saved))

        if MODEL_DICT['config'].train_or_inference.inference is None:
            _model_prefix = MODEL_DICT[
                'config'].train_or_inference.model_prefix
            if _model_prefix is not None:
                aug_info.append('model prefix {}'.format(_model_prefix))
            else:
                aug_info.append('model will not be saved')

        aug_info.append('tb dir - {}'.format(MODEL_DICT['config'].tb_dir))
        aug_info.append('debug mode - {}'.format(
            MODEL_DICT['config'].debug_mode))
        aug_info.append('snippet length - {}'.format(
            MODEL_DICT['config'].snippet_len))
        aug_info.append('batch size - 1')
        aug_info.append('num of batches per epoch - {}'.format(
            MODEL_DICT['config'].batches_per_epoch))
        aug_info.append('num of epochs - {}'.format(
            MODEL_DICT['config'].num_epochs))
        aug_info.append('training start time - {}'.format(
            datetime.datetime.now()))
        aug_info = '\n\n'.join(aug_info)
        logging.info(aug_info)
        summary_writer_dict[MODEL_DICT['config'].model_names[0]].add_summary(
            sess.run(aug_info_summary, feed_dict={aug_info_pl: aug_info}))

        logging.info('local vars -')
        for idx, var in enumerate(tf.local_variables()):
            logging.info('{}\t{}'.format(idx, var.op.name))

        logging.info('trainable vars -')
        for idx, var in enumerate(
                tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)):
            logging.info('{}\t{}\t{}'.format(idx, var.op.name, var.shape))

        OP_DICT = {}
        for model_name in MODEL_DICT['config'].model_names:
            m = MODEL_DICT[model_name]

            if m.is_training:
                tmp = dict(batch=[m.training_op, m.stats['update_op']],
                           epoch=m.tb_proto)
            else:
                tmp = dict(batch=dict(rec_idx=m.batch['rec_idx'],
                                      update_op=m.stats['update_op']),
                           epoch=m.tb_proto)

            OP_DICT[model_name] = tmp

        def test_or_validate_fn(valid_or_test, global_step=None):
            assert valid_or_test in MODEL_DICT[
                'config'].model_names and 'training' not in valid_or_test

            ops_per_batch = OP_DICT[valid_or_test]['batch']
            ops_per_epoch = OP_DICT[valid_or_test]['epoch']

            batch_idx = 0
            _dataset_test = MODEL_DICT[valid_or_test].dataset
            total_num_snippets = sum(
                len(rec_dict['split_list']) for rec_dict in _dataset_test)
            num_recs = len(_dataset_test)

            for rec_idx in xrange(num_recs):
                rec_dict = _dataset_test[rec_idx]
                split_list = rec_dict['split_list']
                num_snippets = len(split_list)
                num_frames = len(rec_dict['sg'])
                assert num_frames == MODEL_DICT[
                    valid_or_test].num_frames_vector[rec_idx]
                for snippet_idx in xrange(num_snippets):
                    logging.debug('batch {}/{}'.format(batch_idx + 1,
                                                       total_num_snippets))
                    tmp = sess.run(ops_per_batch)
                    _rec_idx = tmp['rec_idx'][0]
                    assert _rec_idx == rec_idx
                    batch_idx += 1
            summary_writer_dict[valid_or_test].add_summary(
                sess.run(ops_per_epoch), global_step)

        def check_all_global_vars_initialized_fn():
            tmp = sess.run(
                tf.report_uninitialized_variables(tf.global_variables()))
            assert tmp.size == 0

        if MODEL_DICT['config'].train_or_inference.inference is not None:
            save_path = MODEL_DICT['config'].train_or_inference.inference
            tf.train.Saver().restore(sess, save_path)
            check_all_global_vars_initialized_fn()

            sess.run(tf.initializers.variables(tf.local_variables()))
            for model_name in MODEL_DICT['config'].model_names:
                if 'test' in model_name:
                    sess.run(MODEL_DICT[model_name].
                             reinitializable_iter_for_dataset.initializer)
                    logging.info('do inference on {}'.format(model_name))
                    test_or_validate_fn(model_name)

        elif MODEL_DICT['config'].train_or_inference.from_saved is not None:
            save_path = MODEL_DICT['config'].train_or_inference.from_saved
            tf.train.Saver().restore(sess, save_path)
            check_all_global_vars_initialized_fn()

            logging.info('reproduce results ...')
            sess.run(tf.initializers.variables(tf.local_variables()))
            for model_name in MODEL_DICT['config'].model_names:
                if 'training' not in model_name:
                    sess.run(MODEL_DICT[model_name].
                             reinitializable_iter_for_dataset.initializer)

            for model_name in MODEL_DICT['config'].model_names:
                if 'training' not in model_name:
                    logging.info(model_name)
                    test_or_validate_fn(model_name, 0)

        else:  # neither inference or from saved
            logging.info('train from scratch')
            sess.run(tf.initializers.variables(tf.global_variables()))
            check_all_global_vars_initialized_fn()

        if MODEL_DICT['config'].train_or_inference.inference is None:
            check_all_global_vars_initialized_fn()
            if MODEL_DICT[
                    'config'].train_or_inference.model_prefix is not None:
                assert 'model_saver' not in MODEL_DICT
                MODEL_DICT['model_saver'] = tf.train.Saver(max_to_keep=200)

            for training_valid_test_epoch_idx in xrange(
                    MODEL_DICT['config'].num_epochs):
                logging.info('\n\ncycle - {}/{}'.format(
                    training_valid_test_epoch_idx + 1,
                    MODEL_DICT['config'].num_epochs))

                sess.run(tf.initializers.variables(tf.local_variables()))

                # to enable prefetch
                for model_name in MODEL_DICT['config'].model_names:
                    if 'training' not in model_name:
                        sess.run(MODEL_DICT[model_name].
                                 reinitializable_iter_for_dataset.initializer)

                for training_valid_or_test in MODEL_DICT['config'].model_names:
                    logging.info(training_valid_or_test)

                    if 'training' in training_valid_or_test:
                        ops_per_batch = OP_DICT[training_valid_or_test][
                            'batch']
                        ops_per_epoch = OP_DICT[training_valid_or_test][
                            'epoch']
                        for batch_idx in xrange(
                                MODEL_DICT['config'].batches_per_epoch):
                            sess.run(ops_per_batch)
                            logging.debug('batch - {}/{}'.format(
                                batch_idx + 1,
                                MODEL_DICT['config'].batches_per_epoch))
                        summary_writer_dict[
                            training_valid_or_test].add_summary(
                                sess.run(ops_per_epoch),
                                training_valid_test_epoch_idx + 1)

                        if MODEL_DICT[
                                'config'].train_or_inference.model_prefix is not None:
                            save_path = MODEL_DICT['config'].train_or_inference.model_prefix + \
                                        '_' + 'epoch_{}_of_{}'.format(training_valid_test_epoch_idx + 1,
                                                                      MODEL_DICT['config'].num_epochs)
                            save_path = os.path.join('saved_model', save_path)
                            save_path = MODEL_DICT['model_saver'].save(
                                sess=sess,
                                save_path=save_path,
                                global_step=None,
                                write_meta_graph=False)
                            logging.info('model saved to {}'.format(save_path))

                    else:
                        test_or_validate_fn(training_valid_or_test,
                                            training_valid_test_epoch_idx + 1)

        msg = 'training end time - {}'.format(datetime.datetime.now())
        logging.info(msg)
        summary_writer_dict[MODEL_DICT['config'].model_names[0]].add_summary(
            sess.run(aug_info_summary, feed_dict={aug_info_pl: msg}))

        for training_valid_or_test in MODEL_DICT['config'].model_names:
            summary_writer_dict[training_valid_or_test].close()
Esempio n. 25
0
    def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cost,
                                  value_cost, l2_cost, combined_cost, step,
                                  est_mode=tf.estimator.ModeKeys.TRAIN):
        policy_entropy = -tf.reduce_mean(tf.reduce_sum(
            policy_output * tf.log(policy_output), axis=1))
        # pi_tensor is one_hot when generated from sgfs (for supervised learning)
        # and soft-max when using self-play records. argmax normalizes the two.
        policy_target_top_1 = tf.argmax(pi_tensor, axis=1)

        policy_output_in_top1 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=1))
        policy_output_in_top3 = tf.to_float(
            tf.nn.in_top_k(policy_output, policy_target_top_1, k=3))

        policy_top_1_confidence = tf.reduce_max(policy_output, axis=1)
        policy_target_top_1_confidence = tf.boolean_mask(
            policy_output,
            tf.one_hot(policy_target_top_1, tf.shape(policy_output)[1]))

        value_cost_normalized = value_cost / params['value_cost_weight']

        with tf.variable_scope("metrics"):
            metric_ops = {
                'policy_cost': tf.metrics.mean(policy_cost),
                'value_cost': tf.metrics.mean(value_cost),
                'value_cost_normalized': tf.metrics.mean(value_cost_normalized),
                'l2_cost': tf.metrics.mean(l2_cost),
                'policy_entropy': tf.metrics.mean(policy_entropy),
                'combined_cost': tf.metrics.mean(combined_cost),

                'policy_accuracy_top_1': tf.metrics.mean(policy_output_in_top1),
                'policy_accuracy_top_3': tf.metrics.mean(policy_output_in_top3),
                'policy_top_1_confidence': tf.metrics.mean(policy_top_1_confidence),
                'policy_target_top_1_confidence': tf.metrics.mean(
                    policy_target_top_1_confidence),
                'value_confidence': tf.metrics.mean(tf.abs(value_output)),
            }

        if est_mode == tf.estimator.ModeKeys.EVAL:
            return metric_ops

        # NOTE: global_step is rounded to a multiple of FLAGS.summary_steps.
        eval_step = tf.reduce_min(step)

        # Create summary ops so that they show up in SUMMARIES collection
        # That way, they get logged automatically during training
        summary_writer = summary.create_file_writer(FLAGS.work_dir)
        with summary_writer.as_default(), \
                summary.record_summaries_every_n_global_steps(
                    params['summary_steps'], eval_step):
            for metric_name, metric_op in metric_ops.items():
                summary.scalar(metric_name, metric_op[1], step=eval_step)

        # Reset metrics occasionally so that they are mean of recent batches.
        reset_op = tf.variables_initializer(tf.local_variables("metrics"))
        cond_reset_op = tf.cond(
            tf.equal(eval_step % params['summary_steps'], tf.to_int64(1)),
            lambda: reset_op,
            lambda: tf.no_op())

        return summary.all_summary_ops() + [cond_reset_op]
Esempio n. 26
0
def main(device, input_path_test, downsampling_fact, downsampling_mode, channels, data_format, label_id, weights, image_dir, checkpoint_dir, output_graph_file, tst_sz, loss_type, model, decoder, fs_type, batch, batchnorm, dtype, scale_factor, have_imsave):
    #init horovod
    comm_rank = 0
    comm_local_rank = 0
    comm_size = 1
    comm_local_size = 1

    #downsampling? recompute image dimensions
    image_height =  image_height_orig // downsampling_fact
    image_width = image_width_orig // downsampling_fact

    #session config
    sess_config=tf.ConfigProto(inter_op_parallelism_threads=2, #1
                               intra_op_parallelism_threads=33, #6
                               log_device_placement=False,
                               allow_soft_placement=True)
    sess_config.gpu_options.visible_device_list = str(comm_local_rank)
    sess_config.gpu_options.force_gpu_compatible = True

    #get data
    test_graph = tf.Graph()
    if comm_rank == 0:
        print("Loading data...")
    tst_data = load_data(input_path_test, shuffle=False, max_files=tst_sz, use_horovod=False)
    if comm_rank == 0:
        print("Shape of tst_data is {}".format(tst_data.shape[0]))
        print("done.")

    #print some stats
    if comm_rank==0:
        print("Num workers: {}".format(comm_size))
        print("Local batch size: {}".format(batch))
        if dtype == tf.float32:
            print("Precision: {}".format("FP32"))
        else:
            print("Precision: {}".format("FP16"))
        print("Decoder: {}".format(decoder))
        print("Batch normalization: {}".format(batchnorm))
        print("Channels: {}".format(channels))
        print("Loss type: {}".format(loss_type))
        print("Loss weights: {}".format(weights))
        print("Loss scale factor: {}".format(scale_factor))
        print("Num test samples: {}".format(tst_data.shape[0]))

    #compute epochs and stuff:
    if fs_type == "local":
        num_samples = tst_data.shape[0] // comm_local_size
    else:
        num_samples = tst_data.shape[0] // comm_size

    with test_graph.as_default():
        #create readers
        tst_reader = h5_input_reader(input_path_test, channels, weights, dtype, normalization_file="stats.h5", update_on_read=False, data_format=data_format, label_id=label_id)
        #create datasets
        if fs_type == "local":
            tst_dataset = create_dataset(tst_reader, tst_data, batch, 1, comm_local_size, comm_local_rank, dtype, shuffle=False)
        else:
            tst_dataset = create_dataset(tst_reader, tst_data, batch, 1, comm_size, comm_rank, dtype, shuffle=False)

        #create iterators
        handle = tf.placeholder(tf.string, shape=[], name="iterator-placeholder")
        iterator = tf.data.Iterator.from_string_handle(handle, (dtype, tf.int32, dtype, tf.string),
                                                       ((batch, len(channels), image_height_orig, image_width_orig) if data_format=="channels_first" else (batch, image_height_orig, image_width_orig, len(channels)),
                                                        (batch, image_height_orig, image_width_orig),
                                                        (batch, image_height_orig, image_width_orig),
                                                        (batch))
                                                       )
        next_elem = iterator.get_next()

        #if downsampling, do some preprocessing
        if downsampling_fact != 1:
            if downsampling_mode == "scale":
                rand_select = tf.cast(tf.one_hot(tf.random_uniform((batch, image_height, image_width), minval=0, maxval=downsampling_fact*downsampling_fact, dtype=tf.int32), depth=downsampling_fact*downsampling_fact, axis=-1), dtype=tf.int32)
                next_elem = (tf.layers.average_pooling2d(next_elem[0], downsampling_fact, downsampling_fact, 'valid', data_format), \
                            tf.reduce_max(tf.multiply(tf.image.extract_image_patches(tf.expand_dims(next_elem[1], axis=-1), \
                                                                              [1, downsampling_fact, downsampling_fact, 1], \
                                                                              [1, downsampling_fact, downsampling_fact, 1], \
                                                                              [1,1,1,1], 'VALID'), rand_select), axis=-1), \
                            tf.squeeze(tf.layers.average_pooling2d(tf.expand_dims(next_elem[2], axis=-1), downsampling_fact, downsampling_fact, 'valid', "channels_last"), axis=-1), \
                            next_elem[3])
        
            elif downsampling_mode == "center-crop":
                #some parameters
                length = 1./float(downsampling_fact)
                offset = length/2.
                boxes = [[ offset, offset, offset+length, offset+length ]]*batch
                box_ind = list(range(0,batch))
                crop_size = [image_height, image_width]
                
                #be careful with data order
                if data_format=="channels_first":
                    next_elem[0] = tf.transpose(next_elem[0], perm=[0,2,3,1])
                    
                #crop
                next_elem = (tf.image.crop_and_resize(next_elem[0], boxes, box_ind, crop_size, method='bilinear', extrapolation_value=0, name="data_cropping"), \
                             ensure_type(tf.squeeze(tf.image.crop_and_resize(tf.expand_dims(next_elem[1],axis=-1), boxes, box_ind, crop_size, method='nearest', extrapolation_value=0, name="label_cropping"), axis=-1), tf.int32), \
                             tf.squeeze(tf.image.crop_and_resize(tf.expand_dims(next_elem[2],axis=-1), boxes, box_ind, crop_size, method='bilinear', extrapolation_value=0, name="weight_cropping"), axis=-1), \
                             next_elem[3])
                
                #be careful with data order
                if data_format=="channels_first":
                    next_elem[0] = tf.transpose(next_elem[0], perm=[0,3,1,2])
                    
            else:
                raise ValueError("Error, downsampling mode {} not supported. Supported are [center-crop, scale]".format(downsampling_mode))

        #create init handles
        #tst
        tst_iterator = tst_dataset.make_initializable_iterator()
        tst_handle_string = tst_iterator.string_handle()
        tst_init_op = iterator.make_initializer(tst_dataset)

        #compute the input filter number based on number of channels used
        num_channels = len(channels)
        #set up model
        model = deeplab_v3_plus_generator(num_classes=3, output_stride=8,
                                          base_architecture=model,
                                          decoder=decoder,
                                          batchnorm=batchnorm,
                                          pre_trained_model=None,
                                          batch_norm_decay=None,
                                          data_format=data_format)

        logit, prediction = model(next_elem[0], True, dtype)

        #set up loss
        loss = None

        #cast the logits to fp32
        logit = ensure_type(logit, tf.float32)
        if loss_type == "weighted":
            #cast weights to FP32
            w_cast = ensure_type(next_elem[2], tf.float32)
            loss = tf.losses.sparse_softmax_cross_entropy(labels=next_elem[1],
                                                          logits=logit,
                                                          weights=w_cast,
                                                          reduction=tf.losses.Reduction.SUM)
            if scale_factor != 1.0:
                loss *= scale_factor

        elif loss_type == "weighted_mean":
            #cast weights to FP32
            w_cast = ensure_type(next_elem[2], tf.float32)
            loss = tf.losses.sparse_softmax_cross_entropy(labels=next_elem[1],
                                                          logits=logit,
                                                          weights=w_cast,
                                                          reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
            if scale_factor != 1.0:
                loss *= scale_factor

        elif loss_type == "focal":
            #one-hot-encode
            labels_one_hot = tf.contrib.layers.one_hot_encoding(next_elem[1], 3)
            #cast to FP32
            labels_one_hot = ensure_type(labels_one_hot, tf.float32)
            loss = focal_loss(onehot_labels=labels_one_hot, logits=logit, alpha=1., gamma=2.)

        else:
            raise ValueError("Error, loss type {} not supported.",format(loss_type))

        #set up streaming metrics
        iou_op, iou_update_op = tf.metrics.mean_iou(labels=next_elem[1],
                                                    predictions=tf.argmax(prediction, axis=3),
                                                    num_classes=3,
                                                    weights=None,
                                                    metrics_collections=None,
                                                    updates_collections=None,
                                                    name="iou_score")
        iou_reset_op = tf.variables_initializer([ i for i in tf.local_variables() if i.name.startswith('iou_score/') ])

        #initializers:
        init_op =  tf.global_variables_initializer()
        init_local_op = tf.local_variables_initializer()

        #create image dir if not exists
        if not os.path.isdir(image_dir):
            os.makedirs(image_dir)


        #start session
        with tf.Session(config=sess_config) as sess:
            #initialize
            sess.run([init_op, init_local_op])
            #restore from checkpoint:
            load_model(sess, tf.train.Saver(), checkpoint_dir)
            #create iterator handles
            tst_handle = sess.run(tst_handle_string)
            #init iterators
            sess.run(tst_init_op, feed_dict={handle: tst_handle})

            #remove training nodes
            if output_graph_file:
                print("Storing inference graph to {}.".format(output_graph_file))
                inference_graph_def = tf.graph_util.remove_training_nodes(sess.graph_def, protected_nodes=None)
                #save the inference graph
                with open(output_graph_file, 'wb') as ogf:
                    ogf.write(inference_graph_def.SerializeToString())

            #start inference
            eval_loss = 0.
            eval_steps = 0
            print("Starting evaluation on test set")
            while True:
                try:
                    #construct feed dict
                    _, tmp_loss, tst_model_predictions, tst_model_labels, tst_model_filenames = sess.run([iou_update_op,
                                                                                                          loss,
                                                                                                          prediction,
                                                                                                          next_elem[1],
                                                                                                          next_elem[3]],
                                                                                                          feed_dict={handle: tst_handle})
                    #print some images
                    if have_imsave:
                        imsave(image_dir+'/test_pred_estep'
                               +str(eval_steps)+'_rank'+str(comm_rank)+'.png', np.argmax(tst_model_predictions[0,...],axis=-1)*100)
                        imsave(image_dir+'/test_label_estep'
                               +str(eval_steps)+'_rank'+str(comm_rank)+'.png', tst_model_labels[0,...]*100)
                        imsave(image_dir+'/test_combined_estep'
                               +str(eval_steps)+'_rank'+str(comm_rank)+'.png', plot_colormap[tst_model_labels[0,...],np.argmax(tst_model_predictions[0,...],axis=-1)])
                    else:
                        np.savez(image_dir+'/test_estep'
                                 +str(eval_steps)+'_rank'+str(comm_rank)+'.npz', prediction=np.argmax(tst_model_predictions[...],axis=-1)*100,
                                                                                                 label=tst_model_labels[...]*100, filename=tst_model_filenames)

                    #update loss
                    eval_loss += tmp_loss
                    eval_steps += 1

                except tf.errors.OutOfRangeError:
                    eval_steps = np.max([eval_steps,1])
                    eval_loss /= eval_steps
                    print("COMPLETED: evaluation loss is {}".format(eval_loss))
                    iou_score = sess.run(iou_op)
                    print("COMPLETED: evaluation IoU is {}".format(iou_score))
                    break
def main():
    # Configure
    config = tf.ConfigProto(log_device_placement=False)

    # Server Setup
    cluster_spec = {
        'ps': ['localhost:2222'],
        'worker': ['localhost:2223', 'localhost:2224']
    }
    n_pss = len(cluster_spec['ps'])  #the number of parameter servers
    n_workers = len(cluster_spec['worker'])  #the number of worker nodes
    cluster = tf.train.ClusterSpec(
        cluster_spec)  #allows this node know about all other nodes

    if FLAGS.job_name == 'ps':  #checks if parameter server
        server = tf.train.Server(cluster,
                                 job_name="ps",
                                 task_index=FLAGS.task_index,
                                 config=config)
        server.join()
    else:  #it must be a worker server
        is_chief = (FLAGS.task_index == 0)  #checks if this is the chief node
        server = tf.train.Server(cluster,
                                 job_name="worker",
                                 task_index=FLAGS.task_index,
                                 config=config)

        # Graph
        # Local operations
        with tf.device("/job:worker/replica:0/task:%d" % FLAGS.task_index):
            a = tf.Variable(tf.constant(0., shape=[2]),
                            dtype=tf.float32,
                            collections=[tf.GraphKeys.LOCAL_VARIABLES])
            b = tf.Variable(tf.constant(0., shape=[2]),
                            dtype=tf.float32,
                            collections=[tf.GraphKeys.LOCAL_VARIABLES])
            c = a + b

            local_step = tf.Variable(0,
                                     dtype=tf.int32,
                                     trainable=False,
                                     name='local_step',
                                     collections=['local_non_trainable'])
            lr = .0001
            loptimizer = tf.train.GradientDescentOptimizer(
                lr * FLAGS.task_index)  #local optimizer

            target = tf.constant(100., shape=[2], dtype=tf.float32)
            loss = tf.reduce_mean(tf.square(c - target))

            # DOWNPOUR
            update_window = 3  # T: communication window
            grad_list = [
            ]  # array to store the gradients through the communication window
            for t in range(update_window):
                if t != 0:
                    with tf.control_dependencies([
                            opt_local
                    ]):  #compute gradients only if the local opt was run
                        grads, varss = zip(*loptimizer.compute_gradients(
                            loss, var_list=tf.local_variables()))
                else:
                    grads, varss = zip(*loptimizer.compute_gradients(
                        loss, var_list=tf.local_variables()))
                grad_list.append(grads)  #add gradients to the list
                opt_local = loptimizer.apply_gradients(
                    zip(grads, varss),
                    global_step=local_step)  #update local parameters

            grads = tf.reduce_sum(
                grad_list, axis=0)  #sum updates before applying globally
            grads = tuple([grads[i] for i in range(len(varss))])

        with tf.device(
                tf.train.replica_device_setter(
                    ps_tasks=n_pss,
                    worker_device="/job:%s/task:%d" %
                    (FLAGS.job_name, FLAGS.task_index))):

            global_step = tf.Variable(0,
                                      dtype=tf.int32,
                                      trainable=False,
                                      name='global_step')

            # all workers use the same learning rate and it is decided on by the task 0
            # or maybe the from the graph of the chief worker
            optimizer = tf.train.AdagradOptimizer(lr)  #global optimizer

            # create global variables and/or references
            local_to_global, global_to_local = create_global_variables()
            opt = optimizer.apply_gradients(
                zip(grads, [local_to_global[v] for v in varss]),
                global_step=global_step
            )  #apply the gradients to variables on ps

            # Pull params from global server
            with tf.control_dependencies([opt]):
                assign_locals = assign_global_to_local(global_to_local)

            # Grab global state before training so all workers have same initialization
            grab_global_init = assign_global_to_local(global_to_local)

            # Assigns local values to global ones for chief to execute
            assign_global = assign_local_to_global(local_to_global)

            # Init ops
            init = tf.global_variables_initializer()  # for global variables
            init_local = tf.variables_initializer(tf.local_variables() \
               +tf.get_collection('local_non_trainable'))#for local variables

        # Session
        stop_hook = tf.train.StopAtStepHook(last_step=60)
        hooks = [stop_hook]
        scaff = tf.train.Scaffold(init_op=init, local_init_op=[init_local])

        # Monitored Training Session
        sess = tf.train.MonitoredTrainingSession(master=server.target,
                                                 is_chief=is_chief,
                                                 config=config,
                                                 scaffold=scaff,
                                                 hooks=hooks,
                                                 save_checkpoint_secs=1,
                                                 checkpoint_dir='logdir')

        if is_chief:
            sess.run(assign_global)  #Assigns chief's initial values to ps
            time.sleep(
                10
            )  #grace period to wait on other workers before starting training

        # Train until hook stops session
        print('Starting training on worker %d' % FLAGS.task_index)
        sess.run(grab_global_init)
        while not sess.should_stop():
            _, _, r, gs, ls = sess.run(
                [opt, assign_locals, c, global_step, local_step])

            print(r, "global step: " + str(gs),
                  "worker: " + str(FLAGS.task_index), "local step: " + str(ls))

            time.sleep(1)  # so we can observe training
        print('Done', FLAGS.task_index)

        time.sleep(10)  #grace period to wait before closing session
        sess.close()
        print('Session from worker %d closed cleanly' % FLAGS.task_index)
    def __init__(self, sess, args):
        ''' the main neural network model class '''
        #self.config = vars(args)
        self.x = tf.placeholder(tf.float32, [None, feature_dim], name="input")
        self.y_ = tf.placeholder(tf.float32, [None, output_dim], name="output")
        self.is_training = tf.placeholder(tf.bool)
        ## for the augmented data
        self.x1 = tf.placeholder(tf.float32, [None, feature_dim], name="input")
        self.class_label = tf.placeholder(tf.float32, [None],
                                          name="condition_checking")
        self.layer_sizes = [args.net1_h1, args.net1_h2]
        ## build the model
        self.y = am_util.build_model(
            self.x, self.layer_sizes, self.is_training, output_dim,
            None)  # reuse none so that the variables are created
        self.prob = tf.nn.softmax(self.y, name='prob')
        self.pred = tf.arg_max(self.prob, 1, name='pred')

        ##accuarcy
        self.correct_predictions = tf.equal(self.pred, tf.argmax(self.y_, 1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_predictions,
                                               "float"),
                                       name="accuracy")
        #loss and optimizer
        self.loss_f = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=self.y,
                                                    labels=self.y_))
        self.optimizer_f = tf.train.AdamOptimizer(
            args.lr, name="opt1").minimize(self.loss_f)

        ## build all the summaries and writers
        self.summaries = tf.summary.merge(self.get_summaries())
        self.train_summary_writer = tf.summary.FileWriter("%s/logs/train" %
                                                          output_dir,
                                                          sess.graph,
                                                          flush_secs=60)
        self.val_summary_writer = tf.summary.FileWriter("%s/logs/val" %
                                                        output_dir,
                                                        sess.graph,
                                                        flush_secs=60)
        self.test_summary_writer = tf.summary.FileWriter("%s/logs/test" %
                                                         output_dir,
                                                         sess.graph,
                                                         flush_secs=60)
        # build a saver object
        self.saver = tf.train.Saver(tf.global_variables() +
                                    tf.local_variables())

        # for augmented data
        self.y1 = am_util.build_model(
            self.x1, self.layer_sizes, self.is_training, output_dim, True
        )  # reuse true so that the variables are shared from previosly builded network
        self.cond = tf.reduce_sum(tf.squared_difference(self.y1, self.y), 1)
        self.row_index = tf.where(self.class_label > 0)
        self.y1_filterred = tf.squeeze(tf.gather(self.y1, self.row_index))
        self.y_filtered = tf.squeeze(tf.gather(self.y, self.row_index))
        self.is_empty = tf.equal(tf.size(self.row_index), 0)
        self.loss_y_y1 = output_dim * tf.reduce_mean(tf.squared_difference(
            self.y_filtered, self.y1_filterred),
                                                     name="loss_f_G")
        self.loss_y_y1_filtered = tf.cond(
            tf.cast(self.is_empty,
                    tf.bool), lambda: tf.constant(0, tf.float32), lambda: self.
            loss_y_y1)  #then corresponding loss is zero, in this way avoid nan
        self.y__filtered = tf.squeeze(tf.gather(self.y_, self.row_index))
        self.loss_fx_y = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=self.y1_filterred,
                                                    labels=self.y__filtered),
            name="filtered_reg")
        self.loss_fx_y_filtered = tf.cond(
            tf.cast(self.is_empty,
                    tf.bool), lambda: tf.constant(0, tf.float32), lambda: self.
            loss_fx_y)  #then corresponding loss is zero, in this way avoid nan

        #final loss
        self.final_reg = tf.add(args.reg_param1 * self.loss_y_y1_filtered,
                                args.reg_param2 * self.loss_fx_y_filtered,
                                name="loss_final")
        self.loss_final = tf.add(self.final_reg,
                                 self.loss_f,
                                 name="loss_final")
        self.optimizer_final = tf.train.AdamOptimizer(
            args.lr, name="opt2").minimize(self.loss_final)
Esempio n. 29
0
 def testNotInLocalVariables(self):
   with self.test_session():
     with tf.variable_scope('A'):
       a = tf.contrib.framework.model_variable('a', [5])
       self.assertTrue(a in tf.all_variables())
       self.assertFalse(a in tf.local_variables())
def train(model_name, num_classes, input_shape, train_data, train_label,
          test_data, test_label):
    placeholders, ops = model(num_classes, input_shape)
    saver = tf.train.Saver()

    val_accs, tst_accs = [], []
    val_losses, tst_losses = [], []

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())

        num_iter = train_label.shape[0] * args.epochs

        for i in range(num_iter):

            # train op
            idx = np.random.randint(0, len(train_label), size=args.batch_size)
            sess.run(ops["train"],
                     feed_dict={
                         placeholders["input"]: train_data[idx],
                         placeholders["output"]: train_label[idx]
                     })

            if (i % 100 == 0):
                print("Iteration: {:6d}/{:6d}".format(i, num_iter))
                # reset the metrics
                stream_vars_valid = [
                    v for v in tf.local_variables() if 'metrics/' in v.name
                ]
                sess.run(tf.variables_initializer(stream_vars_valid))

                #validation accuracy
                val_acc, val_loss = sess.run(
                    [ops["cumulative_accuracy"], ops["loss"]],
                    feed_dict={
                        placeholders["input"]: train_data[idx],
                        placeholders["output"]: train_label[idx]
                    })
                print("Validation - "),
                print("accuracy: {:.6f}".format(val_acc)),
                print(", loss: {0}".format(val_loss))

                val_accs.append(val_acc)
                tst_accs.append(val_loss)

                #test accuracy
                idx = np.random.randint(0,
                                        len(test_label),
                                        size=args.batch_size)
                tst_acc, tst_loss = sess.run(
                    [ops["cumulative_accuracy"], ops["loss"]],
                    feed_dict={
                        placeholders["input"]: test_data[idx],
                        placeholders["output"]: test_label[idx]
                    })
                print("Test - "),
                print("accuracy: {:.6f}".format(tst_acc)),
                print(", loss: {0}".format(tst_loss))

                val_losses.append(tst_acc)
                tst_losses.append(tst_loss)

        # plot learning
        plt.subplot(211)
        plt.plot(np.array(val_accs))
        plt.plot(np.array(tst_accs))

        plt.suptitle('Accuracies')
        plt.subplot(212)
        plt.plot(np.array(val_losses))
        plt.plot(np.array(tst_losses))

        plt.suptitle('Losses')

        plt.legend()
        #plt.show()
        plt.savefig("accs_losses.png")

        # save model
        print("Finished training")
        saver.save(sess, model_name + "/model")
        print("Model saved")
    def build_graph(self, inputs):
        BaseVideoPredictionModel.build_graph(self, inputs)

        global_step = tf.train.get_or_create_global_step()
        # Capture the variables created from here until the train_op for the
        # saveable_variables. Note that if variables are being reused (e.g.
        # they were created by a previously built model), those variables won't
        # be captured here.
        original_global_variables = tf.global_variables()

        if self.num_gpus <= 1:  # cpu or 1 gpu
            outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs)
            self.outputs, self.eval_outputs = outputs_tuple
            self.d_losses, self.g_losses, g_losses_post = losses_tuple
            self.d_loss, self.g_loss, g_loss_post = loss_tuple
            self.metrics, self.eval_metrics = metrics_tuple

            self.d_vars = tf.trainable_variables(self.discriminator_scope)
            self.g_vars = tf.trainable_variables(self.generator_scope)
            g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
            d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)

            if self.mode == 'train' and (self.d_losses or self.g_losses):
                with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                    if self.d_losses:
                        with tf.name_scope('d_compute_gradients'):
                            d_gradvars = d_optimizer.compute_gradients(self.d_loss, var_list=self.d_vars)
                        with tf.name_scope('d_apply_gradients'):
                            d_train_op = d_optimizer.apply_gradients(d_gradvars)
                    else:
                        d_train_op = tf.no_op()
                with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []):
                    if g_losses_post:
                        if not self.hparams.joint_gan_optimization:
                            replace_read_ops(g_loss_post, self.d_vars)
                        with tf.name_scope('g_compute_gradients'):
                            g_gradvars = g_optimizer.compute_gradients(g_loss_post, var_list=self.g_vars)
                        with tf.name_scope('g_apply_gradients'):
                            g_train_op = g_optimizer.apply_gradients(g_gradvars)
                    else:
                        g_train_op = tf.no_op()
                with tf.control_dependencies([g_train_op]):
                    train_op = tf.assign_add(global_step, 1)
                self.train_op = train_op
            else:
                self.train_op = None

            global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
            self.saveable_variables = [global_step] + global_variables
            self.post_init_ops = []
        else:
            if tf.get_variable_scope().name:
                # This is because how variable scope works with empty strings when it's not the root scope, causing
                # repeated forward slashes.
                raise NotImplementedError('Unable to handle multi-gpu model created within a non-root variable scope.')

            tower_inputs = [OrderedDict() for _ in range(self.num_gpus)]
            for name, input in self.inputs.items():
                input_splits = tf.split(input, self.num_gpus)  # assumes batch_size is divisible by num_gpus
                for i in range(self.num_gpus):
                    tower_inputs[i][name] = input_splits[i]

            tower_outputs_tuple = []
            tower_d_losses = []
            tower_g_losses = []
            tower_g_losses_post = []
            tower_d_loss = []
            tower_g_loss = []
            tower_g_loss_post = []
            tower_metrics_tuple = []
            for i in range(self.num_gpus):
                worker_device = '/gpu:%d' % i
                if self.aggregate_nccl:
                    scope_name = '' if i == 0 else 'v%d' % i
                    scope_reuse = False
                    device_setter = worker_device
                else:
                    scope_name = ''
                    scope_reuse = i > 0
                    device_setter = local_device_setter(worker_device=worker_device)
                with tf.variable_scope(scope_name, reuse=scope_reuse):
                    with tf.device(device_setter):
                        outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(tower_inputs[i])
                        tower_outputs_tuple.append(outputs_tuple)
                        d_losses, g_losses, g_losses_post = losses_tuple
                        tower_d_losses.append(d_losses)
                        tower_g_losses.append(g_losses)
                        tower_g_losses_post.append(g_losses_post)
                        d_loss, g_loss, g_loss_post = loss_tuple
                        tower_d_loss.append(d_loss)
                        tower_g_loss.append(g_loss)
                        tower_g_loss_post.append(g_loss_post)
                        tower_metrics_tuple.append(metrics_tuple)
            self.d_vars = tf.trainable_variables(self.discriminator_scope)
            self.g_vars = tf.trainable_variables(self.generator_scope)

            if self.aggregate_nccl:
                scope_replica = lambda scope, i: ('' if i == 0 else 'v%d/' % i) + scope
                tower_d_vars = [tf.trainable_variables(
                    scope_replica(self.discriminator_scope, i)) for i in range(self.num_gpus)]
                tower_g_vars = [tf.trainable_variables(
                    scope_replica(self.generator_scope, i)) for i in range(self.num_gpus)]
                assert self.d_vars == tower_d_vars[0]
                assert self.g_vars == tower_g_vars[0]
                tower_d_optimizer = [tf.train.AdamOptimizer(
                    self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)]
                tower_g_optimizer = [tf.train.AdamOptimizer(
                    self.learning_rate, self.hparams.beta1, self.hparams.beta2) for _ in range(self.num_gpus)]

                if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)):
                    tower_d_gradvars = []
                    tower_g_gradvars = []
                    tower_d_train_op = []
                    tower_g_train_op = []
                    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                        if any(tower_d_losses):
                            for i in range(self.num_gpus):
                                with tf.device('/gpu:%d' % i):
                                    with tf.name_scope(scope_replica('d_compute_gradients', i)):
                                        d_gradvars = tower_d_optimizer[i].compute_gradients(
                                            tower_d_loss[i], var_list=tower_d_vars[i])
                                        tower_d_gradvars.append(d_gradvars)

                            all_d_grads, all_d_vars = tf_utils.split_grad_list(tower_d_gradvars)
                            all_d_grads = tf_utils.allreduce_grads(all_d_grads, average=True)
                            tower_d_gradvars = tf_utils.merge_grad_list(all_d_grads, all_d_vars)

                            for i in range(self.num_gpus):
                                with tf.device('/gpu:%d' % i):
                                    with tf.name_scope(scope_replica('d_apply_gradients', i)):
                                        d_train_op = tower_d_optimizer[i].apply_gradients(tower_d_gradvars[i])
                                        tower_d_train_op.append(d_train_op)
                            d_train_op = tf.group(*tower_d_train_op)
                        else:
                            d_train_op = tf.no_op()
                    with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []):
                        if any(tower_g_losses_post):
                            for i in range(self.num_gpus):
                                with tf.device('/gpu:%d' % i):
                                    if not self.hparams.joint_gan_optimization:
                                        replace_read_ops(tower_g_loss_post[i], tower_d_vars[i])

                                    with tf.name_scope(scope_replica('g_compute_gradients', i)):
                                        g_gradvars = tower_g_optimizer[i].compute_gradients(
                                            tower_g_loss_post[i], var_list=tower_g_vars[i])
                                        tower_g_gradvars.append(g_gradvars)

                            all_g_grads, all_g_vars = tf_utils.split_grad_list(tower_g_gradvars)
                            all_g_grads = tf_utils.allreduce_grads(all_g_grads, average=True)
                            tower_g_gradvars = tf_utils.merge_grad_list(all_g_grads, all_g_vars)

                            for i, g_gradvars in enumerate(tower_g_gradvars):
                                with tf.device('/gpu:%d' % i):
                                    with tf.name_scope(scope_replica('g_apply_gradients', i)):
                                        g_train_op = tower_g_optimizer[i].apply_gradients(g_gradvars)
                                        tower_g_train_op.append(g_train_op)
                            g_train_op = tf.group(*tower_g_train_op)
                        else:
                            g_train_op = tf.no_op()
                    with tf.control_dependencies([g_train_op]):
                        train_op = tf.assign_add(global_step, 1)
                    self.train_op = train_op
                else:
                    self.train_op = None

                global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
                tower_saveable_vars = [[] for _ in range(self.num_gpus)]
                for var in global_variables:
                    m = re.match('v(\d+)/.*', var.name)
                    i = int(m.group(1)) if m else 0
                    tower_saveable_vars[i].append(var)
                self.saveable_variables = [global_step] + tower_saveable_vars[0]

                post_init_ops = []
                for i, saveable_vars in enumerate(tower_saveable_vars[1:], 1):
                    assert len(saveable_vars) == len(tower_saveable_vars[0])
                    for var, var0 in zip(saveable_vars, tower_saveable_vars[0]):
                        assert var.name == 'v%d/%s' % (i, var0.name)
                        post_init_ops.append(var.assign(var0.read_value()))
                self.post_init_ops = post_init_ops
            else:  # not self.aggregate_nccl (i.e. aggregation in cpu)
                g_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)
                d_optimizer = tf.train.AdamOptimizer(self.learning_rate, self.hparams.beta1, self.hparams.beta2)

                if self.mode == 'train' and (any(tower_d_losses) or any(tower_g_losses)):
                    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                        if any(tower_d_losses):
                            with tf.name_scope('d_compute_gradients'):
                                d_gradvars = compute_averaged_gradients(
                                    d_optimizer, tower_d_loss, var_list=self.d_vars)
                            with tf.name_scope('d_apply_gradients'):
                                d_train_op = d_optimizer.apply_gradients(d_gradvars)
                        else:
                            d_train_op = tf.no_op()
                    with tf.control_dependencies([d_train_op] if not self.hparams.joint_gan_optimization else []):
                        if any(tower_g_losses_post):
                            for g_loss_post in tower_g_loss_post:
                                if not self.hparams.joint_gan_optimization:
                                    replace_read_ops(g_loss_post, self.d_vars)
                            with tf.name_scope('g_compute_gradients'):
                                g_gradvars = compute_averaged_gradients(
                                    g_optimizer, tower_g_loss_post, var_list=self.g_vars)
                            with tf.name_scope('g_apply_gradients'):
                                g_train_op = g_optimizer.apply_gradients(g_gradvars)
                        else:
                            g_train_op = tf.no_op()
                    with tf.control_dependencies([g_train_op]):
                        train_op = tf.assign_add(global_step, 1)
                    self.train_op = train_op
                else:
                    self.train_op = None

                global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
                self.saveable_variables = [global_step] + global_variables
                self.post_init_ops = []

            # Device that runs the ops to apply global gradient updates.
            consolidation_device = '/cpu:0'
            with tf.device(consolidation_device):
                with tf.name_scope('consolidation'):
                    self.outputs, self.eval_outputs = reduce_tensors(tower_outputs_tuple)
                    self.d_losses = reduce_tensors(tower_d_losses, shallow=True)
                    self.g_losses = reduce_tensors(tower_g_losses, shallow=True)
                    self.metrics, self.eval_metrics = reduce_tensors(tower_metrics_tuple)
                    self.d_loss = reduce_tensors(tower_d_loss)
                    self.g_loss = reduce_tensors(tower_g_loss)

        original_local_variables = set(tf.local_variables())
        self.accum_eval_metrics = OrderedDict()
        for name, eval_metric in self.eval_metrics.items():
            _, self.accum_eval_metrics['accum_' + name] = tf.metrics.mean_tensor(eval_metric)
        local_variables = set(tf.local_variables()) - original_local_variables
        self.accum_eval_metrics_reset_op = tf.group([tf.assign(v, tf.zeros_like(v)) for v in local_variables])

        original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        add_summaries(self.inputs)
        add_summaries(self.outputs)
        add_scalar_summaries(self.d_losses)
        add_scalar_summaries(self.g_losses)
        add_scalar_summaries(self.metrics)
        if self.d_losses:
            add_scalar_summaries({'d_loss': self.d_loss})
        if self.g_losses:
            add_scalar_summaries({'g_loss': self.g_loss})
        if self.d_losses and self.g_losses:
            add_scalar_summaries({'loss': self.d_loss + self.g_loss})
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries
        # split summaries into non-image summaries and image summaries
        self.summary_op = tf.summary.merge(list(summaries - set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))))
        self.image_summary_op = tf.summary.merge(list(summaries & set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))))

        original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        add_gif_summaries(self.eval_outputs)
        add_plot_and_scalar_summaries(
            {name: tf.reduce_mean(metric, axis=0) for name, metric in self.eval_metrics.items()},
            x_offset=self.hparams.context_frames + 1)
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries
        self.eval_summary_op = tf.summary.merge(list(summaries))

        original_summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        add_plot_and_scalar_summaries(
            {name: tf.reduce_mean(metric, axis=0) for name, metric in self.accum_eval_metrics.items()},
            x_offset=self.hparams.context_frames + 1)
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) - original_summaries
        self.accum_eval_summary_op = tf.summary.merge(list(summaries))
Esempio n. 32
0
 def testLocalVariableNotInAllVariables(self):
   with self.test_session():
     with tf.variable_scope('A'):
       a = tf.contrib.framework.local_variable(0)
       self.assertFalse(a in tf.global_variables())
       self.assertTrue(a in tf.local_variables())
Esempio n. 33
0
    def run(self):
        """Main method for training PulseNet model."""

        # Assign config variable to local variable
        config = self.config

        preprocess_train = PreProcess(index=config.INDEX.TRAIN,
                                      **config.PIPELINE)
        preprocess_val = PreProcess(index=config.INDEX.VAL, **config.PIPELINE)

        # Prep train and eval data
        train_dataset = preprocess_train.prep(config.DATA.TRAIN)
        val_dataset = preprocess_val.prep(config.DATA.VAL)

        # Dataset iterator
        iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
                                                   train_dataset.output_shapes)
        train_data_X, train_data_y = train_dataset.make_one_shot_iterator(
        ).get_next()
        val_data_X, val_data_y = val_dataset.make_one_shot_iterator().get_next(
        )
        data_X, data_y = iterator.get_next()

        # Initialize with required Datasets
        train_iterator = iterator.make_initializer(train_dataset)
        val_iterator = iterator.make_initializer(val_dataset)

        # Length of wavelength
        X_length = preprocess_train.stop_index - preprocess_train.start_index

        placeholder_X = tf.placeholder(tf.float32, [None, X_length, 1])
        placeholder_y = tf.placeholder(tf.int64, [None, 2])

        # Custom summaries
        train_epoch_acc_summary = tf.Summary()
        train_epoch_auc_summary = tf.Summary()
        val_epoch_acc_summary = tf.Summary()
        val_epoch_auc_summary = tf.Summary()

        loss_arr = []
        acc_arr = []
        epoch_arr = []

        time_string = datetime.datetime.now().isoformat()
        config.NAME += config.NAME + '_{}'.format(time_string)

        # Instantiate model
        model = PulseNet(data_X,
                         data_y,
                         hparams=config.HPARAMS,
                         run_dir=config.LOGS,
                         learning_rate=config.HPARAMS.learning_rate,
                         experiment_name=config.NAME,
                         causal=True)

        # Prints info before training starts
        self.print_info()

        # Store loss
        hist_loss = []

        # Run session
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            # Write logs to Tensorboard
            train_writer = tf.summary.FileWriter(
                config.LOGS + 'train/' + config.NAME, sess.graph)
            val_writer = tf.summary.FileWriter(
                config.LOGS + 'val/' + config.NAME, sess.graph)

            n_train_samples = 0
            n_val_samples = 0

            for epoch_no in range(config.HPARAMS.epochs):

                train_loss, train_accuracy = 0, 0
                val_loss, val_accuracy = 0, 0

                X_train, y_train = sess.run((train_data_X, train_data_y))
                X_val, y_val = sess.run((val_data_X, val_data_y))

                # Initialize iterator with training data
                sess.run(train_iterator,
                         feed_dict={
                             placeholder_X: X_train,
                             placeholder_y: y_train
                         })

                # Set model to training mode
                model.is_training = True

                i_batch = 0
                try:
                    with tqdm(total=len(y_train)) as pbar:
                        while i_batch <= config.HPARAMS.n_train_batches:
                            _, loss, acc_update_op, summary = sess.run([
                                model.optimizer, model.loss,
                                model.accuracy_update_op, model.summaries
                            ])
                            train_loss += loss
                            n_train_samples += config.PIPELINE.batch
                            pbar.update(config.PIPELINE.batch)

                            if i_batch % config.PIPELINE.train_sample_rate == 0:
                                print(
                                    '\nEpoch {}: batch = {}, loss = {:.4f}, accuracy = {:.4f}'
                                    .format(epoch_no, i_batch + 1,
                                            train_loss / (i_batch + 1),
                                            acc_update_op))
                                # Write logs at every iteration
                                train_writer.add_summary(
                                    summary, n_train_samples)

                            i_batch += 1
                except tf.errors.OutOfRangeError:
                    print('End of dataset')

                # After every epoch, calculate the accuracy of the last seen training batch
                acc_train = sess.run(model.accuracy)

                # Add logs at end of train epoch
                train_epoch_acc_summary.value.add(tag="epoch_accuracy",
                                                  simple_value=acc_train)
                train_writer.add_summary(train_epoch_acc_summary, epoch_no)

                print(
                    "Epoch {}: training loss = {:.3f}, training accuracy: {:.2f}%"
                    .format(epoch_no, train_loss / (i_batch), acc_train * 100))

                # Early stopping
                hist_loss.append(train_loss / (i_batch))
                patience_cnt = 0
                if epoch_no > 0:
                    if hist_loss[epoch_no - 1] - hist_loss[
                            epoch_no] > config.HPARAMS.min_delta:
                        patience_cnt = 0
                    else:
                        patience_cnt += 1

                if patience_cnt > config.HPARAMS.patience:
                    print("\nEarly stopping...")
                    # Save model
                    print('\nSaving model...')
                    model.saver.save(
                        sess, 'logs/checkpoints/' + config.NAME + '/model')
                    print('ok\n')
                    break

                # Save model
                print('\nSaving model...')
                model.saver.save(sess,
                                 'logs/checkpoints/' + config.NAME + '/model')
                print('ok\n')

                # Initialize iterator with validation data
                sess.run(val_iterator,
                         feed_dict={
                             placeholder_X: X_val,
                             placeholder_y: y_val
                         })

                # Set model to validation mode
                model.is_training = False

                i_batch = 1
                try:
                    with tqdm(total=len(y_val)) as pbar:
                        while i_batch <= config.HPARAMS.n_val_batches:
                            loss, val_acc_update_op, summary, auc, auc_update_op = sess.run(
                                [
                                    model.loss, model.accuracy_update_op,
                                    model.summaries, model.auc,
                                    model.auc_update_op
                                ])
                            val_loss += loss
                            n_val_samples += config.PIPELINE.batch
                            pbar.update(config.PIPELINE.batch)

                            # Write logs at every iteration
                            val_writer.add_summary(summary, n_val_samples)

                            i_batch += 1
                except tf.errors.OutOfRangeError:
                    print('End of dataset')

                # After each epoch, calculate the accuracy of the test data
                acc_val, auc_val = sess.run([model.accuracy, model.auc])
                print('AUC =', auc_val)

                auc_local_variables = [
                    str(i.name) for i in tf.local_variables()
                    if str(i.name).startswith('auc')
                ]
                roc_dict = {
                    vn.split('/')[-1].split(':')[0]:
                    sess.run(tf.get_default_graph().get_tensor_by_name(vn))
                    for vn in auc_local_variables
                }

                # Add logs at end of train epoch
                val_epoch_acc_summary.value.add(tag="epoch_accuracy",
                                                simple_value=acc_val)
                val_epoch_auc_summary.value.add(tag="epoch_AUC",
                                                simple_value=auc_val)
                val_writer.add_summary(val_epoch_acc_summary, epoch_no)
                val_writer.add_summary(val_epoch_auc_summary, epoch_no)

                print(
                    "Average validation set accuracy and loss over {} batch iterations are {:.2f}% and {:.2f}"
                    .format(i_batch, acc_val * 100, val_loss / i_batch))

            # Calculate and save True Positive Rate (TPR) and False Positive Rate (FPR)
            tpr = roc_dict['true_positives'] / (roc_dict['true_positives'] +
                                                roc_dict['false_negatives'])
            tnr = roc_dict['true_negatives'] / (roc_dict['true_negatives'] +
                                                roc_dict['false_positives'])
            fpr = 1 - tnr

            roc_dict['auc'] = auc_val
            roc_dict['tpr'] = tpr
            roc_dict['tnr'] = fpr
            roc_dict['fpr'] = fpr

            log_data_path = config.LOGS + 'checkpoints' + '/' + config.NAME + '/'

            # Dump AUC ROC data to json
            with open(log_data_path + 'roc_auc.json', 'w') as fp:
                roc_dict_as_list = {
                    key: val.tolist()
                    for key, val in roc_dict.items()
                }
                json.dump(roc_dict_as_list, fp)

            # Dump configuration to json
            with open(log_data_path + 'config.json', 'w') as fp:
                json.dump(
                    {
                        'hparams': config.HPARAMS,
                        'pipeline_settings': config.PIPELINE
                    }, fp)

            train_writer.close()
            val_writer.close()

            print('\nTraining finished.')
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack):
    """Serializes the 'target' as a TF computation with a given parameter type.

  Args:
    target: The entity to convert into and serialize as a TF computation. This
      can currently only be a Python function. In the future, we will add here
      support for serializing the various kinds of non-eager and eager defuns,
      and eventually aim at full support for and compliance with TF 2.0. This
      function is currently required to declare either zero parameters if
      `parameter_type` is `None`, or exactly one parameter if it's not `None`.
      The nested structure of this parameter must correspond to the structure of
      the 'parameter_type'. In the future, we may support targets with multiple
      args/keyword args (to be documented in the API and referenced from here).
    parameter_type: The parameter type specification if the target accepts a
      parameter, or `None` if the target doesn't declare any parameters. Either
      an instance of `types.Type`, or something that's convertible to it by
      `types.to_type()`.
    context_stack: The context stack to use.

  Returns:
    The constructed `pb.Computation` instance with the `pb.TensorFlow` variant
      set.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the signature of the target is not compatible with the given
      parameter type.
  """
    # TODO(b/113112108): Support a greater variety of target type signatures,
    # with keyword args or multiple args corresponding to elements of a tuple.
    # Document all accepted forms with examples in the API, and point to there
    # from here.

    py_typecheck.check_type(target, types.FunctionType)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    parameter_type = computation_types.to_type(parameter_type)
    argspec = inspect.getargspec(target)

    with tf.Graph().as_default() as graph:
        args = []
        if parameter_type:
            if len(argspec.args) != 1:
                raise ValueError(
                    'Expected the target to declare exactly one parameter, '
                    'found {}.'.format(repr(argspec.args)))
            parameter_name = argspec.args[0]
            parameter_value, parameter_binding = graph_utils.stamp_parameter_in_graph(
                parameter_name, parameter_type, graph)
            args.append(parameter_value)
        else:
            if argspec.args:
                raise ValueError(
                    'Expected the target to declare no parameters, found {}.'.
                    format(repr(argspec.args)))
            parameter_binding = None
        context = tf_computation_context.TensorFlowComputationContext(graph)
        with context_stack.install(context):
            result = target(*args)

            # TODO(b/122081673): This needs to change for TF 2.0. We may also
            # want to allow the person creating a tff.tf_computation to specify
            # a different initializer; e.g., if it is known that certain
            # variables will be assigned immediately to arguments of the function,
            # then it is wasteful to initialize them before this.
            #
            # The following is a bit of a work around: the collections below may
            # contain variables more than once, hence we throw into a set. TFF needs
            # to ensure all variables are initialized, but not all variables are
            # always in the collections we expect. tff.learning._KerasModel tries to
            # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into
            # TFF_MODEL_VARIABLES for now.
            all_variables = set(tf.global_variables() + tf.local_variables() +
                                tf.get_collection(graph_keys.GraphKeys.
                                                  VARS_FOR_TFF_TO_INITIALIZE))
            if all_variables:
                # Use a readable but not-too-long name for the init_op.
                name = 'init_op_for_' + '_'.join(
                    [v.name.replace(':0', '') for v in all_variables])
                if len(name) > 50:
                    name = 'init_op_for_{}_variables'.format(
                        len(all_variables))
                with tf.control_dependencies(context.init_ops):
                    # Before running the main new init op, run any initializers for sub-
                    # computations from context.init_ops. Variables from import_graph_def
                    # will not make it into the global collections, and so will not be
                    # initialized without this code path.
                    init_op_name = tf.initializers.variables(all_variables,
                                                             name=name).name
            elif context.init_ops:
                init_op_name = tf.group(*context.init_ops,
                                        name='subcomputation_init_ops').name
            else:
                init_op_name = None

        result_type, result_binding = graph_utils.capture_result_from_graph(
            result, graph)

    return pb.Computation(type=pb.Type(function=pb.FunctionType(
        parameter=type_serialization.serialize_type(parameter_type),
        result=type_serialization.serialize_type(result_type))),
                          tensorflow=pb.TensorFlow(
                              graph_def=serialization_utils.pack_graph_def(
                                  graph.as_graph_def()),
                              parameter=parameter_binding,
                              result=result_binding,
                              initialize_op=init_op_name))
Esempio n. 35
0
def create_train_op(loss, params):
    tf.logging.info("create_train_op(loss=%s, params=%s)", loss, pps(params))
    lr = params["lr"]
    global_step = tf.train.get_global_step()
    assert global_step is not None
    if "warmup_steps" in params.keys():
        tf.logging.info(
            'create_train_op: lr = cosine_decay_with_warmup(%s, %s, %s, warmup_steps=%s)',
            global_step, lr, params["max_steps"], params["warmup_steps"])
        lr = cosine_decay_with_warmup(global_step,
                                      lr,
                                      params["max_steps"],
                                      warmup_steps=params["warmup_steps"])

    if params["opt_name"] == "adam":
        if not "weight_decay" in params.keys():
            optimizer = tf.train.AdamOptimizer(learning_rate=lr,
                                               beta1=params["beta1"],
                                               beta2=params["beta2"],
                                               epsilon=params["epsilon"])
            tf.logging.info(
                'create_train_op: optimizer = tf.train.AdamOptimizer(learning_rate=%s, beta1=%s, beta2=%s, epsilon=%s)',
                lr, params["beta1"], params["beta2"], params["epsilon"])

        else:
            optimizer = tf.contrib.opt.AdamWOptimizer(
                learning_rate=lr,
                weight_decay=lr * params["weight_decay"],
                beta1=params["beta1"],
                beta2=params["beta2"],
                epsilon=params["epsilon"])
            tf.logging.info(
                'create_train_op: optimizer = tf.train.AdamWOptimizer(learning_rate=%s, weight_decay=lr*%s, beta1=%s, beta2=%s, epsilon=%s)',
                lr, params["weight_decay"], params["beta1"], params["beta2"],
                params["epsilon"])

    elif params["opt_name"] == "adafactor":
        if params["decay_type"] == "adam":
            decay_rate = adafactor_decay_rate_adam(params["beta2"])
        elif params["decay_type"] == "pow":
            decay_rate = adafactor_decay_rate_pow(params["decay_exponent"])
        elif params["decay_type"] == "none":
            decay_rate = None
        else:
            raise ValueError("unknown optimizer_adafactor_decay_type")

        if not "weight_decay" in params.keys():
            optimizer = AdafactorOptimizer(learning_rate=lr,
                                           decay_rate=decay_rate,
                                           beta1=params["beta1"],
                                           name="Adafactor")
            tf.logging.info(
                'create_train_op: optimizer = AdafactorOptimizer(learning_rate=%s, decay_rate=%s, beta1=%s)',
                lr, decay_rate, params["beta1"])

        else:
            AdafactorWOptimizer = tf.contrib.opt.extend_with_decoupled_weight_decay(
                AdafactorOptimizer)

            optimizer = AdafactorWOptimizer(
                weight_decay=params["weight_decay"] * lr,
                learning_rate=lr,
                decay_rate=decay_rate,
                beta1=params["beta1"],
                name="AdafactorW")
            tf.logging.info(
                'create_train_op: optimizer = AdafactorWOptimizer(weight_decay=lr*%s, learning_rate=%s, decay_rate=%s, beta1=%s)',
                params["weight_decay"], lr, decay_rate, params["beta1"])

    else:
        raise ValueError("Unknown optimizer type!")

    if params["use_tpu"]:
        optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # To update batchnorm, if present
    only_train_transformer_layers = False if 'only_train_transformer_layers' not in params else params[
        'only_train_transformer_layers']

    def should_train_variable(v):
        if only_train_transformer_layers:
            if '/h' not in v.name and '/ln_f' not in v.name:
                tf.logging.info("NOT training variable: %s", v)
                return False
            #for i in range(1):
            #  if ('/h%01d/' % i) in v.name:
            #    return False
            #  if ('/h%02d/' % i) in v.name:
            #    return False
        tf.logging.info("    training variable: %s", v)
        return True

    train_vars = [
        v for v in tf.trainable_variables() if should_train_variable(v)
    ]
    non_train_vars = [
        v for v in tf.trainable_variables() if not should_train_variable(v)
    ]
    other_vars = [
        v for v in tf.global_variables()
        if v not in train_vars and v not in non_train_vars
    ]
    local_vars = [v for v in tf.local_variables()]

    paramcount = lambda vs: sum([np.prod(v.shape.as_list()) for v in vs])

    def logvars(variables, label, print_variables=False):
        if print_variables:
            tf.logging.info("%s (%s parameters): %s", label,
                            paramcount(variables), pps(variables))
        else:
            tf.logging.info("%s (%s parameters)", label, paramcount(variables))
        return variables

    tf.logging.info(
        "Training %d parameters (%.2fM) out of %d parameters (%.2fM)" % (
            paramcount(train_vars),
            paramcount(train_vars) / (1024.0 * 1024.0),
            paramcount(tf.trainable_variables()),
            paramcount(tf.trainable_variables()) / (1024.0 * 1024.0),
        ))

    tf.logging.info("---------")
    tf.logging.info("Variable details:")
    logvars(train_vars, "trainable variables", print_variables=True)
    logvars(non_train_vars, "non-trainable variables", print_variables=True)
    logvars(other_vars, "other global variables", print_variables=True)
    logvars(local_vars, "other local variables", print_variables=True)

    tf.logging.info("---------")
    tf.logging.info("Variable summary:")
    logvars(train_vars, "trainable variables")
    logvars(non_train_vars, "non-trainable variables")
    logvars(other_vars, "other global variables")
    logvars(local_vars, "other local variables")

    tf.logging.info("---------")
    tf.logging.info("Gradient options:")
    #use_memory_saving_gradients=True
    use_memory_saving_gradients = False if 'memory_saving_gradients' not in params else params[
        'memory_saving_gradients']
    colocate_gradients_with_ops = True if 'colocate_gradients' not in params else params[
        'colocate_gradients']
    gate_gradients = None
    tf.logging.info("use_memory_saving_gradients=%s",
                    use_memory_saving_gradients)
    tf.logging.info("colocate_gradients_with_ops=%s",
                    colocate_gradients_with_ops)
    tf.logging.info("gate_gradients=%s", gate_gradients)
    if use_memory_saving_gradients:
        #grads = memory_saving_gradients.gradients(loss, train_vars, colocate_gradients_with_ops=colocate_gradients_with_ops, checkpoints='memory')
        #grads = memory_saving_gradients.gradients_memory if i == 0 else memory_saving_gradients.gradients_speed
        #grads = memory_saving_gradients.gradients_speed if i == 0 else memory_saving_gradients.gradients_speed
        grads = memory_saving_gradients.gradients
        grads = grads(loss,
                      train_vars,
                      colocate_grients_with_ops=colocate_gradients_with_ops,
                      gate_gradients=gate_gradients)
    else:
        grads = gradients.gradients(
            loss,
            train_vars,
            colocate_gradients_with_ops=colocate_gradients_with_ops,
            gate_gradients=gate_gradients)

    grads = list(zip(grads, train_vars))
    disconnected_grads = [v for g, v in grads if g is None]
    grads = [(g, v) if g is not None else (tf.zeros_like(v), v)
             for g, v in grads]  # replace disconnected gradients with zeros

    tf.logging.info("---------")
    tf.logging.info("Gradient details:")
    tf.logging.info("%s", pps(grads))
    tf.logging.info("Disconnected gradients:")
    tf.logging.info("%s", pps(disconnected_grads))
    tf.logging.info("---------")

    #train_op = optimizer.minimize(loss, global_step=global_step)
    train_op = optimizer.apply_gradients(grads, global_step=global_step)
    train_op = tf.group([train_op, update_ops], name="train_op")

    return train_op
Esempio n. 36
0
def simple_model(X_train,
                 Y_train,
                 X_test,
                 Y_test,
                 learning_rate=0.009,
                 num_epochs=100,
                 minibatch_size=64,
                 print_cost=True,
                 restore_file=None):
    """
    Implements a three-layer ConvNet in Tensorflow:
    CONV2D -> RELU -> MAXPOOL -> CONV2D -> RELU -> MAXPOOL -> FLATTEN -> FULLYCONNECTED

    Arguments:
    X_train -- training set, of shape (None, 64, 64, 3)
    Y_train -- test set, of shape (None, n_y = 6)
    X_test -- training set, of shape (None, 64, 64, 3)
    Y_test -- test set, of shape (None, n_y = 6)
    learning_rate -- learning rate of the optimization
    num_epochs -- number of epochs of the optimization loop
    minibatch_size -- size of a minibatch
    print_cost -- True to print the cost every 100 epochs
    restore -- File path to restore variables from previous session

    Returns:
    train_accuracy -- real number, accuracy on the train set (X_train)
    test_accuracy -- real number, testing accuracy on the test set (X_test)
    parameters -- parameters learnt by the model. They can then be used to predict.
    """

    ops.reset_default_graph(
    )  # to be able to rerun the model without overwriting tf variables
    tf.set_random_seed(1)  # to keep results consistent (tensorflow seed)
    seed = 3  # to keep results consistent (numpy seed)
    (m, n_H0, n_W0, n_C0) = X_train.shape
    n_y = Y_train.shape[1]
    costs = []  # To keep track of the cost

    # Create Placeholders of the correct shape
    X = tf.placeholder(tf.float32, [None, n_H0, n_W0, n_C0])
    Y = tf.placeholder(tf.float32, [None, n_y])

    # Initialize parameters
    with tf.variable_scope("conv_weights"):
        W1 = tf.get_variable(
            "W1", [3, 3, 1, 16],
            initializer=tf.contrib.layers.xavier_initializer())
        W2 = tf.get_variable(
            "W2", [5, 5, 16, 32],
            initializer=tf.contrib.layers.xavier_initializer())
        W3 = tf.get_variable(
            "W3", [3, 3, 32, 64],
            initializer=tf.contrib.layers.xavier_initializer())
        W4 = tf.get_variable(
            "W4", [5, 5, 64, 128],
            initializer=tf.contrib.layers.xavier_initializer())

    parameters = {"W1": W1, "W2": W2, "W3": W3, "W4": W4}

    # Forward propagation: Build the forward propagation in the tensorflow graph
    Z5 = forward_propagation(X, parameters)

    # Cost function: Add cost function to tensorflow graph
    cost = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=Z5, labels=Y))

    # Backpropagation: Define the tensorflow optimizer. Use an AdamOptimizer that minimizes the cost.
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

    # Accuracy
    acc = tf.equal(tf.argmax(Z5, 1), tf.argmax(Y, 1))
    acc = tf.reduce_mean(tf.cast(acc, tf.float32))

    # Initialize all the variables globally
    init = tf.global_variables_initializer()

    # Allow saving
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=.25)

    # Start the session to compute the tensorflow graph
    print('Starting TensorFlow session...')
    with tf.Session() as sess:

        if restore_file is None:
            # Run the initialization
            print("Initializing parameters...")
            sess.run(init)
        else:
            # Restore variables from disk
            print("Restoring parameters...")
            saver.restore(sess, restore_file)

        # Define saving folder
        folder_name = datetime.datetime.now().strftime("%y%m%d_%H%M") \
                      + '_lr' + str(learning_rate) \
                      + '_ep' + str(num_epochs) \
                      + '_mb' + str(minibatch_size)
        progress_path = os.path.join('.', 'saver', folder_name)
        print("Progress will be saved under %s" % progress_path)

        # Tensorboard summaries
        # op to write logs to Tensorboard
        logs_path = os.path.join(progress_path, 'summaries')
        summary_writer = tf.summary.FileWriter(logs_path,
                                               graph=tf.get_default_graph())
        # create a summary to see the graph
        # summary_writer.add_graph(sess.graph) #--> already defined, not necessary.
        # Create a summary to show input images
        tf.summary.image('input', X, 1)
        # Create a summary to monitor cost and accuracy tensors
        tf.summary.scalar("loss", cost)
        tf.summary.scalar("accuracy", acc)
        # Merge all summaries into a single op
        merged_summary = tf.summary.merge_all()

        # Do the training loop
        print("\n --- TRAINING --- ")
        step = 0
        for epoch in range(num_epochs):

            minibatch_cost = 0.
            num_minibatches = int(
                m / minibatch_size
            )  # number of minibatches of size minibatch_size in the train set
            seed = seed + 1
            minibatches = random_mini_batches(X_train, Y_train, minibatch_size,
                                              seed)
            # mb = 0
            # firstmini = True
            for minibatch in minibatches:
                # Select a minibatch
                (minibatch_X, minibatch_Y) = minibatch
                # IMPORTANT: The line that runs the graph on a minibatch.
                # Run the session to execute the optimizer and the cost. feedict should contain a minibatch for (X,Y).
                # + Run optimization op (backprop), cost op (loss) and summary nodes
                _, temp_cost, summary = sess.run(
                    [optimizer, cost, merged_summary], {
                        X: minibatch_X,
                        Y: minibatch_Y
                    })

                # Compute average loss
                minibatch_cost += temp_cost / num_minibatches

                # Write logs at every iteration
                # if firstmini is True:
                summary_writer.add_summary(summary,
                                           step)  # * len(minibatches) + mb)
                # mb += 1
                # else:
                #     firstmini = False

            # Print the cost every epoch
            if print_cost == True and epoch % 5 == 0:
                print("Cost after epoch %i: %f" % (epoch, minibatch_cost))
            if print_cost == True and epoch % 1 == 0:
                costs.append(minibatch_cost)

            # Save the variables to disk every epoch.
            # file_name = "epoch" #+ str(epoch).zfill(4) #+ ".ckpt"
            checkpoint = os.path.join(progress_path, 'epoch')
            save_path = saver.save(sess, checkpoint, global_step=epoch)
            # saver = tf.train.Saver(var_list=None)
            # saver.save(sess, file)
            # print("Epoch " + str(epoch) + " saved in file: %s" % save_path)

        print('\nFINAL COST after %i epochs: ' % num_epochs, costs[-1])

        # plot the cost
        plt.plot(np.squeeze(costs))
        plt.ylabel('cost')
        plt.xlabel('iterations (per tens)')
        plt.title("Learning rate =" + str(learning_rate))
        plt.show()

        # Calculate the correct predictions
        predict_op = tf.argmax(Z5, 1)
        correct_prediction = tf.equal(predict_op, tf.argmax(Y, 1))

        # Calculate accuracy on the test set
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        # print(accuracy)
        train_accuracy = accuracy.eval({X: X_train, Y: Y_train})
        test_accuracy = accuracy.eval({X: X_test, Y: Y_test})
        print("Train Accuracy:", train_accuracy)
        print("Test Accuracy:", test_accuracy)
        print("  --- DONE ---  \n")

        print('These the variables used in the network:')
        print("GLOBAL:")
        for i in tf.global_variables():
            print(i)
        print("TRAINABLE:")
        for i in tf.trainable_variables():
            print(i)
        print("LOCAL:")
        for i in tf.local_variables():
            print(i)
        print("MODEL:")
        for i in tf.model_variables():
            print(i)

        # Close the method to save the operations for TensorBoard
        summary_writer.close()

        return train_accuracy, test_accuracy, parameters
Esempio n. 37
0
def get_unique_variable_by_name_without_creating(name):
    variables = [v for v in tf.global_variables() + tf.local_variables() if name == v.op.name]
    assert len(variables) == 1, f"Found {len(variables)} variables for name {name}."
    return variables[0]
Esempio n. 38
0
def batch_train(FLAGS):
    print("Initializing...........................")
    model_dir = FLAGS.model_dir + str(FLAGS.output_strides) + "_" + str(FLAGS.num_epochs)
    init(model_dir)
    print("Initializing completed.................\n")

    print("Preparing training meta data...........")
    list_train = os.listdir(FLAGS.images_dir_train)
    list_valid = os.listdir(FLAGS.images_dir_valid)

    list_train = list_train
    list_valid = list_valid

    # list of train images and labels
    list_images_train = [os.path.join(FLAGS.images_dir_train, x) for x in list_train]
    list_labels_train = [
        os.path.join(FLAGS.labels_dir_train, x.replace("leftImg8bit", "label")) for x in list_train
    ]

    # list of validation images and labels
    list_images_valid = [os.path.join(FLAGS.images_dir_valid, x) for x in list_valid]
    list_labels_valid = [
        os.path.join(FLAGS.labels_dir_valid, x.replace("leftImg8bit", "label")) for x in list_valid
    ]

    # num of train batches
    num_samples_train = len(list_images_train)
    num_batches_train = int(math.ceil(num_samples_train / float(FLAGS.batch_size)))

    # num of validation batches
    num_samples_valid = len(list_images_valid)
    num_batches_valid = int(math.ceil(num_samples_valid / float(FLAGS.batch_size)))
    print("Preparing training meta data completed.\n")

    print(f"Learning rate : {FLAGS.learning_rate}")
    print(f"Number of epochs to train : {FLAGS.num_epochs}")
    print(f"Batch size : {FLAGS.batch_size}")
    print(f"Number of train samples : {num_samples_train}")
    print(f"Number of train batches : {num_batches_train}")
    print(f"Number of validation samples : {num_samples_valid}")
    print(f"Number of validation batches : {num_batches_valid}\n")

    print("Building the model.....................")
    axis = -1
    if FLAGS.data_format == "channels_first":
        axis = 1

    # create train and validation dataset
    dataset_train = get_tf_dataset(
        list_images_train, list_labels_train, FLAGS.num_epochs, FLAGS.batch_size)
    dataset_valid = get_tf_dataset(
        list_images_valid, list_labels_valid, FLAGS.num_epochs, FLAGS.batch_size)

    # create iterator for the dataset
    iterator = tf.data.Iterator.from_structure(
        dataset_train.output_types, dataset_train.output_shapes)
    features, labels = iterator.get_next()

    # create initializer for train and validation datasets
    init_op_train = iterator.make_initializer(dataset_train)
    init_op_valid = iterator.make_initializer(dataset_valid)

    # create training placeholder to control behavior of batchnorm and dropout
    is_training = tf.placeholder(tf.bool)
    deep_lab_3 = DeepLab3(FLAGS.pretrained_weights, is_training,
        FLAGS.data_format, FLAGS.num_classes, FLAGS.output_strides)
    deep_lab_3.resnet50_encoder(features)
    print("resnet-50 encoder built")
    deep_lab_3.deeplabv3()
    print(f"deeplabv3 built with strides : {FLAGS.output_strides}")

    logits = deep_lab_3.logits

    # get all trainable variables to apply l2 loss
    train_var_list = [v for v in tf.trainable_variables()]

    loss_1 = compute_loss(labels, logits, axis=axis)
    loss_2 = FLAGS.weight_decay * tf.add_n([tf.nn.l2_loss(v) for v in train_var_list])
    loss = loss_1 + loss_2

    acc_value, acc_op, iou_value, iou_op = compute_metrics(
        labels, logits, axis=axis, num_classes=FLAGS.num_classes)

    global_step = tf.placeholder(tf.int32)
    optimizer_op = get_optimizer(FLAGS.learning_rate, loss, global_step)
    extra_update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    print("Building the model completed...........\n")

    print("Training the model.....................")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    ss = tf.Session(config=tf.ConfigProto(device_count={"GPU": 1}))
    ss.run(tf.global_variables_initializer())
    ss.run(tf.local_variables_initializer())

    train_loss_per_epoch = list()
    valid_loss_per_epoch = list()
    valid_acc_per_epoch = list()
    valid_iou_per_epoch = list()

    for epoch in range(FLAGS.num_epochs):
        ti = time.time()
        temp_train_loss_per_epoch = 0
        temp_valid_loss_per_epoch = 0
        temp_valid_acc_per_epoch = 0
        temp_valid_iou_per_epoch = 0

        ss.run(init_op_train)
        for batch_id in range(num_batches_train):
            _, _, loss_per_batch = ss.run(
                [extra_update_op, optimizer_op, loss],
                feed_dict={is_training: True, global_step: epoch})
            temp_train_loss_per_epoch += loss_per_batch

        ss.run(init_op_valid)
        for batch_id in range(num_batches_valid):
            loss_per_batch, _, _ = ss.run(
                [loss_1, acc_op, iou_op], feed_dict={is_training: False})
            temp_valid_loss_per_epoch += loss_per_batch

        acc_valid, iou_valid = ss.run([acc_value, iou_value])
        temp_valid_acc_per_epoch = acc_valid
        temp_valid_iou_per_epoch = iou_valid

        stream_vars_valid = [v for v in tf.local_variables() if "valid_metrics" in v.name]
        ss.run(tf.variables_initializer(stream_vars_valid))

        ti = time.time() - ti
        train_loss_per_epoch.append(temp_train_loss_per_epoch)
        valid_loss_per_epoch.append(temp_valid_loss_per_epoch)
        valid_acc_per_epoch.append(temp_valid_acc_per_epoch)
        valid_iou_per_epoch.append(temp_valid_iou_per_epoch)

        print(f"Epoch : {epoch+1} / {FLAGS.num_epochs}, time taken : {ti:.2f} sec.")
        print(f"training loss : {temp_train_loss_per_epoch/num_batches_train:.4f}, \
            validation loss : {temp_valid_loss_per_epoch/num_batches_valid:.4f}, \
            validation accuracy : {acc_valid:.4f}, validation mean iou : {iou_valid:.4f}")

        if (epoch + 1) % FLAGS.checkpoint_epoch == 0:
            save_model(ss, model_dir, FLAGS.model_file, epoch)
    print("Training the model completed...........\n")

    print("Saving the model.......................")
    save_model(ss, model_dir, FLAGS.model_file, epoch)
    train_loss_per_epoch = np.array(train_loss_per_epoch)
    valid_loss_per_epoch = np.array(valid_loss_per_epoch)
    valid_acc_per_epoch = np.array(valid_acc_per_epoch)
    valid_iou_per_epoch = np.array(valid_iou_per_epoch)

    train_loss_per_epoch = np.true_divide(train_loss_per_epoch, num_batches_train)
    valid_loss_per_epoch = np.true_divide(valid_loss_per_epoch, num_batches_valid)

    losses_dict = dict()
    losses_dict["train_loss"] = train_loss_per_epoch
    losses_dict["valid_loss"] = valid_loss_per_epoch
    losses_dict["valid_acc"] = valid_acc_per_epoch
    losses_dict["valid_iou"] = valid_iou_per_epoch

    np.save(os.path.join(os.getcwd(), model_dir, FLAGS.model_metrics), (losses_dict))
    print("Saving the model completed.............\n")
    ss.close()
Esempio n. 39
0
def main():
    logging.getLogger().setLevel(logging.INFO)

    logging.info(os.environ["CLUSTER_SPEC"])
    logging.info(job_name)
    logging.info(task_index)

    # Server Setup
    if job_name == 'ps':  # checks if parameter server
        server = tf.train.Server(cluster, job_name="ps", task_index=task_index)
        server.join()
    else:  # it must be a worker server
        logging.info("Loading data from worker index = %d" % task_index)

        # import data
        context = load_data(os.environ["TRAINING_DATA_PATH"])
        total_batch = int(len(context["train_data"]) / BATCH_SIZE)
        # split data into batch
        input_batch = np.array_split(context["train_data"], total_batch)
        target_batch = np.array_split(context["train_target"], total_batch)
        train_sample_weight_batch = np.array_split(
            context["train_data_sample_weight"], total_batch)

        is_chief = (task_index == 0)  # checks if this is the chief node
        server = tf.train.Server(cluster,
                                 job_name="worker",
                                 task_index=task_index)

        # import data
        context = load_data(os.environ["TRAINING_DATA_PATH"])

        # Graph
        with tf.device("/job:worker/replica:0/task:%d" % task_index):
            input_placeholder = tf.placeholder(dtype=tf.float32,
                                               shape=(None, FEATURE_COUNT),
                                               name="shifu_input_0")
            label_placeholder = tf.placeholder(dtype=tf.int32, shape=(None, 1))
            sample_weight_placeholder = tf.placeholder(dtype=tf.float32,
                                                       shape=(None, 1))

            loptimizer, loss, local_step = model(input_placeholder,
                                                 label_placeholder,
                                                 sample_weight_placeholder)

            # Hidden Layer
            weight = tf.Variable(tf.constant(
                0., shape=[context["feature_count"], 20]),
                                 dtype=tf.float32,
                                 collections=[tf.GraphKeys.LOCAL_VARIABLES])
            bias = tf.Variable(tf.constant(0., shape=[20]),
                               dtype=tf.float32,
                               collections=[tf.GraphKeys.LOCAL_VARIABLES])
            output = tf.matmul(input_placeholder, weight)
            output = tf.add(output, bias)
            output = tf.nn.leaky_relu(output)
            # Output layer
            weight = tf.Variable(tf.constant(0., shape=[20, 1]),
                                 dtype=tf.float32,
                                 collections=[tf.GraphKeys.LOCAL_VARIABLES])
            bias = tf.Variable(tf.constant(0., shape=[1]),
                               dtype=tf.float32,
                               collections=[tf.GraphKeys.LOCAL_VARIABLES])
            output = tf.matmul(output, weight)
            output = tf.add(output, bias)

            prediction = tf.nn.sigmoid(output, name="shifu_output_0")

            loss = tf.losses.mean_squared_error(
                predictions=prediction,
                labels=label_placeholder,
                weights=sample_weight_placeholder)

            local_step = tf.Variable(0,
                                     dtype=tf.int32,
                                     trainable=False,
                                     name='local_step',
                                     collections=['local_non_trainable'])

            loptimizer = tf.train.AdamOptimizer(LEARNING_RATE)
            # loptimizer = tf.train.GradientDescentOptimizer(base_lr)

            # SDAG (simplest case since all batches are the same)
            update_window = 5  # T: communication window
            grad_list = None  # the array to store the gradients through the communication window
            for t in range(update_window):
                if t != 0:
                    # compute gradients only if the local opt was run
                    with tf.control_dependencies([opt_local]):
                        grads, varss = zip(*loptimizer.compute_gradients( \
                            loss, var_list=tf.local_variables()))
                else:
                    grads, varss = zip(*loptimizer.compute_gradients( \
                        loss, var_list=tf.local_variables()))

                # add gradients to the list
                if grad_list:
                    for i in range(len(grads)):
                        grad_list[i].append(grads[i])
                else:
                    # Init grad list
                    grad_list = []
                    for grad in grads:
                        grad_list.append([grad])

                # update local parameters
                opt_local = loptimizer.apply_gradients(zip(grads, varss),
                                                       global_step=local_step)

            # averages updates before applying globally
            grad_tuple = []
            for grad in grad_list:
                grad_tuple.append(tf.reduce_mean(grad, axis=0))

            grads = tuple(grad_tuple)

            # add these variables created by local optimizer to local collection
            lopt_vars = add_global_variables_to_local_collection()

            # delete the variables from the global collection
            clear_global_collection()

        with tf.device(
                tf.train.replica_device_setter(
                    ps_tasks=n_pss,
                    worker_device="/job:%s/task:%d" % (job_name, task_index))):

            global_step = tf.Variable(0,
                                      dtype=tf.int32,
                                      trainable=False,
                                      name='global_step')

            # create global variables and/or references
            local_to_global, global_to_local = create_global_variables(
                lopt_vars)

            optimizer = tf.train.AdamOptimizer(LEARNING_RATE)
            # optimizer = tf.train.GradientDescentOptimizer(base_lr)
            optimizer1 = tf.train.SyncReplicasOptimizer(
                optimizer,
                replicas_to_aggregate=int(n_workers * 0.9),
                total_num_replicas=n_workers)

            # apply the gradients to variables on ps
            opt = optimizer1.apply_gradients(zip(
                grads, [local_to_global[v] for v in varss]),
                                             global_step=global_step)

            with tf.control_dependencies([opt]):
                assign_locals = assign_global_to_local(global_to_local)

            # Grab global state before training so all workers have same initialization
            grab_global_init = assign_global_to_local(global_to_local)

            # Assigns local values to global ones for chief to execute
            assign_global = assign_local_to_global(local_to_global)

            # Initialized global step tokens
            init_tokens_op = optimizer1.get_init_tokens_op()

            # Init ops
            # gets step token
            local_init = optimizer1.local_step_init_op
            if is_chief:
                # fills token queue and gets token
                local_init = optimizer1.chief_init_op

            # indicates if variables are initialized
            ready_for_local_init = optimizer1.ready_for_local_init_op

            with tf.control_dependencies([local_init]):
                init_local = tf.variables_initializer(tf.local_variables() \
                                                      + tf.get_collection('local_non_trainable'))  # for local variables

            init = tf.global_variables_initializer(
            )  # must come after other init ops

        total_batch = int(len(context["train_data"]) / BATCH_SIZE)
        input_batch = np.array_split(context["train_data"], total_batch)
        target_batch = np.array_split(context["train_target"], total_batch)
        train_sample_weight_batch = np.array_split(
            context["train_data_sample_weight"], total_batch)

        # Session
        sync_replicas_hook = optimizer1.make_session_run_hook(is_chief)
        stop_hook = tf.train.StopAtStepHook(
            last_step=1000
        )  # epoch * total_batch) # step means every step to update variable
        chief_hooks = [sync_replicas_hook, stop_hook]
        scaff = tf.train.Scaffold(init_op=init,
                                  local_init_op=init_local,
                                  ready_for_local_init_op=ready_for_local_init)

        # Monitored Training Session
        sess = tf.train.MonitoredTrainingSession(master=server.target,
                                                 is_chief=is_chief,
                                                 config=config,
                                                 scaffold=scaff,
                                                 hooks=chief_hooks,
                                                 stop_grace_period_secs=10)

        if is_chief:
            sess.run(assign_global)  # Assigns chief's initial values to ps
            time.sleep(
                40
            )  # grace period to wait on other workers before starting training

        # Train until hook stops session
        print('Starting training on worker %d' % task_index)
        sess.run(grab_global_init)

        # Train until hook stops session
        print('Starting training on worker %d' % task_index)
        cur_epoch = 1
        while not sess.should_stop():
            sum_train_error = 0.0
            for i in range(total_batch):
                _, _, l, gs, ls = sess.run(
                    [opt, assign_locals, loss, global_step, local_step],
                    feed_dict={
                        input_placeholder: input_batch[i],
                        label_placeholder: target_batch[i],
                        sample_weight_placeholder:
                        train_sample_weight_batch[i],
                    })
                sum_train_error += l
            # _,r,gs=sess.run([opt,c,global_step])
            print("Epoch ", cur_epoch, sum_train_error, gs, task_index)
            if is_chief: time.sleep(1)
            time.sleep(1)
            cur_epoch += 1
        print('Done', task_index)

        time.sleep(10)  # grace period to wait before closing session
        sess.close()
        print('Session from worker %d closed cleanly' % task_index)
Esempio n. 40
0
	main_dict["h1_e"]     = lrelu(tf.matmul(x, main_dict["W1_e"]) + main_dict["b1_e"], alpha=0.2)

# Decoder
with tf.device('/gpu:0'):
	main_dict["W1_d"] 	  = weight_variable([1, n_feat])
	main_dict["b1_d"]     = bias_variable([n_feat])

	main_dict["y1_d"]     = tf.matmul(main_dict["h1_e"], main_dict["W1_d"]) + main_dict["b1_d"]

cost           = tf.sqrt(tf.reduce_mean(tf.square( x - main_dict["y1_d"] ))) / (tf.reduce_max(x) - tf.reduce_min(x)) # NRMSEcost
accuracy       = cost
coded_feat     = main_dict["h1_e"]

# Initialize all variables
train_step     = tf.train.AdamOptimizer(slr).minimize(cost) 
sess.run(tf.group(tf.initialize_all_variables(), tf.initialize_variables(tf.local_variables()))) 
saver          = tf.train.Saver(max_to_keep = None)

# Validation checks
check          = 10 # Batch size during accuracy checking
train_checks   = int(math.ceil(train_feat.shape[0] / float(check) )) #n_samples/check
n_batches      = train_feat.shape[0] / batch_size # n_samples/batch_size

best_valid     = np.Inf
patience       = 0 # Epoch-based
max_patience   = 30 # Epoch-based
max_training   = 3000 # Maximum number of epochs to train for
count_batch    = 0 # To get tabs of batch number
count          = 0 # Number of epochs

# Train
def train():
    """Train model.

    Returns:
        best validation error. Save best model"""

    best_validation_error_value = float('inf')

    with tf.Graph().as_default(), tf.device(TRAIN_DEVICE):
        global_step = tf.Variable(0, trainable=False, name="global_step")

        # Get images and labels for CIFAR-10.
        images = DATASET.distorted_inputs(BATCH_SIZE)
        print("images  array is :", np.array(images))
        # Build a Graph that computes the reconstructions predictions from the
        # inference model.
        is_training_, reconstructions = MODEL.get(images,
                                                  train_phase=True,
                                                  l2_penalty=L2_PENALTY)

        # display original images next to reconstructed images
        with tf.variable_scope("visualization"):
            grid_side = math.floor(math.sqrt(BATCH_SIZE))
            inputs = put_kernels_on_grid(
                tf.transpose(images, perm=(1, 2, 3, 0))[:, :, :,
                                                        0:grid_side**2],
                grid_side)
            inputs = tf.pad(inputs, [[0, 0], [0, 0], [0, 10], [0, 0]])
            outputs = put_kernels_on_grid(
                tf.transpose(reconstructions,
                             perm=(1, 2, 3, 0))[:, :, :, 0:grid_side**2],
                grid_side)
        tf_log(
            tf.summary.image('input_output',
                             tf.concat(2, [inputs, outputs]),
                             max_outputs=1))

        # Calculate loss.
        loss = MODEL.loss(reconstructions, images)
        # reconstruction error
        error_ = tf.placeholder(tf.float32, shape=())
        error = tf.summary.scalar('error', error_)

        if LR_DECAY:
            # Decay the learning rate exponentially based on the number of steps.
            learning_rate = tf.train.exponential_decay(INITIAL_LR,
                                                       global_step,
                                                       STEPS_PER_DECAY,
                                                       LR_DECAY_FACTOR,
                                                       staircase=True)
        else:
            learning_rate = tf.constant(INITIAL_LR)

        tf_log(tf.summary.scalar('learning_rate', learning_rate))
        train_op = OPTIMIZER.minimize(loss, global_step=global_step)

        # Create the train saver.
        variables = variables_to_save([global_step])
        train_saver = tf.train.Saver(variables, max_to_keep=2)
        # Create the best model saver
        best_saver = tf.train.Saver(variables, max_to_keep=1)

        # read collection after that every op added its own
        # summaries in the train_summaries collection
        train_summaries = tf.summary.merge(
            tf.get_collection_ref(MODEL_SUMMARIES))

        # Build an initialization operation to run below.
        init = tf.variables_initializer(tf.global_variables() +
                                        tf.local_variables())

        # Start running operations on the Graph.
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(init)

            # Start the queue runners with a coordinator
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            if not RESTART:  # continue from the saved checkpoint
                # restore previous session if exists
                checkpoint = tf.train.latest_checkpoint(LOG_DIR)
                if checkpoint:
                    train_saver.restore(sess, checkpoint)
                    print("has restored.")
                else:
                    print("[I] Unable to restore from checkpoint")

            train_log = tf.summary.FileWriter(os.path.join(
                LOG_DIR, str(InputType.train)),
                                              graph=sess.graph)
            validation_log = tf.summary.FileWriter(os.path.join(
                LOG_DIR, str(InputType.validation)),
                                                   graph=sess.graph)

            # Extract previous global step value
            old_gs = sess.run(global_step)

            # Restart from where we were
            for step in range(old_gs, MAX_STEPS):
                start_time = time.time()
                _, loss_value = sess.run([train_op, loss],
                                         feed_dict={is_training_: True})
                duration = time.time() - start_time

                if np.isnan(loss_value):
                    print('Model diverged with loss = NaN')
                    break

                # update logs every 10 iterations
                if step % 10 == 0:
                    num_examples_per_step = BATCH_SIZE
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = ('{}: step {}, loss = {:.4f} '
                                  '({:.1f} examples/sec; {:.3f} sec/batch)')
                    print(
                        format_str.format(datetime.now(), step, loss_value,
                                          examples_per_sec, sec_per_batch))
                    # log train error and summaries
                    train_error_summary_line, train_summary_line = sess.run(
                        [error, train_summaries],
                        feed_dict={
                            error_: loss_value,
                            is_training_: True
                        })
                    train_log.add_summary(train_error_summary_line,
                                          global_step=step)
                    train_log.add_summary(train_summary_line, global_step=step)

                # Save the model checkpoint at the end of every epoch
                # evaluate train and validation performance
                if (step > 0 and step % STEPS_PER_EPOCH
                        == 0) or (step + 1) == MAX_STEPS:
                    checkpoint_path = os.path.join(LOG_DIR, 'model.ckpt')
                    print("model has been kept.")
                    train_saver.save(sess, checkpoint_path, global_step=step)

                    # validation error
                    validation_error_value = evaluate.error(
                        LOG_DIR,
                        MODEL,
                        DATASET,
                        InputType.validation,
                        device=EVAL_DEVICE)

                    summary_line = sess.run(
                        error, feed_dict={error_: validation_error_value})
                    validation_log.add_summary(summary_line, global_step=step)

                    print('{} ({}): train error = {} validation error = {}'.
                          format(datetime.now(), int(step / STEPS_PER_EPOCH),
                                 loss_value, validation_error_value))
                    if validation_error_value < best_validation_error_value:
                        best_validation_error_value = validation_error_value
                        best_saver.save(sess,
                                        os.path.join(BEST_MODEL_DIR,
                                                     'model.ckpt'),
                                        global_step=step)
            # end of for

            validation_log.close()
            train_log.close()

            # When done, ask the threads to stop.
            coord.request_stop()
            # Wait for threads to finish.
            coord.join(threads)
    return best_validation_error_value
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)

    parser.add_argument('--startpoint',
                        '-b',
                        help='Setting to use',
                        required=True)

    args = parser.parse_args()

    time_string = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    root_folder = os.path.join(
        args.save_folder,
        '%s_%s_%s_%d' % (args.model, args.setting, time_string, os.getpid()))
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    #sys.stdout = open(os.path.join(root_folder, 'log.txt'), 'w')

    print('PID:', os.getpid())

    print(args)

    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = setting.num_epochs
    batch_size = setting.batch_size
    sample_num = setting.sample_num
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    label_weights_val = [1.0] * 1 + [1.0] * (setting.num_class - 1)

    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
        filelist_train)
    data_val, _, data_num_val, label_val, _ = data_utils.load_seg(
        args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = int(math.ceil(num_val / batch_size))
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    net = model.Net(points_augmented, features_augmented, is_training, setting)
    logits = net.logits
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
        t_1_per_mean_iou_op, t_1_per_mean_iou_op_update_op = \
            tf.metrics.mean_iou(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)

    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_mean_iou/val',
                          tensor=t_1_per_mean_iou_op,
                          collections=['val'])

    #_ = tf.summary.histogram('summary/Add_F2', Add_F2, collections=['summary_values'])
    #_ = tf.summary.histogram('summary/Add_F3', Add_F3, collections=['summary_values'])
    #_ = tf.summary.histogram('summary/Z', Z, collections=['summary_values'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=5)

    saver_best = tf.train.Saver(max_to_keep=5)
    # backup all code
    code_folder = os.path.abspath(os.path.dirname(__file__))
    shutil.copytree(code_folder,
                    os.path.join(root_folder, os.path.basename(code_folder)))

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_ckpt_best = os.path.join(root_folder, 'ckpt-best')
    if not os.path.exists(folder_ckpt_best):
        os.makedirs(folder_ckpt_best)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    folder_cm_matrix = os.path.join(root_folder, 'cm-matrix')
    if not os.path.exists(folder_cm_matrix):
        os.makedirs(folder_cm_matrix)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    _highest_val = 0.0
    max_val = 10
    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_values_op = tf.summary.merge_all('summary_values')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        #img_d_summary_writer = tf.summary.FileWriter(folder_cm_matrix, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        batch_num = 200000
        for batch_idx_train in range(int(args.startpoint), batch_num):
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                ######################################################################
                # Validation
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                summary_hist = None
                _idxVal = np.arange(num_val)
                np.random.shuffle(_idxVal)

                _pred = []
                _label = []
                for batch_val_idx in tqdm(range(batch_num_val)):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx

                    points_batch = data_val[_idxVal[start_idx:end_idx], ...]
                    points_num_batch = data_num_val[_idxVal[start_idx:end_idx],
                                                    ...]
                    labels_batch = label_val[_idxVal[start_idx:end_idx], ...]

                    weights_batch = np.array(label_weights_val)[labels_batch]

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)

                    _labels_sampled, _predictions, _, _, _, _ = sess.run(
                        [
                            labels_sampled, predictions, loss_mean_update_op,
                            t_1_acc_update_op, t_1_per_class_acc_update_op,
                            t_1_per_mean_iou_op_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            indices:
                            pf.get_indices(batch_size_val, sample_num,
                                           points_num_batch),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            labels_seg:
                            labels_batch,
                            labels_weights:
                            weights_batch,
                            is_training:
                            False,
                        })
                    _pred.append(_predictions.flatten())
                    _label.append(_labels_sampled.flatten())

                loss_val, t_1_acc_val, t_1_per_class_acc_val, t1__mean_iou, summaries_val = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        t_1_per_mean_iou_op, summaries_val_op
                    ])

                img_d_summary = pf.plot_confusion_matrix(
                    _label,
                    _pred, ["Environment", "Pedestrian", "Car", "Cyclist"],
                    tensor_name='confusion_matrix')

                max_val = max_val - 1

                if (t_1_per_class_acc_val > _highest_val):

                    max_val = 10

                    _highest_val = t_1_per_class_acc_val

                    filename_ckpt = os.path.join(folder_ckpt_best,
                                                 str(_highest_val) + "-iter-")
                    saver_best.save(sess,
                                    filename_ckpt,
                                    global_step=global_step)

                if (max_val < 0):
                    sys.exit(0)

                summary_writer.add_summary(img_d_summary, batch_idx_train)
                summary_writer.add_summary(summaries_val, batch_idx_train)

                print(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  Diff-Best: {:.4f} T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            _highest_val - t_1_per_class_acc_val,
                            t_1_per_class_acc_val, t1__mean_iou))
                sys.stdout.flush()

                ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]

            if start_idx + batch_size_train == num_train:
                if is_list_of_h5_list:
                    filelist_train_prev = seg_list[(seg_list_idx - 1) %
                                                   len(seg_list)]
                    filelist_train = seg_list[seg_list_idx % len(seg_list)]
                    if filelist_train != filelist_train_prev:
                        data_train, _, data_num_train, label_train, _ = data_utils.load_seg(
                            filelist_train)
                        num_train = data_train.shape[0]
                    seg_list_idx = seg_list_idx + 1
                data_train, data_num_train, label_train = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train])

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)
            sess.run(reset_metrics_op)
            sess.run(
                [
                    train_op, loss_mean_update_op, t_1_acc_update_op,
                    t_1_per_class_acc_update_op, t_1_per_mean_iou_op_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   points_num_batch),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    labels_seg:
                    labels_batch,
                    labels_weights:
                    weights_batch,
                    is_training:
                    True,
                })
            if batch_idx_train % 100 == 0:
                loss, t_1_acc, t_1_per_class_acc, t_1__mean_iou, summaries = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        t_1_per_mean_iou_op, summaries_op
                    ],
                    feed_dict={
                        pts_fts:
                        points_batch,
                        indices:
                        pf.get_indices(batch_size_train, sample_num_train,
                                       points_num_batch),
                        xforms:
                        xforms_np,
                        rotations:
                        rotations_np,
                        jitter_range:
                        np.array([jitter]),
                        labels_seg:
                        labels_batch,
                        labels_weights:
                        weights_batch,
                        is_training:
                        True,
                    })
                summary_writer.add_summary(summaries, batch_idx_train)
                print(
                    '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}   T-1 mIoU: {:.4f}'
                    .format(datetime.now(), batch_idx_train, loss, t_1_acc,
                            t_1_per_class_acc, t_1__mean_iou))
                sys.stdout.flush()
            ######################################################################
        print('{}-Done!'.format(datetime.now()))
Esempio n. 43
0
    def run(self):
        """Worker runtime body.
        """
        # Logging:
        StreamHandler(sys.stdout).push_application()
        self.log = Logger('Worker_{}'.format(self.task), level=self.log_level)

        tf.reset_default_graph()

        if self.test_mode:
            import gym

        # Define cluster:
        cluster = tf.train.ClusterSpec(self.cluster_spec).as_cluster_def()

        # Start tf.server:
        if self.job_name in 'ps':
            server = tf.train.Server(
                cluster,
                job_name=self.job_name,
                task_index=self.task,
                config=tf.ConfigProto(device_filters=["/job:ps"]))
            self.log.debug('parameters_server started.')
            # Just block here:
            server.join()

        else:
            server = tf.train.Server(
                cluster,
                job_name='worker',
                task_index=self.task,
                config=tf.ConfigProto(
                    intra_op_parallelism_threads=1,  # original was: 1
                    inter_op_parallelism_threads=2  # original was: 2
                ))
            self.log.debug('tf.server started.')

            self.log.debug('making environments:')
            # Making as many environments as many entries in env_config `port` list:
            # TODO: Hacky-II: only one example over all parallel environments can be data-master [and renderer]
            # TODO: measure data_server lags, maybe launch several instances
            self.env_list = []
            env_kwargs = self.env_kwargs.copy()
            env_kwargs['log_level'] = self.log_level
            port_list = env_kwargs.pop('port')
            data_port_list = env_kwargs.pop('data_port')
            data_master = env_kwargs.pop('data_master')
            render_enabled = env_kwargs.pop('render_enabled')

            render_list = [False for entry in port_list]
            if render_enabled:
                if self.render_last_env:
                    render_list[-1] = True
                else:
                    render_list[0] = True

            data_master_list = [False for entry in port_list]
            if data_master:
                data_master_list[0] = True

            # Parallel envs. numbering:
            if len(port_list) > 1:
                task_id = 0.0
            else:
                task_id = 0

            for port, data_port, is_render, is_master in zip(
                    port_list, data_port_list, render_list, data_master_list):
                if not self.test_mode:
                    # Assume BTgym env. class:
                    self.log.debug(
                        'setting env at port_{} is data_master: {}'.format(
                            port, data_master))
                    self.log.debug('env_kwargs:')
                    for k, v in env_kwargs.items():
                        self.log.debug('{}: {}'.format(k, v))
                    try:
                        self.env_list.append(
                            self.env_class(port=port,
                                           data_port=data_port,
                                           data_master=is_master,
                                           render_enabled=is_render,
                                           task=self.task + task_id,
                                           **env_kwargs))
                        data_master = False
                        self.log.info(
                            'set BTGym environment {} @ port:{}, data_port:{}'.
                            format(self.task + task_id, port, data_port))
                        task_id += 0.01

                    except:
                        self.log.exception(
                            'failed to make BTGym environment at port_{}.'.
                            format(port))
                        raise RuntimeError

                else:
                    # Assume atari testing:
                    try:
                        self.env_list.append(
                            self.env_class(env_kwargs['gym_id']))
                        self.log.debug('set Gyn/Atari environment.')

                    except:
                        self.log.exception(
                            'failed to make Gym/Atari environment')
                        raise RuntimeError

            self.log.debug('Defining trainer...')

            # Define trainer:
            trainer = self.trainer_class(
                env=self.env_list,
                task=self.task,
                policy_config=self.policy_config,
                log_level=self.log_level,
                cluster_spec=self.cluster_spec,
                random_seed=self.random_seed,
                **self.trainer_kwargs,
            )

            self.log.debug('trainer ok.')

            # Saver-related:
            variables_to_save = [
                v for v in tf.global_variables() if not 'local' in v.name
            ]
            local_variables = [
                v for v in tf.global_variables() if 'local' in v.name
            ] + tf.local_variables()
            init_op = tf.variables_initializer(variables_to_save)
            local_init_op = tf.variables_initializer(local_variables)
            init_all_op = tf.global_variables_initializer()

            saver = _FastSaver(variables_to_save)

            self.log.debug('VARIABLES TO SAVE:')
            for v in variables_to_save:
                self.log.debug('{}: {}'.format(v.name, v.get_shape()))

            def init_fn(ses):
                self.log.info("initializing all parameters.")
                ses.run(init_all_op)

            config = tf.ConfigProto(device_filters=[
                "/job:ps", "/job:worker/task:{}/cpu:0".format(self.task)
            ])
            logdir = os.path.join(self.log_dir, 'train')
            summary_dir = logdir + "_{}".format(self.task)

            summary_writer = tf.summary.FileWriter(summary_dir)

            self.log.debug('before tf.train.Supervisor... ')

            # TODO: switch to tf.train.MonitoredTrainingSession
            sv = tf.train.Supervisor(
                is_chief=(self.task == 0),
                logdir=logdir,
                saver=saver,
                summary_op=None,
                init_op=init_op,
                local_init_op=local_init_op,
                init_fn=init_fn,
                #ready_op=tf.report_uninitialized_variables(variables_to_save),
                ready_op=tf.report_uninitialized_variables(),
                global_step=trainer.global_step,
                save_model_secs=300,
            )
            self.log.info("connecting to the parameter server... ")

            with sv.managed_session(server.target,
                                    config=config) as sess, sess.as_default():
                #sess.run(trainer.sync)
                trainer.start(sess, summary_writer)

                # Note: `self.global_step` refers to number of environment steps
                # summarized over all environment instances, not to number of policy optimizer train steps.
                global_step = sess.run(trainer.global_step)
                self.log.notice(
                    "started training at step: {}".format(global_step))

                while not sv.should_stop(
                ) and global_step < self.max_env_steps:
                    trainer.process(sess)
                    global_step = sess.run(trainer.global_step)

                # Ask for all the services to stop:
                for env in self.env_list:
                    env.close()

                sv.stop()
            self.log.notice('reached {} steps, exiting.'.format(global_step))
Esempio n. 44
0
    def train_segmenter(self, restored_model, display_step=1):

        print "Start training the segmenter ..."

        self.optimizer_overall = self._get_optimizers()
        self._init_tfboard()

        init_glb = tf.global_variables_initializer()
        init_loc = tf.variables_initializer(tf.local_variables())
        config = tf.ConfigProto(log_device_placement=False)
        config.gpu_options.allow_growth = True

        with open(os.path.join(self.output_path, 'eva.txt'), 'w') as f:
            f.write("Record the test performance on the fly as training ...\n")

        with tf.Session(config=config) as sess:
            sess.run([init_glb, init_loc])
            coord = tf.train.Coordinator()

            train_summary_writer = tf.summary.FileWriter(
                self.output_path + "/train_log_" + self.opt_kwargs["prefix"],
                graph=sess.graph)
            val_summary_writer = tf.summary.FileWriter(
                self.output_path + "/val_log_" + self.opt_kwargs["prefix"],
                graph=sess.graph)

            self.restore_model(sess, restored_model)

            source_train_feed, source_train_feed_fid = self.next_batch(
                self.source_train_queue)
            source_val_feed, source_val_feed_fid = self.next_batch(
                self.source_val_queue)
            target_train_feed, target_train_feed_fid = self.next_batch(
                self.target_train_queue)
            target_val_feed, target_val_feed_fid = self.next_batch(
                self.target_val_queue)
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            best_avg = 0
            best_model_performance = []
            best_model_save_path = None

            for epoch in xrange(self.num_epochs):
                for step in xrange((epoch * self.training_iters),
                                   ((epoch + 1) * self.training_iters)):
                    logging.info("Running step %s epoch %s ..." %
                                 (str(step), str(epoch)))
                    start = time.time()

                    source_train_batch, source_train_fid = sess.run(
                        [source_train_feed, source_train_feed_fid])
                    source_train_batch_x = source_train_batch[:, :, :, 0:3]
                    source_train_batch_y = _label_decomp(
                        source_train_batch[:, :, :, 3], self.net.n_class)

                    target_train_batch, target_train_fid = sess.run(
                        [target_train_feed, target_train_feed_fid])
                    target_train_batch_x = target_train_batch[:, :, :, 0:3]
                    target_train_batch_y = _label_decomp(
                        target_train_batch[:, :, :, 3], self.net.n_class)


                    _, source_loss, target_loss, kd_loss, source_prob, target_prob, lr = sess.run(\
                              (self.optimizer_overall, self.net.source_seg_dice_loss, self.net.target_seg_dice_loss,\
                               self.net.kd_loss, self.net.source_prob, self.net.target_prob, self.learning_rate_node),\
                                  feed_dict={self.net.source: source_train_batch_x,
                                             self.net.source_y: source_train_batch_y,
                                             self.net.target: target_train_batch_x,
                                             self.net.target_y: target_train_batch_y,
                                             self.net.keep_prob: 0.75,})

                    logging.info(
                        "Training at global step %s epoch %s, source loss is %0.4f, target loss is %0.4f"
                        % (str(self.global_step.eval()), str(epoch),
                           source_loss, target_loss))
                    logging.info("Knowledge Distilling loss: %0.4f" % kd_loss)
                    print "source prob:", source_prob
                    print "target prob:", target_prob
                    logging.info("Current learning rate %0.8f" % lr)
                    logging.info("Time elapsed %s seconds" %
                                 (str(time.time() - start)))

                    if step % (display_step * 20) == 0:
                        print 'update the tensorboard for training ...'
                        self.minibatch_stats_segmenter(sess,
                                                       train_summary_writer,
                                                       step,
                                                       source_train_batch_x,
                                                       source_train_batch_y,
                                                       target_train_batch_x,
                                                       target_train_batch_y,
                                                       section="train")

                    if step % (display_step * 20) == 0:
                        print 'update the tensorboard for validation ...'
                        source_val_batch = source_val_feed.eval()
                        source_val_batch_x = source_val_batch[:, :, :, 0:3]
                        source_val_batch_y = _label_decomp(
                            source_val_batch[:, :, :, 3], self.net.n_class)

                        target_val_batch = target_val_feed.eval()
                        target_val_batch_x = target_val_batch[:, :, :, 0:3]
                        target_val_batch_y = _label_decomp(
                            target_val_batch[:, :, :, 3], self.net.n_class)

                        self.minibatch_stats_segmenter(sess,
                                                       val_summary_writer,
                                                       step,
                                                       source_val_batch_x,
                                                       source_val_batch_y,
                                                       target_val_batch_x,
                                                       target_val_batch_y,
                                                       section="val")

                    ## The followings are learning rate decay
                    if self.global_step.eval() % 1000 == 0:
                        _pre_lr = sess.run(self.learning_rate_node)
                        sess.run(
                            tf.assign(self.learning_rate_node, _pre_lr * 0.95))

                    # save the model periodically
                    if self.global_step.eval() % (self.checkpoint_space) == 0:
                        saver = tf.train.Saver()
                        saved_model_name = self.opt_kwargs[
                            "prefix"] + "_itr%d_model.cpkt" % self.global_step.eval(
                            )
                        save_path = saver.save(
                            sess,
                            os.path.join(self.output_path, saved_model_name),
                            global_step=self.global_step.eval())
                        logging.info(
                            "Model saved as step %d, save path is %s" %
                            (self.global_step.eval(), save_path))

            logging.info("Modeling training Finished!")
            coord.request_stop()
            coord.join(threads)
            return 0
Esempio n. 45
0
   def train(self):
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      adam_ep = 1e-4 if self.use_fp16 else 1e-8
      self.gen_optimizer = tf.train.AdamOptimizer(learning_rate=args.gen_lr, beta1=0.5, beta2=0.9, epsilon=adam_ep)
      self.dis_optimizer = tf.train.AdamOptimizer(learning_rate=args.dis_lr, beta1=0.5, beta2=0.9, epsilon=adam_ep)
      self.enc_optimizer = tf.train.AdamOptimizer(learning_rate=args.enc_lr, beta1=0.5, beta2=0.9, epsilon=adam_ep)
      with tf.control_dependencies(update_ops):
         with tf.control_dependencies([tf.check_numerics(self.gen_loss, 'nan')]):
            self.gen_loss = tf.identity(self.gen_loss)
         with tf.control_dependencies([tf.check_numerics(self.dis_loss, 'nan')]):
            self.dis_loss = tf.identity(self.dis_loss)
         with tf.control_dependencies([tf.check_numerics(self.enc_loss, 'nan')]):
            self.enc_loss = tf.identity(self.enc_loss)
         self.gen_grad = self.gen_optimizer.compute_gradients(self.gen_loss * self.fp16_scale, var_list=self.gen_vars)
         self.dis_grad = self.dis_optimizer.compute_gradients(self.dis_loss * self.fp16_scale, var_list=self.dis_vars)
         self.enc_grad = self.enc_optimizer.compute_gradients(self.enc_loss * self.fp16_scale, var_list=self.enc_vars)
#         with tf.control_dependencies([tf.check_numerics(x[0], 'nan') for x in self.gen_grad]):
#            self.gen_grad[0] = (tf.identity(self.gen_grad[0][0]), self.gen_grad[0][1])
#         with tf.control_dependencies([tf.check_numerics(x[0], 'nan') for x in self.dis_grad]):
#            self.dis_grad[0] = (tf.identity(self.dis_grad[0][0]), self.dis_grad[0][1])
#         with tf.control_dependencies([tf.check_numerics(x[0], 'nan') for x in self.enc_grad]):
#            self.enc_grad[0] = (tf.identity(self.enc_grad[0][0]), self.enc_grad[0][1])
#         self.gen_grad = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in self.gen_grad if grad is not None]
#         self.dis_grad = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in self.dis_grad if grad is not None]
#         self.end_grad = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in self.enc_grad if grad is not None]
         self.gen_grad_norm = tf.global_norm([grad for grad, var in self.gen_grad]) / self.numpara_gen * args.gen_lr
         self.enc_grad_norm = tf.global_norm([grad for grad, var in self.enc_grad]) / self.numpara_enc * args.enc_lr
         self.dis_grad_norm = tf.global_norm([grad for grad, var in self.dis_grad]) / self.numpara_dis * args.dis_lr
#         self.stand_grad_norm = tf.maximum(self.dis_grad_norm, 1e-13)
#         self.norm_ratio = self.stand_grad_norm / (self.dis_grad_norm + 1e-80) * 1
#         self.dis_grad = [(self.norm_ratio * grad, var) for grad, var in self.dis_grad if grad is not None]
#         self.norm_ratio = self.dis_grad_norm / (self.gen_grad_norm + 1e-80) * 1
#         self.gen_grad = [(self.norm_ratio * grad, var) for grad, var in self.gen_grad if grad is not None]
#         self.norm_ratio = self.dis_grad_norm / (self.enc_grad_norm + 1e-80) * 1
#         self.enc_grad = [(self.norm_ratio * grad, var) for grad, var in self.enc_grad if grad is not None]
#         self.gen_grad_norm = tf.global_norm([grad for grad, var in self.gen_grad]) / self.numpara_gen * args.gen_lr
#         self.enc_grad_norm = tf.global_norm([grad for grad, var in self.enc_grad]) / self.numpara_enc * args.enc_lr
#         self.dis_grad_norm = tf.global_norm([grad for grad, var in self.dis_grad]) / self.numpara_dis * args.dis_lr
         self.gen_grad = [(grad / self.fp16_scale, var) for grad, var in self.gen_grad if grad is not None]
         self.gen_grad = [(grad / self.fp16_scale, var) for grad, var in self.gen_grad if grad is not None]
         self.dis_grad = [(grad / self.fp16_scale, var) for grad, var in self.dis_grad if grad is not None]
         self.enc_grad = [(grad / self.fp16_scale, var) for grad, var in self.enc_grad if grad is not None]
         self.gen_op = self.gen_optimizer.apply_gradients(self.gen_grad, global_step=self.global_step_gen)
         self.dis_op = self.dis_optimizer.apply_gradients(self.dis_grad, global_step=self.global_step_dis)
         self.enc_op = self.enc_optimizer.apply_gradients(self.enc_grad, global_step=self.global_step_enc)
      varset = list(set(tf.global_variables()) | set(tf.local_variables()))
      self.saver = tf.train.Saver(var_list=varset, max_to_keep=8)
      num_batch = int(args.total_examples / self.batch_size)
      do_initialzie = True
      if args.loading_path:
         if self.load(args.loading_path):
            start_epoch = int(self.global_step_gen.eval() / num_batch)
            do_initialzie = False
      if do_initialzie:
         init_op = tf.global_variables_initializer()
         start_epoch = 0
         self.sess.run(init_op)
      self.writer = tf.summary.FileWriter(args.summary_dir, None)
      with tf.name_scope("summaries"):
         self.s_gen_loss = tf.summary.scalar('generator_loss', self.gen_loss)
         self.s_enc_loss = tf.summary.scalar('encoder_loss', self.enc_loss)
         self.s_dis_loss = tf.summary.scalar('discriminator_loss', self.dis_loss)
         self.s_gen_grad = tf.summary.scalar('generator_grad', self.gen_grad_norm)
         self.s_enc_grad = tf.summary.scalar('encoder_grad', self.enc_grad_norm)
         self.s_dis_grad = tf.summary.scalar('discriminator_grad', self.dis_grad_norm)
         self.gen_merged = tf.summary.merge([self.s_gen_loss, self.s_gen_grad])
         self.enc_merged = tf.summary.merge([self.s_enc_loss, self.s_enc_grad])
         self.dis_merged = tf.summary.merge([self.s_dis_loss, self.s_dis_grad])

      self.sample(args.sampling_path, start_epoch, self.sample_num)
      try:
         for epoch in range(start_epoch, args.epoch):
            loss_names = ["Generator Loss",
                          "DisLoss True",
                          "DisLoss False",
                          "Distance Loss"]
            buffers = buff(loss_names)
            for batch in tqdm(range(num_batch)):
               for i in range(args.num_train_gen):
                  input_data = self.data_generator.dequeue()
                  feed_dict = {a: b for a, b in zip(self.placeholders, input_data)}
                  _, gen_loss, gen_sum, gen_step = self.sess.run([self.gen_op,
                                                   self.gen_loss,
                                                   self.gen_merged,
                                                   self.global_step_gen],
                                                   feed_dict=feed_dict)
                  self.gate_add_summary(gen_sum, gen_step)
                  buffers.put([gen_loss], [0])
               for i in range(args.num_train_dis):
                  input_data = self.data_generator.dequeue()
                  feed_dict = {a: b for a, b in zip(self.placeholders, input_data)}
                  _, dis_loss_t, dis_loss_f, dis_sum, dis_step = self.sess.run([self.dis_op,
                                                   self.D_logit_loss_T,
                                                   self.D_logit_loss_F,
                                                   self.dis_merged,
                                                   self.global_step_dis],
                                                   feed_dict=feed_dict)
                  self.gate_add_summary(dis_sum, dis_step)
                  buffers.put([dis_loss_t, dis_loss_f], [1, 2])
               for i in range(args.num_train_enc):
                  input_data = self.data_generator.dequeue()
                  feed_dict = {a: b for a, b in zip(self.placeholders, input_data)}
                  _, dl2, enc_sum, enc_step = self.sess.run([self.enc_op,
                                          self.enc_loss,
                                          self.enc_merged,
                                          self.global_step_enc],
                                          feed_dict=feed_dict)
                  self.gate_add_summary(enc_sum, enc_step)
                  buffers.put([dl2], [3])
               if (batch + 1) % args.display_step == 0:
                  buffers.printout([epoch + 1, batch + 1, num_batch])
                  fd = True
            if (epoch + 1) % args.saving_epoch == 0 and args.saving_path:
               try :
                  self.save(args.saving_path, epoch + 1)
               except:
                  print ("Failed saving model, maybe no space left...")
            if (epoch + 1) % args.sample_epoch == 0 and args.sampling_path:
               self.sample(args.sampling_path, epoch + 1, self.sample_num)
      except KeyboardInterrupt:
         print ("KeyboardInterrupt")
      finally:
         for x in self.procs:
            x.terminate()
Esempio n. 46
0
    def test(self, test_model, part, test_list_fid, test_nii_list_fid):

        test_list = _read_lists(test_list_fid)
        test_nii_list = _read_lists(test_nii_list_fid)
        test_pair_list = zip(test_list, test_nii_list)

        init_glb = tf.global_variables_initializer()
        init_loc = tf.variables_initializer(tf.local_variables())
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            sess.run([init_glb, init_loc])
            coord = tf.train.Coordinator()

            saver = tf.train.Saver()
            saver.restore(sess, test_model)
            logging.info("test segmenter, model is %s" % test_model)

            dice = []
            for idx_file, pair in enumerate(test_pair_list):
                fid = pair[0]  # this is npz data
                _npz_dict = np.load(fid)
                raw = np.flip(np.flip(_npz_dict['arr_0'], axis=0), axis=1)
                gt_y = np.flip(np.flip(_npz_dict['arr_1'], axis=0), axis=1)
                pred_y = np.zeros(gt_y.shape)

                frame_list = [kk for kk in range(1, raw.shape[2] - 1)]
                np.random.shuffle(frame_list)
                for ii in xrange(
                        int(np.floor(raw.shape[2] // self.net.batch_size))):
                    vol = np.zeros([
                        self.net.batch_size, raw_size[0], raw_size[1],
                        raw_size[2]
                    ])
                    for idx, jj in enumerate(
                            frame_list[ii * self.net.batch_size:(ii + 1) *
                                       self.net.batch_size]):
                        vol[idx, ...] = raw[..., jj - 1:jj + 2].copy()

                    if part == "source":
                        pred = sess.run(self.net.source_pred_compact,
                                        feed_dict={
                                            self.net.source: vol,
                                            self.net.keep_prob: 1.0,
                                            self.net.training_mode_source:
                                            False,
                                            self.net.training_mode_target:
                                            False,
                                        })
                    elif part == "target":
                        pred = sess.run(self.net.target_pred_compact,
                                        feed_dict={
                                            self.net.target: vol,
                                            self.net.keep_prob: 1.0,
                                            self.net.training_mode_source:
                                            False,
                                            self.net.training_mode_target:
                                            False,
                                        })

                    for idx, jj in enumerate(
                            frame_list[ii * self.net.batch_size:(ii + 1) *
                                       self.net.batch_size]):
                        pred_y[..., jj] = pred[idx, ...].copy()

                dice_subject = _eval_dice(gt_y, pred_y)
                dice.append(dice_subject)
                _save_nii(pred_y, gt_y, pair[1], self.output_path)

            print dice

            dice_avg = np.mean(dice, axis=0).tolist()
            dice_std = np.std(dice, axis=0).tolist()

            for cls in xrange(1, self.net.n_class):
                logging.info("%s avg dice is %.4f, std is %.4f" %
                             (class_map[str(cls)], dice_avg[cls - 1],
                              dice_std[cls - 1]))
            logging.info("average dice is: %f" % np.mean(dice_avg))

        return dice_avg
Esempio n. 47
0
    'Car': (154, 205, 50),
    'Truck': (255, 215, 0),
    'Van': (255, 20, 147),
    'Tram': (250, 128, 114),
    'Misc': (128, 0, 128),
    'Cyclist': (255, 165, 0),
}
# runing network ==============================================================
time_dict = {}
saver = tf.train.Saver()
graph = tf.get_default_graph()
gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(graph=graph,
                config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    sess.run(tf.variables_initializer(tf.global_variables()))
    sess.run(tf.variables_initializer(tf.local_variables()))
    model_path = tf.train.latest_checkpoint(CHECKPOINT_PATH)
    print('Restore from checkpoint %s' % model_path)
    saver.restore(sess, model_path)
    previous_step = sess.run(global_step)
    for frame_idx in tqdm(range(0, NUM_TEST_SAMPLE)):
        start_time = time.time()
        if VISUALIZATION_LEVEL == 2:
            pcd = open3d.PointCloud()
            line_set = open3d.LineSet()
            graph_line_set = open3d.LineSet()
        # provide input ======================================================
        cam_rgb_points = dataset.get_cam_points_in_image_with_rgb(
            frame_idx, config['downsample_by_voxel_size'])
        calib = dataset.get_calib(frame_idx)
        image = dataset.get_image(frame_idx)
Esempio n. 48
0
def main(device, input_path_train, input_path_validation, dummy_data,
         downsampling_fact, downsampling_mode, channels, data_format, label_id,
         weights, image_dir, checkpoint_dir, trn_sz, val_sz, loss_type, model,
         decoder, fs_type, optimizer, batch, batchnorm, num_epochs, dtype,
         disable_checkpoints, disable_imsave, tracing, trace_dir,
         output_sampling, scale_factor, intra_threads, inter_threads):
    #init horovod
    comm_rank = 0
    comm_local_rank = 0
    comm_size = 1
    comm_local_size = 1
    if horovod:
        hvd.init()
        comm_rank = hvd.rank()
        comm_local_rank = hvd.local_rank()
        comm_size = hvd.size()
        #not all horovod versions have that implemented
        try:
            comm_local_size = hvd.local_size()
        except:
            comm_local_size = 1
        if comm_rank == 0:
            print("Using distributed computation with Horovod: {} total ranks".
                  format(comm_size, comm_rank))

    #downsampling? recompute image dims
    image_height = image_height_orig // downsampling_fact
    image_width = image_width_orig // downsampling_fact

    #parameters
    per_rank_output = False
    loss_print_interval = 1

    #session config
    sess_config = tf.ConfigProto(
        inter_op_parallelism_threads=inter_threads,  #6
        intra_op_parallelism_threads=intra_threads,  #1
        log_device_placement=False,
        allow_soft_placement=True)
    sess_config.gpu_options.visible_device_list = str(comm_local_rank)
    sess_config.gpu_options.force_gpu_compatible = True

    #get data
    training_graph = tf.Graph()
    if comm_rank == 0:
        print("Loading data...")
    train_files = load_data(input_path_train, True, trn_sz, horovod)
    valid_files = load_data(input_path_validation, False, val_sz, horovod)

    #print some stats
    if comm_rank == 0:
        print("Num workers: {}".format(comm_size))
        print("Local batch size: {}".format(batch))
        if dtype == tf.float32:
            print("Precision: {}".format("FP32"))
        else:
            print("Precision: {}".format("FP16"))
        print("Decoder: {}".format(decoder))
        print("Batch normalization: {}".format(batchnorm))
        print("Channels: {}".format(channels))
        print("Loss type: {}".format(loss_type))
        print("Loss weights: {}".format(weights))
        print("Loss scale factor: {}".format(scale_factor))
        print("Output sampling target: {}".format(output_sampling))
        #print optimizer parameters
        for k, v in optimizer.items():
            print("Solver Parameters: {k}: {v}".format(k=k, v=v))
        print("Num training samples: {}".format(train_files.shape[0]))
        print("Num validation samples: {}".format(valid_files.shape[0]))
        if dummy_data:
            print("Using synthetic dummy data")
        print("Disable checkpoints: {}".format(disable_checkpoints))
        print("Disable image save: {}".format(disable_imsave))

    #compute epochs and stuff:
    if fs_type == "local":
        num_samples = train_files.shape[0] // comm_local_size
    else:
        num_samples = train_files.shape[0] // comm_size
    num_steps_per_epoch = num_samples // batch
    num_steps = num_epochs * num_steps_per_epoch
    if per_rank_output:
        print("Rank {} does {} steps per epoch".format(comm_rank,
                                                       num_steps_per_epoch))

    with training_graph.as_default():

        if dummy_data:
            trn_dataset = create_dummy_dataset(n_samples=trn_sz,
                                               batchsize=batch,
                                               num_epochs=num_epochs,
                                               dtype=dtype)
            val_dataset = create_dummy_dataset(n_samples=val_sz,
                                               batchsize=batch,
                                               num_epochs=1,
                                               dtype=dtype)
        else:
            #create readers
            trn_reader = h5_input_reader(input_path_train,
                                         channels,
                                         weights,
                                         dtype,
                                         normalization_file="stats.h5",
                                         update_on_read=False,
                                         data_format=data_format,
                                         label_id=label_id,
                                         sample_target=output_sampling)
            val_reader = h5_input_reader(input_path_validation,
                                         channels,
                                         weights,
                                         dtype,
                                         normalization_file="stats.h5",
                                         update_on_read=False,
                                         data_format=data_format,
                                         label_id=label_id)
            #create datasets
            if fs_type == "local":
                trn_dataset = create_dataset(trn_reader,
                                             train_files,
                                             batch,
                                             num_epochs,
                                             comm_local_size,
                                             comm_local_rank,
                                             dtype,
                                             shuffle=True)
                val_dataset = create_dataset(val_reader,
                                             valid_files,
                                             batch,
                                             1,
                                             comm_local_size,
                                             comm_local_rank,
                                             dtype,
                                             shuffle=False)
            else:
                trn_dataset = create_dataset(trn_reader,
                                             train_files,
                                             batch,
                                             num_epochs,
                                             comm_size,
                                             comm_rank,
                                             dtype,
                                             shuffle=True)
                val_dataset = create_dataset(val_reader,
                                             valid_files,
                                             batch,
                                             1,
                                             comm_size,
                                             comm_rank,
                                             dtype,
                                             shuffle=False)

        #create iterators
        handle = tf.placeholder(tf.string,
                                shape=[],
                                name="iterator-placeholder")
        iterator = tf.data.Iterator.from_string_handle(
            handle, (dtype, tf.int32, dtype, tf.string),
            ((batch, len(channels), image_height_orig,
              image_width_orig) if data_format == "channels_first" else
             (batch, image_height_orig, image_width_orig, len(channels)),
             (batch, image_height_orig, image_width_orig),
             (batch, image_height_orig, image_width_orig), (batch)))
        next_elem = iterator.get_next()

        #if downsampling, do some preprocessing
        if downsampling_fact != 1:
            if downsampling_mode == "scale":
                #do downsampling
                rand_select = tf.cast(tf.one_hot(tf.random_uniform(
                    (batch, image_height, image_width),
                    minval=0,
                    maxval=downsampling_fact * downsampling_fact,
                    dtype=tf.int32),
                                                 depth=downsampling_fact *
                                                 downsampling_fact,
                                                 axis=-1),
                                      dtype=tf.int32)
                next_elem = (tf.layers.average_pooling2d(next_elem[0], downsampling_fact, downsampling_fact, 'valid', data_format), \
                             tf.reduce_max(tf.multiply(tf.image.extract_image_patches(tf.expand_dims(next_elem[1], axis=-1), \
                                                                                 [1, downsampling_fact, downsampling_fact, 1], \
                                                                                 [1, downsampling_fact, downsampling_fact, 1], \
                                                                                 [1,1,1,1], 'VALID'), rand_select), axis=-1), \
                             tf.squeeze(tf.layers.average_pooling2d(tf.expand_dims(next_elem[2], axis=-1), downsampling_fact, downsampling_fact, 'valid', "channels_last"), axis=-1), \
                             next_elem[3])
            elif downsampling_mode == "center-crop":
                #some parameters
                length = 1. / float(downsampling_fact)
                offset = length / 2.
                boxes = [[offset, offset, offset + length, offset + length]
                         ] * batch
                box_ind = list(range(0, batch))
                crop_size = [image_height, image_width]

                #be careful with data order
                if data_format == "channels_first":
                    next_elem[0] = tf.transpose(next_elem[0],
                                                perm=[0, 2, 3, 1])

                #crop
                next_elem = (tf.image.crop_and_resize(next_elem[0], boxes, box_ind, crop_size, method='bilinear', extrapolation_value=0, name="data_cropping"), \
                             ensure_type(tf.squeeze(tf.image.crop_and_resize(tf.expand_dims(next_elem[1],axis=-1), boxes, box_ind, crop_size, method='nearest', extrapolation_value=0, name="label_cropping"), axis=-1), tf.int32), \
                             tf.squeeze(tf.image.crop_and_resize(tf.expand_dims(next_elem[2],axis=-1), boxes, box_ind, crop_size, method='bilinear', extrapolation_value=0, name="weight_cropping"), axis=-1), \
                             next_elem[3])

                #be careful with data order
                if data_format == "channels_first":
                    next_elem[0] = tf.transpose(next_elem[0],
                                                perm=[0, 3, 1, 2])

            else:
                raise ValueError(
                    "Error, downsampling mode {} not supported. Supported are [center-crop, scale]"
                    .format(downsampling_mode))

        #create init handles
        #trn
        trn_iterator = trn_dataset.make_initializable_iterator()
        trn_handle_string = trn_iterator.string_handle()
        trn_init_op = iterator.make_initializer(trn_dataset)
        #val
        val_iterator = val_dataset.make_initializable_iterator()
        val_handle_string = val_iterator.string_handle()
        val_init_op = iterator.make_initializer(val_dataset)

        #compute the input filter number based on number of channels used
        num_channels = len(channels)
        #set up model
        model = deeplab_v3_plus_generator(num_classes=3,
                                          output_stride=8,
                                          base_architecture=model,
                                          decoder=decoder,
                                          batchnorm=batchnorm,
                                          pre_trained_model=None,
                                          batch_norm_decay=None,
                                          data_format=data_format)

        logit, prediction = model(next_elem[0], True, dtype)

        #set up loss
        loss = None

        #cast the logits to fp32
        logit = ensure_type(logit, tf.float32)
        if loss_type == "weighted":
            #cast weights to FP32
            w_cast = ensure_type(next_elem[2], tf.float32)
            loss = tf.losses.sparse_softmax_cross_entropy(
                labels=next_elem[1],
                logits=logit,
                weights=w_cast,
                reduction=tf.losses.Reduction.SUM)
            if scale_factor != 1.0:
                loss *= scale_factor

        elif loss_type == "weighted_mean":
            #cast weights to FP32
            w_cast = ensure_type(next_elem[2], tf.float32)
            loss = tf.losses.sparse_softmax_cross_entropy(
                labels=next_elem[1],
                logits=logit,
                weights=w_cast,
                reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
            if scale_factor != 1.0:
                loss *= scale_factor

        elif loss_type == "focal":
            #one-hot-encode
            labels_one_hot = tf.contrib.layers.one_hot_encoding(
                next_elem[1], 3)
            #cast to FP32
            labels_one_hot = ensure_type(labels_one_hot, tf.float32)
            loss = focal_loss(onehot_labels=labels_one_hot,
                              logits=logit,
                              alpha=1.,
                              gamma=2.)

        else:
            raise ValueError("Error, loss type {} not supported.",
                             format(loss_type))

        #determine flops
        flops = graph_flops.graph_flops(
            format="NHWC" if data_format == "channels_last" else "NCHW",
            verbose=False,
            batch=batch,
            sess_config=sess_config)
        flops *= comm_size
        if comm_rank == 0:
            print('training flops: {:.3f} TF/step'.format(flops * 1e-12))

        #number of trainable parameters
        if comm_rank == 0:
            num_params = get_number_of_trainable_parameters()
            print('number of trainable parameters: {} ({} MB)'.format(
                num_params,
                num_params * (4 if dtype == tf.float32 else 2) * (2**-20)))

        if horovod:
            loss_avg = hvd.allreduce(ensure_type(loss, tf.float32))
        else:
            loss_avg = tf.identity(loss)

        #set up global step - keep on CPU
        with tf.device('/device:CPU:0'):
            global_step = tf.train.get_or_create_global_step()

        #set up optimizer
        if optimizer['opt_type'].startswith("LARC"):
            if comm_rank == 0:
                print("Enabling LARC")
            train_op, lr = get_larc_optimizer(optimizer, loss, global_step,
                                              num_steps_per_epoch, horovod)
        else:
            train_op, lr = get_optimizer(optimizer, loss, global_step,
                                         num_steps_per_epoch, horovod)

        #set up streaming metrics
        iou_op, iou_update_op = tf.metrics.mean_iou(labels=next_elem[1],
                                                    predictions=tf.argmax(
                                                        prediction, axis=3),
                                                    num_classes=3,
                                                    weights=None,
                                                    metrics_collections=None,
                                                    updates_collections=None,
                                                    name="iou_score")
        iou_reset_op = tf.variables_initializer([
            i for i in tf.local_variables() if i.name.startswith('iou_score/')
        ])

        if horovod:
            iou_avg = hvd.allreduce(iou_op)
        else:
            iou_avg = tf.identity(iou_op)

        if "gpu" in device.lower():
            with tf.device(device):
                mem_usage_ops = [
                    tf.contrib.memory_stats.MaxBytesInUse(),
                    tf.contrib.memory_stats.BytesLimit()
                ]
        #hooks
        #these hooks are essential. regularize the step hook by adding one additional step at the end
        hooks = [tf.train.StopAtStepHook(last_step=num_steps + 1)]
        #bcast init for bcasting the model after start
        if horovod:
            init_bcast = hvd.broadcast_global_variables(0)
        #initializers:
        init_op = tf.global_variables_initializer()
        init_local_op = tf.local_variables_initializer()

        #checkpointing
        if comm_rank == 0:
            checkpoint_save_freq = 5 * num_steps_per_epoch
            checkpoint_saver = tf.train.Saver(max_to_keep=1000)
            if (not disable_checkpoints):
                hooks.append(
                    tf.train.CheckpointSaverHook(
                        checkpoint_dir=checkpoint_dir,
                        save_steps=checkpoint_save_freq,
                        saver=checkpoint_saver))
            #create image dir if not exists
            if not os.path.isdir(image_dir):
                os.makedirs(image_dir)

        #tracing
        if tracing is not None:
            import tracehook
            tracing_hook = tracehook.TraceHook(steps_to_trace=tracing,
                                               cache_traces=True,
                                               trace_dir=trace_dir)
            hooks.append(tracing_hook)

        # instead of averaging losses over an entire epoch, use a moving
        #  window average
        recent_losses = []
        loss_window_size = 10
        #start session
        with tf.train.MonitoredTrainingSession(config=sess_config,
                                               hooks=hooks) as sess:
            #initialize
            sess.run([init_op, init_local_op])
            #restore from checkpoint:
            if comm_rank == 0 and not disable_checkpoints:
                load_model(sess, checkpoint_saver, checkpoint_dir)
            #broadcast loaded model variables
            if horovod:
                sess.run(init_bcast)
            #create iterator handles
            trn_handle, val_handle = sess.run(
                [trn_handle_string, val_handle_string])
            #init iterators
            sess.run(trn_init_op, feed_dict={handle: trn_handle})
            sess.run(val_init_op, feed_dict={handle: val_handle})

            # figure out what step we're on (it won't be 0 if we are
            #  restoring from a checkpoint) so we can count from there
            train_steps = sess.run([global_step])[0]

            #do the training
            epoch = 1
            step = 1

            prev_mem_usage = 0
            t_sustained_start = time.time()
            r_peak = 0

            #start training
            start_time = time.time()
            print('Begin training loop')
            while not sess.should_stop():

                #training loop
                try:
                    #construct feed dict
                    t_inst_start = time.time()
                    _, tmp_loss, cur_lr = sess.run(
                        [
                            train_op,
                            (loss if per_rank_output else loss_avg), lr
                        ],
                        feed_dict={handle: trn_handle})
                    t_inst_end = time.time()
                    if "gpu" in device.lower():
                        mem_used = sess.run(mem_usage_ops)
                    else:
                        mem_used = [0, 0]
                    train_steps += 1
                    train_steps_in_epoch = train_steps % num_steps_per_epoch
                    recent_losses = [tmp_loss
                                     ] + recent_losses[0:loss_window_size - 1]
                    train_loss = sum(recent_losses) / len(recent_losses)
                    step += 1

                    r_inst = 1e-12 * flops / (t_inst_end - t_inst_start)
                    r_peak = max(r_peak, r_inst)

                    #print step report
                    if (train_steps % loss_print_interval) == 0:
                        if "gpu" in device.lower():
                            mem_used = sess.run(mem_usage_ops)
                        else:
                            mem_used = [0, 0]
                        if per_rank_output:
                            print(
                                "REPORT: rank {}, training loss for step {} (of {}) is {:.5f}, time {:.3f}"
                                .format(comm_rank, train_steps, num_steps,
                                        train_loss,
                                        time.time() - start_time))
                        else:
                            if comm_rank == 0:
                                if mem_used[0] > prev_mem_usage:
                                    print(
                                        "memory usage: {:.2f} GB / {:.2f} GB".
                                        format(mem_used[0] / 2.0**30,
                                               mem_used[1] / 2.0**30))
                                    prev_mem_usage = mem_used[0]
                                print(
                                    "REPORT: training loss for step {} (of {}) is {}, time {:.3f}, r_inst {:.3f}, r_peak {:.3f}, lr {:.2g}"
                                    .format(train_steps, num_steps, train_loss,
                                            time.time() - start_time, r_inst,
                                            r_peak, cur_lr))

                    #do the validation phase
                    if train_steps_in_epoch == 0:
                        end_time = time.time()
                        #print epoch report
                        if per_rank_output:
                            print(
                                "COMPLETED: rank {}, training loss for epoch {} (of {}) is {:.5f}, time {:.3f}, r_sust {:.3f}"
                                .format(
                                    comm_rank, epoch, num_epochs, train_loss,
                                    time.time() - start_time,
                                    1e-12 * flops * num_steps_per_epoch /
                                    (end_time - t_sustained_start)))
                        else:
                            if comm_rank == 0:
                                print(
                                    "COMPLETED: training loss for epoch {} (of {}) is {:.5f}, time {:.3f}, r_sust {:.3f}"
                                    .format(
                                        epoch, num_epochs, train_loss,
                                        time.time() - start_time,
                                        1e-12 * flops * num_steps_per_epoch /
                                        (end_time - t_sustained_start)))

                        #evaluation loop
                        eval_loss = 0.
                        eval_steps = 0
                        while True:
                            try:
                                #construct feed dict
                                _, tmp_loss, val_model_predictions, val_model_labels, val_model_filenames = sess.run(
                                    [
                                        iou_update_op,
                                        (loss
                                         if per_rank_output else loss_avg),
                                        prediction, next_elem[1], next_elem[3]
                                    ],
                                    feed_dict={handle: val_handle})
                                #print some images
                                if comm_rank == 0 and not disable_imsave:
                                    if have_imsave:
                                        imsave(
                                            image_dir + '/test_pred_epoch' +
                                            str(epoch) + '_estep' +
                                            str(eval_steps) + '_rank' +
                                            str(comm_rank) + '.png',
                                            np.argmax(
                                                val_model_predictions[0, ...],
                                                axis=2) * 100)
                                        imsave(
                                            image_dir + '/test_label_epoch' +
                                            str(epoch) + '_estep' +
                                            str(eval_steps) + '_rank' +
                                            str(comm_rank) + '.png',
                                            val_model_labels[0, ...] * 100)
                                        imsave(
                                            image_dir +
                                            '/test_combined_epoch' +
                                            str(epoch) + '_estep' +
                                            str(eval_steps) + '_rank' +
                                            str(comm_rank) + '.png',
                                            plot_colormap[
                                                val_model_labels[0, ...],
                                                np.argmax(
                                                    val_model_predictions[0,
                                                                          ...],
                                                    axis=2)])
                                    else:
                                        np.savez(
                                            image_dir + '/test_epoch' +
                                            str(epoch) + '_estep' +
                                            str(eval_steps) + '_rank' +
                                            str(comm_rank) + '.npz',
                                            prediction=np.argmax(
                                                val_model_predictions[0, ...],
                                                axis=2) * 100,
                                            label=val_model_labels[0, ...] *
                                            100,
                                            filename=val_model_filenames[0])
                                eval_loss += tmp_loss
                                eval_steps += 1
                            except tf.errors.OutOfRangeError:
                                eval_steps = np.max([eval_steps, 1])
                                eval_loss /= eval_steps
                                if per_rank_output:
                                    print(
                                        "COMPLETED: rank {}, evaluation loss for epoch {} (of {}) is {:.5f}"
                                        .format(comm_rank, epoch, num_epochs,
                                                eval_loss))
                                else:
                                    if comm_rank == 0:
                                        print(
                                            "COMPLETED: evaluation loss for epoch {} (of {}) is {:.5f}"
                                            .format(epoch, num_epochs,
                                                    eval_loss))
                                if per_rank_output:
                                    iou_score = sess.run(iou_op)
                                    print(
                                        "COMPLETED: rank {}, evaluation IoU for epoch {} (of {}) is {:.5f}"
                                        .format(comm_rank, epoch, num_epochs,
                                                iou_score))
                                else:
                                    iou_score = sess.run(iou_avg)
                                    if comm_rank == 0:
                                        print(
                                            "COMPLETED: evaluation IoU for epoch {} (of {}) is {:.5f}"
                                            .format(epoch, num_epochs,
                                                    iou_score))
                                sess.run(iou_reset_op)
                                sess.run(val_init_op,
                                         feed_dict={handle: val_handle})
                                break

                        #reset counters
                        epoch += 1
                        step = 0
                        t_sustained_start = time.time()

                except tf.errors.OutOfRangeError:
                    break

        # write any cached traces to disk
        if tracing is not None:
            tracing_hook.write_traces()

    print('All done')
Esempio n. 49
0
    def run(self, episodes=-1, max_timesteps=-1, episode_finished=None, before_execution=None):
        """
        Runs an environments for the specified number of episodes and time steps per episode.
        
        Args:
            episodes: Number of episodes to execute
            max_timesteps: Max timesteps in a given episode
            episode_finished: Optional termination condition, e.g. a particular mean reward threshold
            before_execution: Optional filter function to apply to action before execution

        Returns:

        """
        if self.cluster_spec is not None:
            assert self.task_index is not None
            # Redirect process output
            # sys.stdout = open('tf_worker_' + str(self.task_index) + '.txt', 'w', 0)
            cluster_def = self.cluster_spec.as_cluster_def()

            if self.task_index == -1:
                server = tf.train.Server(
                    server_or_cluster_def=cluster_def,
                    job_name='ps',
                    task_index=0,
                    config=tf.ConfigProto(device_filters=["/job:ps"])
                )
                # Param server does nothing actively
                server.join()
                return

            # Worker creates runner for execution
            server = tf.train.Server(
                server_or_cluster_def=cluster_def,
                job_name='worker',
                task_index=self.task_index,
                config=tf.ConfigProto(
                    intra_op_parallelism_threads=1,
                    inter_op_parallelism_threads=2,
                    log_device_placement=True
                )
            )

            variables_to_save = [v for v in tf.global_variables() if not v.name.startswith('local')]
            init_op = tf.variables_initializer(variables_to_save)
            local_init_op = tf.variables_initializer(tf.local_variables() + [v for v in tf.global_variables() if v.name.startswith('local')])
            init_all_op = tf.global_variables_initializer()

            def init_fn(sess):
                sess.run(init_all_op)

            config = tf.ConfigProto(device_filters=['/job:ps', '/job:worker/task:{}/cpu:0'.format(self.task_index)])

            supervisor = tf.train.Supervisor(
                is_chief=(self.task_index == 0),
                logdir='/tmp/train_logs',
                global_step=self.agent.model.global_step,
                init_op=init_op,
                local_init_op=local_init_op,
                init_fn=init_fn,
                ready_op=tf.report_uninitialized_variables(variables_to_save),
                saver=self.agent.model.saver)
            # summary_op=tf.summary.merge_all(),
            # summary_writer=worker_agent.model.summary_writer)

            # # Connecting to parameter server
            # self.logger.debug('Connecting to session..')
            # self.logger.info('Server target = ' + str(server.target))

            # with supervisor.managed_session(server.target, config=config) as session, session.as_default():
            # self.logger.info('Established session, starting runner..')
            managed_session = supervisor.managed_session(server.target, config=config)
            session = managed_session.__enter__()
            self.agent.model.session = session
            # session.run(self.agent.model.update_local)

        # save episode reward and length for statistics
        self.episode_rewards = []
        self.episode_lengths = []

        self.episode = 1
        while True:
            state = self.environment.reset()
            self.agent.reset()
            episode_reward = 0

            self.timestep = 1
            while True:
                if self.preprocessor:
                    processed_state = self.preprocessor.process(state)
                else:
                    processed_state = state

                action = self.agent.act(state=processed_state)

                if before_execution:
                    action = before_execution(self, action)

                if self.repeat_actions > 1:
                    reward = 0
                    for repeat in xrange(self.repeat_actions):
                        state, step_reward, terminal = self.environment.execute(action=action)
                        reward += step_reward
                        if terminal:
                            break
                else:
                    state, reward, terminal = self.environment.execute(action=action)

                episode_reward += reward
                self.agent.observe(state=processed_state, action=action, reward=reward, terminal=terminal)

                if terminal or self.timestep == max_timesteps:
                    break
                self.timestep += 1

            self.episode_rewards.append(episode_reward)
            self.episode_lengths.append(self.timestep)

            if self.save_path and self.save_episodes is not None and self.episode % self.save_episodes == 0:
                print("Saving agent after episode {}".format(self.episode))
                self.agent.save_model(self.save_path)

            if episode_finished and not episode_finished(self):
                return
            if self.cluster_spec is None:
                if self.episode >= episodes:
                    return
            elif session.run(self.agent.model.global_episode) >= episodes:
                return
            self.episode += 1

        if self.cluster_spec is not None:
            managed_session.__exit__(None, None, None)
            supervisor.stop()
Esempio n. 50
0
def streaming_auc(target, prediction):
    with tf.variable_scope('auc') as scope:
        name = scope.original_name_scope
        auc = tf.metrics.auc(target, prediction, num_thresholds=1000)
    variables = tf.local_variables(name) + tf.global_variables(name)
    return StreamingTensor(auc[0], auc[1], variables)
def main():
	# Configure
	config=tf.ConfigProto(log_device_placement=False)

	#Server Setup
	cluster_spec = {'ps':['localhost:2222'],
				'worker':['localhost:2223','localhost:2224']}
	n_pss = len(cluster_spec['ps']) #the number of parameter servers
	n_workers = len(cluster_spec['worker']) #the number of worker nodes
	cluster = tf.train.ClusterSpec(cluster_spec) #allows this node know about all other nodes

	if FLAGS.job_name == 'ps': #checks if parameter server
		server = tf.train.Server(cluster,
					job_name="ps",
					task_index=FLAGS.task_index,
					config=config)
		server.join()
	else: #it must be a worker server
		is_chief = (FLAGS.task_index == 0) #checks if this is the chief node
		server = tf.train.Server(cluster,
					job_name="worker",
					task_index=FLAGS.task_index,
					config=config)
		
		# Graph
		# Local operations
		with tf.device("/job:worker/replica:0/task:%d" % FLAGS.task_index):
			a = tf.Variable(tf.constant(0.,shape=[2]),dtype=tf.float32,
						collections=[tf.GraphKeys.LOCAL_VARIABLES])
			b = tf.Variable(tf.constant(0.,shape=[2]),dtype=tf.float32,
						collections=[tf.GraphKeys.LOCAL_VARIABLES])
			c=a+b

			local_step = tf.Variable(0,dtype=tf.int32,trainable=False,name='local_step',
						collections=['local_non_trainable'])
			lr = .0001
			
			#loptimizer = tf.train.GradientDescentOptimizer(lr*FLAGS.task_index) #local optimizer
			loptimizer = tf.train.AdagradOptimizer(lr) #local optimizer

			target = tf.constant(100.,shape=[2],dtype=tf.float32)
			loss = tf.reduce_mean(tf.square(c-target))

			# DOWNPOUR
			update_window = 3 # T: communication window
			grad_list = [] # the array to store the gradients through the communication window
			for t in range(update_window):
				if t != 0:
					with tf.control_dependencies([opt_local]): #compute gradients only if the local opt was run
						grads, varss = zip(*loptimizer.compute_gradients( \
									loss,var_list=tf.local_variables()))
				else:
					grads, varss = zip(*loptimizer.compute_gradients( \
								loss,var_list=tf.local_variables()))
				grad_list.append(grads) #add gradients to the list
				opt_local = loptimizer.apply_gradients(zip(grads,varss),
							global_step=local_step) #update local parameters

			grads = tf.reduce_sum(grad_list,axis=0) #sum updates before applying globally
			grads = tuple([grads[i]for i in range(len(varss))])

			# add these variables created by local optimizer to local collection
			lopt_vars = add_global_variables_to_local_collection()

			# delete the variables from the global collection
			clear_global_collection()

		with tf.device(tf.train.replica_device_setter(ps_tasks=n_pss,
          worker_device="/job:%s/task:%d" % (FLAGS.job_name,FLAGS.task_index))):
			global_step = tf.Variable(0,dtype=tf.int32,trainable=False,name='global_step')

			# all workers use the same learning rate and it is decided on by the task 0 
			# or maybe the from the graph of the chief worker
			optimizer = tf.train.AdagradOptimizer(lr) #global optimizer

			# create global variables and/or references
			local_to_global, global_to_local = create_global_variables(lopt_vars)
			opt = optimizer.apply_gradients(
						zip(grads,[local_to_global[v] for v in varss])
						,global_step=global_step) #apply the gradients to variables on ps

			# Pull params from global server
			with tf.control_dependencies([opt]):
				assign_locals = assign_global_to_local(global_to_local)

			# Grab global state before training so all workers have same initialization
			grab_global_init = assign_global_to_local(global_to_local)

			# Assigns local values to global ones for chief to execute
			assign_global = assign_local_to_global(local_to_global)

			# Init ops
			init = tf.global_variables_initializer() # for global variables
			init_local = tf.variables_initializer(tf.local_variables() \
						+tf.get_collection('local_non_trainable')) #for local variables

		# Session
		stop_hook = tf.train.StopAtStepHook(last_step=60)
		hooks = [stop_hook]
		scaff = tf.train.Scaffold(init_op=init,local_init_op=[init_local])

		# Monitored Training Session
		sess = tf.train.MonitoredTrainingSession(master=server.target,
					is_chief=is_chief,
					config=config,
					scaffold=scaff,
					hooks=hooks,
					save_checkpoint_secs=1,
					checkpoint_dir='logdir')
		
		if is_chief:
			sess.run(assign_global) #Assigns chief's initial values to ps
			time.sleep(10) #grace period to wait on other workers before starting training

		# Train until hook stops session
		print('Starting training on worker %d'%FLAGS.task_index)
		sess.run(grab_global_init)
		while not sess.should_stop():
			_,_,r,gs,ls = sess.run([opt,assign_locals,c,global_step,local_step])

			print(r,"global step: "+str(gs),"worker: "+str(FLAGS.task_index),"local step: "+str(ls))

			time.sleep(1) # so we can observe training
		print('Done',FLAGS.task_index)

		time.sleep(10) #grace period to wait before closing session
		sess.close()
		print('Session from worker %d closed cleanly'%FLAGS.task_index)
Esempio n. 52
0
def main(args, train_set, class_num, train_classifier, pre_ckpt, model_def,
         is_augmenter, anchor_file, image_size, output_size, batch_size,
         rand_seed, max_nrof_epochs, init_learning_rate,
         learning_rate_decay_epochs, learning_rate_decay_factor, obj_weight,
         noobj_weight, obj_thresh, iou_thresh, log_dir):
    g = tf.get_default_graph()
    tf.set_random_seed(rand_seed)
    """ import network """
    network = eval(model_def)
    """ generate the dataset """
    # [(0.57273, 0.677385), (1.87446, 2.06253), (3.33843, 5.47434), (7.88282, 3.52778), (9.77052, 9.16828)]
    helper = Helper('data/{}_img.list'.format(train_set),
                    'data/{}_ann.list'.format(train_set), class_num,
                    anchor_file, image_size, output_size)
    helper.set_dataset(batch_size,
                       rand_seed,
                       is_training=(is_augmenter == 'True'))
    next_img, next_label = helper.get_iter()
    """ define the model """
    batch_image = tf.placeholder_with_default(
        next_img,
        shape=[None, image_size[0], image_size[1], 3],
        name='Input_image')
    batch_label = tf.placeholder_with_default(next_label,
                                              shape=[
                                                  None, output_size[0],
                                                  output_size[1],
                                                  len(helper.anchors),
                                                  5 + class_num
                                              ],
                                              name='Input_label')
    training_control = tf.placeholder_with_default(True,
                                                   shape=[],
                                                   name='training_control')
    true_label = tf.identity(batch_label)
    nets, endpoints = network(batch_image,
                              len(helper.anchors),
                              class_num,
                              phase_train=training_control)
    """ reshape the model output """
    pred_label = tf.reshape(nets, [
        -1, output_size[0], output_size[1],
        len(helper.anchors), 5 + class_num
    ],
                            name='predict')
    """ split the label """
    pred_xy = pred_label[..., 0:2]
    pred_wh = pred_label[..., 2:4]
    pred_confidence = pred_label[..., 4:5]
    pred_cls = pred_label[..., 5:]

    pred_xy = tf.nn.sigmoid(pred_xy)
    pred_wh = tf.exp(pred_wh)
    pred_confidence_sigmoid = tf.nn.sigmoid(pred_confidence)

    true_xy = true_label[..., 0:2]
    true_wh = true_label[..., 2:4]
    true_confidence = true_label[..., 4:5]
    true_cls = true_label[..., 5:]

    obj_mask = true_confidence[..., 0] > obj_thresh
    """ calc the noobj mask ~ """
    if train_classifier == 'True':
        noobj_mask = tf.logical_not(obj_mask)
    else:
        noobj_mask = calc_noobj_mask(true_xy,
                                     true_wh,
                                     pred_xy,
                                     pred_wh,
                                     obj_mask,
                                     iou_thresh=iou_thresh,
                                     helper=helper)
    """ define loss """
    xy_loss = tf.reduce_sum(
        tf.square(
            tf.boolean_mask(true_xy, obj_mask) -
            tf.boolean_mask(pred_xy, obj_mask))) / batch_size
    wh_loss = tf.reduce_sum(
        tf.square(
            tf.boolean_mask(true_wh, obj_mask) -
            tf.boolean_mask(pred_wh, obj_mask))) / batch_size
    obj_loss = obj_weight * tf.reduce_sum(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.boolean_mask(true_confidence, obj_mask),
            logits=tf.boolean_mask(pred_confidence, obj_mask))) / batch_size
    noobj_loss = noobj_weight * tf.reduce_sum(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.boolean_mask(true_confidence, noobj_mask),
            logits=tf.boolean_mask(pred_confidence, noobj_mask))) / batch_size
    cls_loss = tf.reduce_sum(
        tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=tf.boolean_mask(true_cls, obj_mask),
            logits=tf.boolean_mask(pred_cls, obj_mask))) / batch_size

    # xy_loss = tf.losses.mean_squared_error(tf.boolean_mask(true_xy, obj_mask), tf.boolean_mask(pred_xy, obj_mask))
    # wh_loss = tf.losses.mean_squared_error(tf.boolean_mask(true_wh, obj_mask), tf.boolean_mask(pred_wh, obj_mask))
    # obj_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.boolean_mask(true_confidence, obj_mask), logits=tf.boolean_mask(pred_confidence, obj_mask), weights=5.0)
    # noobj_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.boolean_mask(true_confidence, noobj_mask), logits=tf.boolean_mask(pred_confidence, noobj_mask), weights=.5)
    # cls_loss = tf.losses.softmax_cross_entropy(onehot_labels=tf.boolean_mask(true_cls, obj_mask), logits=tf.boolean_mask(pred_cls, obj_mask))

    if train_classifier == 'True':
        total_loss = obj_loss + noobj_loss + cls_loss
    else:
        total_loss = obj_loss + noobj_loss + cls_loss + xy_loss + wh_loss
    """ define steps """
    global_steps = tf.train.create_global_step()
    """ define learing rate """
    current_learning_rate = tf.train.exponential_decay(
        init_learning_rate,
        global_steps,
        helper.epoch_step // learning_rate_decay_epochs,
        learning_rate_decay_factor,
        staircase=False)
    """ define train_op """
    train_op = slim.learning.create_train_op(
        total_loss, tf.train.AdamOptimizer(current_learning_rate),
        global_steps)
    """ calc the accuracy """
    precision, prec_op = tf.metrics.precision_at_thresholds(
        true_confidence, pred_confidence_sigmoid, [obj_thresh])
    test_precision, test_prec_op = tf.metrics.precision_at_thresholds(
        true_confidence, pred_confidence_sigmoid, [obj_thresh])
    recall, recall_op = tf.metrics.recall_at_thresholds(
        true_confidence, pred_confidence_sigmoid, [obj_thresh])
    test_recall, test_recall_op = tf.metrics.recall_at_thresholds(
        true_confidence, pred_confidence_sigmoid, [obj_thresh])
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        """ must save the bn paramter! """
        var_list = tf.global_variables() + tf.local_variables(
        )  # list(set(tf.trainable_variables() + [g for g in tf.global_variables() if 'moving_' in g.name]))
        saver = tf.train.Saver(var_list)

        # init the model and restore the pre-train weight
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer()
                 )  # NOTE the accuracy must init local variable
        restore_ckpt(sess, var_list, pre_ckpt)
        # define the log and saver
        subdir = os.path.join(
            log_dir, datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S'))
        train_writer = tf.summary.FileWriter(subdir, graph=sess.graph)
        write_arguments_to_file(args, os.path.join(subdir, 'arguments.txt'))
        tf.summary.scalar('total_loss', total_loss)
        tf.summary.scalar('obj_loss', obj_loss)
        tf.summary.scalar('noobj_loss', noobj_loss)
        tf.summary.scalar('mse_loss', xy_loss + wh_loss)
        tf.summary.scalar('class_loss', cls_loss)
        tf.summary.scalar('leraning_rate', current_learning_rate)
        tf.summary.scalar('precision', precision[0])
        tf.summary.scalar('recall', recall[0])
        merged = tf.summary.merge_all()
        t_prec_summary = tf.summary.scalar('test_precision', test_precision[0])
        t_recall_summary = tf.summary.scalar('test_recall', test_recall[0])

        try:
            for i in range(max_nrof_epochs):
                with tqdm(total=helper.epoch_step,
                          bar_format=
                          '{n_fmt}/{total_fmt} |{bar}| {rate_fmt}{postfix}',
                          unit=' batch',
                          dynamic_ncols=True) as t:
                    for j in range(helper.epoch_step):
                        if j % 30 == 0:
                            summary1, summary2, _, _, step_cnt = sess.run(
                                [
                                    t_prec_summary, t_recall_summary,
                                    test_recall_op, test_prec_op, global_steps
                                ],
                                feed_dict={training_control: False})
                            train_writer.add_summary(summary1, step_cnt)
                            train_writer.add_summary(summary2, step_cnt)
                        else:
                            summary, _, total_l, prec, _, _, lr, step_cnt = sess.run(
                                [
                                    merged, train_op, total_loss, precision,
                                    prec_op, recall_op, current_learning_rate,
                                    global_steps
                                ])
                            t.set_postfix(loss='{:<5.3f}'.format(total_l),
                                          prec='{:<4.2f}%'.format(prec[0] *
                                                                  100),
                                          lr='{:f}'.format(lr))
                            train_writer.add_summary(summary, step_cnt)
                        t.update()
            saver.save(sess,
                       save_path=os.path.join(subdir, 'model.ckpt'),
                       global_step=global_steps)
            print('save over')
        except KeyboardInterrupt as e:
            saver.save(sess,
                       save_path=os.path.join(subdir, 'model.ckpt'),
                       global_step=global_steps)
            print('save over')
Esempio n. 53
0
 def testLocalVariableNotInAllVariables(self):
     with self.test_session():
         with tf.variable_scope('A'):
             a = tf.contrib.framework.local_variable(0)
             self.assertFalse(a in tf.all_variables())
             self.assertTrue(a in tf.local_variables())
Esempio n. 54
0
 def __init__(self,
              experiment,
              aggregator,
              dev_tuples,
              optimizer,
              optimizer_args,
              learning_rate,
              learning_rate_args,
              regularizations=(-1., -1.),
              trace=False):
     """ Full graph (training + evaluation) constructor.
 Args:
   experiment         Experiment instance to use
   aggregator         Aggregator instance to use
   dev_tuples         Tuple of devices (i.e. tuples of strings (job name, task ID, device type, device ID)) for (parameter server, each workers' inference/loss/gradient computation, evaluator)
   optimizer          Optimizer name to use
   optimizer_args     Additional optimizer key-value arguments
   learning_rate      Learning rate name to use
   learning_rate_args Additional learning rate key-value arguments
   regularizations    Pair of (l1, l2) regularization values, non-positive values for no regularization
   trace              Whether to add trace prints for every important step of the computations
 """
     # Tuple extraction and device name reconstruction
     ps_tuple, wk_tuples, ev_tuple = dev_tuples
     ps_device = tools.device_from_tuple(*ps_tuple)
     wk_jobs = {}  # Map job -> taskid -> list of pairs of (devtype, devid)
     for job, taskid, devtype, devid in wk_tuples:
         if job in wk_jobs:
             wk_tasks = wk_jobs[job]
             if taskid in wk_tasks:
                 wk_tasks[taskid].append((devtype, devid))
             else:
                 wk_tasks[taskid] = [(devtype, devid)]
         else:
             wk_jobs[job] = {taskid: [(devtype, devid)]}
     # Graph building
     graph = tf.Graph()
     with graph.as_default():
         with tf.name_scope("ps/"):
             with tf.device(ps_device):
                 # Instantiate global step counter, optimizer and learning rate
                 global_step = tf.train.create_global_step()
                 learning_rate = build(learning_rates,
                                       "learning rate decay",
                                       learning_rate,
                                       learning_rate_args,
                                       global_step=global_step)
                 optimizer = build(optimizers,
                                   "optimizer",
                                   optimizer,
                                   optimizer_args,
                                   learning_rate=learning_rate)
                 tf.summary.scalar("learning_rate", learning_rate)
                 # Create workers' gradient computation
                 totlosses = [
                 ]  # List of losses, for summary (and printing) only
                 gradients = [
                 ]  # List of gradients, one per non-Byzantine worker
                 flatmap = None  # Flat map used to flatten the gradients coherently
                 with tf.name_scope("workers/"):
                     for job, wk_tasks in wk_jobs.items():
                         for taskid, models in wk_tasks.items():
                             device_dataset = tools.device_from_tuple(
                                 job, taskid, "CPU", "*")
                             device_models = [
                                 replica_device_setter(
                                     ps_device,
                                     tools.device_from_tuple(
                                         job, taskid, devtype, devid))
                                 for devtype, devid in models
                             ]
                             # Compute losses
                             losses = experiment.losses(device_dataset,
                                                        device_models,
                                                        trace=trace)
                             totlosses += losses
                             # Compute gradients
                             for i in range(len(device_models)):
                                 with tf.device(device_models[i]):
                                     loss = losses[i]
                                     for norm in [1, 2]:
                                         strength = regularizations[
                                             norm -
                                             1]  # 'norm - 1' is just a basic numbering trick...
                                         if strength > 0.:
                                             loss = loss + strength * regularization(
                                                 norm)
                                     if trace:
                                         loss = tools.trace_graph(
                                             loss, "Worker " +
                                             str(len(gradients)) +
                                             ": loss computation")
                                     grad_vars = optimizer.compute_gradients(
                                         loss)
                                     if flatmap is None:
                                         gradient, flatmap = flatten(
                                             grad_vars)
                                     else:
                                         gradient = flatten(
                                             grad_vars, flatmap)
                                     if trace:
                                         gradient = tools.trace_graph(
                                             gradient, "Worker " +
                                             str(len(gradients)) +
                                             ": gradient computation")
                                     gradients.append(gradient)
                 total_loss = tf.add_n(totlosses, name="total_loss")
                 tools.info(
                     "Created workers' dataset, inference, loss and gradient computation nodes"
                 )
                 # Aggregate and apply the workers' gradients
                 with tf.name_scope("GAR"):
                     time1 = time.time()
                     aggregated = aggregator.aggregate(gradients)
                     time2 = time.time()
                     #print("ms=$$$$$$$$$$$$$$$$$$$$$$",(time2-time1)*1000)
                     if trace:
                         aggregated = tools.trace_graph(
                             aggregated,
                             "Master: aggregated gradient computation")
                 apply_op = optimizer.apply_gradients(
                     inflate(aggregated, mapflat(flatmap)),
                     global_step=global_step)
                 if trace:
                     apply_op = tools.trace_graph(
                         apply_op,
                         "Master: aggregated gradient application")
                 tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, apply_op)
                 tools.info(
                     "Created parameter server's gradient aggregation and application nodes"
                 )
                 # Create accuracy computation
                 with tf.name_scope("eval/"):
                     device_dataset = tools.device_from_tuple(
                         ev_tuple[0], ev_tuple[1], "CPU", "*")
                     device_model = tools.device_from_tuple(*ev_tuple)
                     accuracy_tns = experiment.accuracy(
                         device_dataset,
                         [replica_device_setter(ps_device, device_model)],
                         trace=trace)
                 for key, val in accuracy_tns.items():
                     tf.add_to_collection(
                         tf.GraphKeys.SUMMARIES,
                         tf.summary.scalar("eval-" + key, val))
                 tools.info(
                     "Created evaluator's dataset, inference and accuracy computation nodes"
                 )
                 # Global summary protocol buffer
                 summary_tn = tf.summary.merge(
                     list(set(tf.get_collection(tf.GraphKeys.SUMMARIES))))
                 # Full initialization operation
                 rsrc_init_ops = []
                 for resource in tf.get_collection(tf.GraphKeys.RESOURCES):
                     rsrc_init_ops.append(resource.initializer)
                 for resource in tf.get_collection(
                         tf.GraphKeys.LOCAL_RESOURCES):
                     rsrc_init_ops.append(resource.initializer)
                 init_op = tf.group(
                     tf.variables_initializer(tf.global_variables() +
                                              tf.local_variables()),
                     tf.tables_initializer(), *rsrc_init_ops)
                 # Build the training operation
                 with tf.control_dependencies(
                         tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                     train_tn = tf.identity(total_loss, name="train_tn")
     # Finalization
     self.graph = graph
     self.step = global_step
     self.rate = learning_rate
     self.optimizer = optimizer
     self.total_loss = total_loss
     self.summary_tn = summary_tn
     self.init_op = init_op
     self.train_tn = train_tn
     self.eval_tns = accuracy_tns
Esempio n. 55
0
def train_net(Net,
              training_data,
              base_lr,
              loss_weight,
              train_mode,
              num_epochs=[1, None, None],
              load_model=False,
              load_filename=None,
              save_model=False,
              save_filename=None,
              num_iter_to_save=10000,
              gpu_memory_fraction=1):
    images = []
    labels = []
    tasks = ['cls', 'bbx', 'pts']
    shape = 12
    if Net.__name__ == 'RNet':
        shape = 24
    elif Net.__name__ == 'ONet':
        shape = 48
    for index in range(train_mode):
        image, label = inputs(filename=training_data[index],
                              batch_size=FLAGS.batch_size,
                              num_epochs=num_epochs[index],
                              label_type=tasks[index],
                              shape=shape)
        images.append(image)
        labels.append(label)
    while len(images) is not 3:
        images.append(tf.placeholder(tf.float32, [None, shape, shape, 3]))
        labels.append(tf.placeholder(tf.float32))
    net = Net((('cls', images[0]), ('bbx', images[1]), ('pts', images[2])))

    print('all trainable variables:')
    all_vars = tf.get_collection(key=tf.GraphKeys.TRAINABLE_VARIABLES)
    for var in all_vars:
        print(var)

    print('all local variable:')
    local_variables = tf.local_variables()
    for l_v in local_variables:
        print(l_v.name)

    prefix = str(all_vars[0].name[0:5])
    out_put = net.get_all_output()
    cls_output = tf.reshape(out_put[0], [-1, 2])
    bbx_output = tf.reshape(out_put[1], [-1, 4])
    pts_output = tf.reshape(out_put[2], [-1, 10])

    # cls loss
    softmax_loss = loss_weight[0]* \
                   tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels[0],
                                                                          logits=cls_output))
    weight_losses_cls = net.get_weight_decay()['cls']
    losses_cls = softmax_loss + tf.add_n(weight_losses_cls)

    # bbx loss
    square_bbx_loss = loss_weight[1]* \
                      tf.reduce_mean(tf.squared_difference(bbx_output, labels[1]))
    weight_losses_bbx = net.get_weight_decay()['bbx']
    losses_bbx = square_bbx_loss + tf.add_n(weight_losses_bbx)

    # pts loss
    square_pts_loss = loss_weight[2]* \
                      tf.reduce_mean(tf.squared_difference(pts_output, labels[2]))
    weight_losses_pts = net.get_weight_decay()['pts']
    losses_pts = square_pts_loss + tf.add_n(weight_losses_pts)

    global_step_cls = tf.Variable(1, name='global_step_cls', trainable=False)
    global_step_bbx = tf.Variable(1, name='global_step_bbx', trainable=False)
    global_step_pts = tf.Variable(1, name='global_step_pts', trainable=False)

    train_cls = tf.train.AdamOptimizer(learning_rate=base_lr).minimize(
        losses_cls, global_step=global_step_cls)
    train_bbx = tf.train.AdamOptimizer(learning_rate=base_lr).minimize(
        losses_bbx, global_step=global_step_bbx)
    train_pts = tf.train.AdamOptimizer(learning_rate=base_lr).minimize(
        losses_pts, global_step=global_step_pts)

    # summary_writer = tf.summary.FileWriter('./tensorflow_logs', graph=tf.get_default_graph())

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction
    config.gpu_options.allow_growth = True

    loss_agg_cls = [0]
    loss_agg_bbx = [0]
    loss_agg_pts = [0]
    step_value = [1, 1, 1]
    with tf.Session(config=config) as sess:
        sess.run(init_op)
        saver = tf.train.Saver(max_to_keep=200000)
        if load_model == 2:
            saver.restore(sess, load_filename)
        elif load_model == 1:
            net.load(load_filename, sess, prefix)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        start_time = time.time()
        try:
            while not coord.should_stop():
                choic = np.random.randint(0, train_mode)
                if choic == 0:
                    _, loss_value_cls, step_value[0] = sess.run(
                        [train_cls, softmax_loss, global_step_cls])
                    loss_agg_cls.append(loss_value_cls)
                elif choic == 1:
                    _, loss_value_bbx, step_value[1] = sess.run(
                        [train_bbx, square_bbx_loss, global_step_bbx])
                    loss_agg_bbx.append(loss_value_bbx)
                else:
                    _, loss_value_pts, step_value[2] = sess.run(
                        [train_pts, square_pts_loss, global_step_pts])
                    loss_agg_pts.append(loss_value_pts)

                if sum(step_value) % (100 * train_mode) == 0:
                    agg_cls = sum(loss_agg_cls) / len(loss_agg_cls)
                    agg_bbx = sum(loss_agg_bbx) / len(loss_agg_bbx)
                    agg_pts = sum(loss_agg_pts) / len(loss_agg_pts)
                    print('Step %d for cls: loss = %.5f' %
                          (step_value[0], agg_cls),
                          end='. ')
                    print('Step %d for bbx: loss = %.5f' %
                          (step_value[1], agg_bbx),
                          end='. ')
                    print('Step %d for pts: loss = %.5f' %
                          (step_value[2], agg_pts))
                    loss_agg_cls = [0]
                    loss_agg_bbx = [0]
                    loss_agg_pts = [0]
                if step_value[0] > 600:
                    break

                if save_model and (step_value[0] % num_iter_to_save == 0):
                    saver.save(sess, save_filename, global_step=step_value[0])

        except tf.errors.OutOfRangeError:
            print('Done training for %d epochs, %d steps.' %
                  (num_epochs[0], step_value[0]))
        finally:
            coord.request_stop()

        coord.join(threads)
Esempio n. 56
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  processors = {
      "autocomplete": AutoCompleteProcessor,
  }

  tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                FLAGS.init_checkpoint)

  if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict and not FLAGS.do_data:
    raise ValueError(
        "At least one of `do_data`, `do_train`, `do_eval` or `do_predict' must be True.")

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  tf.gfile.MakeDirs(FLAGS.output_dir)

  task_name = FLAGS.task_name.lower()

  if task_name not in processors:
    raise ValueError("Task not found: %s" % (task_name))

  processor = processors[task_name]()

  label_list = processor.get_labels()

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  train_examples_sz = processor.train_sz
  num_train_steps = int(
      train_examples_sz / FLAGS.batch_size * FLAGS.num_train_epochs)
  num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

  model_fn = model_fn_builder(
      bert_config=bert_config,
      num_labels=len(label_list),
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=False,
      use_one_hot_embeddings=False)

  estimator = tf.estimator.Estimator(model_fn=model_fn)
  
  if FLAGS.do_data:
    file_based_convert_examples_to_features(
        processor, label_list, FLAGS.max_seq_length, tokenizer, FLAGS.output_dir)
    reader = tf.TFRecordReader()
    filename_queue = tf.train.string_input_producer([os.path.join(FLAGS.output_dir, "train.tf_record")], num_epochs=1, shuffle=True)
    with tf.Session() as sess:
      sess.run(
          tf.variables_initializer(
              tf.global_variables() + tf.local_variables()
          )
      )
      # Start queue runners
      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess=sess, coord=coord)
      key, record_string = reader.read(filename_queue)
      record_string = sess.run(record_string)
      data_record = _extract_examples(record_string, FLAGS.batch_size, FLAGS.max_seq_length)
      examples = sess.run(data_record)
      for k, f in examples.items():
        if hasattr(f, "shape"):
          print("feature {} has content {} with shape {}".format(k, f, f.shape))
        else:
          print("feature {} has content {}".format(k, f))

  if FLAGS.do_train:
    train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
    tf.logging.info("***** Running training *****")
    tf.logging.info("  Num examples = %d", train_examples_sz)
    tf.logging.info("  Batch size = %d", FLAGS.batch_size)
    tf.logging.info("  Num steps = %d", num_train_steps)
    train_input_fn = file_based_input_fn_builder(
        input_file=train_file,
        batch_size=FLAGS.batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

  if FLAGS.do_eval:
    num_actual_eval_examples = processor.dev_sz
    eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Batch size = %d", FLAGS.batch_size)

    # This tells the estimator to run through the entire set.
    eval_steps = None
    # However, if running eval on the TPU, you will need to specify the
    eval_drop_remainder = False
    eval_input_fn = file_based_input_fn_builder(
        input_file=eval_file,
        batch_size = FLAGS.batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=eval_drop_remainder)

    result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
    results = [res for res in estimator.predict(input_fn=eval_input_fn)]
    processor.evaluate(results, processor.devs)
    output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
    with tf.gfile.GFile(output_eval_file, "w") as writer:
      tf.logging.info("***** Eval results *****")
      for key in sorted(result.keys()):
        tf.logging.info("  %s = %s", key, str(result[key]))
        writer.write("%s = %s\n" % (key, str(result[key])))

  if FLAGS.do_predict:
    predict_examples, num_actual_predict_examples = processor.get_test_examples(FLAGS.data_dir)
    predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")

    tf.logging.info("***** Running prediction*****")
    tf.logging.info("  Batch size = %d", FLAGS.batch_size)

    predict_drop_remainder = False
    predict_input_fn = file_based_input_fn_builder(
        input_file=predict_file,
        batch_size = FLAGS.batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=predict_drop_remainder)

    result = estimator.predict(input_fn=predict_input_fn)
    processor.evaluate([i for i in result], data = processor.tests)
    output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
    results = []
    with tf.gfile.GFile(output_predict_file, "w") as writer:
      num_written_lines = 0
      tf.logging.info("***** Predict results *****")
      for (i, prediction) in enumerate(result):
        results.append(prediction)
        probabilities = prediction["probabilities"]
        if i >= num_actual_predict_examples:
          break
        output_line = "\t".join(
            str(class_probability)
            for class_probability in probabilities) + "\n"
        writer.write(output_line)
        num_written_lines += 1
    processor.evaluate(results, processor.tests)
Esempio n. 57
0
 def testLocalVariableNotInVariablesToRestore(self):
   with self.test_session():
     with tf.variable_scope('A'):
       a = tf.contrib.framework.local_variable(0)
       self.assertFalse(a in tf.contrib.framework.get_variables_to_restore())
       self.assertTrue(a in tf.local_variables())
Esempio n. 58
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filelist',
                        '-t',
                        help='Path to training set ground truth (.txt)',
                        required=True)
    parser.add_argument('--filelist_val',
                        '-v',
                        help='Path to validation set ground truth (.txt)',
                        required=True)
    parser.add_argument('--load_ckpt',
                        '-l',
                        help='Path to a check point file for load')
    parser.add_argument(
        '--save_folder',
        '-s',
        help='Path to folder for saving check points and summary',
        required=True)
    parser.add_argument('--model', '-m', help='Model to use', required=True)
    parser.add_argument('--setting',
                        '-x',
                        help='Setting to use',
                        required=True)
    parser.add_argument(
        '--epochs',
        help='Number of training epochs (default defined in setting)',
        type=int)
    parser.add_argument('--batch_size',
                        help='Batch size (default defined in setting)',
                        type=int)
    parser.add_argument(
        '--log',
        help=
        'Log to FILE in save folder; use - for stdout (default is log.txt)',
        metavar='FILE',
        default='log.txt')
    parser.add_argument('--no_timestamp_folder',
                        help='Dont save to timestamp folder',
                        action='store_true')
    parser.add_argument('--no_code_backup',
                        help='Dont backup code',
                        action='store_true')
    args = parser.parse_args()

    root_folder = args.save_folder
    if not os.path.exists(root_folder):
        os.makedirs(root_folder)

    if args.log != '-':
        sys.stdout = open(os.path.join(root_folder, args.log), 'w')

    print('PID:', os.getpid())

    print(args)

    print(os.path.join(os.path.dirname(__file__), 'option'))
    sys.path.append(os.path.join(os.path.dirname(__file__), 'option'))
    model = importlib.import_module(args.model)
    setting_path = os.path.join(os.path.dirname(__file__), args.model)
    sys.path.append(setting_path)
    setting = importlib.import_module(args.setting)

    num_epochs = args.epochs or setting.num_epochs
    batch_size = args.batch_size or setting.batch_size
    sample_num = setting.sample_num
    sample_num_extra = setting.sample_num_extra
    step_val = setting.step_val
    label_weights_list = setting.label_weights
    rotation_range = setting.rotation_range
    rotation_range_val = setting.rotation_range_val
    scaling_range = setting.scaling_range
    scaling_range_val = setting.scaling_range_val
    jitter = setting.jitter
    jitter_val = setting.jitter_val

    # Prepare inputs
    print('{}-Preparing datasets...'.format(datetime.now()))
    is_list_of_h5_list = not data_utils.is_h5_list(args.filelist)
    if is_list_of_h5_list:
        seg_list = data_utils.load_seg_list(args.filelist)
        seg_list_idx = 0
        filelist_train = seg_list[seg_list_idx]
        seg_list_idx = seg_list_idx + 1
    else:
        filelist_train = args.filelist
    data_train, _, data_num_train, label_train, _, direc_train, \
    data_train_extra, data_num_train_extra, label_train_extra, direc_train_extra = data_utils.load_seg(filelist_train)
    data_val, _, data_num_val, label_val, _, direc_val,\
    data_val_extra, data_num_val_extra, label_val_extra, direc_val_extra = data_utils.load_seg(args.filelist_val)

    # shuffle
    data_train, data_num_train, label_train, direc_train,\
    data_train_extra, data_num_train_extra, label_train_extra, direc_train_extra = \
        data_utils.grouped_shuffle([data_train, data_num_train, label_train, direc_train,
                                    data_train_extra, data_num_train_extra, label_train_extra, direc_train_extra])

    num_train = data_train.shape[0]
    point_num = data_train.shape[1]
    point_num_extra = data_train_extra.shape[1]
    num_val = data_val.shape[0]
    print('{}-{:d}/{:d} training/validation samples.'.format(
        datetime.now(), num_train, num_val))
    batch_num = (num_train * num_epochs + batch_size - 1) // batch_size
    print('{}-{:d} training batches.'.format(datetime.now(), batch_num))
    batch_num_val = math.ceil(num_val / batch_size)
    print('{}-{:d} testing batches per test.'.format(datetime.now(),
                                                     batch_num_val))

    ######################################################################
    # Placeholders
    indices = tf.placeholder(tf.int32, shape=(None, None, 2), name="indices")
    indices_extra = tf.placeholder(tf.int32,
                                   shape=(None, None, 2),
                                   name="indices_extra")
    xforms = tf.placeholder(tf.float32, shape=(None, 3, 3), name="xforms")
    rotations = tf.placeholder(tf.float32,
                               shape=(None, 3, 3),
                               name="rotations")
    jitter_range = tf.placeholder(tf.float32, shape=(1), name="jitter_range")
    global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = tf.placeholder(tf.bool, name='is_training')

    sampled_indices = tf.placeholder(tf.int32,
                                     shape=(None, None, 2),
                                     name="sampled_indices")

    pts_fts = tf.placeholder(tf.float32,
                             shape=(None, point_num, setting.data_dim),
                             name='pts_fts')
    direc_pl = tf.placeholder(tf.float32,
                              shape=(None, point_num, 6),
                              name='direction')
    labels_seg = tf.placeholder(tf.int64,
                                shape=(None, point_num),
                                name='labels_seg')
    labels_weights = tf.placeholder(tf.float32,
                                    shape=(None, point_num),
                                    name='labels_weights')

    pts_fts_extra = tf.placeholder(tf.float32,
                                   shape=(None, point_num_extra,
                                          setting.data_dim),
                                   name='pts_fts')
    direc_pl_extra = tf.placeholder(tf.float32,
                                    shape=(None, point_num_extra, 6),
                                    name='direction')
    labels_seg_extra = tf.placeholder(tf.int64,
                                      shape=(None, point_num_extra),
                                      name='labels_seg')

    ######################################################################
    pts_fts_sampled = tf.gather_nd(pts_fts,
                                   indices=indices,
                                   name='pts_fts_sampled')
    pts_fts_extra_sampled = tf.gather_nd(pts_fts_extra,
                                         indices=indices_extra,
                                         name='pts_fts_extra_sampled')
    pts_fts_sampled = tf.concat([pts_fts_sampled, pts_fts_extra_sampled],
                                axis=1)
    features_augmented = None
    if setting.data_dim > 3:
        points_sampled, features_sampled = tf.split(
            pts_fts_sampled, [3, setting.data_dim - 3],
            axis=-1,
            name='split_points_features')
        if setting.use_extra_features:
            if setting.with_normal_feature:
                if setting.data_dim < 6:
                    print('Only 3D normals are supported!')
                    exit()
                elif setting.data_dim == 6:
                    features_augmented = pf.augment(features_sampled,
                                                    rotations)
                else:
                    normals, rest = tf.split(features_sampled,
                                             [3, setting.data_dim - 6])
                    normals_augmented = pf.augment(normals, rotations)
                    features_augmented = tf.concat([normals_augmented, rest],
                                                   axis=-1)
            else:
                features_augmented = features_sampled
    else:
        points_sampled = pts_fts_sampled
    points_augmented = pf.augment(points_sampled, xforms, jitter_range)

    dir_sampled = tf.gather_nd(direc_pl,
                               indices=indices,
                               name='direction_sampled')
    dir_sampled_extra = tf.gather_nd(direc_pl_extra,
                                     indices=indices_extra,
                                     name='direction_sampled_extra')
    dir_sampled = tf.concat([dir_sampled, dir_sampled_extra], axis=1)
    dir_augmented = pf.dir_augment(dir_sampled, rotations)

    labels_sampled = tf.gather_nd(labels_seg,
                                  indices=indices,
                                  name='labels_sampled')
    labels_sampled_extra = tf.gather_nd(labels_seg_extra,
                                        indices=indices_extra,
                                        name='labels_sampled_extra')
    labels_weights_sampled = tf.gather_nd(labels_weights,
                                          indices=indices,
                                          name='labels_weight_sampled')

    bn_exp_op = tf.train.exponential_decay(0.5,
                                           global_step,
                                           setting.decay_steps,
                                           0.5,
                                           staircase=True)
    bn_clip_op = tf.minimum(1 - bn_exp_op, 0.99)

    net = model.get_model(tf.concat([points_augmented, features_augmented],
                                    axis=-1),
                          dir_augmented,
                          is_training,
                          bn_decay=bn_clip_op)
    logits = tf.gather_nd(net, indices=sampled_indices)
    probs = tf.nn.softmax(logits, name='probs')
    predictions = tf.argmax(probs, axis=-1, name='predictions')

    loss_op = tf.losses.sparse_softmax_cross_entropy(
        labels=labels_sampled, logits=logits, weights=labels_weights_sampled)

    with tf.name_scope('metrics'):
        loss_mean_op, loss_mean_update_op = tf.metrics.mean(loss_op)
        t_1_acc_op, t_1_acc_update_op = tf.metrics.accuracy(
            labels_sampled, predictions, weights=labels_weights_sampled)
        t_1_per_class_acc_op, t_1_per_class_acc_update_op = \
            tf.metrics.mean_per_class_accuracy(labels_sampled, predictions, setting.num_class,
                                               weights=labels_weights_sampled)
    reset_metrics_op = tf.variables_initializer([
        var for var in tf.local_variables()
        if var.name.split('/')[0] == 'metrics'
    ])

    _ = tf.summary.scalar('loss/train',
                          tensor=loss_mean_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_acc/train',
                          tensor=t_1_acc_op,
                          collections=['train'])
    _ = tf.summary.scalar('t_1_per_class_acc/train',
                          tensor=t_1_per_class_acc_op,
                          collections=['train'])

    _ = tf.summary.scalar('loss/val', tensor=loss_mean_op, collections=['val'])
    _ = tf.summary.scalar('t_1_acc/val',
                          tensor=t_1_acc_op,
                          collections=['val'])
    _ = tf.summary.scalar('t_1_per_class_acc/val',
                          tensor=t_1_per_class_acc_op,
                          collections=['val'])

    lr_exp_op = tf.train.exponential_decay(setting.learning_rate_base,
                                           global_step,
                                           setting.decay_steps,
                                           setting.decay_rate,
                                           staircase=True)
    lr_clip_op = tf.maximum(lr_exp_op, setting.learning_rate_min)
    _ = tf.summary.scalar('learning_rate',
                          tensor=lr_clip_op,
                          collections=['train'])
    reg_loss = setting.weight_decay * tf.losses.get_regularization_loss()
    if setting.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate=lr_clip_op,
                                           epsilon=setting.epsilon)
    elif setting.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr_clip_op,
                                               momentum=setting.momentum,
                                               use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss_op + reg_loss,
                                      global_step=global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver(max_to_keep=None)

    folder_ckpt = os.path.join(root_folder, 'ckpts')
    if not os.path.exists(folder_ckpt):
        os.makedirs(folder_ckpt)

    folder_summary = os.path.join(root_folder, 'summary')
    if not os.path.exists(folder_summary):
        os.makedirs(folder_summary)

    parameter_num = np.sum(
        [np.prod(v.shape.as_list()) for v in tf.trainable_variables()])
    print('{}-Parameter number: {:d}.'.format(datetime.now(), parameter_num))

    with tf.Session() as sess:
        summaries_op = tf.summary.merge_all('train')
        summaries_val_op = tf.summary.merge_all('val')
        summary_writer = tf.summary.FileWriter(folder_summary, sess.graph)

        sess.run(init_op)

        # Load the model
        if args.load_ckpt is not None:
            saver.restore(sess, args.load_ckpt)
            print('{}-Checkpoint loaded from {}!'.format(
                datetime.now(), args.load_ckpt))
        else:
            latest_ckpt = tf.train.latest_checkpoint(folder_ckpt)
            if latest_ckpt:
                print('{}-Found checkpoint {}'.format(datetime.now(),
                                                      latest_ckpt))
                saver.restore(sess, latest_ckpt)
                print('{}-Checkpoint loaded from {} (Iter {})'.format(
                    datetime.now(), latest_ckpt, sess.run(global_step)))

        for batch_idx_train in range(batch_num):
            ######################################################################
            # Validation
            if (batch_idx_train % step_val == 0 and (batch_idx_train != 0 or args.load_ckpt is not None)) \
                    or batch_idx_train == batch_num - 1:
                filename_ckpt = os.path.join(folder_ckpt, 'iter')
                saver.save(sess, filename_ckpt, global_step=global_step)
                print('{}-Checkpoint saved to {}!'.format(
                    datetime.now(), filename_ckpt))

                sess.run(reset_metrics_op)
                for batch_val_idx in range(batch_num_val):
                    start_idx = batch_size * batch_val_idx
                    end_idx = min(start_idx + batch_size, num_val)
                    batch_size_val = end_idx - start_idx
                    points_batch = data_val[start_idx:end_idx, ...]
                    points_num_batch = data_num_val[start_idx:end_idx, ...]
                    labels_batch = label_val[start_idx:end_idx, ...]
                    weights_batch = np.array(label_weights_list)[labels_batch]
                    direc_batch = direc_val[start_idx:end_idx, ...]

                    points_batch_extra = data_val_extra[start_idx:end_idx, ...]
                    points_num_batch_extra = data_num_val_extra[
                        start_idx:end_idx, ...]
                    labels_batch_extra = label_val_extra[start_idx:end_idx,
                                                         ...]
                    direc_batch_extra = direc_val_extra[start_idx:end_idx, ...]

                    ind0 = np.tile(
                        np.reshape(np.arange(batch_size_val), [-1, 1, 1]),
                        [1, sample_num, 1])
                    ind1 = np.tile(
                        np.reshape(np.arange(sample_num), [1, -1, 1]),
                        [batch_size_val, 1, 1])
                    sampled_indices_batch = np.concatenate([ind0, ind1],
                                                           axis=-1)

                    xforms_np, rotations_np = pf.get_xforms(
                        batch_size_val,
                        rotation_range=rotation_range_val,
                        scaling_range=scaling_range_val,
                        order=setting.rotation_order)
                    sess.run(
                        [
                            loss_mean_update_op, t_1_acc_update_op,
                            t_1_per_class_acc_update_op
                        ],
                        feed_dict={
                            pts_fts:
                            points_batch,
                            direc_pl:
                            direc_batch,
                            indices:
                            pf.get_indices(batch_size_val, sample_num,
                                           points_num_batch),
                            pts_fts_extra:
                            points_batch_extra,
                            direc_pl_extra:
                            direc_batch_extra,
                            indices_extra:
                            pf.get_indices(batch_size_val, sample_num_extra,
                                           points_num_batch_extra),
                            xforms:
                            xforms_np,
                            rotations:
                            rotations_np,
                            jitter_range:
                            np.array([jitter_val]),
                            labels_seg:
                            labels_batch,
                            labels_seg_extra:
                            labels_batch_extra,
                            labels_weights:
                            weights_batch,
                            is_training:
                            False,
                            sampled_indices:
                            sampled_indices_batch,
                        })
                loss_val, t_1_acc_val, t_1_per_class_acc_val, summaries_val, step = sess.run(
                    [
                        loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                        summaries_val_op, global_step
                    ])
                summary_writer.add_summary(summaries_val, step)
                print(
                    '{}-[Val  ]-Average:      Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), loss_val, t_1_acc_val,
                            t_1_per_class_acc_val))
                sys.stdout.flush()
            ######################################################################

            ######################################################################
            # Training
            start_idx = (batch_size * batch_idx_train) % num_train
            end_idx = min(start_idx + batch_size, num_train)
            batch_size_train = end_idx - start_idx
            points_batch = data_train[start_idx:end_idx, ...]
            points_num_batch = data_num_train[start_idx:end_idx, ...]
            labels_batch = label_train[start_idx:end_idx, ...]
            weights_batch = np.array(label_weights_list)[labels_batch]
            direc_batch = direc_train[start_idx:end_idx, ...]

            points_batch_extra = data_train_extra[start_idx:end_idx, ...]
            points_num_batch_extra = data_num_train_extra[start_idx:end_idx,
                                                          ...]
            labels_batch_extra = label_train_extra[start_idx:end_idx, ...]
            direc_batch_extra = direc_train_extra[start_idx:end_idx, ...]

            if start_idx + batch_size_train == num_train:
                if is_list_of_h5_list:
                    filelist_train_prev = seg_list[(seg_list_idx - 1) %
                                                   len(seg_list)]
                    filelist_train = seg_list[seg_list_idx % len(seg_list)]
                    if filelist_train != filelist_train_prev:
                        data_train, _, data_num_train, label_train, _, direc_train, \
                        data_train_extra, data_num_train_extra, label_train_extra, direc_train_extra = \
                            data_utils.load_seg(filelist_train)
                        num_train = data_train.shape[0]
                    seg_list_idx = seg_list_idx + 1
                data_train, data_num_train, label_train, direc_train, \
                data_train_extra, data_num_train_extra, label_train_extra, direc_train_extra = \
                    data_utils.grouped_shuffle([data_train, data_num_train, label_train, direc_train,
                                                data_train_extra, data_num_train_extra, label_train_extra,
                                                direc_train_extra])

            offset = int(
                random.gauss(0, sample_num * setting.sample_num_variance))
            offset = max(offset, -sample_num * setting.sample_num_clip)
            offset = min(offset, sample_num * setting.sample_num_clip)
            sample_num_train = sample_num + offset
            sample_num_train_extra = int(sample_num_train / 2)
            xforms_np, rotations_np = pf.get_xforms(
                batch_size_train,
                rotation_range=rotation_range,
                scaling_range=scaling_range,
                order=setting.rotation_order)

            ind0 = np.tile(np.reshape(np.arange(batch_size_train), [-1, 1, 1]),
                           [1, sample_num_train, 1])
            ind1 = np.tile(np.reshape(np.arange(sample_num_train), [1, -1, 1]),
                           [batch_size_train, 1, 1])
            sampled_indices_batch = np.concatenate([ind0, ind1], axis=-1)

            sess.run(reset_metrics_op)
            sess.run(
                [
                    train_op, loss_mean_update_op, t_1_acc_update_op,
                    t_1_per_class_acc_update_op
                ],
                feed_dict={
                    pts_fts:
                    points_batch,
                    direc_pl:
                    direc_batch,
                    indices:
                    pf.get_indices(batch_size_train, sample_num_train,
                                   points_num_batch),
                    pts_fts_extra:
                    points_batch_extra,
                    direc_pl_extra:
                    direc_batch_extra,
                    indices_extra:
                    pf.get_indices(batch_size_train, sample_num_train_extra,
                                   points_num_batch_extra),
                    xforms:
                    xforms_np,
                    rotations:
                    rotations_np,
                    jitter_range:
                    np.array([jitter]),
                    labels_seg:
                    labels_batch,
                    labels_seg_extra:
                    labels_batch_extra,
                    labels_weights:
                    weights_batch,
                    is_training:
                    True,
                    sampled_indices:
                    sampled_indices_batch,
                })
            if batch_idx_train % 10 == 0:
                loss, t_1_acc, t_1_per_class_acc, summaries, step = sess.run([
                    loss_mean_op, t_1_acc_op, t_1_per_class_acc_op,
                    summaries_op, global_step
                ])
                summary_writer.add_summary(summaries, step)
                print(
                    '{}-[Train]-Iter: {:06d}  Loss: {:.4f}  T-1 Acc: {:.4f}  T-1 mAcc: {:.4f}'
                    .format(datetime.now(), step, loss, t_1_acc,
                            t_1_per_class_acc))
                sys.stdout.flush()
                ######################################################################
        print('{}-Done!'.format(datetime.now()))
Esempio n. 59
0
def train(train_op,
          logdir,
          metric_op=None,
          metric_collection_name=None,
          train_step_fn=train_step,
          train_step_kwargs=_USE_DEFAULT,
          log_every_n_steps=1,
          graph=None,
          master='',
          is_chief=True,
          global_step=None,
          number_of_samples=None,
          number_of_steps=None,
          number_of_epochs=None,
          batch_size=None,
          init_op=_USE_DEFAULT,
          init_feed_dict=None,
          local_init_op=_USE_DEFAULT,
          init_fn=None,
          ready_op=_USE_DEFAULT,
          summary_op=_USE_DEFAULT,
          save_summaries_secs=600,
          summary_writer=_USE_DEFAULT,
          startup_delay_steps=0,
          saver=None,
          save_interval_secs=600,
          sync_optimizer=None,
          session_config=None,
          trace_every_n_steps=None):
    """Runs a training loop using a TensorFlow supervisor.

    When the sync_optimizer is supplied, gradient updates are applied
    synchronously. Otherwise, gradient updates are applied asynchronous.

    Args:
      train_op: A `Tensor` that, when executed, will apply the gradients and
        return the loss value.

      metric_op: A `Tensor` that, when executed, will update the streaming_metrics ops.
      metric_collection_name: The name associated with the metric_op.
      logdir: The directory where training logs are written to. If None, model
        checkpoints and summaries will not be written.
      train_step_fn: The function to call in order to execute a single gradient
        step. The function must have take exactly four arguments: the current
        session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
      train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
        default, two `Boolean`, scalar ops called "should_stop" and "should_log"
        are provided.
      log_every_n_steps: The frequency, in terms of global steps, that the loss
        and global step and logged.
      graph: The graph to pass to the supervisor. If no graph is supplied the
        default graph is used.
      master: The address of the tensorflow master.
      is_chief: Specifies whether or not the training is being run by the primary
        replica during replica training.
      global_step: The `Tensor` representing the global step. If left as `None`,
        then slim.variables.get_or_create_global_step() is used.
      number_of_steps: The max number of gradient steps to take during training,
        as measured by 'global_step': training will stop if global_step is
        greater than 'number_of_steps'. If the value is left as None, training
        proceeds indefinitely.
      number_of_epochs: The total number of epochs per training.
      batch_size: The number of samples in each batch.
      init_op: The initialization operation. If left to its default value, then
        the session is initialized by calling `tf.global_variables_initializer()`.
      init_feed_dict: A feed dictionary to use when executing the `init_op`.
      local_init_op: The local initialization operation. If left to its default
        value, then the session is initialized by calling
        `tf.local_variables_initializer()` and `tf.tables_initializer()`.
      init_fn: An optional callable to be executed after `init_op` is called. The
        callable must accept one argument, the session being initialized.
      ready_op: Operation to check if the model is ready to use. If left to its
        default value, then the session checks for readiness by calling
        `tf.report_uninitialized_variables()`.
      summary_op: The summary operation.
      save_summaries_secs: How often, in seconds, to save summaries.
      summary_writer: `SummaryWriter` to use.  Can be `None`
        to indicate that no summaries should be written. If unset, we
        create a SummaryWriter.
      startup_delay_steps: The number of steps to wait for before beginning. Note
        that this must be 0 if a sync_optimizer is supplied.
      saver: Saver to save checkpoints. If None, a default one will be created
        and used.
      save_interval_secs: How often, in seconds, to save the model to `logdir`.
      sync_optimizer: an instance of tf.train.SyncReplicasOptimizer. If the
        argument is supplied, gradient updates will be synchronous. If left as
        `None`, gradient updates will be asynchronous.
      session_config: An instance of `tf.ConfigProto` that will be used to
        configure the `Session`. If left as `None`, the default will be used.
      trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
        and add it to the summaries every `trace_every_n_steps`. If None, no trace
        information will be produced or saved.

    Returns:
      the value of the loss function after training.

    Raises:
      ValueError: if `train_op` is empty or if `startup_delay_steps` is
        non-zero when `sync_optimizer` is supplied, if `number_of_steps` is
        negative, or if `trace_every_n_steps` is not `None` and no `logdir` is
        provided.
    """

    # Check if the calculation of some metrics is desired.
    if metric_op is not None:

        # Check the necessary requirements
        if metric_collection_name is None:
            raise ValueError('metric_collection_name must be fed and cannot be No')
        if number_of_samples is None:
            raise ValueError('number_of_samples must be fed and cannot be No')
        if number_of_steps is None:
            raise ValueError('number_of_steps must be fed and cannot be No')
        if number_of_epochs is None:
            raise ValueError('number_of_epochs must be fed and cannot be No')
        if batch_size is None:
            raise ValueError('batch_size must be fed and cannot be None')

    if train_op is None:
        raise ValueError('train_op cannot be None.')

    if logdir is None:
        if summary_op != _USE_DEFAULT:
            raise ValueError('Cannot provide summary_op because logdir=None')
        if saver is not None:
            raise ValueError('Cannot provide saver because logdir=None')
        if trace_every_n_steps is not None:
            raise ValueError('Cannot provide trace_every_n_steps because '
                             'logdir=None')

    if sync_optimizer is not None and startup_delay_steps > 0:
        raise ValueError(
            'startup_delay_steps must be zero when sync_optimizer is supplied.')

    if number_of_steps is not None and number_of_steps <= 0:
        raise ValueError(
            '`number_of_steps` must be either None or a positive number.')

    graph = graph or ops.get_default_graph()
    with graph.as_default():
        if global_step is None:
            global_step = variables.get_or_create_global_step()
        saver = saver or tf_saver.Saver()

        with ops.name_scope('init_ops'):
            if init_op == _USE_DEFAULT:
                init_op = tf_variables.global_variables_initializer()

            if ready_op == _USE_DEFAULT:
                ready_op = tf_variables.report_uninitialized_variables()

            if local_init_op == _USE_DEFAULT:
                local_init_op = control_flow_ops.group(
                    tf_variables.local_variables_initializer(),
                    lookup_ops.tables_initializer())

            if sync_optimizer is not None and isinstance(
                    sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
                with ops.control_dependencies([local_init_op] if local_init_op is
                not None else []):
                    if is_chief:
                        local_init_op = sync_optimizer.chief_init_op
                    else:
                        local_init_op = sync_optimizer.local_step_init_op
                ready_for_local_init_op = sync_optimizer.ready_for_local_init_op
            else:
                ready_for_local_init_op = None

        if summary_op == _USE_DEFAULT:
            summary_op = summary.merge_all()

        if summary_writer == _USE_DEFAULT:
            summary_writer = supervisor.Supervisor.USE_DEFAULT

        if is_chief and sync_optimizer is not None:
            if not isinstance(sync_optimizer,
                              (sync_replicas_optimizer.SyncReplicasOptimizer)):
                raise ValueError(
                    '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.')

            # Need to create these BEFORE the supervisor finalizes the graph:
            init_tokens_op = sync_optimizer.get_init_tokens_op()
            chief_queue_runner = sync_optimizer.get_chief_queue_runner()

        if train_step_kwargs == _USE_DEFAULT:
            with ops.name_scope('train_step'):
                train_step_kwargs = {}

                if number_of_steps:
                    should_stop_op = math_ops.greater_equal(global_step, number_of_steps)
                else:
                    should_stop_op = constant_op.constant(False)
                train_step_kwargs['should_stop'] = should_stop_op
                train_step_kwargs['should_log'] = math_ops.equal(
                    math_ops.mod(global_step, log_every_n_steps), 0)
                if is_chief and trace_every_n_steps is not None:
                    train_step_kwargs['should_trace'] = math_ops.equal(
                        math_ops.mod(global_step, trace_every_n_steps), 0)
                    train_step_kwargs['logdir'] = logdir
                if number_of_samples is not None and batch_size is not None:
                    train_step_kwargs['num_batches_per_epoch'] = int(number_of_samples / float(batch_size))
                if number_of_samples is not None and batch_size is not None:
                    train_step_kwargs['num_steps_per_epoch'] = int(number_of_steps / float(number_of_epochs))

        # If metric calculation is desired.
        if metric_op is not None:
            # The reset_op is defined for resetting the streaming_variables(streaming_acurracy,...)
            # The reason for defining it here is that the supervisor finalized the graph and the graph will be fixed after it.
            # By calling the reset_op in the train_step function, after each epoch the total & count variables will reset to zero.
            # This help to have the averaged accuracy per epoch which is useful to realize if we are getting to the highest accuracy in the training.
            stream_vars = [i for i in tf.local_variables() if i.name.split('/')[1] == metric_collection_name]
            reset_op = tf.variables_initializer(stream_vars)

    sv = supervisor.Supervisor(
        graph=graph,
        is_chief=is_chief,
        logdir=logdir,
        init_op=init_op,
        init_feed_dict=init_feed_dict,
        local_init_op=local_init_op,
        ready_for_local_init_op=ready_for_local_init_op,
        ready_op=ready_op,
        summary_op=summary_op,
        summary_writer=summary_writer,
        global_step=global_step,
        saver=saver,
        save_summaries_secs=save_summaries_secs,
        save_model_secs=save_interval_secs,
        init_fn=init_fn)

    if summary_writer is not None:
        train_step_kwargs['summary_writer'] = sv.summary_writer

    should_retry = True
    while should_retry:
        try:
            should_retry = False

            with sv.managed_session(
                    master, start_standard_services=False, config=session_config) as sess:

                logging.info('Starting Session.')
                if is_chief:
                    if logdir:
                        sv.start_standard_services(sess)
                elif startup_delay_steps > 0:
                    _wait_for_step(sess, global_step,
                                   min(startup_delay_steps, number_of_steps or
                                       sys.maxint))
                sv.start_queue_runners(sess)
                logging.info('Starting Queues.')
                if is_chief and sync_optimizer is not None:
                    sv.start_queue_runners(sess, [chief_queue_runner])
                    sess.run(init_tokens_op)
                try:
                    while not sv.should_stop():
                        if metric_op is not None:
                            total_loss, should_stop = train_step_fn(
                                sess, train_op, global_step, train_step_kwargs, metric_op, reset_op)
                        else:
                            total_loss, should_stop = train_step_fn(
                                sess, train_op, global_step, train_step_kwargs)
                        if should_stop:
                            logging.info('Stopping Training.')
                            break
                except errors.OutOfRangeError:
                    # OutOfRangeError is thrown when epoch limit per
                    # tf.train.limit_epochs is reached.
                    logging.info('Caught OutOfRangeError. Stopping Training.')
                if logdir and sv.is_chief:
                    logging.info('Finished training! Saving model to disk.')
                    sv.saver.save(sess, sv.save_path, global_step=sv.global_step)

        except errors.AbortedError:
            # Always re-run on AbortedError as it indicates a restart of one of the
            # distributed tensorflow servers.
            logging.info('Retrying training!')
            should_retry = True

    return total_loss
def main():
	# Configure
	config=tf.ConfigProto(log_device_placement=False)

	# Server Setup
	cluster_spec = {
  			'ps':['localhost:2222'],
  			'worker':['localhost:2223','localhost:2224']
  			} #allows this node know about all other nodes
	n_pss = len(cluster_spec['ps']) #the number of parameter servers
	n_workers = len(cluster_spec['worker']) #the number of worker nodes
	cluster = tf.train.ClusterSpec(cluster_spec) #allows this node know about all other nodes

	if FLAGS.job_name == 'ps': #checks if parameter server
		server = tf.train.Server(cluster,
					job_name="ps",
					task_index=FLAGS.task_index,
					config=config)
		server.join()
	else: #it must be a worker server
		is_chief = (FLAGS.task_index == 0) #checks if this is the chief node
		server = tf.train.Server(cluster,
					job_name="worker",
					task_index=FLAGS.task_index,
					config=config)
		# Graph
		with tf.device("/job:worker/replica:0/task:%d" % FLAGS.task_index):
			a = tf.Variable(tf.constant(0.,shape=[2]),dtype=tf.float32,
						collections=[tf.GraphKeys.LOCAL_VARIABLES])
			b = tf.Variable(tf.constant(0.,shape=[2]),dtype=tf.float32,
						collections=[tf.GraphKeys.LOCAL_VARIABLES])
			c=a+b

			local_step = tf.Variable(0,dtype=tf.int32,trainable=False,
						name='local_step',collections=['local_non_trainable'])

			target = tf.constant(100.,shape=[2],dtype=tf.float32)
			loss = tf.reduce_mean(tf.square(c-target))

			base_lr = .0001
			loptimizer = tf.train.AdamOptimizer(base_lr)
			# loptimizer = tf.train.GradientDescentOptimizer(base_lr)

			# SDAG (simplest case since all batches are the same)
			update_window = 5 # T: communication window
			grad_list = [] # the array to store the gradients through the communication window
			for t in range(update_window):
				if t != 0:
					#compute gradients only if the local opt was run
					with tf.control_dependencies([opt_local]): 
						grads, varss = zip(*loptimizer.compute_gradients( \
									loss,var_list=tf.local_variables()))
				else:
					grads, varss = zip(*loptimizer.compute_gradients( \
								loss,var_list=tf.local_variables()))
				#add gradients to the list
				grad_list.append(grads)
				#update local parameters
				opt_local = loptimizer.apply_gradients(zip(grads,varss),
							global_step=local_step)

			# averages updates before applying globally
			grads = tf.reduce_mean(grad_list,axis=0)
			grads = tuple([grads[i] for i in range(len(varss))])

			# add these variables created by local optimizer to local collection
			lopt_vars = add_global_variables_to_local_collection()

			# delete the variables from the global collection
			clear_global_collection()

		with tf.device(tf.train.replica_device_setter(ps_tasks=n_pss,
        	worker_device="/job:%s/task:%d" % (FLAGS.job_name,FLAGS.task_index))):

			global_step = tf.Variable(0,dtype=tf.int32,trainable=False,name='global_step')

			#create global variables and/or references
			local_to_global, global_to_local = create_global_variables(lopt_vars)

			optimizer = tf.train.AdamOptimizer(base_lr)
			# optimizer = tf.train.GradientDescentOptimizer(base_lr)
			optimizer1 = tf.train.SyncReplicasOptimizer(optimizer,
						replicas_to_aggregate=2,
						total_num_replicas=2)

			#apply the gradients to variables on ps
			opt = optimizer1.apply_gradients(
						zip(grads,[local_to_global[v] for v in varss])
						,global_step=global_step)

			with tf.control_dependencies([opt]):
				assign_locals = assign_global_to_local(global_to_local)

			# Grab global state before training so all workers have same initialization
			grab_global_init = assign_global_to_local(global_to_local)

			# Assigns local values to global ones for chief to execute
			assign_global = assign_local_to_global(local_to_global)

			# Initialized global step tokens
			init_tokens_op = optimizer1.get_init_tokens_op()

			# Init ops
			# gets step token
			local_init=optimizer1.local_step_init_op
			if is_chief:
				# fills token queue and gets token
				local_init = optimizer1.chief_init_op

			# indicates if variables are initialized
			ready_for_local_init = optimizer1.ready_for_local_init_op

			with tf.control_dependencies([local_init]):
				init_local = tf.variables_initializer(tf.local_variables() \
							+tf.get_collection('local_non_trainable')) #for local variables

			init = tf.global_variables_initializer() # must come after other init ops

		# Session
		sync_replicas_hook = optimizer1.make_session_run_hook(is_chief)
		stop_hook = tf.train.StopAtStepHook(last_step=10)
		chief_hooks = [sync_replicas_hook,stop_hook]
		scaff = tf.train.Scaffold(init_op=init,
					local_init_op=init_local,
					ready_for_local_init_op=ready_for_local_init)

		#Monitored Training Session
		sess = tf.train.MonitoredTrainingSession(master=server.target,
					is_chief=is_chief,
					config=config,
					scaffold=scaff,
					hooks=chief_hooks,
					stop_grace_period_secs=10)

		if is_chief:
			sess.run(assign_global) # Assigns chief's initial values to ps
			time.sleep(40) # grace period to wait on other workers before starting training

		# Train until hook stops session
		print('Starting training on worker %d'%FLAGS.task_index)
		sess.run(grab_global_init)


		# Train until hook stops session
		print('Starting training on worker %d'%FLAGS.task_index)
		while not sess.should_stop():
			_,_,r,gs,ls = sess.run([opt,assign_locals,c,global_step,local_step])
			# _,r,gs=sess.run([opt,c,global_step])
			print(r,gs,FLAGS.task_index)
			if is_chief: time.sleep(1)
			time.sleep(1)
		print('Done',FLAGS.task_index)

		time.sleep(10) #grace period to wait before closing session
		sess.close()
		print('Session from worker %d closed cleanly'%FLAGS.task_index)