コード例 #1
0
ファイル: __init__.py プロジェクト: imito/odin
def initialize_all_variables(vars=None):
  """ This function will automatically check if the variables
  are initialized, and only perform initialization for
  un-initialized variables.

  Note
  ----
  Re-initialize an initialized variable will give it random values
  """
  if vars is None:
    vars = get_all_variables()
  else:
    vars = [v for v in as_tuple(vars)
            if is_variable(v)]
  # ====== check if variable not initialized ====== #
  init_info = eval([tf.is_variable_initialized(v) for v in vars])
  vars = [v for v, inited in zip(vars, init_info)
          if not inited]
  # ====== build mapping graph -> list of vars ====== #
  graph = defaultdict(list)
  for v in vars:
    graph[v.graph].append(v)
  # ====== run the initialization ====== #
  for g, v in graph.items():
    get_session(graph=g).run([i.initializer for i in v])
コード例 #2
0
def _as_variable(x, name, roles=None):
  # nothing to do
  if x is None:
    return None
  # create variable
  if not is_tensor(x):
    x = tf.Variable(x, dtype=floatX, name=name)
    get_session().run(x.initializer)
  return add_roles(x, roles)
コード例 #3
0
ファイル: __init__.py プロジェクト: imito/odin
def restore_variables(path, session=None):
  if session is None:
    session = get_session()
  # ====== load and check var meta ====== #
  with open(path + '.collections', 'rb') as f:
    collections, var_meta = cPickle.load(f)
  var_list = []
  allvars = {v.name.split(':')[0]: v for v in get_all_variables()}
  for name, dtype, shape in var_meta:
    if name in allvars: # found predefined variable
      var_list.append(allvars[name])
    else: # create new variable
      if tf.get_variable_scope().name:
        raise RuntimeError("The current variable scope is: %s, you can "
            "only restore variables from default scope."
            % tf.get_variable_scope().name)
      var_list.append(tf.get_variable(
          shape=shape, name=name, dtype=dtype))
  # ====== restore the variables ====== #
  name = '|'.join(sorted([v.name for v in var_list]))
  if name in _saver:
    saver = _saver[name]
  else:
    saver = tf.train.Saver(var_list=var_list, restore_sequentially=False,
                           allow_empty=False)
  saver.restore(session, path)
  # ====== restore the collections ====== #
  for v in var_list:
    role.add_roles(v, collections[v.name])
コード例 #4
0
ファイル: __init__.py プロジェクト: imito/odin
def save_variables(var_list, path, session=None):
  """ This function only apply for trainable parameters """
  if session is None:
    session = get_session()
  var_list = [v for v in set(as_tuple(var_list)) if is_variable(v)]
  name = '|'.join(sorted([v.name for v in var_list]))
  if name in _saver:
    saver = _saver[name]
  else:
    saver = tf.train.Saver(var_list=var_list, restore_sequentially=False,
        allow_empty=False)
  # ====== save the variables ====== #
  checkpoint = saver.save(session, path, global_step=None,
      write_meta_graph=False, write_state=False)
  # ====== save meta-info for recreate variable ====== #
  var_meta = []
  for v in var_list:
    name = v.name.split(':')[0]
    dtype = v.dtype.base_dtype.name
    shape = v.shape.as_list()
    var_meta.append((name, dtype, shape))
  # ====== save the collections ====== #
  collections = {var.name: role.get_roles(var, return_string=True)
                 for var in var_list}
  with open(path + '.collections', 'wb') as f:
    cPickle.dump([collections, var_meta], f,
                 protocol=cPickle.HIGHEST_PROTOCOL)
  return checkpoint
コード例 #5
0
ファイル: __init__.py プロジェクト: professorlust/odin-ai
def save_variables(var_list, path, session=None):
    """ This function only apply for trainable parameters """
    if session is None:
        session = get_session()
    var_list = [v for v in set(as_tuple(var_list)) if is_variable(v)]
    name = '|'.join(sorted([v.name for v in var_list]))
    if name in _saver:
        saver = _saver[name]
    else:
        saver = tf.train.Saver(var_list=var_list,
                               restore_sequentially=False,
                               allow_empty=False)
    # ====== save the variables ====== #
    checkpoint = saver.save(session,
                            path,
                            global_step=None,
                            write_meta_graph=False,
                            write_state=False)
    # ====== save meta-info for recreate variable ====== #
    var_meta = []
    for v in var_list:
        name = v.name.split(':')[0]
        dtype = v.dtype.base_dtype.name
        shape = v.shape.as_list()
        var_meta.append((name, dtype, shape))
    # ====== save the collections ====== #
    collections = {
        var.name: role.get_roles(var, return_string=True)
        for var in var_list
    }
    with open(path + '.collections', 'wb') as f:
        cPickle.dump([collections, var_meta],
                     f,
                     protocol=cPickle.HIGHEST_PROTOCOL)
    return checkpoint
