def loadNetwork(path, sess, model_name): img = tf.placeholder(dtype = tf.float32, shape = (None, None, None, 3)) with tf.variable_scope(model_name): pred = inference(img, 68 if model_name=='my_model' else 17) # 为什么是68,github上作者解释过了,模型是68,但是实际使用了前34(论文中提到是34) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) # 设置Tensorboard writer=tf.summary.FileWriter("TensorBoard",sess.graph) variables_to_restore = tf.global_variables() dic = {} for i in variables_to_restore: if 'global_step' not in i.name and 'Adam' not in i.name: dic[str(i.op.name).replace(model_name+'/', 'my_model/')] = i init_fn = assign_from_checkpoint_fn(os.path.join(path, 'snapshot'), dic, ignore_missing_vars = True) init_fn(sess) def func(imgs): output = sess.run(pred, feed_dict={img: imgs}) print('-------output.shape-------') print(output.shape) # (2, 200, 200, 68) return { 'det': output[:,:,:,:17], # 前17个 #detection scores 'tag': output[:,:,:,-17:] # 后17个 #identity tags } return func
def post_build(self, sess, rng): print "get_variables_to_restore():" print get_variables_to_restore() print "get_model_variables():" print get_model_variables() restore = assign_from_checkpoint_fn(ck_name, get_variables_to_restore()[2:]) restore(sess)
def convolute_and_save(module_path, signature, export_path, transform_fn, transform_checkpoint_path, new_signature=None): """Loads TFHub module, convolutes it with transform_fn and saves it again. Args: module_path: String with path from which the module is constructed. signature: String with name of signature to use for loaded module. export_path: String with path where to save the final TFHub module. transform_fn: Function that creates the graph to be appended to the loaded TFHub module. The function should take as keyword arguments the tensors returned by the loaded TFHub module. The function should return a dictionary of tensor that will be the output of the new TFHub module. transform_checkpoint_path: Path to checkpoint from which the transformer_fn variables will be read. new_signature: String with new name of signature to use for saved module. If None, `signature` is used instead. """ if new_signature is None: new_signature = signature # We create a module_fn that creates the new TFHub module. def module_fn(): module = hub.Module(module_path) inputs = _placeholders_from_module(module, signature=signature) intermediate_tensor = module(inputs, signature=signature, as_dict=True) # We need to scope the variables that are created when the transform_fn is # applied. with tf.variable_scope("transform"): outputs = transform_fn(**intermediate_tensor) hub.add_signature(name=new_signature, inputs=inputs, outputs=outputs) # We create a new graph where we will build the module for export. with tf.Graph().as_default(): # Create the module_spec for the export. spec = hub.create_module_spec(module_fn) m = hub.Module(spec, trainable=True) # We need to recover the scoped variables and remove the scope when loading # from the checkpoint. prefix = "transform/" transform_variables = { k[len(prefix):]: v for k, v in m.variable_map.items() if k.startswith(prefix) } if transform_variables: init_fn = contrib_framework.assign_from_checkpoint_fn( transform_checkpoint_path, transform_variables) with tf.Session() as sess: # Initialize all variables, this also loads the TFHub parameters. sess.run(tf.global_variables_initializer()) # Load the transformer variables from the checkpoint. if transform_variables: init_fn(sess) # Export the new TFHub module. m.export(export_path, sess)
def train(data_dir, train_dir, max_steps, log_frequency): with tf.Graph().as_default(): # Set up the data loading: images, labels = load_data.read_tfrecord(data_dir, image_size=IMAGE_SIZE, is_train=True, batch_size=BATCH_SIZE, num_classes=NUM_CLASS) # Define the model: predictions, end_points = inception_v3.inception_v3( inputs=images, num_classes=NUM_CLASS, is_training=True) summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) # Specify the loss function: slim.losses.softmax_cross_entropy(predictions, labels) for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) total_loss = slim.losses.get_total_loss() tf.summary.scalar('losses/total_loss', total_loss) # Specify the optimization scheme: optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) # create_train_op that ensures that when we evaluate it to get the loss, # the update_ops are done and the gradient updates are computed. train_tensor = slim.learning.create_train_op(total_loss, optimizer) # 使用预训练模型 checkpoint_path = '../train_log/inceptions_v3/inception_v3.ckpt' # Read data from checkpoint file reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() exclude_var = [] for key in var_to_shape_map: if 'InceptionV3/Logits' in key or 'InceptionV3/AuxLogits' in key: exclude_var.append(key) variables_to_restore = slim.get_variables_to_restore( exclude=exclude_var) init_fn = assign_from_checkpoint_fn(checkpoint_path, variables_to_restore) # 循环训练 slim.learning.train(train_tensor, train_dir, init_fn=init_fn, number_of_steps=max_steps, save_summaries_secs=log_frequency, save_interval_secs=100, log_every_n_steps=log_frequency)
def get_init_fn_v2(pretrained_model_file: str, exclusion_list: list, logger=logging.getLogger("get_init_fn_v2"), index_add: int = 0, model_root_name: str = None): checkpoint_reader = framework.load_checkpoint(pretrained_model_file) ckpt_variable_shape_map = checkpoint_reader.get_variable_to_shape_map() model_variable_to_restore = list_all_var() # Variable filtering by given exclude_scopes. model_variables_should_be_restore = {} for v in model_variable_to_restore: excluded = False for exclusion in exclusion_list: if v.name.startswith(exclusion): excluded = True break var_name = v.name.split(':')[0] if not excluded: model_variables_should_be_restore[var_name] = v else: logger.warning( 'Skip init {}(name in model) because it is excluded by user.'. format(var_name)) # Final filter by checking shape matching and skipping variables that # are not in the checkpoint. ckpt_variables_used = {} model_var_skip_list = [] util = TensorSortUtils(model_var_dict=model_variables_should_be_restore, ckpt_var_dict=ckpt_variable_shape_map, index_add=index_add, root_name=model_root_name) for var_name, var_tensor in model_variables_should_be_restore.items(): ckpt_tensor_name, error_info = util.get_tensor_in_checkpoint( var_tensor) if not ckpt_tensor_name: logger.warning(error_info) model_var_skip_list.append(var_name) else: ckpt_variables_used[ckpt_tensor_name] = var_tensor if model_var_skip_list: logger.warning("var(name in ckpt file) skip with unexpect: {}".format( model_var_skip_list)) logger.info("var in ckpt file is : {}".format( [var_name for var_name in ckpt_variable_shape_map])) logger.info("var(name in ckpt file) restored is : {}".format( [var_name for var_name in ckpt_variables_used])) return framework.assign_from_checkpoint_fn(pretrained_model_file, ckpt_variables_used)
def post_build(self, sess, rng): """ this function is highly critical, several ways of restoring variables from checkpoint did not work note that get_model_variables() returrns an empty list, so cannot be used, using get_variables_to_restore() the list has few spurious variables that are not found in the checkpoint (and are of no use) current solution is to import from flat_cifar a list of the scopes used when defining the network """ scopes = flat_cifar.scopes print "get_variables_to_restore():" print get_variables_to_restore() var_lst = get_variables_to_restore(include=scopes) print "get_variables_to_restore( include=scopes ):" print var_lst restore = assign_from_checkpoint_fn(ck_name, var_lst) restore(sess)
def loadNetwork(path, sess, model_name): img = tf.placeholder(dtype = tf.float32, shape = (None, None, None, 3)) with tf.variable_scope(model_name): pred = inference(img, 68 if model_name=='my_model' else 17) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) variables_to_restore = tf.global_variables() dic = {} for i in variables_to_restore: if 'global_step' not in i.name and 'Adam' not in i.name: dic[str(i.op.name).replace(model_name+'/', 'my_model/')] = i init_fn = assign_from_checkpoint_fn(os.path.join(path, 'snapshot'), dic, ignore_missing_vars = True) init_fn(sess) def func(imgs): output = sess.run(pred, feed_dict={img: imgs}) return { 'det': output[:,:,:,:17], 'tag': output[:,:,:,-17:] } return func
def begin(self): ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY) variables_to_restore = ema.variables_to_restore() self._load_ema = contrib_framework.assign_from_checkpoint_fn( tf.train.latest_checkpoint(self._model_dir), variables_to_restore)
def train_q(dataset, policy, optimizer=None, pack_transition_fn=None, q_graph_fn=None, log_dir=None, master='', task=0, training_steps=None, max_training_steps=100000, reuse=False, init_checkpoint=None, update_target_every_n_steps=50, log_every_n_steps=None, save_checkpoint_steps=500, save_summaries_steps=500): """Self-contained learning loop for offline Q-learning. Code inspired by OpenAI Baselines' deepq.build_train. This function is compatible with discrete Q-learning graphs, continuous Q learning graphs, and SARSA. Args: dataset: tf.data.Dataset providing transitions. policy: Instance of TFDQNPolicy class that provides functor for building the critic function. optimizer: Optional instance of an optimizer. If not specified, creates an AdamOptimizer using the default constructor. pack_transition_fn: Optional function that performs additional processing of the transition. This is a convenience method for ad-hoc manipulation of transition data passed to the learning function after parsing. q_graph_fn: Function used to construct training objectives w.r.t. critic outputs. log_dir: Where to save model checkpoints and tensorboard summaries. master: Optional address of master worker. Specify this when doing distributed training. task: Optional worker task for distributed training. Defaults to solo master task on a single machine. training_steps: Optional number of steps to run training before terminating early. Max_training_steps remains unchanged - training will terminate after max_training_steps whether or not training_steps is specified. max_training_steps: maximum number of training iters. reuse: If True, reuse existing variables for all declared variables by this function. init_checkpoint: Optional checkpoint to restore prior to training. If not provided, variables are initialized using global_variables_initializer(). update_target_every_n_steps: How many global steps (training) between copying the Q network weights (scope='q_func') to target network (scope='target_q_func'). log_every_n_steps: How many global steps between logging loss tensors. save_checkpoint_steps: How many global steps between saving TF variables to a checkpoint file. save_summaries_steps: How many global steps between saving TF summaries. Returns: (int) Current `global_step` reached after training for training_steps, or `max_training_steps` if `global_step` has reached `max_training_steps`. Raises: ValueError: If a batch of transitions is empty or the zeroth element is empty, when it's supposed to be of length batch_size. """ data_iterator = dataset.make_one_shot_iterator() transition = data_iterator.get_next() if pack_transition_fn: transition = pack_transition_fn(transition) if optimizer is None: optimizer = tf.train.AdamOptimizer() q_func = policy.get_q_func(is_training=True, reuse=reuse) loss, all_summaries = q_graph_fn(q_func, transition) q_func_vars = contrib_framework.get_trainable_variables(scope='q_func') target_q_func_vars = contrib_framework.get_trainable_variables( scope='target_q_func') global_step = tf.train.get_or_create_global_step() # Only optimize q_func and update its batchnorm params. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='q_func') with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step=global_step, var_list=q_func_vars) chief_hooks = [] hooks = [] # Save summaries periodically. if save_summaries_steps is not None: chief_hooks.append( tf.train.SummarySaverHook(save_steps=save_summaries_steps, output_dir=log_dir, summary_op=all_summaries)) # Stop after training_steps if max_training_steps: hooks.append(tf.train.StopAtStepHook(last_step=max_training_steps)) # Report if loss tensor is NaN. hooks.append(tf.train.NanTensorHook(loss)) if log_every_n_steps is not None: tensor_dict = {'global_step': global_step, 'train loss': loss} chief_hooks.append( tf.train.LoggingTensorHook(tensor_dict, every_n_iter=log_every_n_steps)) # Measure how fast we are training per sec and save to summary. chief_hooks.append( tf.train.StepCounterHook(every_n_steps=log_every_n_steps, output_dir=log_dir)) # If target network exists, periodically update target Q network with new # weights (frozen target network). We hack this by # abusing a LoggingTensorHook for this. if target_q_func_vars and update_target_every_n_steps is not None: update_target_expr = [] for var, var_t in zip(sorted(q_func_vars, key=lambda v: v.name), sorted(target_q_func_vars, key=lambda v: v.name)): update_target_expr.append(var_t.assign(var)) update_target_expr = tf.group(*update_target_expr) with tf.control_dependencies([update_target_expr]): update_target = tf.constant(0) chief_hooks.append( tf.train.LoggingTensorHook( {'update_target': update_target}, every_n_iter=update_target_every_n_steps)) # Save checkpoints periodically, save all of them. saver = tf.train.Saver(max_to_keep=None) chief_hooks.append( tf.train.CheckpointSaverHook(log_dir, save_steps=save_checkpoint_steps, saver=saver, checkpoint_basename='model.ckpt')) # Save our experiment params to checkpoint dir. chief_hooks.append( gin.tf.GinConfigSaverHook(log_dir, summarize_config=True)) session_config = tf.ConfigProto(log_device_placement=False) init_fn = None if init_checkpoint: assign_fn = contrib_framework.assign_from_checkpoint_fn( init_checkpoint, contrib_framework.get_model_variables()) init_fn = lambda _, sess: assign_fn(sess) scaffold = tf.train.Scaffold(saver=saver, init_fn=init_fn) with tf.train.MonitoredTrainingSession( master=master, is_chief=(task == 0), config=session_config, checkpoint_dir=log_dir, scaffold=scaffold, hooks=hooks, chief_only_hooks=chief_hooks) as sess: np_step = 0 while not sess.should_stop(): np_step, _ = sess.run([global_step, train_op]) if training_steps and np_step % training_steps == 0: break done = np_step >= max_training_steps return np_step, done
q = tf.reduce_sum(y, 1) # quantity keep_prob = tf.placeholder(tf.float32) is_training = tf.placeholder(tf.bool) # model # resnet_v1 101 with slim.arg_scope(resnet_v1.resnet_arg_scope()): net, end_points = resnet_v1.resnet_v1_101(img, num_classes, is_training=False) net_logit = tf.squeeze(net) # tensorflow operation for load pretrained weights variables_to_restore = get_variables_to_restore( exclude=['resnet_v1_101/logits', 'resnet_v1_101/AuxLogits']) init_fn = assign_from_checkpoint_fn('resnet_v1_101.ckpt', variables_to_restore) # multiscale resnet_v1 101 visual_features, fusion_logit = multiscale_resnet101(end_points, num_classes, is_training) textual_features, textual_logit = mlp(tag, num_classes, is_training) refined_features = tf.concat([visual_features, textual_features], 1) # score is prediction score, and k is label quantity score = multi_class_classification_model(refined_features, num_classes) k = label_quantity_prediction_model(refined_features, keep_prob) k = tf.reshape(k, shape=[batch_size]) # make trainable variable list var_list0 = [ v for v in tf.trainable_variables()
def model_fn(features, labels, mode, params, config): feat_tensor = caption_tensor = cap_idx_tensor = cap_len_tensor = None scaffold = None bin_size = 8 if mode == ModeKeys.TRAIN or mode == ModeKeys.EVAL: cap_lens = labels["index"].map(lambda t: tf.size(t)) # todo: cannot utilize GPU to accelerate input pipeline, so train 1 by 1 # def extract_feats(image): # with tf.device("/gpu:0"): # _, end_points = vgg.vgg_16(tf.expand_dims(image, 0), # is_training=(mode == ModeKeys.TRAIN), # spatial_squeeze=False) # final_conv_layer = end_points['vgg_16/conv5/conv5_3'] # feats = spatial_pyramid_pooling(final_conv_layer, [bin_size], mode='avg') # return tf.reshape(feats, shape=(bin_size * bin_size, tf.shape(final_conv_layer)[-1])) # features = features.map(extract_feats) datasets = (features, labels["raw"], labels["index"], cap_lens) # todo: 512 is the feature depth, should not hard code here # pad_size = ((bin_size * bin_size, 512), (), (None,), ()) pad_size = ((None, None, 3), (), (None, ), ()) # todo: cannot utilize GPU to accelerate input pipeline, so train 1 by 1 batches = Dataset.zip(datasets) \ .shuffle(buffer_size=200 * params.batch_size) \ .padded_batch(1, pad_size) if mode == ModeKeys.TRAIN: train_iterator = batches \ .repeat() \ .make_initializable_iterator() feat_tensor, caption_tensor, cap_idx_tensor, cap_len_tensor = \ train_iterator.get_next() tf.add_to_collection("train_initializer", train_iterator.initializer) if mode == ModeKeys.EVAL: val_iterator = batches \ .make_initializable_iterator() feat_tensor, caption_tensor, cap_idx_tensor, cap_len_tensor = \ val_iterator.get_next() tf.add_to_collection("val_initializer", val_iterator.initializer) scaffold = tf.train.Scaffold(init_op=val_iterator.initializer) if mode == ModeKeys.INFER: batches = features.batch(params.batch_size) infer_iterator = batches.make_initializable_iterator() feat_tensor = infer_iterator.get_next() tf.add_to_collection("infer_initializer", infer_iterator.initializer) feat_tensor = _extract_feats(bin_size, feat_tensor, mode) if mode == ModeKeys.TRAIN: variables_to_restore = slim.get_variables_to_restore( exclude=['global_step']) init_fn = assign_from_checkpoint_fn(params.vgg_model_path, variables_to_restore) # signature of sc scaffold = tf.train.Scaffold(init_fn=lambda _, sess: init_fn(sess)) loss_op = None train_op = None predictions = None model = AttendTell(vocab_size=params.vocab_size, selector=params.selector, dropout=params.dropout, ctx2out=params.ctx2out, prev2out=params.prev2out, hard_attention=params.hard_attention, mode=mode) if mode != ModeKeys.INFER: if params.use_sampler: outputs = model.build_train(feat_tensor, cap_idx_tensor, use_generated_inputs=True) else: outputs = model.build_train(feat_tensor, cap_idx_tensor, use_generated_inputs=False) loss_op = create_loss(outputs, cap_idx_tensor, cap_len_tensor) train_op = _get_train_op(loss_op, params.learning_rate, params.hard_attention) else: outputs = model.build_infer(feat_tensor) predictions = tf.argmax(outputs, axis=-1) return EstimatorSpec(mode=mode, predictions=predictions, loss=loss_op, train_op=train_op, scaffold=scaffold)
def main(mode): data_dir = "data/challenger.ai" bin_size = 14 with open(os.path.join(data_dir, 'word_to_idx.pkl'), 'rb') as f: word_to_idx = pickle.load(f) with open( os.path.join( data_dir, "annotations/caption_%s_annotations_20170902.json" % mode)) as f: annotations = json.load(f) image_ids = [ann['image_id'] for ann in annotations] caps = [ann['caption'] for ann in annotations] def my_split(text): text = text.decode("utf-8") # todo: take care of the unknown character. idx = [word_to_idx.get(ch, 0) for ch in text] idx.insert(0, word_to_idx['<START>']) idx.append(word_to_idx['<END>']) return np.array(idx, dtype=np.int32) def parse(img_id, caps): filename = os.path.join(data_dir, "image/%s" % mode) + "/" + img_id image = tf.image.decode_jpeg(tf.read_file(filename), channels=3) image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.subtract(image, 0.5) image = tf.multiply(image, 2.0) splitted_caps = tuple( map(lambda c: tf.py_func(my_split, [c], tf.int32, stateful=False), tf.unstack(caps))) return { 'img_id': img_id, 'raw_img': image, 'raw_caps': caps, 'cap_idx': splitted_caps } # return img_id, image, caps, splitted_caps it = Dataset.from_tensor_slices( (image_ids, caps)).map(parse).make_one_shot_iterator() feat_tensor_dict = it.get_next() arg_scope = inception_v4_arg_scope() with slim.arg_scope(arg_scope): final_conv_layer, end_points = inception_v4_base( tf.expand_dims(feat_tensor_dict['raw_img'], 0)) feats_tensor = spatial_pyramid_pooling(final_conv_layer, [bin_size], mode='avg') feats_tensor = tf.reshape(feats_tensor, shape=(-1, bin_size * bin_size, 1536)) sess = tf.Session() variables_to_restore = slim.get_variables_to_restore( exclude=['global_step']) init_fn = assign_from_checkpoint_fn("data/model/inception_v4.ckpt", variables_to_restore) init_fn(sess) tfrecord_filename_base = 'data/challenger.ai/tfrecords/%s_feat_14x14x1536_inception_v4' % mode writer = tf.python_io.TFRecordWriter(tfrecord_filename_base + "-0.tfrecords") i = 0 while True: try: feature_dict, feats = sess.run((feat_tensor_dict, feats_tensor)) example = tf.train.Example(features=tf.train.Features( feature={ 'img_id': tf.train.Feature(bytes_list=tf.train.BytesList( value=[feature_dict['img_id']])), # 'raw_img': tf.train.Feature( # bytes_list=tf.train.BytesList(value=[feature_dict['raw_img'].tostring()])), 'img_feats': tf.train.Feature(bytes_list=tf.train.BytesList( value=[feats.tostring()])), 'raw_caps': tf.train.Feature(bytes_list=tf.train.BytesList( value=feature_dict['raw_caps'])), 'cap_idx': tf.train.Feature(bytes_list=tf.train.BytesList(value=[ idx.tostring() for idx in feature_dict['cap_idx'] ])), })) writer.write(example.SerializeToString()) print(i) i += 1 if i % 10000 == 0: writer.close() writer = tf.python_io.TFRecordWriter(tfrecord_filename_base + "-%d.tfrecords" % i) except OutOfRangeError as e: print(e) break writer.close()
def transform_test_dataset(): data_dir = "data/challenger.ai" bin_size = 14 filenames = [ fn.split('/')[-1] for fn in glob.glob(os.path.join(data_dir, "image/test/*")) ] def parse(img_id): filename = os.path.join(data_dir, "image/test/") + img_id image = tf.image.decode_jpeg(tf.read_file(filename), channels=3) image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.subtract(image, 0.5) image = tf.multiply(image, 2.0) return { 'img_id': img_id, 'raw_img': image, } it = Dataset.from_tensor_slices(filenames).map( parse).make_one_shot_iterator() feat_tensor_dict = it.get_next() arg_scope = inception_v4_arg_scope() with slim.arg_scope(arg_scope): final_conv_layer, end_points = inception_v4_base( tf.expand_dims(feat_tensor_dict['raw_img'], 0)) feats_tensor = spatial_pyramid_pooling(final_conv_layer, [bin_size], mode='avg') feats_tensor = tf.reshape(feats_tensor, shape=(-1, bin_size * bin_size, 1536)) sess = tf.Session() variables_to_restore = slim.get_variables_to_restore( exclude=['global_step']) init_fn = assign_from_checkpoint_fn("data/model/inception_v4.ckpt", variables_to_restore) init_fn(sess) tfrecord_filename_base = 'data/challenger.ai/tfrecords/test_feat_14x14x1536_inception_v4' writer = tf.python_io.TFRecordWriter(tfrecord_filename_base + "-0.tfrecords") i = 0 while True: try: feature_dict, feats = sess.run((feat_tensor_dict, feats_tensor)) example = tf.train.Example(features=tf.train.Features( feature={ 'img_id': tf.train.Feature(bytes_list=tf.train.BytesList( value=[feature_dict['img_id']])), # 'raw_img': tf.train.Feature( # bytes_list=tf.train.BytesList(value=[feature_dict['raw_img'].tostring()])), 'img_feats': tf.train.Feature(bytes_list=tf.train.BytesList( value=[feats.tostring()])), })) writer.write(example.SerializeToString()) print(i) i += 1 if i % 10000 == 0: writer.close() writer = tf.python_io.TFRecordWriter(tfrecord_filename_base + "-%d.tfrecords" % i) except OutOfRangeError as e: print(e) break writer.close()
def train(model_fn, pre_trained_model, train_log_dir, scope, arg_scope, train_layer, epochs=1, steps=None, learning_rate=0.001, num_classes=2, decay_steps=1000, decay_rate=0.8, save_interval_secs=300, image_h=224, image_w=224, batch_size=32, batch_threads=10, log_every_n_steps=10, train_image_dir="input/train", validation_image_dir="input/validation", file_ext_name=".jpg", restore_full_layer=False, lock_layer=True): batch_capacity = batch_size * 2 train_image_path = os.path.join(train_image_dir, "*" + file_ext_name) validation_image_path = os.path.join(validation_image_dir, "*" + file_ext_name) tf.logging.info("train_image_path={} validation_image_path={}".format( train_image_path, validation_image_path)) # 计算训练图片数 train_image_nums = len([ f for f in os.listdir(train_image_dir) if os.path.splitext(f)[1].lower() == file_ext_name.lower() ]) steps_per_epoch = train_image_nums // batch_size max_step = steps if steps is not None else steps_per_epoch * epochs epochs_ = max_step // steps_per_epoch + ( 1 if max_step % steps_per_epoch != 0 else 0) if steps is not None else epochs config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Graph().as_default(): image_train, label_train, name_train = read_image( train_image_path, image_h, image_w, epochs=None, batch_size=batch_size, batch_threads=batch_threads, batch_capacity=batch_capacity, data_augmentation=True) image_val, label_val, name_val = read_image( validation_image_path, image_h, image_w, epochs=None, batch_size=batch_size * 5, batch_threads=batch_threads, batch_capacity=batch_capacity * 5) with tf.variable_scope(scope) as scope_: if arg_scope: with slim.arg_scope(arg_scope): predictions, end_points = model_fn(image_train, num_classes=num_classes, scope=scope_) scope_.reuse_variables() predictions_val, end_points_val = model_fn( image_val, num_classes=num_classes, scope=scope_, is_training=False) else: predictions, end_points = model_fn(image_train, num_classes=num_classes, scope=scope_) scope_.reuse_variables() predictions_val, end_points_val = model_fn( image_val, num_classes=num_classes, scope=scope_, is_training=False) # for v in end_points: # print(v) excludes = train_layer if not restore_full_layer else None variables_to_restore = slim.get_variables_to_restore(exclude=excludes) init_fn = assign_from_checkpoint_fn(pre_trained_model, variables_to_restore) # Specify the loss function: loss = tf.losses.softmax_cross_entropy(label_train, predictions) total_loss = slim.losses.get_total_loss() loss_val = tf.losses.softmax_cross_entropy(label_val, predictions_val) accuracy = accuracy_fn(predictions, label_train) accuracy_val = accuracy_fn(predictions_val, label_val, "accuracy_val") tf.summary.scalar('train/loss', loss) tf.summary.scalar('train/accuracy', accuracy) tf.summary.image('train/inputs', tf.reshape(image_train, [-1, image_h, image_w, 3]), 5) tf.summary.scalar('validation/loss', loss_val) tf.summary.scalar('validation/accuracy', accuracy_val) tf.summary.image('validation/inputs', tf.reshape(image_val, [-1, image_h, image_w, 3]), 5) global_step = slim.get_or_create_global_step() learning_rate = tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase=True) # Specify the optimization scheme: optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) # create_train_op that ensures that when we evaluate it to get the loss, # the update_ops are done and the gradient updates are computed. variables_to_train = filter_train_variables_by_layer( train_layer) if lock_layer else None if lock_layer: assert variables_to_train, "no variables to train, train_layer={}".format( train_layer) train_tensor = slim.learning.create_train_op( total_loss, optimizer, global_step=global_step, variables_to_train=variables_to_train) # Actually runs training. start_time = time.time() def train_step(sess, train_op, global_step, train_step_kwargs): global_step_value, accuracy_value, loss_value, accuracy_validation_value, loss_val_value, learning_rate_value = \ sess.run([global_step, accuracy, loss, accuracy_val, loss_val, learning_rate]) # , end_points['Logits'] if global_step_value % log_every_n_steps == 0 or global_step_value >= max_step - 1: tf.logging.info( "global_step = {}/{} epoch = {}/{} accuracy = {:.5f} loss = {:.5f} accuracy_val = {:.5f} loss_val = {:.5f} learning_rate = {:.5f} time_elipse = {:.2f} s" .format(global_step_value + 1, max_step, global_step_value // steps_per_epoch + 1, epochs_, accuracy_value, loss_value, accuracy_validation_value, loss_val_value, learning_rate_value, time.time() - start_time)) # if global_step_value >= max_step - 1: # tf.logging.info("logits_shape={}".format(logits_value.shape)) return slim.learning.train_step(sess, train_op, global_step, train_step_kwargs) slim.learning.train(train_tensor, train_log_dir, init_fn=init_fn, train_step_fn=train_step, global_step=global_step, log_every_n_steps=log_every_n_steps, save_summaries_secs=60, save_interval_secs=save_interval_secs, number_of_steps=max_step, session_config=config)