def import_state_tuples(state_tuples, name, num_replicas): restored = [] for i in range(len(state_tuples) * num_replicas): c = tf.get_collection_ref(name)[2 * i + 0] h = tf.get_collection_ref(name)[2 * i + 1] restored.append(tf.contrib.rnn.LSTMStateTuple(c, h)) return tuple(restored)
def save(self, model_file, global_step=None): """ save. Save a Tensorflow model Arguments: model_file: `str`. Saving path of tensorflow model global_step: `float`. The training step to append to the model file name (optional). """ # Temp workaround for tensorflow 0.7.0 dict proto serialization issue try: # Try latest api l = tf.get_collection_ref("summary_tags") l4 = tf.get_collection_ref(tf.GraphKeys.GRAPH_CONFIG) except Exception: l = tf.get_collection("summary_tags") l4 = tf.get_collection(tf.GraphKeys.GRAPH_CONFIG) l_stags = list(l) l4_stags = list(l4) del l[:] del l4[:] try: # Try latest api l1 = tf.get_collection_ref(tf.GraphKeys.DATA_PREP) l2 = tf.get_collection_ref(tf.GraphKeys.DATA_AUG) except Exception: l1 = tf.get_collection(tf.GraphKeys.DATA_PREP) l2 = tf.get_collection(tf.GraphKeys.DATA_AUG) l1_dtags = list(l1) l2_dtags = list(l2) del l1[:] del l2[:] try: # Do not save exclude variables l3 = tf.get_collection_ref(tf.GraphKeys.EXCL_RESTORE_VARS) except Exception: l3 = tf.get_collection(tf.GraphKeys.EXCL_RESTORE_VARS) l3_tags = list(l3) del l3[:] # Temp workaround for tensorflow 0.7.0 relative path issue if model_file[0] not in ['/', '~']: model_file = './' + model_file self.saver.save(self.session, model_file, global_step=global_step) # 0.7 workaround, restore values for t in l_stags: tf.add_to_collection("summary_tags", t) for t in l4_stags: tf.add_to_collection(tf.GraphKeys.GRAPH_CONFIG, t) for t in l1_dtags: tf.add_to_collection(tf.GraphKeys.DATA_PREP, t) for t in l2_dtags: tf.add_to_collection(tf.GraphKeys.DATA_AUG, t) for t in l3_tags: tf.add_to_collection(tf.GraphKeys.EXCL_RESTORE_VARS, t)
def restore_collection(backup): """ Restore from a collection backup. Args: backup (dict): """ for k, v in six.iteritems(backup): del tf.get_collection_ref(k)[:] tf.get_collection_ref(k).extend(v)
def save(self, model_file, global_step=None): """ save. Save a Tensorflow model Arguments: model_file: `str`. Saving path of tensorflow model global_step: `float`. The training step to append to the model file name (optional). """ # Temp workaround for tensorflow 0.7.0 dict proto serialization issue try: # Try latest api l = tf.get_collection_ref("summary_tags") except Exception: l = tf.get_collection("summary_tags") l_stags = list(l) del l[:] # Temp workaround for tensorflow 0.7.0 relative path issue if model_file[0] not in ['/', '~']: model_file = './' + model_file self.saver.save(self.session, model_file, global_step=global_step) # 0.7 workaround, restore values for t in l_stags: tf.add_to_collection("summary_tags", t)
def filter_trainable_variables(trainable_scopes): """Keep only trainable variables which are prefixed with given scopes. Args: trainable_scopes: either list of trainable scopes or string with comma separated list of trainable scopes. This function removes all variables which are not prefixed with given trainable_scopes from collection of trainable variables. Useful during network fine tuning, when you only need to train subset of variables. """ if not trainable_scopes: return if isinstance(trainable_scopes, six.string_types): trainable_scopes = [scope.strip() for scope in trainable_scopes.split(',')] trainable_scopes = {scope for scope in trainable_scopes if scope} if not trainable_scopes: return trainable_collection = tf.get_collection_ref( tf.GraphKeys.TRAINABLE_VARIABLES) non_trainable_vars = [ v for v in trainable_collection if not any([v.op.name.startswith(s) for s in trainable_scopes]) ] for v in non_trainable_vars: trainable_collection.remove(v)
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k): with tf.Session() as sess, gfile.Open(out_file_location, "w+") as out_file: video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size) latest_checkpoint = tf.train.latest_checkpoint(train_dir) if latest_checkpoint is None: raise Exception("unable to find a checkpoint at location: %s" % train_dir) else: if FLAGS.check_point < 0: meta_graph_location = latest_checkpoint + ".meta" else: meta_graph_location = FLAGS.train_dir + "/model.ckpt-" + str(FLAGS.check_point) + ".meta" latest_checkpoint = FLAGS.train_dir + "/model.ckpt-" + str(FLAGS.check_point) logging.info("loading meta-graph: " + meta_graph_location) saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) logging.info("restoring variables from " + latest_checkpoint) saver.restore(sess, latest_checkpoint) input_tensor = tf.get_collection("input_batch_raw")[0] num_frames_tensor = tf.get_collection("num_frames")[0] predictions_tensor = tf.get_collection("predictions")[0] # Workaround for num_epochs issue. def set_up_init_ops(variables): init_op_list = [] for variable in list(variables): if "train_input" in variable.name: init_op_list.append(tf.assign(variable, 1)) variables.remove(variable) init_op_list.append(tf.variables_initializer(variables)) return init_op_list sess.run(set_up_init_ops(tf.get_collection_ref( tf.GraphKeys.LOCAL_VARIABLES))) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) num_examples_processed = 0 start_time = time.time() out_file.write("VideoId,LabelConfidencePairs\n") try: while not coord.should_stop(): video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch]) predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val}) now = time.time() num_examples_processed += len(video_batch_val) num_classes = predictions_val.shape[1] logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time)) for line in format_lines(video_id_batch_val, predictions_val, top_k): out_file.write(line) out_file.flush() except tf.errors.OutOfRangeError: logging.info('Done with inference. The output file was written to ' + out_file_location) finally: coord.request_stop() coord.join(threads) sess.close()
def fix_saver(collection_lists=None): # Workaround to prevent serialization warning by removing objects if collection_lists is None: try: # Try latest api l = tf.get_collection_ref("summary_tags") l4 = tf.get_collection_ref(tf.GraphKeys.GRAPH_CONFIG) except Exception: l = tf.get_collection("summary_tags") l4 = tf.get_collection(tf.GraphKeys.GRAPH_CONFIG) l_stags = list(l) l4_stags = list(l4) del l[:] del l4[:] try: # Try latest api l1 = tf.get_collection_ref(tf.GraphKeys.DATA_PREP) l2 = tf.get_collection_ref(tf.GraphKeys.DATA_AUG) except Exception: l1 = tf.get_collection(tf.GraphKeys.DATA_PREP) l2 = tf.get_collection(tf.GraphKeys.DATA_AUG) l1_dtags = list(l1) l2_dtags = list(l2) del l1[:] del l2[:] try: # Do not save exclude variables l3 = tf.get_collection_ref(tf.GraphKeys.EXCL_RESTORE_VARS) except Exception: l3 = tf.get_collection(tf.GraphKeys.EXCL_RESTORE_VARS) l3_tags = list(l3) del l3[:] return [l_stags, l1_dtags, l2_dtags, l3_tags, l4_stags] else: # 0.7+ workaround, restore values for t in collection_lists[0]: tf.add_to_collection("summary_tags", t) for t in collection_lists[4]: tf.add_to_collection(tf.GraphKeys.GRAPH_CONFIG, t) for t in collection_lists[1]: tf.add_to_collection(tf.GraphKeys.DATA_PREP, t) for t in collection_lists[2]: tf.add_to_collection(tf.GraphKeys.DATA_AUG, t) for t in collection_lists[3]: tf.add_to_collection(tf.GraphKeys.EXCL_RESTORE_VARS, t)
def _variable_tracking_custom_getter(getter, *args, **kwargs): """Custom getter that tracks variables created. This custom getter places any variables that `getter` creates into the `_all_variables` attribute of the `AbstractModule` that is on top of the module call stack. The module call stack is a graph-dependent stack that keeps track of the sonnet module call order. Note that this assumes that variables added appended to `tf.Graph` collections. This is a safe assumption to make because `tf.add_to_collection()` appends objects to collections, and `tf.Variable` uses `tf.add_to_collections()` to add itself to `tf.Graph` collections. Note that this assumes that all variables are added either the `tf.GraphKeys.GLOBAL_VARIABLES` or `tf.GraphKeys.LOCAL_VARIABLES` collection. Args: getter: The true getter or another custom getter. *args: See positional arguments for `tf.get_variable()`. **kwargs: See keyword arguments for `tf.get_variable()`. Returns: See docstring for `tf.get_variable()`. """ # Get the module that is calling `tf.get_variable()` module_stack = _MODULE_STACKS[tf.get_default_graph()] module = module_stack[-1] # Get lists of local and global variables. We use `tf.get_collection_ref()` # instead of `tf.get_collection()` to avoid copying the collections. local_variables = tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES) global_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) num_local_vars_before = len(local_variables) num_global_vars_before = len(global_variables) out = getter(*args, **kwargs) # Add any local or global variables that have been created to `module` # pylint: disable=protected-access module._all_variables.update(local_variables[num_local_vars_before:]) module._all_variables.update(global_variables[num_global_vars_before:]) # pylint: enable=protected-access return out
def local_state_variables(init_values, return_init_values): """Create local variables initialized from init_values. This will create local variables from a list of init_values. Each variable will be named based on the value's shape and dtype. As a convenience, a boolean tensor allows you to return value from the created local variable or from the original init value. Args: init_values: iterable of tensors return_init_values: boolean tensor Returns: local_vars: list of the created local variables. vals: if return_init_values is true, then this returns the values of init_values. Otherwise it returns the values of the local_vars. """ if not init_values: return [], [] # This generates a harmless warning when saving the metagraph. variable_use_count = tf.get_collection_ref(_LOCAL_STATE_VARIABLE_COLLECTION) if not variable_use_count: variable_use_count.append(collections.defaultdict(int)) variable_use_count = variable_use_count[0] local_vars = [] with tf.variable_scope(OPTIMIZER_SCOPE): # We can't use the init_value as an initializer as init_value may # itself depend on some problem variables. This would produce # inter-variable initialization order dependence which TensorFlow # sucks at making easy. for init_value in init_values: name = create_local_state_variable_name(init_value) unique_name = name + "_" + str(variable_use_count[name]) variable_use_count[name] += 1 # The overarching idea here is to be able to reuse variables between # different sessions on the same TensorFlow master without errors. By # uniquifying based on the type and name we mirror the checks made inside # TensorFlow, while still allowing some memory reuse. Ultimately this is a # hack due to the broken Session.reset(). local_vars.append( tf.get_local_variable( unique_name, initializer=tf.zeros( init_value.get_shape(), dtype=init_value.dtype))) # It makes things a lot simpler if we use the init_value the first # iteration, instead of the variable itself. It allows us to propagate # gradients through it as well as simplifying initialization. The variable # ends up assigned to after the first iteration. vals = tf.cond(return_init_values, lambda: init_values, lambda: local_vars) if len(init_values) == 1: # tf.cond extracts elements from singleton lists. vals = [vals] return local_vars, vals
def body(self, features): hp = self.hparams is_distill = hp.distill_phase == "distill" targets = features["targets_raw"] targets = tf.squeeze(targets, [1, 2, 3]) one_hot_targets = tf.one_hot(targets, hp.num_classes, dtype=tf.float32) # Teacher Network with tf.variable_scope("teacher"): teacher_outputs = self.teacher_model.body(features) tf.logging.info("teacher output shape: %s" % teacher_outputs.get_shape()) teacher_outputs = tf.reduce_mean(teacher_outputs, axis=[1, 2]) teacher_logits = tf.layers.dense(teacher_outputs, hp.num_classes) teacher_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2( labels=one_hot_targets, logits=teacher_logits) outputs = teacher_logits if is_distill: # Load teacher weights tf.train.init_from_checkpoint(hp.teacher_dir, {"teacher/": "teacher/"}) # Do not train the teacher trainable_vars = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) del trainable_vars[:] # Student Network if is_distill: with tf.variable_scope("student"): student_outputs = self.student_model.body(features) tf.logging.info( "student output shape: %s" % student_outputs.get_shape()) student_outputs = tf.reduce_mean(student_outputs, axis=[1, 2]) student_logits = tf.layers.dense(student_outputs, hp.num_classes) student_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2( labels=one_hot_targets, logits=student_logits) teacher_targets = tf.nn.softmax(teacher_logits / hp.distill_temperature) student_distill_xent = tf.nn.softmax_cross_entropy_with_logits_v2( labels=tf.stop_gradient(teacher_targets), logits=student_logits) outputs = student_logits # Summaries tf.summary.scalar("distill_xent", student_distill_xent) if not is_distill: phase_loss = teacher_task_xent else: phase_loss = hp.task_balance * student_task_xent phase_loss += (1 - hp.task_balance) * student_distill_xent losses = {"training": phase_loss} outputs = tf.reshape(outputs, [-1, 1, 1, 1, outputs.shape[1]]) return outputs, losses
def variables_to_save(addlist): """Create a list of all trained variables and required variables of the model. Appends to the list, the addlist passed as argument. Args: addlist: (list, of, variables, to, save) Returns: a a list of variables""" return tf.trainable_variables() + tf.get_collection_ref( REQUIRED_NON_TRAINABLES) + addlist
def load_model(sess, checkpoint_path): meta_graph_location = checkpoint_path + '.meta' saver = tf.train.import_meta_graph( meta_graph_location, clear_devices=True, import_scope='m2' ) saver.restore(sess, checkpoint_path) sess.run( set_up_init_ops(tf.get_collection_ref(tf.GraphKeys.LOCAL_VARIABLES)) )
def import_ops(self, num_gpus = 1): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] self._kl_loss = tf.get_collection_ref(util.with_prefix(self._name, "kl_div"))[0] self._input_data = tf.get_collection_ref(util.with_prefix(self._name, "input_data"))[0] self._output = tf.get_collection_ref(util.with_prefix(self._name, "output"))[0] self._targets = tf.get_collection_ref(util.with_prefix(self._name, "targets"))[0] num_replicas = num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
def load_prior(config, sess, saver): logging.info('Loading prior model parameters from file ' + os.path.abspath(config.prior_model)) saver.restore(sess, os.path.abspath(config.prior_model)) # fill prior variables with the loaded values prior_variables = tf.get_collection_ref('prior_variables') prior_variables_dict = dict([(v.name, v) for v in prior_variables]) assign_tensors = [] with tf.variable_scope('prior'): for v in tf.trainable_variables(): prior_name = 'loss/prior/'+v.name prior_variable = prior_variables_dict[prior_name] assign_tensors.append(prior_variable.assign(v)) tf.variables_initializer(prior_variables) sess.run(assign_tensors)
def MarkAsNonTrainable(self): """Mark all the variables of this block as non-trainable. All the variables owned directly or indirectly (through subblocks) are marked as non trainable. This function along with CheckpointInitOp can be used to load a pretrained model that consists in only one part of the whole graph. """ assert self._called all_variables = self.VariableList() collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) for v in all_variables: if v in collection: collection.remove(v)
def _mk_training_step(self): """Function to construct the training steps and procedure. This function returns an operation to iterate and train the network. """ # Training methods if not self.sgd: self._optimizer = tf.train.AdagradOptimizer( self.lr, initial_accumulator_value=self.init_value_ada) else: self._optimizer = tf.train.GradientDescentOptimizer( self.lr) if not self.reg_unary: _reg = tf.constant(0, dtype=tf.float32) else: _reg = tf.add_n(tf.get_collection_ref("regularisation")) _reg /= self.n_layers # For parameters A of the layers, use Adagrad in the Stiefel manifold grads_and_vars = self._optimizer.compute_gradients( self._cost + self.reg_scale * _reg) for i, (g, v) in enumerate(grads_and_vars): if v in tf.get_collection('Unitary'): if self.proj_A: # gA = gA - A.gA^T.A g = g - tf.matmul(v, tf.matmul(g, v, transpose_a=True)) grads_and_vars[i] = (g, v) _train = self._optimizer.apply_gradients( grads_and_vars, global_step=self.global_step) _svd = tf.get_collection('svd') if self.manifold: # Use the dependency to project only once Adagrad has done its step with tf.control_dependencies([_train]): _train = tf.group(*_svd) else: self._svd = tf.group(*_svd) s1 = tf.summary.scalar("cost (pre svd)", self._cost - self.feed_map['c_val']) # summary to track manifold deviation s2 = tf.summary.scalar('cost_manifold', tf.add_n( tf.get_collection("regularisation"))) self._pre_svd = tf.summary.merge([s1, s2]) return _train
def test_tied_weights_untied_bias_registered_bias(self): """Tests that ambiguity in graph raises value error. Graph search will find several possible registrations for tensors. In this registering b_1 as a linked variable will result in an error because there will remain an ambiguity on the other branch of the graph. """ with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(tensor_dict['out_0']) layer_collection.register_squared_error_loss(tensor_dict['out_1']) layer_collection.define_linked_parameters((tensor_dict['b_1'])) with self.assertRaises(gs.AmbiguousRegistrationError): gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES))
def make_major_table(direction, piece, pin_directions): table = np.zeros(Piece.SIZE + 14 * PinDirection.SIZE, dtype=np.int32) normalized_direction = PinDirection[direction.name] name = 'black_{}_{}'.format(piece.name, normalized_direction.name) collection = tf.get_collection_ref(name) if len(collection): return collection[0] table[Piece.EMPTY] = -1 table[piece] = 1 # ピンの方向に対応する移動方向ならフラグを立てる offset = Piece.SIZE + piece - Piece.BLACK_FU for pin_direction in pin_directions: index = offset + 14 * pin_direction table[index] = pin_direction == normalized_direction tf.add_to_collection(name, table) return table
def _create_saved_model_impl(inputs, operation, extra_args): """Create a SavedModel from a TF Graph.""" unbound_saved_model_dir = common.get_unique_temp_path( extra_args.base_temp_dir) with extra_args.graph.as_default(): with tf.Session(graph=extra_args.graph) as session: table_initializers_ref = tf.get_collection_ref( tf.GraphKeys.TABLE_INITIALIZERS) original_table_initializers = list(table_initializers_ref) del table_initializers_ref[:] table_initializers_ref.extend(operation.table_initializers) # Initialize all variables so they can be saved. session.run(tf.global_variables_initializer()) saved_transform_io.write_saved_transform_from_session( session, extra_args.input_signature, operation.output_signature, unbound_saved_model_dir) del table_initializers_ref[:] table_initializers_ref.extend(original_table_initializers) return inputs | operation.label >> _BindTensors( extra_args.base_temp_dir, unbound_saved_model_dir, extra_args.pipeline)
def __init__(self, sess, checkpoint_dir, log_dir, training_paths, testing_paths, batch_size=1, layers=3, features_root=32, conv_size=3, dropout=0.5, loss_type='cross_entropy', class_weights=None): self.sess = sess self.checkpoint_dir = checkpoint_dir self.log_dir = log_dir self.training_paths = training_paths self.testing_paths = testing_paths self.nclass = 2 image, _ = read_patch(os.path.join(self.training_paths[0], '0'), self.nclass) self.batch_size = batch_size self.patch_size = image.shape[:-1] self.patch_stride = 4 # Used in deploy self.channel = image.shape[-1] self.layers = layers self.features_root = features_root self.conv_size = conv_size self.dropout = dropout self.loss_type = loss_type self.class_weights = class_weights self.patches_per_image = len(os.listdir(self.training_paths[0])) self.build_model() self.saver = tf.train.Saver(tf.trainable_variables() + tf.get_collection_ref('bn_collections'))
def build_graph(self, trainer, test_mode=False): if test_mode: self.saver = tf.train.Saver(max_to_keep=3) else: self._build_numerical_summaries(trainer) self._build_img_summaries(trainer) if trainer.train_extrap: # Extrap training has an additional Adam optimizer with parameters not existed in the ckpt ckpt_dir = self.force_load_from_dir if self.force_load_from_dir else self.ckpt_dir ckpt_vars = set( [v[0] for v in tf.train.list_variables(ckpt_dir)]) restore_var = [ v for v in tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES) if v.op.name in ckpt_vars ] self.saver = tf.train.Saver(max_to_keep=3, var_list=restore_var) else: self.saver = tf.train.Saver(max_to_keep=3)
def DISABLED_test_tied_weights_untied_bias_registered_weights(self): """Tests that graph search produces right solution on toy model.""" with tf.Graph().as_default(): tensor_dict = _build_model() layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_fully_connected_multi( tensor_dict['w'], (tensor_dict['x'], tensor_dict['y']), (tensor_dict['pre_bias_0'], tensor_dict['pre_bias_1'])) layer_collection_manual.register_generic(tensor_dict['b_0'], batch_size=1) layer_collection_manual.register_generic(tensor_dict['b_1'], batch_size=1) layer_collection = lc.LayerCollection() layer_collection.define_linked_parameters((tensor_dict['w'])) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES), batch_size=1) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual)
def make_long_kernel(direction, size): name = 'long_kernel_{}{}'.format(direction.name, size) collection = tf.get_collection_ref(name) if len(collection) == 0: # まだない f_list = [ make_long_right_up_kernel, make_long_right_kernel, make_long_right_down_kernel, make_long_up_kernel, make_long_down_kernel, make_long_left_up_kernel, make_long_left_kernel, make_long_left_down_kernel ] kernel = tf.constant(f_list[direction.value](size=size), dtype=tf.float32, name=name) tf.add_to_collection(name, kernel) else: kernel = collection[0] return kernel
def apply_mask(x, scope='', prune_option='weight'): """Apply mask to a given weight tensor. Args: x: Input weight tensor scope: The current variable scope. Defaults to "". prune_option: pruning option. Defaults to 'weight'. option = 'first_order_gradient' means using |weight| * |first order gradient| for pruning. option = 'second_order_gradient' means using |weight| * |second order gradient| for pruning. Returns: Tensor representing masked_weights """ mask = pruning_utils.weight_mask_variable(x, scope) threshold = pruning_utils.weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = tf.multiply(mask, x, MASKED_WEIGHT_NAME) if prune_option in ('first_order_gradient', 'second_order_gradient'): # absolute value of gradients for gradient based pruning gradient = pruning_utils.weight_gradient_variable(x, scope) old_weight = pruning_utils.old_weight_variable(x, scope) old_old_weight = pruning_utils.old_old_weight_variable(x, scope) # Make sure the mask for a given variable are not added multiple times to the # collection. This is particularly important when applying mask to RNN's # weight variables if mask not in tf.get_collection_ref(MASK_COLLECTION): tf.add_to_collection(THRESHOLD_COLLECTION, threshold) tf.add_to_collection(MASK_COLLECTION, mask) tf.add_to_collection(MASKED_WEIGHT_COLLECTION, masked_weights) tf.add_to_collection(WEIGHT_COLLECTION, x) if prune_option in ('first_order_gradient', 'second_order_gradient'): tf.add_to_collection(WEIGHT_GRADIENT_COLLECTION, gradient) tf.add_to_collection(OLD_WEIGHT_COLLECTION, old_weight) tf.add_to_collection(OLD_OLD_WEIGHT_COLLECTION, old_old_weight) return masked_weights
def select_black_major_piece(board, data_format, direction, piece): if direction in get_directions(piece=piece): # ピンの方向ごとに計算をする # 4方向あるが、2方向でいい normalized_direction = PinDirection[direction.name] name = '{}_{}'.format(piece.name, normalized_direction.name) collection = tf.get_collection_ref(name) if len(collection): return collection[0] table = make_major_table( direction=direction, piece=piece, pin_directions=get_pin_directions(piece=piece)) converted = tf.gather(table, board) one_hot = make_one_hot(index_board=converted, data_format=data_format) tf.add_to_collection(name, one_hot) return one_hot else: raise ValueError(direction)
def __call__(self, inputs, input_seq_length, is_training): """ Add the neural net variables and operations to the graph. The model scope attribute reuse is initialized to False. After it has been called for the first time, it is set to True, so that the weights are shared when it is called the next time Args: inputs: the inputs to the neural network, this is a dictionary of [batch_size x time x ...] tensors input_seq_length: The sequence lengths of the input utterances, this is a dictionary of [batch_size] vectors is_training: whether or not the network is in training mode Returns: - output logits, which is a dictionary of [batch_size x time x ...] tensors - the output logits sequence lengths which is a dictionary of [batch_size] vectors """ # compute the output logits logits = self._get_outputs(inputs=inputs, input_seq_length=input_seq_length, is_training=is_training) self.scope.reuse_variables() if hasattr(self, 'trainable') and not self.trainable: # Find all variables of the model model_variables = self.variables # remove variable from trainable variables list trainable_collection = tf.get_collection_ref( tf.GraphKeys.TRAINABLE_VARIABLES) for var in model_variables: if var in trainable_collection: trainable_collection.remove(var) return logits
def get_short_effect(board, direction, data_format, use_cudnn): # 桂馬の動きの場合は一度しか利用しないので、保存しない flag = direction in get_eight_directions() if flag: name = 'black_ou_short_move_{}'.format(direction.name) collection = tf.get_collection_ref(name) if len(collection): return collection[0] ou = get_short_ou(board=board) effect = ShortEffectLayer(direction=direction, data_format=data_format, use_cudnn=use_cudnn, name='black_ou_short_{}'.format( direction.name))(ou) if flag: # noinspection PyUnboundLocalVariable tf.add_to_collection(name, effect) return effect
def test_multitower_multi_loss_function(self): """Test multitower setup with multiple loss functions. The automatic graph scanner should handle multiple loss functions per tower, as long as they're registered in a consistent order. """ with tf.Graph().as_default(): w_1 = tf.get_variable('w_1', shape=[10, 10]) b_1 = tf.get_variable('b_1', shape=[10]) w_2 = tf.get_variable('w_2', shape=[10, 10]) b_2 = tf.get_variable('b_2', shape=[10]) layer_collection = lc.LayerCollection() layer_collection_manual = lc.LayerCollection() for tower_num in range(5): x = tf.placeholder(tf.float32, shape=(32, 10)) logits_1 = tf.matmul(x, w_1) + b_1 logits_2 = tf.matmul(x, w_2) + b_2 if tower_num == 0: reuse = False else: reuse = True with tf.variable_scope('tower%d' % tower_num, reuse=reuse): for l in [layer_collection, layer_collection_manual]: l.register_categorical_predictive_distribution( logits_1, name='loss_1') l.register_categorical_predictive_distribution( logits_2, name='loss_2') layer_collection_manual.register_fully_connected( (w_1, b_1), x, logits_1) layer_collection_manual.register_fully_connected( (w_2, b_2), x, logits_2) gs.register_layers( layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual)
def select_black_ke(board, direction): """ 動けるKEが1、それ以外は0 :param board: :param direction: :return: """ if direction not in (Direction.RIGHT_UP_UP, Direction.LEFT_UP_UP): raise ValueError(direction) # どちらの方向でも帰ってくる値は同じ name = 'black_ke_piece' collection = tf.get_collection_ref(name) if len(collection): return collection[0] ke = select_black_piece(board=board, piece=Piece.BLACK_KE, direction=direction) tf.add_to_collection(name, ke) return ke
def build_savers(self): """Create tf.train.Saver instances.""" all_saveable_vars = sorted( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS) + tf.get_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES) + tf.get_collection_ref('batch_norm_non_trainable'), key=lambda v: v.name) all_prefixes = [] for schedule in self._model._learning_schedule: for prefixes in schedule['loss_terms_to_optimize'].values(): all_prefixes += prefixes all_prefixes = list(set(all_prefixes)) # For each prefix, create saver self._savers = {} for prefix in all_prefixes: vars_to_save = [ v for v in all_saveable_vars if v.name.startswith(prefix + '/') ] if len(vars_to_save): self._savers[prefix] = tf.train.Saver(vars_to_save, max_to_keep=5)
def make_black_direction_table(direction): if direction in (Direction.RIGHT_DOWN_DOWN, Direction.LEFT_DOWN_DOWN): raise ValueError(direction) name = 'black_naive_direction_table_{}'.format(direction.name) collection = tf.get_collection_ref(name) if len(collection): return collection[0] base = make_base_table() table = np.zeros(Piece.SIZE + 4 * 14) table[Piece.BLACK_FU:Piece.WHITE_FU] = base[direction] if direction not in (Direction.RIGHT_UP_UP, Direction.LEFT_UP_UP): offset = Piece.SIZE + PinDirection[direction.name] * 14 table[offset:offset + 14] = base[direction] table = tf.constant(table, dtype=tf.float32) tf.add_to_collection(name, table) return table
def mixed_usage_test(self): """Tests that graph search raises error on mixed types usage for tensors. Tensors can be reused in various locations in the tensorflow graph. This occurs regularly in the case of recurrent models or models with parallel graphs. However the tensors must be used for the same operation in each location or graph search should raise an error. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10, 10)) out_0 = tf.matmul(x, w) # pylint: disable=unused-variable out_1 = y + w # pylint: disable=unused-variable layer_collection = lc.LayerCollection() with self.assertRaises(ValueError) as cm: gs.register_layers(layer_collection, tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)) self.assertIn('mixed record types', str(cm.exception))
def load_pretrained(self,sess, ckpt_file): ## When I swicth from unsupervised learning to supervised learning, no supervised layer weights in the model. ## So load the trained layer weights are stored while supervised layer weights are ignored all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) reader = pywrap_tensorflow.NewCheckpointReader(ckpt_file) ckpt_vars = reader.get_variable_to_shape_map() ckpt_variables = ckpt_vars.keys() common_variables = []; for v_old in ckpt_variables: for v in all_variables: if(v.name[:-2] == v_old): if(v_old == 'supervised_fc/supervised_fc_256/bias' or v_old == 'supervised_fc/supervised_fc_256/kernel' or v_old == 'supervised_fc/supervised_fc/kernel' or v_old == 'supervised_fc/supervised_fc/bias'): break; else: common_variables.append(v); break; temp_saver = tf.train.Saver(var_list=common_variables) temp_saver.restore(sess, ckpt_file)
def test_multi_time_batch_fold(self): """Test that graph search provides desired registration on toy model. In this toy example we apply the same linear layer to two different inputs. This tests whether graph search can correctly group them. Also tests whether batch/time folded is correctly registered as fully connected multi fisher blocks. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) b_0 = tf.get_variable('b_0', [ 10, ]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10)) out_0 = tf.matmul(x, w) + b_0 out_1 = tf.matmul(y, w) + b_0 layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_squared_error_loss(out_0) layer_collection_manual.register_squared_error_loss(out_1) layer_collection_manual.register_fully_connected_multi( (w, b_0), (x, y), (out_0, out_1), num_uses=2) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) gs.register_layers(layer_collection, tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES), batch_size=16) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual)
def run_maze(): step = 0 for episode in range(300): # initial observation observation = env.reset() while True: # fresh env env.render() # RL choose action based on observation action = RL.choose_action(observation) # RL take action and get next observation and reward observation_, reward, done = env.step(action) RL.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): RL.learn() # swap observation observation = observation_ # break while loop when end of this episode if done: break step += 1 # end of game print('game over') tf.Variable for wb in tf.get_collection_ref('target_net_params'): print(wb) print(wb.value()) env.destroy()
def test_subset_weights_manual_registration(self): """Test that graph search provides desired registration on toy model. In this toy example we apply the same matmul op to two different inputs followed by adding a bias to one of the inputs. This tests whether graph search can correctly group them. """ with tf.Graph().as_default(): w = tf.get_variable('W', [10, 10]) b_0 = tf.get_variable('b_0', [ 10, ]) x = tf.placeholder(tf.float32, shape=(32, 10)) y = tf.placeholder(tf.float32, shape=(32, 10)) out_n1 = tf.matmul(x, w) out_0 = out_n1 + b_0 out_1 = tf.matmul(y, w) layer_collection_manual = lc.LayerCollection() layer_collection_manual.register_fully_connected_multi( w, (x, y), (out_n1, out_1)) layer_collection_manual.register_generic(b_0, batch_size=1) layer_collection = lc.LayerCollection() layer_collection.register_squared_error_loss(out_0) layer_collection.register_squared_error_loss(out_1) layer_collection.define_linked_parameters(w) gs.register_layers(layer_collection, tf.get_collection_ref( tf.GraphKeys.GLOBAL_VARIABLES), batch_size=1) assert_fisher_blocks_match(self, layer_collection, layer_collection_manual)
def import_ops(self): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] self._probabilities = tf.get_collection_ref(util.with_prefix(self._name, "probabilities"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
def import_ops(self): """Import ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref('train_op')[0] self._lr = tf.get_collection_ref('lr')[0] self._new_lr = tf.get_collection_ref('new_lr')[0] self._lr_update = tf.get_collection_ref('lr_update')[0] rnn_params = tf.get_collection_ref('rnn_params') if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope='Model/RNN') tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, 'cost'))[0] num_replicas = FLAGS.num_gpus if self._name == 'Train' else 1 self._initial_state = util.import_state_tuples(self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples(self._final_state, self._final_state_name, num_replicas)
def import_ops(self): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
def import_ops(self, config): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") """ opaque_params, num_layers, num_units, input_size, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, scope=None, name='cudnn_rnn_saveable'""" import pdb pdb.set_trace() if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.CudnnLSTMSaveable( opaque_params=None, num_layers=config.num_layers, num_units=config.hidden_size, input_size=config.hidden_size, input_mode=CUDNN_INPUT_LINEAR_MODE, direction=CUDNN_RNN_UNIDIRECTION, scope="Model/RNN", name='cudnn_rnn_saveable') tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix( self._name, "cost"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples(self._final_state, self._final_state_name, num_replicas)
def import_ops(self): if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.RNNParamsSaveable( self._cell, self._cell.params_to_canonical, self._cell.canonical_to_params, rnn_params, base_variable_scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(with_prefix(self._name, "cost"))[0] num_replicas = 1 self._initial_state = import_state_tuples(self._initial_state, self._initial_state_name, num_replicas) self._final_state = import_state_tuples(self._final_state, self._final_state_name, num_replicas)
def import_ops(self,config): """Imports ops from collections.""" if self._is_training: self._train_op = tf.get_collection_ref("train_op")[0] self._lr = tf.get_collection_ref("lr")[0] self._new_lr = tf.get_collection_ref("new_lr")[0] self._lr_update = tf.get_collection_ref("lr_update")[0] rnn_params = tf.get_collection_ref("rnn_params") if self._cell and rnn_params: params_saveable = tf.contrib.cudnn_rnn.CudnnLSTMSaveable( rnn_params, config.num_layers, config.hidden_size, config.hidden_size, scope="Model/RNN") tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, params_saveable) self._cost = tf.get_collection_ref(util.with_prefix(self._name, "cost"))[0] num_replicas = FLAGS.num_gpus if self._name == "Train" else 1 self._initial_state = util.import_state_tuples( self._initial_state, self._initial_state_name, num_replicas) self._final_state = util.import_state_tuples( self._final_state, self._final_state_name, num_replicas)
def main(argv): argparser = argparse.ArgumentParser(description='Compile some op') argparser.add_argument('config', help="filename to config-file") argparser.add_argument('--train', type=int, default=0, help='0 disable (default), 1 enable, -1 dynamic') argparser.add_argument( '--eval', type=int, default=0, help='calculate losses. 0 disable (default), 1 enable') argparser.add_argument('--search', type=int, default=0, help='beam search. 0 disable (default), 1 enable') argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)") argparser.add_argument("--summaries_tensor_name", help="create Tensor for tf.summary.merge_all()") argparser.add_argument( "--rec_step_by_step", help="make step-by-step graph for this rec layer (eg. 'output')") argparser.add_argument("--rec_step_by_step_output_file", help="store meta info for rec_step_by_step (JSON)") argparser.add_argument( "--output_file", help='allowed extensions: pb, pbtxt, meta, metatxt, logdir') argparser.add_argument("--output_file_model_params_list", help="line-based, names of model params") argparser.add_argument("--output_file_state_vars_list", help="line-based, name of state vars") args = argparser.parse_args(argv[1:]) assert args.train in [0, 1, -1 ] and args.eval in [0, 1] and args.search in [0, 1] init(config_filename=args.config, log_verbosity=args.verbosity) assert 'network' in config.typed_dict net_dict = config.typed_dict["network"] if args.rec_step_by_step: RecStepByStepLayer.prepare_compile( rec_layer_name=args.rec_step_by_step, net_dict=net_dict) with tf.Graph().as_default() as graph: assert isinstance(graph, tf.Graph) print("Create graph...") # See :func:`Engine._init_network`. tf.set_random_seed(42) if args.train < 0: from TFUtil import get_global_train_flag_placeholder train_flag = get_global_train_flag_placeholder() else: train_flag = bool(args.train) eval_flag = bool(args.eval) search_flag = bool(args.search) network = create_graph(train_flag=train_flag, eval_flag=eval_flag, search_flag=search_flag, net_dict=net_dict) if args.rec_step_by_step: RecStepByStepLayer.post_compile( rec_layer_name=args.rec_step_by_step, network=network, output_file_name=args.rec_step_by_step_output_file) from TFNetworkLayer import LayerBase for layer in network.layers.values(): assert isinstance(layer, LayerBase) if layer.output.time_dim_axis is None: continue with layer.cls_layer_scope(layer.name): tf.identity(layer.output.get_placeholder_as_batch_major(), name="output_batch_major") tf.group(*network.get_post_control_dependencies(), name="post_control_dependencies") # Do some cleanup of collections which do not contain tensors or operations, # because the tf.train.import_meta_graph code might fail otherwise. tf.get_collection_ref(CollectionKeys.RETURNN_LAYERS).clear() if args.summaries_tensor_name: summaries_tensor = tf.summary.merge_all() assert isinstance(summaries_tensor, tf.Tensor), "no summaries in the graph?" tf.identity(summaries_tensor, name=args.summaries_tensor_name) if args.output_file and os.path.splitext( args.output_file)[1] in [".meta", ".metatxt"]: # https://www.tensorflow.org/api_guides/python/meta_graph saver = tf.train.Saver(var_list=network.get_saveable_params_list(), max_to_keep=2**31 - 1) graph_def = saver.export_meta_graph() else: graph_def = graph.as_graph_def(add_shapes=True) print("Graph collection keys:", graph.get_all_collection_keys()) print("Graph num operations:", len(graph.get_operations())) print("Graph def size:", Util.human_bytes_size(graph_def.ByteSize())) if args.output_file: filename = args.output_file _, ext = os.path.splitext(filename) if ext == ".logdir": print("Write TF events to logdir:", filename) writer = tf.summary.FileWriter(logdir=filename) writer.add_graph(graph) writer.flush() else: assert ext in [".pb", ".pbtxt", ".meta", ".metatxt" ], 'filename %r extension invalid' % filename print("Write graph to file:", filename) graph_io.write_graph(graph_def, logdir=os.path.dirname(filename), name=os.path.basename(filename), as_text=ext.endswith("txt")) else: print("Use --output_file if you want to store the graph.") if args.output_file_model_params_list: print("Write model param list to:", args.output_file_model_params_list) with open(args.output_file_model_params_list, "w") as f: for param in network.get_params_list(): assert param.name[-2:] == ":0" f.write("%s\n" % param.name[:-2]) if args.output_file_state_vars_list: print("Write state var list to:", args.output_file_state_vars_list) with open(args.output_file_state_vars_list, "w") as f: for param in tf.get_collection(CollectionKeys.STATE_VARS): assert param.name[-2:] == ":0" f.write("%s\n" % param.name[:-2])
def create_phases(inputs): """Returns a list of `Phase`s describing how to execute the pipeline. The default graph is assumed to contain some `Analyzer`s which must be executed by doing a full pass over the dataset, and passing the inputs for that analyzer into some implementation, then taking the results and replacing the `Analyzer`s outputs with constants in the graph containing these results. The execution plan is described by a list of `Phase`s. Each phase contains a list of `Analyzer`s, which are the `Analyzer`s which are ready to run in that phase, together with a list of ops, which are the table initializers that are ready to run in that phase. An `Analyzer` or op is ready to run when all its dependencies in the graph have been computed. Thus if the graph is constructed by def preprocessing_fn(input) x = inputs['x'] scaled_0 = x - tft.min(x) scaled_0_1 = scaled_0 / tft.max(scaled_0) Then the first phase will contain the analyzer corresponding to the call to `min`, because `x` is an input and so is ready to compute in the first phase, while the second phase will contain the analyzer corresponding to the call to `max` since `scaled_1` depends on the result of the call to `tft.min` which is computed in the first phase. More generally, we define a level for each op and each `Analyzer` by walking the graph, assigning to each operation the max level of its inputs, to each `Tensor` the level of its operation, unless it's the output of an `Analyzer` in which case we assign the level of its `Analyzer` plus one. The above description omits the role of `FunctionApplication`s. A `FunctionApplication` is a hint to create_phases about the control flow of the graph. Because control flow ops can introduce circular dependencies (and other circumstances such as mutable reference introduce similar problems) we allow users to construct a `FunctionApplication` which is a hint that the outputs `Tensor`s depend only on the input `Tensor`s. `FunctionApplication`s are also needed to collect table initializers to determine which phase a table initializer is ready to run in. Args: inputs: A dict whose keys are strings and values are `Tensor` or `SparseTensor`s. Returns: A list of `Phase`s. Raises: ValueError: if the graph cannot be analyzed. """ feed_tensors = inputs.values() remaining_analyzers = tf.get_collection(analyzers.ANALYZER_COLLECTION) analyzer_output_ready = {} for analyzer in remaining_analyzers: for tensor in analyzer.outputs: analyzer_output_ready[tensor] = False # Construct `AnalyzerInfo`s, removing any tensors that are analyzer outputs # from the ASSET_FILEPATHS collection. These tensors will be replaced and # the replacements will be added to the ASSET_FILEPATHS. Setting # AnalyzerOutputInfo.is_asset instructs the implementation to do this. asset_filepaths_collection = tf.get_collection_ref( tf.GraphKeys.ASSET_FILEPATHS) asset_filepaths = collections.OrderedDict( (tensor, True) for tensor in tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)) phases = [] while remaining_analyzers: analyzer_inputs = [] for analyzer in remaining_analyzers: analyzer_inputs.extend(analyzer.inputs) ready_init_ops, ready_analyzer_inputs = ( graph_tools.determine_ready_tensors_and_table_initializers( tf.get_default_graph(), analyzer_inputs, feed_tensors, analyzer_output_ready)) ready_analyzer_inputs = set(ready_analyzer_inputs) new_remaining_analyzers = [] analyzer_infos = [] for analyzer in remaining_analyzers: if all(tensor in ready_analyzer_inputs for tensor in analyzer.inputs): input_tensor_names = [ tensor.name for tensor in analyzer.inputs ] output_infos = [ AnalyzerOutputInfo(tensor.name, asset_filepaths.pop(tensor, False)) for tensor in analyzer.outputs ] analyzer_infos.append( AnalyzerInfo(analyzer.name, input_tensor_names, analyzer.spec, output_infos)) for tensor in analyzer.outputs: analyzer_output_ready[tensor] = True else: new_remaining_analyzers.append(analyzer) phases.append(Phase(analyzer_infos, ready_init_ops)) assert len(new_remaining_analyzers) < len(remaining_analyzers) remaining_analyzers = new_remaining_analyzers del asset_filepaths_collection[:] asset_filepaths_collection.extend(six.iterkeys(asset_filepaths)) return phases
def loadEmbedding(self, sess): """ Initialize embeddings with pre-trained word2vec vectors Will modify the embedding weights of the current loaded model Uses the GoogleNews pre-trained values (path hardcoded) """ # Fetch embedding variables from model with tf.variable_scope("embedding_rnn_seq2seq/rnn/embedding_wrapper", reuse=True): em_in = tf.get_variable("embedding") with tf.variable_scope("embedding_rnn_seq2seq/embedding_rnn_decoder", reuse=True): em_out = tf.get_variable("embedding") # Disable training for embeddings variables = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) variables.remove(em_in) variables.remove(em_out) # If restoring a model, we can leave here if self.globStep != 0: return # New model, we load the pre-trained word2vec data and initialize embeddings embeddings_path = os.path.join(self.rootDir, 'data', 'embeddings', self.embeddingSource) embeddings_format = os.path.splitext(embeddings_path)[1][1:] print("Loading pre-trained word embeddings from %s " % embeddings_path) with open(embeddings_path, "rb") as f: header = f.readline() vocab_size, vector_size = map(int, header.split()) binary_len = np.dtype('float32').itemsize * vector_size initW = np.random.uniform( -0.25, 0.25, (len(self.textData.word2id), vector_size)) for line in tqdm(range(vocab_size)): word = [] while True: ch = f.read(1) if ch == b' ': word = b''.join(word).decode('utf-8') break if ch != b'\n': word.append(ch) if word in self.textData.word2id: if embeddings_format == 'bin': vector = np.fromstring(f.read(binary_len), dtype='float32') elif embeddings_format == 'vec': vector = np.fromstring(f.readline(), sep=' ', dtype='float32') else: raise Exception("Unkown format for embeddings: %s " % embeddings_format) initW[self.textData.word2id[word]] = vector else: if embeddings_format == 'bin': f.read(binary_len) elif embeddings_format == 'vec': f.readline() else: raise Exception("Unkown format for embeddings: %s " % embeddings_format) # PCA Decomposition to reduce word2vec dimensionality if self.embeddingSize < vector_size: U, s, Vt = np.linalg.svd(initW, full_matrices=False) S = np.zeros((vector_size, vector_size), dtype=complex) S[:vector_size, :vector_size] = np.diag(s) initW = np.dot(U[:, :self.embeddingSize], S[:self.embeddingSize, :self.embeddingSize]) # Initialize input and output embeddings sess.run(em_in.assign(initW)) sess.run(em_out.assign(initW))
def create_module_spec(module_fn, tags_and_args=None, drop_collections=None): """Creates a ModuleSpec from a function that builds the module's graph. The `module_fn` is called on a new graph (not the current one) to build the graph of the module and define its signatures via `hub.add_signature()`. Example: ```python # Define a text embedding module. def my_text_module_fn(): text_input = tf.placeholder(dtype=tf.string, shape=[None]) embeddings = compute_embedding(text) hub.add_signature(inputs=text_input, outputs=embeddings) ``` See `add_signature()` for documentation on adding multiple input/output signatures. NOTE: In anticipation of future TF-versions, `module_fn` is called on a graph that uses resource variables by default. If you want old-style variables then you can use `with tf.variable_scope("", use_resource=False)` in `module_fn`. Multiple graph variants can be defined by using the `tags_and_args` argument. For example, the code: ```python hub.create_module_spec( module_fn, tags_and_args=[({"train"}, {"is_training":True}), (set(), {"is_training":False})]) ``` calls `module_fn` twice, once as `module_fn(is_training=True)` and once as `module_fn(is_training=False)` to define the respective graph variants: for training with tags {"train"} and for inference with the empty set of tags. Using the empty set aligns the inference case with the default in Module.__init__(). Args: module_fn: a function to build a graph for the Module. tags_and_args: Optional list of tuples (tags, kwargs) of tags and keyword args used to define graph variants. If omitted, it is interpreted as [set(), {}], meaning `module_fn` is called once with no args. drop_collections: list of collection to drop. Returns: A ModuleSpec. Raises: ValueError: if it fails to construct the ModuleSpec due to bad or unsupported values in the arguments or in the graphs constructed by `module_fn`. """ if not drop_collections: drop_collections = [] report_tags = True if not tags_and_args: tags_and_args = [(set(), {})] report_tags = False saved_model_handler = saved_model_lib.SavedModelHandler() for tags, args in tags_and_args: with tf.Graph().as_default() as graph: with tf.variable_scope("", use_resource=True): module_fn(**args) for collection_key in drop_collections: del tf.get_collection_ref(collection_key)[:] err = find_state_op_colocation_error(graph, tags if report_tags else None) if err: raise ValueError(err) saved_model_handler.add_graph_copy(graph, tags=tags) return _ModuleSpec(saved_model_handler, checkpoint_variables_path=None)
def clear_tf_collections(self): for key in self.tf_collections: tf.get_collection_ref(key).clear()
def model_fn(model, features, mode, hparams, problem_names, train_steps=100000, worker_id=0, worker_replicas=1, eval_run_autoregressive=False, decode_hparams=None): """Builds the model for all modes. * TRAIN: Constructs loss and train_op * EVAL: Constructs the loss and eval metrics * PREDICT: Constructs the predictions Args: model: str, name of model. features: dict<feature name, Tensor>. Expected to have keys {inputs, targets, problem_choice}. mode: tf.estimator.ModeKeys. hparams: model HParams. problem_names: list of str, names of the problems. train_steps: int, total number of training steps. Used to compute learning rate decay. worker_id: int, id of this worker. worker_replicas: int, number of workers. eval_run_autoregressive: bool, whether to run evaluation autoregressively. decode_hparams: HParams for decode settings. Used when mode == PREDICT. Returns: tf.estimator.EstimatorSpec """ assert len(problem_names) == len(hparams.problem_instances) decode_hp = decode_hparams # TODO(rsepassi): This still depends on FLAGS. Rm eventually. dp = devices.data_parallelism(hparams) tf.get_variable_scope().set_initializer(_get_variable_initializer(hparams)) is_training = mode == tf.estimator.ModeKeys.TRAIN # Add input statistics for incoming features. with tf.name_scope("input_stats"): for (k, v) in six.iteritems(features): if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1: tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // dp.n) tf.summary.scalar("%s_length" % k, tf.shape(v)[1]) nonpadding = tf.to_float(tf.not_equal(v, 0)) nonpadding_tokens = tf.reduce_sum(nonpadding) if k == "targets": targets_nonpadding_tokens = nonpadding_tokens tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens) tf.summary.scalar("%s_nonpadding_fraction" % k, tf.reduce_mean(nonpadding)) # Get multi-problem logits and loss based on features["problem_choice"]. loss_variable_names = [] def nth_model(n): """Build the model for the n-th problem, plus some added variables.""" model_class = registry.model(model)( hparams, mode, hparams.problems[n], n, dp, devices.ps_devices(all_workers=True), decode_hparams=decode_hparams) if mode == tf.estimator.ModeKeys.PREDICT: return model_class.infer( features, beam_size=decode_hp.beam_size, top_beams=(decode_hp.beam_size if decode_hp.return_beams else 1), alpha=decode_hp.alpha, decode_length=decode_hp.extra_length) # In distributed mode, we build graph for problem=0 and problem=worker_id. skipping_is_on = hparams.problem_choice == "distributed" and is_training problem_worker_id = worker_id % len(hparams.problems) skip_this_one = n != 0 and n % worker_replicas != problem_worker_id # On worker 0 also build graph for problems <= 1. # TODO(lukaszkaiser): why is this hack needed for variables init? Repair. skip_this_one = skip_this_one and (worker_id != 0 or n > 1) if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL: logits, losses_dict = model_class.eval_autoregressive(features) else: logits, losses_dict = model_class( features, skip=(skipping_is_on and skip_this_one)) with tf.variable_scope("losses_avg"): total_loss, ops = 0.0, [] for loss_key, loss_value in six.iteritems(losses_dict): loss_name = "problem_%d/%s_loss" % (n, loss_key) loss_moving_avg = tf.get_variable( loss_name, initializer=100.0, trainable=False) loss_variable_names.append(loss_name) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + loss_value * 0.1)) total_loss += loss_value try: # Total loss avg might be reused or not, we try both. with tf.variable_scope(tf.get_variable_scope(), reuse=True): # Total loss was already constructed on input. loss_moving_avg = tf.get_variable("problem_%d/total_loss" % n) except ValueError: loss_moving_avg = tf.get_variable( "problem_%d/total_loss" % n, initializer=100.0, trainable=False) ops.append( loss_moving_avg.assign(loss_moving_avg * 0.9 + total_loss * 0.1)) with tf.variable_scope("train_stats"): # Count steps for this problem. problem_steps = tf.get_variable( "problem_%d_steps" % n, initializer=0, trainable=False) ops.append(problem_steps.assign_add(1)) with tf.control_dependencies(ops): # Make sure the ops run. # Ensure the loss is a scalar here. total_loss = tf.reshape(total_loss, [], name="total_loss_control_id") return [total_loss, logits] model_output = input_fn_builder.cond_on_index( nth_model, index_tensor=features["problem_choice"], max_idx=len(hparams.problems) - 1) if mode == tf.estimator.ModeKeys.PREDICT: # If beam searching, model_output will be a dict with keys "outputs" and # "scores". if isinstance(model_output, dict): outputs = model_output["outputs"] scores = model_output["scores"] else: outputs = model_output scores = None batched_problem_choice = ( features["problem_choice"] * tf.ones( (tf.shape(features["inputs"])[0],), dtype=tf.int32)) predictions = { "outputs": outputs, "scores": scores, "inputs": features.get("inputs", None), "targets": features.get("infer_targets", None), "problem_choice": batched_problem_choice, } _del_dict_nones(predictions) export_out = {"outputs": predictions["outputs"]} if "scores" in predictions: export_out["scores"] = predictions["scores"] return tf.estimator.EstimatorSpec( mode, predictions=predictions, export_outputs={ "output": tf.estimator.export.PredictOutput(export_out) }) total_loss, logits = model_output if mode == tf.estimator.ModeKeys.EVAL: eval_metrics_fns = metrics.create_evaluation_metrics( hparams.problem_instances, hparams) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): eval_metrics[metric_name] = metric_fn(logits, features) return tf.estimator.EstimatorSpec( mode, predictions={"predictions": logits}, eval_metric_ops=eval_metrics, loss=total_loss) assert mode == tf.estimator.ModeKeys.TRAIN # Set learning rate learning_rate = hparams.learning_rate * optimize.learning_rate_decay( hparams, num_worker_replicas=worker_replicas, num_train_steps=train_steps) learning_rate /= math.sqrt(float(worker_replicas)) # Get global step global_step = tf.train.get_or_create_global_step() # Some training statistics. with tf.name_scope("training_stats"): tf.summary.scalar("learning_rate", learning_rate) for n in xrange(len(hparams.problems)): names_and_vars = [] with tf.variable_scope("losses_avg", reuse=True): total_loss_var = tf.get_variable("problem_%d/total_loss" % n) names_and_vars.append(("total_loss", total_loss_var)) with tf.variable_scope("losses_avg", reuse=True): for loss_name in loss_variable_names: if loss_name.startswith("problem_%d/" % n): loss_var = tf.get_variable(loss_name) loss_suffix = loss_name[loss_name.index("/") + 1:] names_and_vars.append((loss_suffix, loss_var)) for (loss_name, loss_var) in names_and_vars: tf.summary.scalar("loss_avg_%d/%s" % (n, loss_name), loss_var) with tf.variable_scope("train_stats", reuse=True): nth_steps = tf.get_variable("problem_%d_steps" % n, dtype=tf.int32) tf.summary.scalar("problem_%d_frequency" % n, tf.to_float(nth_steps) / (tf.to_float(global_step) + 1.0)) # Add weight decay and noise. total_size, weight_decay_loss = 0, 0.0 all_weights = {v.name: v for v in tf.trainable_variables()} for v_name in sorted(list(all_weights)): v = all_weights[v_name] v_size = int(np.prod(np.array(v.shape.as_list()))) total_size += v_size if hparams.weight_decay > 0.0 and len(v.shape.as_list()) > 1: # Add weight regularization if set and the weight is not a bias (dim>1). with tf.device(v._ref().device): # pylint: disable=protected-access v_loss = tf.nn.l2_loss(v) / v_size weight_decay_loss += v_loss is_body = len(v_name) > 5 and v_name[:5] == "body/" if hparams.weight_noise > 0.0 and is_body: # Add weight noise if set in hparams. with tf.device(v._ref().device): # pylint: disable=protected-access scale = learning_rate * 0.001 noise = tf.truncated_normal(v.shape) * hparams.weight_noise * scale noise_op = v.assign_add(noise) with tf.control_dependencies([noise_op]): total_loss = tf.identity(total_loss) if hparams.weight_decay > 0.0: total_loss += weight_decay_loss * hparams.weight_decay # The new data reader occasionally emits very small batches, which # cause the examples in those batches to be grossly overweighted. # We decrease the loss proportionally to the ratio of the size of this # batch to the size of the largest training batch ever. # TODO(noam): to be more sophisticated, we could keep separate # maxima based on problem choice. max_nonpadding_var = tf.get_variable( "max_nonpadding", shape=[], initializer=tf.ones_initializer(), trainable=False) max_nonpadding = tf.maximum(max_nonpadding_var, targets_nonpadding_tokens) with tf.control_dependencies([tf.assign(max_nonpadding_var, max_nonpadding)]): small_batch_multiplier = targets_nonpadding_tokens / max_nonpadding tf.summary.scalar("small_batch_multiplier", small_batch_multiplier) total_loss *= small_batch_multiplier # Log variable sizes _log_variable_sizes(tf.trainable_variables(), "Trainable Variables") diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] _log_variable_sizes(diet_vars, "Diet Variables") # Optimize train_op = optimize.optimize(total_loss, learning_rate, hparams) # Remove summaries that will fail to run because they are in conditionals. # TODO(cwhipkey): Test with this code removed, later in 2017. summaries = tf.get_collection_ref(tf.GraphKeys.SUMMARIES) for i in reversed(range(len(summaries))): if summaries[i].name.startswith("cond_"): del summaries[i] tf.logging.info("Global model_fn finished.") return tf.estimator.EstimatorSpec( mode, predictions={"problem_choice": features["problem_choice"]}, loss=total_loss, train_op=train_op)
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k): with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file: video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size) checkpoint_file = os.path.join(FLAGS.train_dir, "inference_model") if not gfile.Exists(checkpoint_file + ".meta"): raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file) meta_graph_location = checkpoint_file + ".meta" logging.info("loading meta-graph: " + meta_graph_location) if FLAGS.output_model_tgz: with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar: for model_file in glob.glob(checkpoint_file + '.*'): tar.add(model_file, arcname=os.path.basename(model_file)) tar.add(os.path.join(FLAGS.train_dir, "model_flags.json"), arcname="model_flags.json") print('Tarred model onto ' + FLAGS.output_model_tgz) with tf.device("/cpu:0"): saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) logging.info("restoring variables from " + checkpoint_file) saver.restore(sess, checkpoint_file) input_tensor = tf.get_collection("input_batch_raw")[0] num_frames_tensor = tf.get_collection("num_frames")[0] predictions_tensor = tf.get_collection("predictions")[0] # Workaround for num_epochs issue. def set_up_init_ops(variables): init_op_list = [] for variable in list(variables): if "train_input" in variable.name: init_op_list.append(tf.assign(variable, 1)) variables.remove(variable) init_op_list.append(tf.variables_initializer(variables)) return init_op_list sess.run(set_up_init_ops(tf.get_collection_ref( tf.GraphKeys.LOCAL_VARIABLES))) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) num_examples_processed = 0 start_time = time.time() out_file.write("VideoId,LabelConfidencePairs\n") try: while not coord.should_stop(): video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch]) predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val}) now = time.time() num_examples_processed += len(video_batch_val) num_classes = predictions_val.shape[1] logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time)) for line in format_lines(video_id_batch_val, predictions_val, top_k): out_file.write(line) out_file.flush() except tf.errors.OutOfRangeError: logging.info('Done with inference. The output file was written to ' + out_file_location) finally: coord.request_stop() coord.join(threads) sess.close()
def clear_tf_collections(self): super(Advanced, self).clear_tf_collections() for key in self.tf_list_collections: tf.get_collection_ref(key).clear()
def custom_getter(getter, *args, **kwargs): out = getter(*args, **kwargs) ref = tf.get_collection_ref(self.graph_collection_name) if out not in ref: ref.append(out) return out
def loadEmbedding(self, sess): """ Initialize embeddings with pre-trained word2vec vectors Will modify the embedding weights of the current loaded model Uses the GoogleNews pre-trained values (path hardcoded) """ # Fetch embedding variables from model with tf.variable_scope("embedding_rnn_seq2seq/rnn/embedding_wrapper", reuse=True): em_in = tf.get_variable("embedding") with tf.variable_scope("embedding_rnn_seq2seq/embedding_rnn_decoder", reuse=True): em_out = tf.get_variable("embedding") # Disable training for embeddings variables = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) variables.remove(em_in) variables.remove(em_out) # If restoring a model, we can leave here if self.globStep != 0: return # New model, we load the pre-trained word2vec data and initialize embeddings embeddings_path = os.path.join(self.args.rootDir, 'data', 'embeddings', self.args.embeddingSource) embeddings_format = os.path.splitext(embeddings_path)[1][1:] print("Loading pre-trained word embeddings from %s " % embeddings_path) with open(embeddings_path, "rb") as f: header = f.readline() vocab_size, vector_size = map(int, header.split()) binary_len = np.dtype('float32').itemsize * vector_size initW = np.random.uniform(-0.25,0.25,(len(self.textData.word2id), vector_size)) for line in tqdm(range(vocab_size)): word = [] while True: ch = f.read(1) if ch == b' ': word = b''.join(word).decode('utf-8') break if ch != b'\n': word.append(ch) if word in self.textData.word2id: if embeddings_format == 'bin': vector = np.fromstring(f.read(binary_len), dtype='float32') elif embeddings_format == 'vec': vector = np.fromstring(f.readline(), sep=' ', dtype='float32') else: raise Exception("Unkown format for embeddings: %s " % embeddings_format) initW[self.textData.word2id[word]] = vector else: if embeddings_format == 'bin': f.read(binary_len) elif embeddings_format == 'vec': f.readline() else: raise Exception("Unkown format for embeddings: %s " % embeddings_format) # PCA Decomposition to reduce word2vec dimensionality if self.args.embeddingSize < vector_size: U, s, Vt = np.linalg.svd(initW, full_matrices=False) S = np.zeros((vector_size, vector_size), dtype=complex) S[:vector_size, :vector_size] = np.diag(s) initW = np.dot(U[:, :self.args.embeddingSize], S[:self.args.embeddingSize, :self.args.embeddingSize]) # Initialize input and output embeddings sess.run(em_in.assign(initW)) sess.run(em_out.assign(initW))
def create_selection_weights(name, type_, shape, inv_t=1, initializer=tf.zeros_initializer(), regularizer=None, names=None): """Create a SelectionWeights tuple. Args: name: Name for the underlying variable containing the unnormalized weights. type_: "softmax" or "sigmoid" or ("softmax_topk", k) where k is an int. shape: Shape for the variable. inv_t: Inverse of the temperature to use in normalization. initializer: Initializer for the variable, passed to `tf.get_variable`. regularizer: Regularizer for the variable. A callable which accepts `tempered_var` and `normalized`. names: Name of each selection. Returns: The created SelectionWeights tuple. Raises: ValueError: if type_ is not in the supported range. """ var = tf.get_variable(name, shape, initializer=initializer) if callable(inv_t): inv_t = inv_t(var) if inv_t == 1: tempered_var = var else: tempered_var = var * inv_t if type_ == "softmax": weights = tf.nn.softmax(tempered_var) elif type_ == "sigmoid": weights = tf.nn.sigmoid(tempered_var) elif isinstance(type_, (list, tuple)) and type_[0] == "softmax_topk": assert len(shape) == 1 # TODO(rshin): Change this to select without replacement? selection = tf.multinomial(tf.expand_dims(var, axis=0), 4) selection = tf.squeeze(selection, axis=0) # [k] selected classes. to_run = tf.one_hot(selection, shape[0]) # [k x nmodules] one-hot. # [nmodules], 0=not run, 1=run. to_run = tf.minimum(tf.reduce_sum(to_run, axis=0), 1) weights = tf.nn.softmax(tempered_var - 1e9 * (1.0 - to_run)) else: raise ValueError("Unknown type: %s" % type_) if regularizer is not None: loss = regularizer(tempered_var, weights) if loss is not None: tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, loss) if names is not None: tf.get_collection_ref("selection_weight_names/" + var.name).extend( names.flatten() if isinstance(names, np.ndarray) else names) tf.add_to_collection("selection_weight_names_tensor/" + var.name, tf.constant(names)) return SelectionWeights( var=var, tempered_var=tempered_var, inv_t=inv_t, normalized=weights)
def _clear_trainable_variables(): del tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)[:]
def __init__(self, train_ops, graph=None, clip_gradients=5.0, tensorboard_dir="/tmp/tflearn_logs/", tensorboard_verbose=0, checkpoint_path=None, best_checkpoint_path=None, max_checkpoints=None, keep_checkpoint_every_n_hours=10000.0, random_seed=None, session=None, best_val_accuracy=0.0): self.graph = tf.get_default_graph() self.summ_writer = None if graph: self.graph = graph with self.graph.as_default(): init_training_mode() train_ops = to_list(train_ops) duplicate_identical_ops(train_ops) if random_seed: tf.set_random_seed(random_seed) self.restored = False self.tensorboard_dir = check_dir_name(tensorboard_dir) self.training_state = TrainingState() self.train_ops = to_list(train_ops) self.validate_trainop_names() self.global_step = tf.Variable(0., name='Global_Step', trainable=False) self.incr_global_step = tf.assign(self.global_step, tf.add(self.global_step, 1)) self.best_val_accuracy = best_val_accuracy self.best_checkpoint_path = best_checkpoint_path config = None tflearn_conf = tf.get_collection(tf.GraphKeys.GRAPH_CONFIG) if tflearn_conf: config = tflearn_conf[0] if not session: self.session = tf.Session(config=config) else: self.session = session self.restored = True self.coord = tf.train.Coordinator() for i, train_op in enumerate(self.train_ops): # For display simplicity in Tensorboard, if only one optmizer, # we don't display its name if len(train_ops) == 1: train_op.scope_name = "" train_op.initialize_training_ops(i, self.session, tensorboard_verbose, clip_gradients) # Saver for saving a model self.saver = tf.train.Saver( max_to_keep=max_checkpoints, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, allow_empty=True) # Saver for saving a best validation accuracy model if self.best_checkpoint_path: self.val_saver = tf.train.Saver( max_to_keep=1, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, allow_empty=True) # Saver for restoring a model (With exclude variable list) all_vars = variables.get_all_variables() excl_vars = tf.get_collection(tf.GraphKeys.EXCL_RESTORE_VARS) to_restore = [item for item in all_vars if check_restore_tensor(item, excl_vars)] self.restorer = tf.train.Saver( var_list=to_restore, max_to_keep=max_checkpoints, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, allow_empty=True) # A second Saver, that only restore trainable variables to_restore_trainvars = [item for item in tf.trainable_variables() if check_restore_tensor(item, excl_vars)] self.restorer_trainvars = tf.train.Saver( var_list=to_restore_trainvars, max_to_keep=max_checkpoints, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, allow_empty=True) self.to_restore = to_restore self.to_restore_trainvars = to_restore_trainvars self.checkpoint_path = checkpoint_path if not self.restored: # TF 0.12 fix try: init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) self.session.run(tf.variables_initializer( tf.get_collection_ref('is_training'))) except Exception as e: init = tf.initialize_all_variables() self.session.run(init)