コード例 #6
0
ファイル: __init__.py プロジェクト: professorlust/odin-ai
def restore_variables(path, session=None):
    if session is None:
        session = get_session()
    # ====== load and check var meta ====== #
    with open(path + '.collections', 'rb') as f:
        collections, var_meta = cPickle.load(f)
    var_list = []
    allvars = {v.name.split(':')[0]: v for v in get_all_variables()}
    for name, dtype, shape in var_meta:
        if name in allvars:  # found predefined variable
            var_list.append(allvars[name])
        else:  # create new variable
            if tf.get_variable_scope().name:
                raise RuntimeError(
                    "The current variable scope is: %s, you can "
                    "only restore variables from default scope." %
                    tf.get_variable_scope().name)
            var_list.append(
                tf.get_variable(shape=shape, name=name, dtype=dtype))
    # ====== restore the variables ====== #
    name = '|'.join(sorted([v.name for v in var_list]))
    if name in _saver:
        saver = _saver[name]
    else:
        saver = tf.train.Saver(var_list=var_list,
                               restore_sequentially=False,
                               allow_empty=False)
    saver.restore(session, path)
    # ====== restore the collections ====== #
    for v in var_list:
        role.add_roles(v, collections[v.name])
コード例 #7
0
ファイル: __init__.py プロジェクト: imito/odin
def save_graph(path, graph=None):
  g = tf.summary.FileWriter(path)
  if graph is None:
    graph = get_session().graph
  elif isinstance(graph, tf.Session):
    graph = graph.graph
  g.add_graph(graph)
  g.flush()
  g.close()
コード例 #8
0
ファイル: __init__.py プロジェクト: professorlust/odin-ai
def save_graph(path, graph=None):
    g = tf.summary.FileWriter(path)
    if graph is None:
        graph = get_session().graph
    elif isinstance(graph, tf.Session):
        graph = graph.graph
    g.add_graph(graph)
    g.flush()
    g.close()
コード例 #9
0
ファイル: helpers.py プロジェクト: professorlust/odin-ai
def set_value(x, value, return_ops=False, name='SetValue'):
  '''Sets the value of a tensor variable,
  from a Numpy array.

  Parameters
  ----------
  x: `Tensor`
  value: real value
  return_ops: bool
      if True, return assign Op and feed_dict instead of running
      the Op directly
  '''
  if isinstance(value, np.ndarray):
    value = value.astype(x.dtype.as_numpy_dtype)
  elif is_tensor(value):
    value = tf.cast(value, dtype=x.dtype)
  assign_op = tf.assign(x, value, name=name)
  if return_ops:
    return assign_op
  get_session().run(assign_op)
  return x
コード例 #10
0
ファイル: helpers.py プロジェクト: professorlust/odin-ai
def is_training(graph=None):
  if graph is None:
    graph = get_session().graph
  training_var = get_all_variables(scope=None, name='IsTraining__',
                                   graph=graph)
  if len(training_var) == 0:
    raise RuntimeError("Cannot find variable with name='IsTraining' scope='' "
                       "within graph=%s" % str(graph))
  elif len(training_var) > 1:
    raise RuntimeError("Found multiple 'IsTraining__' flag: %s" %
      str(training_var))
  return training_var[0]
コード例 #11
0
ファイル: helpers.py プロジェクト: professorlust/odin-ai
def get_operationID(op, graph=None):
  """operation ID is unique ID of Op, the ID represent the order
  of created Op."""
  if graph is None:
    graph = get_session().graph
  ops = graph.get_operations()
  # update OpID
  if len(_ops_ID) != len(ops):
    for ID, op in graph._nodes_by_id.items():
      if op not in _ops_ID:
        _ops_ID[op] = ID
  return _ops_ID[op]
コード例 #12
0
ファイル: __init__.py プロジェクト: professorlust/odin-ai
def initialize_all_variables(vars=None):
    """ This function will automatically check if the variables
  are initialized, and only perform initialization for
  un-initialized variables.

  Note
  ----
  Re-initialize an initialized variable will give it random values
  """
    if vars is None:
        vars = get_all_variables()
    else:
        vars = [v for v in as_tuple(vars) if is_variable(v)]
    # ====== check if variable not initialized ====== #
    init_info = eval([tf.is_variable_initialized(v) for v in vars])
    vars = [v for v, inited in zip(vars, init_info) if not inited]
    # ====== build mapping graph -> list of vars ====== #
    graph = defaultdict(list)
    for v in vars:
        graph[v.graph].append(v)
    # ====== run the initialization ====== #
    for g, v in graph.items():
        get_session(graph=g).run([i.initializer for i in v])
コード例 #13
0
ファイル: __init__.py プロジェクト: imito/odin
def eval(x, feed_dict=None,
         update_before=None, update_after=None,
         options=None, run_metadata=None):
  ''' Generalized version of code evaluation, it
  could evaluate python and tensorflow expression.

  Parameters
  ----------
  x : list, tuple, dictionary, `Tensor`
      tensorfow `Tensor` for evaluation
  feed_dict : dict
      Input dictionary, mapping placeholder -> values
  update_before: {None, list, or dict}
      mapping from `Tensor` to its new value which is `Tensor` or
      real value, the updates is runned before evaluating
  update_after: {None, list, or dict}
      same as `updates_before`, but run the `updates` after
      evaluate `x`
  options: tensorflow.RunOptions
      thhe options allow controlling the behavior of
      this particular step (e.g. turning tracing on).
  run_metadata: tensorflow.RunMetadata
      When appropriate, the non-Tensor output of this
      step will be collected there. For example,
      when users turn on tracing in options, the
      profiled info will be collected into
      this argument and passed back.

  Example
  -------
  >>> import tensorflow as tf
  >>> from odin import backend as K
  >>> run_metadata = tf.RunMetadata()
  >>> K.eval(...,
  ...        options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE,
  ...                              output_partition_graphs=True),
  ...        run_metadata=run_metadata)
  >>> with open('log_path', 'w') as f:
  >>>   f.write(str(run_metadata))

  Note
  ----
  If "Couldn't open CUDA library libcupti.so.8.0" appears when you
  adding RunOptions, try adding "/usr/local/cuda/extras/CUPTI/lib64/"
  to your LD_LIBRARY_PATH
  '''
  results = ()
  update_before = _validate_updates(update_before)
  update_after = _validate_updates(update_after)
  # ====== run updates before ====== #
  if update_before is not None:
    get_session(update_before.graph).run(update_before, feed_dict=feed_dict,
                                   options=options,
                                   run_metadata=run_metadata)
  # ====== list of Tensor or string ====== #
  if isinstance(x, (tuple, list)):
    string_eval = []
    tensor_eval = []
    tensor_idx = []
    # evaluate string expression
    for i, j in enumerate(x):
      if is_string(j):
        string_eval.append(builtins.eval(j))
      else:
        tensor_eval.append(j)
        tensor_idx.append(i)
    # evaluate tensor
    if len(tensor_eval) > 0:
      graph = [i.graph for i in tensor_eval]
      if len(set(graph)) > 1:
        raise RuntimeError("Cannot evaluate multiple `Tensor` come from "
                           "different `Graph`.")
      tensor_eval = get_session(graph[0]).run(tensor_eval,
                                              feed_dict=feed_dict,
                                              options=options,
                                              run_metadata=run_metadata)
    results = tuple([tensor_eval.pop(0) if i in tensor_idx else string_eval.pop(0)
                     for i in range(len(x))])
  # ====== mapping ====== #
  elif isinstance(x, Mapping):
    results = {}
    tensor_eval_key = []
    tensor_eval_value = []
    for k, v in x.items():
      if is_string(v):
        results[k] = builtins.eval(v)
      else:
        tensor_eval_key.append(k)
        tensor_eval_value.append(v)
    # evaluate tensor
    if len(tensor_eval) > 0:
      graph = [i.graph for i in tensor_eval_value]
      if len(set(graph)) > 1:
        raise RuntimeError("Cannot evaluate multiple `Tensor` come from "
                           "different `Graph`.")
      tensor_eval_value = get_session(graph[0]).run(tensor_eval_value,
                                                    feed_dict=feed_dict,
                                                    options=options,
                                                    run_metadata=run_metadata)
    # update results
    for k, v in zip(tensor_eval_key, tensor_eval_value):
      results[k] = v
  # ====== just a string ====== #
  elif is_string(x):
    results = builtins.eval(x)
  # ====== just a Tensorflow object ====== #
  elif isinstance(x, tf.Operation) or \
  is_tensor(x, inc_distribution=True, inc_variable=True):
    results = get_session(x.graph).run(x, feed_dict=feed_dict,
                                       options=options,
                                       run_metadata=run_metadata)
  # ====== exception ====== #
  else:
    raise RuntimeError("Cannot evaluate object of type: %s" % type(x))
  # ====== run updates after ====== #
  if update_after is not None:
    get_session(update_after.graph).run(update_after, feed_dict=feed_dict,
                                        options=options,
                                        run_metadata=run_metadata)
  return results
コード例 #14
0
ファイル: __init__.py プロジェクト: professorlust/odin-ai
def eval(x,
         feed_dict=None,
         update_before=None,
         update_after=None,
         options=None,
         run_metadata=None):
    ''' Generalized version of code evaluation, it
  could evaluate python and tensorflow expression.

  Parameters
  ----------
  x : list, tuple, dictionary, `Tensor`
      tensorfow `Tensor` for evaluation
  feed_dict : dict
      Input dictionary, mapping placeholder -> values
  update_before: {None, list, or dict}
      mapping from `Tensor` to its new value which is `Tensor` or
      real value, the updates is runned before evaluating
  update_after: {None, list, or dict}
      same as `updates_before`, but run the `updates` after
      evaluate `x`
  options: tensorflow.RunOptions
      thhe options allow controlling the behavior of
      this particular step (e.g. turning tracing on).
  run_metadata: tensorflow.RunMetadata
      When appropriate, the non-Tensor output of this
      step will be collected there. For example,
      when users turn on tracing in options, the
      profiled info will be collected into
      this argument and passed back.

  Example
  -------
  >>> import tensorflow as tf
  >>> from odin import backend as K
  >>> run_metadata = tf.RunMetadata()
  >>> K.eval(...,
  ...        options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE,
  ...                              output_partition_graphs=True),
  ...        run_metadata=run_metadata)
  >>> with open('log_path', 'w') as f:
  >>>   f.write(str(run_metadata))

  Note
  ----
  If "Couldn't open CUDA library libcupti.so.8.0" appears when you
  adding RunOptions, try adding "/usr/local/cuda/extras/CUPTI/lib64/"
  to your LD_LIBRARY_PATH
  '''
    results = ()
    update_before = _validate_updates(update_before)
    update_after = _validate_updates(update_after)
    # ====== run updates before ====== #
    if update_before is not None:
        get_session(update_before.graph).run(update_before,
                                             feed_dict=feed_dict,
                                             options=options,
                                             run_metadata=run_metadata)
    # ====== list of Tensor or string ====== #
    if isinstance(x, (tuple, list)):
        string_eval = []
        tensor_eval = []
        tensor_idx = []
        # evaluate string expression
        for i, j in enumerate(x):
            if is_string(j):
                string_eval.append(builtins.eval(j))
            else:
                tensor_eval.append(j)
                tensor_idx.append(i)
        # evaluate tensor
        if len(tensor_eval) > 0:
            graph = [i.graph for i in tensor_eval]
            if len(set(graph)) > 1:
                raise RuntimeError(
                    "Cannot evaluate multiple `Tensor` come from "
                    "different `Graph`.")
            tensor_eval = get_session(graph[0]).run(tensor_eval,
                                                    feed_dict=feed_dict,
                                                    options=options,
                                                    run_metadata=run_metadata)
        results = tuple([
            tensor_eval.pop(0) if i in tensor_idx else string_eval.pop(0)
            for i in range(len(x))
        ])
    # ====== mapping ====== #
    elif isinstance(x, Mapping):
        results = {}
        tensor_eval_key = []
        tensor_eval_value = []
        for k, v in x.items():
            if is_string(v):
                results[k] = builtins.eval(v)
            else:
                tensor_eval_key.append(k)
                tensor_eval_value.append(v)
        # evaluate tensor
        if len(tensor_eval) > 0:
            graph = [i.graph for i in tensor_eval_value]
            if len(set(graph)) > 1:
                raise RuntimeError(
                    "Cannot evaluate multiple `Tensor` come from "
                    "different `Graph`.")
            tensor_eval_value = get_session(graph[0]).run(
                tensor_eval_value,
                feed_dict=feed_dict,
                options=options,
                run_metadata=run_metadata)
        # update results
        for k, v in zip(tensor_eval_key, tensor_eval_value):
            results[k] = v
    # ====== just a string ====== #
    elif is_string(x):
        results = builtins.eval(x)
    # ====== just a Tensorflow object ====== #
    elif isinstance(x, tf.Operation) or \
    is_tensor(x, inc_distribution=True, inc_variable=True):
        results = get_session(x.graph).run(x,
                                           feed_dict=feed_dict,
                                           options=options,
                                           run_metadata=run_metadata)
    # ====== exception ====== #
    else:
        raise RuntimeError("Cannot evaluate object of type: %s" % type(x))
    # ====== run updates after ====== #
    if update_after is not None:
        get_session(update_after.graph).run(update_after,
                                            feed_dict=feed_dict,
                                            options=options,
                                            run_metadata=run_metadata)
    return results
コード例 #15
0
ファイル: helpers.py プロジェクト: professorlust/odin-ai
def set_training(is_training, graph=None, return_ops=False):
  if graph is None:
    graph = get_session().graph
  return set_value(is_training(graph), bool(is_training),
                   return_ops=return_ops)
コード例 #16
0
ファイル: helpers.py プロジェクト: professorlust/odin-ai
 def __call__(self, *inputs, **kwargs):
   show_progress = kwargs.pop('show_progress', False)
   # dictionary as inputs
   if len(kwargs) == len(self.inputs_name):
     inputs = [kwargs[i] for i in self.inputs_name]
   # ====== delete un-matchede inputs ====== #
   inputs_new = []
   tmp = list(inputs)
   shapes = list(self._input_shape)
   # this process iteratively remove inputs with mismatch shape
   # to current given input
   for s in shapes:
     for i in tuple(tmp):
       if len(i.shape) != len(s) or \
       any(a is not None and a > 0 and a != b
               for a, b in zip(s, i.shape)): # different ndim, or shape
         tmp.remove(i)
       else:
         inputs_new.append(i)
         tmp.remove(i)
         break
   if len(inputs_new) != len(self.inputs):
     raise ValueError("Given inputs have shape: %s, cannot match the shape of "
                      "defined inputs: %s" %
                      ('; '.join([str(i.shape) for i in inputs]),
                       '; '.join([str(i) for i in self.input_shape])))
   if not self._strict:
     inputs = inputs_new
   # ====== create feed_dict ====== #
   feed_dict = {}
   inputs = flatten_list(inputs, level=None)
   for tensor, value in zip(self.inputs, inputs):
     feed_dict[tensor] = value
   feed_dict.update(self.defaults)
   # check if modifying training mode
   if self.training is None:
     pass
   elif self.training:
     feed_dict.update({is_training(): True})
   else:
     feed_dict.update({is_training(): False})
   session = get_session()
   outputs = None
   # ====== mini-batches ====== #
   if self.batch_size is not None:
     batch_vars = ([i for i in feed_dict.keys() if is_tensor(i)]
                   if len(self.batch_vars) == 0 else self.batch_vars)
     batch_vars = [i for i in batch_vars
                   if i in feed_dict and hasattr(feed_dict[i], 'shape')]
     n_samples = list(set(feed_dict[i].shape[0] for i in batch_vars))
     assert len(n_samples) == 1, \
     "Data have multiple batching dimension: %s" % str(n_samples)
     n_samples = n_samples[0]
     # only continue if we have more samples than `batch_size`
     if n_samples > self.batch_size:
       n_output = len(self.outputs)
       outputs = []
       all_batches = []
       # (optional) showing progress
       if show_progress:
         prog = Progbar(target=n_samples,
                        print_report=False, print_summary=False,
                        name='')
       for s, e in batching(batch_size=int(self.batch_size),
                            n=n_samples):
         if show_progress:
           prog.add(e - s)
         all_batches.append(e - s)
         feed_dict_minibatch = OrderedDict([(k, v[s:e])
                                            if k in batch_vars else (k, v)
                                            for k, v in feed_dict.items()])
         updated = session.run(self.outputs + [self.updates_ops],
                               feed_dict=feed_dict_minibatch)
         updated = updated[:n_output]
         if not self._return_list:
           updated = updated[0]
         outputs.append(updated)
       ## concatenate all outputs
       if not self._return_list:
         o_ndim = outputs[0].ndim
         if o_ndim == 0: # returned scalars
           outputs = np.array(outputs)
         else: # returned array
           for o_axis in range(o_ndim):
             all_n = [o.shape[o_axis] for o in outputs]
             if all_n == all_batches:
               break
           outputs = np.concatenate(outputs, axis=o_axis)
       ## returning a list of outputs
       else:
         new_outputs = []
         for output_idx in range(len(outputs[0])):
           o = [x[output_idx] for x in outputs]
           o_ndim = o[0].ndim
           if o_ndim == 0: # returned scalars
             o = np.array(o)
           else: # returned array
             for o_axis in range(o[0].ndim):
               all_n = [val.shape[o_axis] for val in o]
               if all_n == all_batches:
                 break
             o = np.concatenate(o, axis=o_axis)
           new_outputs.append(o)
         outputs = new_outputs
   # ====== single batch ====== #
   if outputs is None:
     updated = session.run(self.outputs + [self.updates_ops],
                           feed_dict=feed_dict)
     outputs = updated[:len(self.outputs)]
     if not self._return_list:
       outputs = outputs[0]
   # ====== return final output ====== #
   return outputs
コード例 #17
0
ファイル: helpers.py プロジェクト: professorlust/odin-ai
def get_value(x):
  if isinstance(x, (tuple, list)):
    return get_session().run(x)
  return x.eval(session=get_session())
コード例 #18
0
ファイル: helpers.py プロジェクト: professorlust/odin-ai
def get_all_operations(otype=None, device=None, sort=False, scope=None,
                       footprint=None, graph=None, beginning_scope=True):
  """ Return list of all operations in default graph
  The follow attributes can be access within the operation:
   * name : string
   * otype : string, operation type (e.g. `"MatMul"`).
   * device:  string name of the device to which this op has been assigned
   * _inputs : list of `Tensor`
   * _outputs : list of `Tensor`
   * _control_inputs : Before this op is executed, the operations in
       `control_inputs` have finished executing.
   * graph : `Graph` that contains this operation
   * node_def : serialized `NodeDef` representation of this operation.
   * op_def : `OpDef` proto that represents the type of this op.
   * traceback : call stack from when this operation was constructed.

  Some important op type:
   * "Placeholder"
   * "VariableV2"
   * "Const"
   * "Assign"

  Parameters
  ----------
  beginning_scope : bool (default: True)
    if True, the provide scope must be the beginning scope,
    otherwise, it could be in the middle of multiple scopes
  """
  if graph is None:
    graph = get_session().graph
  ops = graph.get_operations()
  # update OpID
  if len(_ops_ID) != len(ops):
    for ID, op in graph._nodes_by_id.items():
      if op not in _ops_ID:
        _ops_ID[op] = ID
  # filter out some op
  if otype is not None:
    ops = [o for o in ops
           if _filter_string(otype, o.type)]
  if device is not None:
    ops = [o for o in ops
           if _filter_string(device, o.device)]
  # ====== filter by scope ====== #
  if scope is not None:
    scope = str(scope)
    if len(scope) == 0:
      ops = [o for o in ops
             if '/' not in o.name]
    else:
      scope_name_pattern = _TF_SCOPE_PATTERN(scope, beginning_scope)
      ops = [o for o in ops
             if len(scope_name_pattern.findall(o.name))]
  # ====== filter by unique footprint ====== #
  if footprint is not None:
    ops = [o for o in ops
           if get_operation_footprint(o) == footprint]
  # sorted by OpID
  if sort and len(ops) > 1:
    ops = sorted(ops, key=lambda x: _ops_ID[x])
  return ops
コード例 #19
0
ファイル: keras_helpers.py プロジェクト: imito/odin
def tied_session():
  """ Tied the tensorflow Session and keras Session together
  """
  from odin.config import get_session
  from tensorflow.python.keras.backend import set_session
  set_session(get_session())