def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if FLAGS.use_tpu: ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if FLAGS.optimizer == 'Adafactor': optimizer = mtf.optimize.AdafactorOptimizer() else: assert FLAGS.optimizer == 'SGD' optimizer = mtf.optimize.SgdOptimizer(learning_rate=FLAGS.lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {'mean_logits': mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def _finish(self, update_ops, name): outer_wealth = self._get_non_slot(OUTER_WEALTH) betting_domain = self.betting_domain maximum_gradient = self._get_non_slot(MAXIMUM_GRADIENT) wealth_increment = sum(self.wealth_deltas.values()) betting_fraction_dot_product = sum( self.betting_fraction_dot_product_deltas.values()) grad_norm = sum(self.grad_norms.values()) maximum_gradient_updated = self._assign( maximum_gradient, tf.maximum(maximum_gradient, grad_norm)) update_ops.append(maximum_gradient_updated) gradient_scaling = 1.0 / maximum_gradient_updated # We will replace gradient with gradient/maximum_gradient_updated in order # to ensure ||gradient||_1 \le 1. # Since betting_fraction_dot_product and wealth_increment were calculated # using the original gradient, we also scale them by the same amount. betting_fraction_dot_product = betting_fraction_dot_product * gradient_scaling wealth_increment = wealth_increment * gradient_scaling outer_wealth_updated = self._assign_add(outer_wealth, wealth_increment) update_ops.append(outer_wealth_updated) inner_grad_scaling = (1.0 - betting_domain) / ( 1.0 - betting_fraction_dot_product) if self.output_summaries: tf.summary.scalar(self._name + "/total_wealth", outer_wealth_updated) tf.summary.scalar(self._name + "/maximum_gradient_norm", maximum_gradient_updated) tf.summary.scalar(self._name + "/gradient_L1_norm", grad_norm) if self.add_average: grad_norm_squared = tf.square(grad_norm) sum_grad_norm_squared = self._get_non_slot(SUM_GRAD_NORM_SQUARED) sum_grad_norm_squared_updated = self._assign_add( sum_grad_norm_squared, grad_norm_squared) for var in self.grads: grad = self.grads[var] if self.inner_optimizer == SCINOL: inner_grad = grad * inner_grad_scaling else: # Rescale gradient to have L1 norm at most 1.0 scaled_grad = grad * gradient_scaling inner_grad = scaled_grad * inner_grad_scaling betting_fraction, inner_update_op = self._compute_inner_update( var, inner_grad) update_ops.append(inner_update_op) if self.output_summaries: betting_fraction_summary = tf.reduce_mean( tf.abs(betting_fraction)) tf.summary.scalar( self._name + "/mean_abs_betting_fraction/" + var.name, betting_fraction_summary) max_betting_fraction_summary = tf.reduce_max( tf.abs(betting_fraction)) tf.summary.scalar( self._name + "/max_abs_betting_fraction/" + var.name, max_betting_fraction_summary) next_offset = self.lr * betting_fraction * outer_wealth_updated initial_value = self.get_slot(var, INITIAL_VALUE) if self.add_average: average_offset = self.get_slot(var, AVERAGE_OFFSET) previous_sum_grad_norm_squared = sum_grad_norm_squared - grad_norm_squared average_offset_updated = self._assign_add( average_offset, (previous_sum_grad_norm_squared * (next_offset - average_offset)) / (sum_grad_norm_squared_updated)) update_ops.append(average_offset_updated) var_updated = self._assign( var, next_offset + average_offset_updated + initial_value) else: var_updated = self._assign(var, next_offset + initial_value) update_ops.append(var_updated) return tf.group(*update_ops, name=name)
def build(self, input_shape): """Builds the entropy model. Creates the variables for the network modeling the densities, creates the auxiliary loss estimating the median and tail quantiles of the densities, and then uses that to create the probability mass functions and the discrete cumulative density functions used by the range coder. Arguments: input_shape: Shape of the input tensor, used to get the number of channels. Raises: ValueError: if `input_shape` doesn't specify the length of the channel dimension. """ input_shape = tf.TensorShape(input_shape) channel_axis = self._channel_axis(input_shape.ndims) channels = input_shape[channel_axis].value if channels is None: raise ValueError( "The channel dimension of the inputs must be defined.") self.input_spec = tf.keras.layers.InputSpec( ndim=input_shape.ndims, axes={channel_axis: channels}) filters = (1, ) + self.filters + (1, ) scale = self.init_scale**(1 / (len(self.filters) + 1)) # Create variables. self._matrices = [] self._biases = [] self._factors = [] for i in range(len(self.filters) + 1): init = np.log(np.expm1(1 / scale / filters[i + 1])) matrix = self.add_variable( "matrix_{}".format(i), dtype=self.dtype, shape=(channels, filters[i + 1], filters[i]), initializer=tf.initializers.constant(init)) matrix = tf.nn.softplus(matrix) self._matrices.append(matrix) bias = self.add_variable( "bias_{}".format(i), dtype=self.dtype, shape=(channels, filters[i + 1], 1), initializer=tf.initializers.random_uniform(-.5, .5)) self._biases.append(bias) if i < len(self.filters): factor = self.add_variable("factor_{}".format(i), dtype=self.dtype, shape=(channels, filters[i + 1], 1), initializer=tf.initializers.zeros()) factor = tf.math.tanh(factor) self._factors.append(factor) # To figure out what range of the densities to sample, we need to compute # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we # can't take inverses of the cumulative directly, we make it an optimization # problem: # `quantiles = argmin(|logit(cumulative) - target|)` # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`. # Taking the logit (inverse of sigmoid) of the cumulative makes the # representation of the right target more numerically stable. # Numerically stable way of computing logits of `tail_mass / 2` # and `1 - tail_mass / 2`. target = np.log(2 / self.tail_mass - 1) # Compute lower and upper tail quantile as well as median. target = tf.constant([-target, 0, target], dtype=self.dtype) def quantiles_initializer(shape, dtype=None, partition_info=None): del partition_info # unused assert tuple(shape[1:]) == (1, 3) init = tf.constant([[[-self.init_scale, 0, self.init_scale]]], dtype=dtype) return tf.tile(init, (shape[0], 1, 1)) quantiles = self.add_variable("quantiles", shape=(channels, 1, 3), dtype=self.dtype, initializer=quantiles_initializer) logits = self._logits_cumulative(quantiles, stop_gradient=True) loss = tf.math.reduce_sum(abs(logits - target)) self.add_loss(loss, inputs=None) # Quantize such that the median coincides with the center of a bin. medians = quantiles[:, 0, 1] self._medians = tf.stop_gradient(medians) # Largest distance observed between lower tail quantile and median, and # between median and upper tail quantile. minima = medians - quantiles[:, 0, 0] minima = tf.cast(tf.math.ceil(minima), tf.int32) minima = tf.math.maximum(minima, 0) maxima = quantiles[:, 0, 2] - medians maxima = tf.cast(tf.math.ceil(maxima), tf.int32) maxima = tf.math.maximum(maxima, 0) # PMF starting positions and lengths. self._offset = -minima pmf_start = medians - tf.cast(minima, self.dtype) pmf_length = maxima + minima + 1 # Sample the densities in the computed ranges, possibly computing more # samples than necessary at the upper end. max_length = tf.math.reduce_max(pmf_length) samples = tf.range(tf.cast(max_length, self.dtype), dtype=self.dtype) samples += pmf_start[:, None, None] half = tf.constant(.5, dtype=self.dtype) # We strip the sigmoid from the end here, so we can use the special rule # below to only compute differences in the left tail of the sigmoid. # This increases numerical stability (see explanation in `call`). lower = self._logits_cumulative(samples - half, stop_gradient=True) upper = self._logits_cumulative(samples + half, stop_gradient=True) # Flip signs if we can move more towards the left tail of the sigmoid. sign = -tf.math.sign(tf.math.add_n([lower, upper])) pmf = abs( tf.math.sigmoid(sign * upper) - tf.math.sigmoid(sign * lower)) pmf = pmf[:, 0, :] # Compute out-of-range (tail) masses. tail_mass = tf.math.add_n([ tf.math.sigmoid(lower[:, 0, :1]), tf.math.sigmoid(-upper[:, 0, -1:]), ]) # Construct a valid CDF initializer, so that we can run the model without # error even on the zeroth training step. def cdf_initializer(shape, dtype=None, partition_info=None): del shape, partition_info # unused assert dtype == tf.int32 fill = tf.constant(.5, dtype=self.dtype) prob = tf.fill((channels, 2), fill) cdf = range_coding_ops.pmf_to_quantized_cdf( prob, precision=self.range_coder_precision) return tf.placeholder_with_default(cdf, shape=(channels, None)) # We need to supply an initializer without fully defined static shape # here, or the variable will return the wrong dynamic shape later. A # placeholder with default gets the trick done (see initializer above). quantized_cdf = self.add_variable("quantized_cdf", shape=(channels, None), dtype=tf.int32, trainable=False, initializer=cdf_initializer) cdf_length = self.add_variable("cdf_length", shape=(channels, ), dtype=tf.int32, trainable=False, initializer=tf.initializers.constant(3)) # Works around a weird TF issue with reading variables inside a loop. self._quantized_cdf = tf.identity(quantized_cdf) self._cdf_length = tf.identity(cdf_length) update_cdf = tf.assign(quantized_cdf, self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length), validate_shape=False) update_length = tf.assign(cdf_length, pmf_length + 2) update_op = tf.group(update_cdf, update_length) self.add_update(update_op, inputs=None) super(EntropyBottleneck, self).build(input_shape)
metrics=transformer.get_metric_functions()) print(tf.trainable_variables()) print(adamm.variables()) print(len(tf.trainable_variables())) print(len(adamm.variables())) print(adamm.get_slot_names()) print(len(adamm.get_slot_names())) train_model.summary() print('Start unit testing : New BERTWrapper') sess = K.get_session() init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init) test_data = [ ['Hello', 'World'], ['Hello', 'World'], ['Hello', 'World'], ['Hello', 'World'], ] input_vals = itokens.encode(test_data, max_length=16) output_vals = otokens.encode(test_data, max_length=16) print(input_vals) print(output_vals) train_model.fit(x=input_vals, y=output_vals,
def _compute_inner_update_scinol(self, var, grad): update_ops = [] betting_domain = tf.cast(self.betting_domain, var.dtype.base_dtype) reward = self.get_slot(var, INNER_REWARD) betting_fraction = self.get_slot(var, OUTER_BETTING_FRACTION) sum_grad_squared = self.get_slot(var, INNER_SUM_GRAD_SQUARED) sum_grad = self.get_slot(var, INNER_SUM_GRAD) inner_maximum_gradient = self.get_slot(var, INNER_MAXIMUM_GRADIENT) # clip inner gradient to respect previous inner_maximum_gradient value # This introduces at most an additive constant overhead in the regret # since the inner betting fraction lies in a bounded domain. clipped_grad = tf.clip_by_value(grad, -inner_maximum_gradient, inner_maximum_gradient) with tf.control_dependencies([clipped_grad]): inner_maximum_gradient_updated = self._assign( inner_maximum_gradient, tf.maximum(inner_maximum_gradient, tf.abs(grad))) update_ops.append(inner_maximum_gradient_updated) clipped_old_betting_fraction = tf.clip_by_value( betting_fraction, -betting_domain, betting_domain) # Process grad to respect truncation to [-betting_domain, betting_domain] truncated_grad = tf.where( tf.greater_equal( clipped_grad * (betting_fraction - clipped_old_betting_fraction), 0.0), clipped_grad, tf.zeros(tf.shape(clipped_grad))) reward_delta = -betting_fraction * truncated_grad reward_updated = self._assign_add(reward, reward_delta) update_ops.append(reward_updated) sum_grad_squared_updated = self._assign_add(sum_grad_squared, tf.square(truncated_grad)) update_ops.append(sum_grad_squared_updated) sum_grad_updated = self._assign_add(sum_grad, truncated_grad) update_ops.append(sum_grad_updated) # The second term in this maximum, inner_maximum_gradient_updated / self.eta # is a hack to force the betting fraction to not be too big at first. scaling = tf.minimum( tf.rsqrt(sum_grad_squared_updated + tf.square(inner_maximum_gradient_updated)), self.eta / inner_maximum_gradient_updated) theta = -sum_grad_updated * scaling # rescale inner flag is a hack that rescales the epsilon_v by the # maximum inner gradient. if self.rescale_inner: epsilon_scaling = inner_maximum_gradient_updated else: epsilon_scaling = 1.0 inner_betting_fraction = tf.sign(theta) * tf.minimum( tf.abs(theta), 1.0) * scaling / 2.0 new_betting_fraction = inner_betting_fraction * ( reward_updated + epsilon_scaling * self.epsilon_v) betting_fraction_updated = self._assign(betting_fraction, new_betting_fraction) update_ops.append(betting_fraction_updated) clipped_betting_fraction = tf.clip_by_value(betting_fraction_updated, -betting_domain, betting_domain) if self.output_summaries: mean_unclipped_betting_fraction_summary = tf.reduce_mean( tf.abs(betting_fraction_updated)) max_unclipped_betting_fraction_summary = tf.reduce_max( tf.abs(betting_fraction_updated)) mean_clipped_betting_fraction_summary = tf.reduce_mean( tf.abs(clipped_betting_fraction)) max_clipped_betting_fraction_summary = tf.reduce_max( tf.abs(clipped_betting_fraction)) max_abs_gradient = tf.reduce_max(tf.abs(grad)) max_truncated_grad = tf.reduce_max(tf.abs(truncated_grad)) tf.summary.scalar(self._name + "/mean_unclipped_bet/" + var.name, mean_unclipped_betting_fraction_summary) tf.summary.scalar(self._name + "/max_unclipped_bet/" + var.name, max_unclipped_betting_fraction_summary) tf.summary.scalar(self._name + "/mean_clipped_bet/" + var.name, mean_clipped_betting_fraction_summary) tf.summary.scalar(self._name + "/max_clipped_bet/" + var.name, max_clipped_betting_fraction_summary) tf.summary.scalar(self._name + "/max_abs_inner_grad/" + var.name, max_abs_gradient) tf.summary.scalar( self._name + "/max_abs_truncated_inner_grad/" + var.name, max_truncated_grad) return clipped_betting_fraction, tf.group(*update_ops)
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, Global_step, optimizer="adamw", poly_power=1.0, start_warmup_step=0): """Creates an optimizer training op.""" #global_step = tf.train.get_or_create_global_step() # by chenming if Global_step: global_step = Global_step else: global_step = tf.train.get_or_create_global_step() learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) # Implements linear decay of the learning rate. learning_rate = tf.train.polynomial_decay(learning_rate, global_step, num_train_steps, end_learning_rate=0.0, power=poly_power, cycle=False) # Implements linear warmup. I.e., if global_step - start_warmup_step < # num_warmup_steps, the learning rate will be # `(global_step - start_warmup_step)/num_warmup_steps * init_lr`. if num_warmup_steps: tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step) + ", for " + str(num_warmup_steps) + " steps ++++++") global_steps_int = tf.cast(global_step, tf.int32) start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32) global_steps_int = global_steps_int - start_warm_int warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = init_lr * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) # It is OK that you use this optimizer for finetuning, since this # is how the model was trained (note that the Adam m/v variables are NOT # loaded from init_checkpoint.) # It is OK to use AdamW in the finetuning even the model is trained by LAMB. # As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune # is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a # batch size of 64 in the finetune. if optimizer == "adamw": tf.logging.info("using adamw") optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) elif optimizer == "lamb": tf.logging.info("using lamb") optimizer = lamb_optimizer.LAMBOptimizer( learning_rate=learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) else: raise ValueError("Not supported optimizer: ", optimizer) if use_tpu: optimizer = contrib_tpu.CrossShardOptimizer(optimizer) tvars = tf.trainable_variables() grads = tf.gradients(loss, tvars) # This is how the model was pre-trained. (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) train_op = optimizer.apply_gradients(list(zip(grads, tvars)), global_step=global_step) # Normally the global step update is done inside of `apply_gradients`. # However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this. # But if you use a different optimizer, you should probably take this line # out. new_global_step = global_step + 1 train_op = tf.group(train_op, [global_step.assign(new_global_step)]) return train_op
def get_train_op(flags, total_loss, ema=None, tvars=None): """Generates the training operation.""" global_step = tf.train.get_or_create_global_step() # increase the learning rate linearly if flags.warmup_steps > 0: warmup_lr = (tf.cast(global_step, tf.float32) / tf.cast(flags.warmup_steps, tf.float32) * flags.learning_rate) else: warmup_lr = 0.0 # decay the learning rate if flags.decay_method == "poly": decay_lr = tf.train.polynomial_decay( flags.learning_rate, global_step=global_step - flags.warmup_steps, decay_steps=flags.train_steps - flags.warmup_steps, end_learning_rate=flags.learning_rate * flags.min_lr_ratio) elif flags.decay_method == "cos": decay_lr = tf.train.cosine_decay( flags.learning_rate, global_step=global_step - flags.warmup_steps, decay_steps=flags.train_steps - flags.warmup_steps, alpha=flags.min_lr_ratio) else: raise ValueError(flags.decay_method) learning_rate = tf.where(global_step < flags.warmup_steps, warmup_lr, decay_lr) if flags.weight_decay == 0: optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=flags.adam_epsilon) elif flags.weight_decay > 0 and flags.num_core_per_host == 1: optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, epsilon=flags.adam_epsilon, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], weight_decay_rate=flags.weight_decay) else: raise ValueError("Do not support `weight_decay > 0` with multi-gpu " "training so far.") if flags.use_tpu: optimizer = tf.tpu.CrossShardOptimizer(optimizer) if tvars is None: grads_and_vars = optimizer.compute_gradients(total_loss) else: grads_and_vars = optimizer.compute_gradients(total_loss, var_list=tvars) gradients, variables = zip(*grads_and_vars) clipped, gnorm = tf.clip_by_global_norm(gradients, flags.clip) train_op = optimizer.apply_gradients(zip(clipped, variables), global_step=global_step) # Manually increment `global_step` for AdamWeightDecayOptimizer if isinstance(optimizer, AdamWeightDecayOptimizer): new_global_step = global_step + 1 train_op = tf.group(train_op, [global_step.assign(new_global_step)]) if ema is not None: # Update the variables with the EMA after the train op. with tf.control_dependencies([train_op]): train_op = ema.apply(tf.trainable_variables()) return train_op, learning_rate, gnorm
def _model_fn(features, labels, mode, params, model): """Model defination for the SSD model based on ResNet-50. Args: features: the input image tensor with shape [batch_size, height, width, 3]. The height and width are fixed and equal. labels: the input labels in a dictionary. The labels include class targets and box targets which are dense label maps. The labels are generated from get_input_fn function in data/dataloader.py mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT. params: the dictionary defines hyperparameters of model. The default settings are in default_hparams function in this file. model: the SSD model outputs class logits and box regression outputs. Returns: spec: the EstimatorSpec or TPUEstimatorSpec to run training, evaluation, or prediction. """ if mode == tf.estimator.ModeKeys.PREDICT: labels = features features = labels.pop('image') features -= tf.constant(constants.NORMALIZATION_MEAN, shape=[1, 1, 3], dtype=features.dtype) COEF_STD = 1.0 / tf.constant( constants.NORMALIZATION_STD, shape=[1, 1, 3], dtype=features.dtype) features *= COEF_STD def _model_outputs(): return model(features, params, is_training_bn=(mode == tf.estimator.ModeKeys.TRAIN)) if params['dtype'] == 'bf16': with tf.compat.v1.tpu.bfloat16_scope(): cls_outputs, box_outputs = _model_outputs() levels = cls_outputs.keys() for level in levels: cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32) box_outputs[level] = tf.cast(box_outputs[level], tf.float32) else: cls_outputs, box_outputs = _model_outputs() levels = cls_outputs.keys() # First check if it is in PREDICT mode. if mode == tf.estimator.ModeKeys.PREDICT: flattened_cls, flattened_box = concat_outputs(cls_outputs, box_outputs, True) ssd_box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder( scale_factors=constants.BOX_CODER_SCALES) anchors = box_list.BoxList( tf.convert_to_tensor(dataloader.DefaultBoxes()('ltrb'))) decoded_boxes = box_coder.batch_decode(encoded_boxes=flattened_box, box_coder=ssd_box_coder, anchors=anchors) pred_scores = tf.nn.softmax(flattened_cls, axis=2) pred_scores, indices = select_top_k_scores( pred_scores, constants.MAX_NUM_EVAL_BOXES) predictions = dict( labels, indices=indices, pred_scores=pred_scores, pred_box=decoded_boxes, ) if params['visualize_dataloader']: # this is for inference visualization. predictions['image'] = features return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Load pretrained model from checkpoint. if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN: def scaffold_fn(): """Loads pretrained model through scaffold function.""" tf.train.init_from_checkpoint( params['resnet_checkpoint'], { '/': 'resnet%s/' % constants.RESNET_DEPTH, }) return tf.train.Scaffold() else: scaffold_fn = None # Set up training loss and learning rate. update_learning_rate_schedule_parameters(params) global_step = tf.train.get_or_create_global_step() learning_rate = learning_rate_schedule(params, global_step) # cls_loss and box_loss are for logging. only total_loss is optimized. loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs, labels) total_loss = loss + params['weight_decay'] * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=constants.MOMENTUM) if params['distributed_optimizer']: optimizer = params['distributed_optimizer'](optimizer) # Batch norm requires update_ops to be added as a train_op dependency. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.group(optimizer.minimize(total_loss, global_step), update_ops) return model_fn_lib.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold=scaffold_fn()) if mode == tf.estimator.ModeKeys.EVAL: raise NotImplementedError
def _set_up_cache(self): self._lower_offset, update_lower = self._cache_with_update_op( self._lower_offset) self._upper_offset, update_upper = self._cache_with_update_op( self._upper_offset) return tf.group([update_lower, update_upper])
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): """Creates an optimizer training op.""" global_step = tf.train.get_or_create_global_step() learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) # Implements linear decay of the learning rate. learning_rate = tf.train.polynomial_decay( learning_rate, global_step, num_train_steps, end_learning_rate=0.0, power=1.0, cycle=False) # Implements linear warmup. I.e., if global_step < num_warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. if num_warmup_steps: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = init_lr * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ( (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) # It is recommended that you use this optimizer for fine tuning, since this # is how the model was trained (note that the Adam m/v variables are NOT # loaded from init_checkpoint.) optimizer = optimization.AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) if use_tpu: optimizer = contrib_tpu.CrossShardOptimizer(optimizer) tvars = tf.trainable_variables() print(tvars) tvars = [v for v in tvars if "bert" not in v.name] print("no bert") print(tvars) grads = tf.gradients(loss, tvars) # This is how the model was pre-trained. (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=global_step) # Normally the global step update is done inside of `apply_gradients`. # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use # a different optimizer, you should probably take this line out. new_global_step = global_step + 1 train_op = tf.group(train_op, [global_step.assign(new_global_step)]) return train_op
def test_train(args): """Trains the model.""" if args.verbose: tf.logging.set_verbosity(tf.logging.INFO) # Create input data pipeline. with tf.device("/cpu:0"): train_files = glob.glob(args.train_glob) if not train_files: raise RuntimeError( "No training images found with glob '{}'.".format( args.train_glob)) train_dataset = tf.data.Dataset.from_tensor_slices(train_files) train_dataset = train_dataset.shuffle( buffer_size=len(train_files)).repeat() train_dataset = train_dataset.map( read_png, num_parallel_calls=args.preprocess_threads) train_dataset = train_dataset.map( lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3))) train_dataset = train_dataset.batch(args.batchsize) train_dataset = train_dataset.prefetch(32) num_pixels = args.batchsize * args.patchsize**2 # Get training patch from dataset. x = train_dataset.make_one_shot_iterator().get_next() lmbda_log_dist = np.hstack((np.arange(0, 7, 0.01), np.arange(7, 0, -0.01))) lmbda_log_dist = tf.constant(lmbda_log_dist, dtype=tf.float32) s = tf.data.Dataset.from_tensor_slices(lmbda_log_dist).repeat() lmbda_log = s.make_one_shot_iterator().get_next() # levels lmbda = 0.1 * tf.pow(2.0, lmbda_log - 6.0) # true value # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters, lmbda_log) synthesis_transform = SynthesisTransform(args.num_filters, lmbda_log) hyper_analysis_transform = HyperAnalysisTransform(args.num_filters, lmbda_log) hyper_synthesis_transform = HyperSynthesisTransform( args.num_filters, lmbda_log) entropy_bottleneck = tfc.EntropyBottleneck() # Build autoencoder and hyperprior. y = analysis_transform(x) z = hyper_analysis_transform(abs(y)) z_tilde, z_likelihoods = entropy_bottleneck(z, training=True) sigma = hyper_synthesis_transform(z_tilde) scale_table = np.exp( np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table) y_tilde, y_likelihoods = conditional_bottleneck(y, training=True) x_tilde = synthesis_transform(y_tilde) # Total number of bits divided by number of pixels. train_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum( tf.log(z_likelihoods))) / (-np.log(2) * num_pixels) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. train_mse *= 255**2 # The rate-distortion cost. train_loss = lmbda * train_mse + train_bpp # Minimize loss and auxiliary loss, and execute update op. step = tf.train.create_global_step() main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) main_step = main_optimizer.minimize(train_loss, global_step=step) aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0]) train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) tf.summary.scalar("loss", train_loss) tf.summary.scalar("bpp", train_bpp) tf.summary.scalar("mse", train_mse) tf.summary.image("original", quantize_image(x)) tf.summary.image("reconstruction", quantize_image(x_tilde)) hooks = [ tf.train.StopAtStepHook(last_step=args.last_step), tf.train.NanTensorHook(train_loss), ] with tf.train.MonitoredTrainingSession(hooks=hooks, checkpoint_dir=args.checkpoint_dir, save_checkpoint_secs=300, save_summaries_secs=60) as sess: while not sess.should_stop(): sess.run(train_op)
def train(args): """Trains the model.""" if args.verbose: tf.logging.set_verbosity(tf.logging.INFO) # Create input data pipeline. with tf.device("/cpu:0"): train_files = glob.glob(args.train_glob) if not train_files: raise RuntimeError( "No training images found with glob '{}'.".format(args.train_glob)) train_dataset = tf.data.Dataset.from_tensor_slices(train_files) train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat() train_dataset = train_dataset.map( read_png, num_parallel_calls=args.preprocess_threads) train_dataset = train_dataset.map( lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3))) train_dataset = train_dataset.batch(args.batchsize) train_dataset = train_dataset.prefetch(32) #num_pixels = args.batchsize * args.patchsize ** 2 # Get training patch from dataset. x = train_dataset.make_one_shot_iterator().get_next() # Instantiate model. analysis_transform = AnalysisTransform(args.num_filters) #entropy_bottleneck = tfc.EntropyBottleneck() synthesis_transform = SynthesisTransform(args.num_filters) # Build autoencoder. y = analysis_transform(x) #y_tilde, likelihoods = entropy_bottleneck(y, training=True) x_tilde = synthesis_transform(y) # Total number of bits divided by number of pixels. #train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels) # Mean squared error across pixels. train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde)) # Multiply by 255^2 to correct for rescaling. #train_mse *= 255 ** 2 # Calculate psnr and ssim train_psnr = tf.reduce_mean(tf.image.psnr(x_tilde, x, 1)) train_msssim_value = tf.reduce_mean(tf.image.ssim_multiscale(x_tilde, x, 1)) # structural similarity loss train_ssim = tf.reduce_mean(1 - tf.image.ssim_multiscale(x_tilde, x, 1)) #Choose distortion metric distortion = train_ssim if args.ssim_loss else train_mse # The rate-distortion cost. train_loss = distortion # Minimize loss and auxiliary loss, and execute update op. step = tf.train.create_global_step() main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) main_step = main_optimizer.minimize(train_loss, global_step=step) #aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) #aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0]) train_op = tf.group(main_step) # Log scalar values s_loss = tf.summary.scalar("train/loss", train_loss) #s_bpp = tf.summary.scalar("train/bpp", train_bpp) s_mse = tf.summary.scalar("train/mse", train_mse) s_psnr = tf.summary.scalar("train/psnr", train_psnr) s_msssim_value = tf.summary.scalar("train/multiscale ssim value", train_msssim_value) s_ssim = tf.summary.scalar("train/multiscale ssim", -10 * tf.log(train_ssim)) # Log training images s_original = tf.summary.image("images/original", quantize_image(x)) s_reconstruction = tf.summary.image("images/reconstruction", quantize_image(x_tilde)) # Merge scalars into a summary train_summary = tf.summary.merge([s_loss, s_mse, s_psnr, s_msssim_value, s_ssim]) #Merge images into a summary image_summary = tf.summary.merge([s_original, s_reconstruction]) hooks = [ tf.train.StopAtStepHook(last_step=args.last_step), tf.train.NanTensorHook(train_loss), tf.train.SummarySaverHook(save_secs=30,output_dir=args.checkpoint_dir,summary_op=train_summary), tf.train.SummarySaverHook(save_secs=3600,output_dir=args.checkpoint_dir,summary_op=image_summary) ] with tf.train.MonitoredTrainingSession( hooks=hooks, checkpoint_dir=args.checkpoint_dir, save_checkpoint_secs=300, save_summaries_steps=None, save_summaries_secs=None) as sess: while not sess.should_stop(): sess.run(train_op)
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_size = mesh_shape.size mesh_devices = [""] * mesh_size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def test_multitower_examples_model(self): """Ensure graph search runs properly on a multitower setup. This test uses linear_model from examples/convnets. """ with tf.Graph().as_default(): def linear_model(images, labels, num_classes): """Creates a linear model. Args: images: The input image tensors, a tensor of size (batch_size x height_in x width_in x channels). labels: The sparse target labels, a tensor of size (batch_size x 1). num_classes: The number of classes, needed for one-hot encoding (int). Returns: loss: The total loss for this model (0-D tensor). logits: Predictions for this model (batch_size x num_classes). """ images = tf.reshape(images, [images.shape[0], -1]) logits = tf.layers.dense(images, num_classes, name='logits') loss = sparse_softmax_cross_entropy(labels, logits, num_classes) return loss, logits model = linear_model layer_collection = lc.LayerCollection() num_towers = 2 batch_size = num_towers num_classes = 2 # Set up data. images = tf.random_uniform(shape=[batch_size, 32, 32, 1]) labels = tf.random_uniform(dtype=tf.int64, shape=[batch_size, 1], maxval=num_classes) tower_images = tf.split(images, num_towers) tower_labels = tf.split(labels, num_towers) # Build model. losses = [] logits = [] for tower_id in range(num_towers): tower_name = 'tower%d' % tower_id with tf.name_scope(tower_name): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): current_loss, current_logits = model( tower_images[tower_id], tower_labels[tower_id], num_classes + 1) layer_collection.register_categorical_predictive_distribution( current_logits, name='logits') losses.append(current_loss) logits.append(current_logits) # Run the graph scanner. with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): gs.register_layers(layer_collection, tf.trainable_variables()) self.assertEqual(len(layer_collection.fisher_blocks), 1) fisher_block = list(layer_collection.fisher_blocks.values())[0] self.assertIsInstance(fisher_block, fb.FullyConnectedKFACBasicFB) self.assertEqual(fisher_block.num_registered_towers, num_towers) global_step = tf.train.get_or_create_global_step() opt = optimizer.KfacOptimizer(learning_rate=0.1, cov_ema_decay=0.1, damping=0.1, layer_collection=layer_collection, momentum=0.1) cost = tf.reduce_mean(losses) (cov_update_thunks, inv_update_thunks) = opt.make_vars_and_create_op_thunks() cov_update_op = tf.group(*(thunk() for thunk in cov_update_thunks)) inv_update_op = tf.group(*(thunk() for thunk in inv_update_thunks)) train_op = opt.minimize(cost, global_step=global_step) init = tf.global_variables_initializer() # Run a single training step. with self.test_session() as sess: sess.run(init) sess.run([cov_update_op]) sess.run([inv_update_op]) sess.run([train_op])
def benchmark_model(self, warmup_runs, bm_runs, num_threads, trace_filename=None): """Benchmark model.""" if self.tensorrt: print('Using tensorrt ', self.tensorrt) graphdef = self.freeze_model() if num_threads > 0: print('num_threads for benchmarking: {}'.format(num_threads)) sess_config = tf.ConfigProto( intra_op_parallelism_threads=num_threads, inter_op_parallelism_threads=1) else: sess_config = tf.ConfigProto() # rewriter_config_pb2.RewriterConfig.OFF sess_config.graph_options.rewrite_options.dependency_optimization = 2 if self.use_xla: sess_config.graph_options.optimizer_options.global_jit_level = ( tf.OptimizerOptions.ON_2) with tf.Graph().as_default(), tf.Session(config=sess_config) as sess: inputs = tf.placeholder(tf.float32, name='input', shape=self.inputs_shape) output = self.build_model(inputs, is_training=False) img = np.random.uniform(size=self.inputs_shape) sess.run(tf.global_variables_initializer()) if self.tensorrt: fetches = [inputs.name] + [i.name for i in output] goutput = self.convert_tr(graphdef, fetches) inputs, output = goutput[0], goutput[1:] if not self.use_xla: # Don't use tf.group because XLA removes the whole graph for tf.group. output = tf.group(*output) else: output = tf.add_n([tf.reduce_sum(x) for x in output]) output_name = [output.name] input_name = inputs.name graphdef = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_name) with tf.Graph().as_default(), tf.Session(config=sess_config) as sess: tf.import_graph_def(graphdef, name='') for i in range(warmup_runs): start_time = time.time() sess.run(output_name, feed_dict={input_name: img}) logging.info('Warm up: {} {:.4f}s'.format(i, time.time() - start_time)) print('Start benchmark runs total={}'.format(bm_runs)) start = time.perf_counter() for i in range(bm_runs): sess.run(output_name, feed_dict={input_name: img}) end = time.perf_counter() inference_time = (end - start) / bm_runs print('Per batch inference time: ', inference_time) print('FPS: ', self.batch_size / inference_time) if trace_filename: run_options = tf.RunOptions() run_options.trace_level = tf.RunOptions.FULL_TRACE run_metadata = tf.RunMetadata() sess.run( output_name, feed_dict={input_name: img}, options=run_options, run_metadata=run_metadata) logging.info('Dumping trace to %s', trace_filename) trace_dir = os.path.dirname(trace_filename) if not tf.io.gfile.exists(trace_dir): tf.io.gfile.makedirs(trace_dir) with tf.io.gfile.GFile(trace_filename, 'w') as trace_file: trace = timeline.Timeline(step_stats=run_metadata.step_stats) trace_file.write(trace.generate_chrome_trace_format(show_memory=True))
def make_update_op(update_thunks): update_ops = [thunk() for thunk in update_thunks] return tf.group(*update_ops)
def build_all_reduce_iterations(all_device_tensors, tower_devices, variable_mgr, num_iters): """Builds the all-reduce ops for multiple iterations to aggregate tensors. The tensors in `all_device_tensors` are aggregated `num_iters` times. Each iteration aggregates the results from the previous iteration. The iterations are run sequentially, so the aggregations for an iteration do not start running until the previous iteration has completed. Each iteration after the first is aggregating already-aggregated values, but it does not matter because we are only aggregating for benchmarking purposes. Args: all_device_tensors: List of lists of tensors. all_device_tensors[t][i] is a tensor, where t is the tower the tensor is on and i is the index of the tensor. tower_devices: A list of device strings. tower_devices[t] is the device of the tensors in all_device_tensors[t]. variable_mgr: The VariableMgr to perform the all-reduce. num_iters: Number of iterations to aggregate tensors for. Returns: An op that when run, causes the all-reduce ops to run. """ for i in range(num_iters): with tf.name_scope('iteration_%d' % i): # Step 1: Do the aggregation. with tf.name_scope('tensor_aggregation'): all_device_tensors = all_reduce(all_device_tensors, variable_mgr) # Step 2. Create identity ops, to bring the aggregated results back to # each device. new_all_device_tensors = [] for device, device_tensors in zip(tower_devices, all_device_tensors): with tf.device(device): new_all_device_tensors.append([ tf.identity(t, name='identity_after_allreduce') for t in device_tensors ]) all_device_tensors = new_all_device_tensors # Step 3. Add control dependencies to delay the next iteration until this # iteration is complete. To avoid extra overhead, we do not have any # cross-device control dependencies, which means it's possible for two # iterations to slightly overlap. new_all_device_tensors = [] for device_tensors in all_device_tensors: new_all_device_tensors.append([ control_flow_ops.with_dependencies( device_tensors, t, name='identity_after_dependencies') for t in device_tensors ]) all_device_tensors = new_all_device_tensors # To prevent the dependency optimizer from removing every op we created, # we store the results in variables. ops_to_run = [] for device, device_tensors in zip(tower_devices, all_device_tensors): with tf.device(device): for t in device_tensors: # The placeholder initial value is never run. var = tf.Variable(tf.placeholder(tf.float32, t.shape), collections=[]) ops_to_run.append(var.assign(t)) return tf.group(*ops_to_run)
def make_batch_executed_op(update_thunks, batch_size=1): return tf.group(*kfac.utils.batch_execute( global_step, update_thunks, batch_size=batch_size))
def __init__(self, config): if config.dataset.dir: # Gets the names of the classes classes_file = os.path.join(config.dataset.dir, 'classes.json') if tf.gfile.Exists(classes_file): self.class_labels = json.load(tf.gfile.GFile(classes_file)) else: self.class_labels = None # Don't use data augmentation in predictions config.dataset.data_augmentation = None dataset_class = get_dataset(config.dataset.type) model_class = get_model(config.model.type) dataset = dataset_class(config) model = model_class(config) graph = tf.Graph() self.session = tf.Session(graph=graph) with graph.as_default(): self.image_placeholder = tf.placeholder(tf.float32, (None, None, 3)) image_tf, _, process_meta = dataset.preprocess( self.image_placeholder) pred_dict = model(image_tf) # Restore checkpoint if config.train.job_dir: job_dir = config.train.job_dir if config.train.run_name: job_dir = os.path.join(job_dir, config.train.run_name) ckpt = tf.train.get_checkpoint_state(job_dir) if not ckpt or not ckpt.all_model_checkpoint_paths: raise ValueError( 'Could not find checkpoint in {}.'.format(job_dir)) ckpt = ckpt.all_model_checkpoint_paths[-1] saver = tf.train.Saver(sharded=True, allow_empty=True) saver.restore(self.session, ckpt) tf.logging.info('Loaded checkpoint.') else: # A prediction without checkpoint is just used for testing tf.logging.warning( 'Could not load checkpoint. Using initialized model.') init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) self.session.run(init_op) if config.model.type == 'ssd': cls_prediction = pred_dict['classification_prediction'] objects_tf = cls_prediction['objects'] objects_labels_tf = cls_prediction['labels'] objects_labels_prob_tf = cls_prediction['probs'] elif config.model.type == 'fasterrcnn': if config.model.network.get('with_rcnn', False): cls_prediction = pred_dict['classification_prediction'] objects_tf = cls_prediction['objects'] objects_labels_tf = cls_prediction['labels'] objects_labels_prob_tf = cls_prediction['probs'] else: rpn_prediction = pred_dict['rpn_prediction'] objects_tf = rpn_prediction['proposals'] objects_labels_prob_tf = rpn_prediction['scores'] # All labels without RCNN are zero objects_labels_tf = tf.zeros( tf.shape(objects_labels_prob_tf), dtype=tf.int32) else: raise ValueError("Model type '{}' not supported".format( config.model.type)) self.fetches = { 'objects': objects_tf, 'labels': objects_labels_tf, 'probs': objects_labels_prob_tf, 'scale_factor': process_meta['scale_factor'] } # If in debug mode, return the full prediction dictionary. if config.train.debug: self.fetches['_debug'] = pred_dict
def update_op_if_nan_or_inf(): """Update loss_scale and discard gradients if nans/infs occurred.""" return tf.group(tf.assign(loss_scale, loss_scale / 2.), tf.assign(loss_scale_normal_steps, 0))
gen_cost, var_list=lib.params_with_name('Generator'), colocate_gradients_with_ops=True) disc_train_op = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize( disc_cost, var_list=lib.params_with_name('Discriminator.'), colocate_gradients_with_ops=True) clip_ops = [] for var in lib.params_with_name('Discriminator'): clip_bounds = [-.01, .01] clip_ops.append( tf.assign( var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]))) clip_disc_weights = tf.group(*clip_ops) elif MODE == 'wgan-gp': gen_train_op = tf.train.AdamOptimizer( learning_rate=1e-4, beta1=0., beta2=0.9).minimize(gen_cost, var_list=lib.params_with_name('Generator'), colocate_gradients_with_ops=True) disc_train_op = tf.train.AdamOptimizer( learning_rate=1e-4, beta1=0., beta2=0.9).minimize( disc_cost, var_list=lib.params_with_name('Discriminator.'), colocate_gradients_with_ops=True) elif MODE == 'dcgan': gen_train_op = tf.train.AdamOptimizer(
def update_op_if_no_nan_or_inf(): """Apply gradients, and update loss scaling.""" return tf.group( get_loss_scale_update_op(loss_scale, loss_scale_normal_steps, inc_loss_scale_every_n), *get_apply_gradients_ops_func())
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: dictionary where keys are strings like "inputs" and "targets" and the values are the actual values of "inputs". See TPUEstimator's docs for more information labels: ignored argument mode: a tf.estimator.ModeKeys params: dictionary containing the key "context" config: ignored argument Returns: a TPUEstimatorSpec """ del labels, config global_step = tf.train.get_global_step() if use_tpu and "context" in params: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) # deprecated mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu( mesh_shape.to_integer_list, physical_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None # deprecated mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) mtf_features = {} for key, x in features.items(): outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) # Some auxiliary features may have been generated in packing. # The names of these new features are of the form # "<original_feature_name>_<suffix>", e.g. "inputs_segmentation". # We look up the lengths based on the original feature name, without # the "_<suffix>". feature_length = sequence_length[key.split("_")[0]] length_dim = mtf.Dimension("length", feature_length) ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)] if ensemble_inputs else []) feature_shape = mtf.Shape(ensemble_dims + [outer_batch_dim, batch_dim, length_dim]) x = tf.cast(features[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) if not use_tpu: tf.logging.info("feature %s : %s" % (key, x)) x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=10) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if key == "targets" or key == "codeprefixedtargets" or key == "controlcode": anon_targets = mtf.anonymize(mtf_features[key]) if mode == tf.estimator.ModeKeys.PREDICT: def _feature_shape(key): feature_length = sequence_length[key.split("_")[0]] return mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", feature_length) ]) mtf_features = { k: mtf.reshape(v, _feature_shape(k)) for k, v in six.iteritems(mtf_features) } inputs = mtf_features["inputs"] if attribute_embedding: attributes = mtf_features["attribute"] else: attributes = None if has_partial_sequences: controlcodes = mtf_features["controlcode"] else: controlcodes = None if predict_fn: mtf_samples = predict_fn(model=transformer_model, features=mtf_features, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Unitransformer): # pad so that there is enough room for the targets inputs = mtf.pad(inputs, [0, sequence_length["targets"]], length_dim.name) mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype(), remove_partial_sequences=True) elif isinstance(transformer_model, Bitransformer_ll): mtf_samples = transformer_model.decode( inputs, attributes=attributes, controlcodes=controlcodes, has_partial_sequences=has_partial_sequences, remove_partial_sequences=remove_partial_sequences, variable_dtype=get_variable_dtype()) # elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} # When exporting a model, we need to communicate to TF-Serving that # master variables need to be copied to their slave slice variables. # Estimator uses a Scaffold's "local_init_op" for this purpose, so we # augment the default "local_init_op" here. # # The "ready_op" is also constructed here to ensure the variables # initialized by "local_init_op" are the same ones checked by "ready_op". # # WARNING: Any variables created outside of this model_fn() # (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor # checked by these ops. def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) def logits_and_loss(mtf_features): """Compute logits and loss. Args: mtf_features: a dictionary Returns: logits: a mtf.Tensor loss: a mtf.Tensor """ if model_type == "lm": # TOTRY Adapt that to our case if "inputs" in mtf_features: mtf_features = _dynamic_text2self(mtf_features) _, _, length_dim = mtf_features["targets"].shape inputs = mtf.shift(mtf_features["targets"], offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if attribute_embedding: attributes = mtf_features["attribute"] else: attributes = None if control_codes: codeprefixedtargets = mtf_features["codeprefixedtargets"] else: codeprefixedtargets = None if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": if control_codes: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "codeprefixedtargets_segmentation", None), decoder_subsequence_id=mtf_features.get( "codeprefixedtargets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "codeprefixedtargets_position", None), ) else: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), decoder_subsequence_id=mtf_features.get( "targets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") if isinstance(transformer_model, Bitransformer_ll): if cycle_consistency_loss: logits_ae, l_ae = transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if has_partial_sequences: controlcodes = mtf_features["controlcode"] else: controlcodes = None with gin.config_scope('training'): mtf_samples = transformer_model.decode( inputs, attributes=attributes, controlcodes=controlcodes, has_partial_sequences=has_partial_sequences, remove_partial_sequences=remove_partial_sequences, variable_dtype=get_variable_dtype()) # mtf_samples = mtf.anonymize(mtf_samples) outputs = mtf_samples logits_cycle, l_cycle = transformer_model.call_simple( inputs=outputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle return logits_cycle, loss_ae_cycle else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), num_microbatches=num_microbatches, **position_kwargs) if mode == tf.estimator.ModeKeys.TRAIN: num_microbatches = serialize_num_microbatches( batch_dim, sequence_length, mesh_shape, layout_rules) if num_microbatches > 1: def serialized_fn(mtf_features): return { "loss": (logits_and_loss(mtf_features)[1] / num_microbatches) } var_grads, loss_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = loss_dict["loss"] else: loss = logits_and_loss(mtf_features)[1] var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if tpu_summaries: mtf.scalar_summary("loss", loss) if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule if isinstance(variable_filter, str): pattern = re.compile(variable_filter) variable_filter_fn = lambda v: pattern.search(v.name) elif variable_filter is None: variable_filter_fn = lambda v: True elif callable(variable_filter): variable_filter_fn = variable_filter else: raise ValueError( "variable_filter must be None, a string, or a callable function" ) trainable_vars = [ v for v in graph.trainable_variables if variable_filter_fn(v) ] trainable_var_grads = [ g for g, v in zip(var_grads, graph.trainable_variables) if variable_filter_fn(v) ] if len(trainable_vars) != len(graph.trainable_variables): tf.logging.info("Variables being trained:") tf.logging.info([v.name for v in trainable_vars]) tf.logging.info("Variables not being trained:") tf.logging.info([ v.name for v in graph.trainable_variables if not variable_filter_fn(v) ]) update_ops = optimizer(learning_rate=learning_rate).apply_grads( trainable_var_grads, trainable_vars) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() if tpu_summaries: # has to be outside of # with mtf.utils.outside_all_rewrites() host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None with mtf.utils.outside_all_rewrites(): if init_checkpoint: ckpt_vars = { v for v, _ in tf.train.list_variables(init_checkpoint) } global_vars = {v.op.name for v in tf.global_variables()} restore_vars = ckpt_vars.intersection(global_vars) tf.logging.info("Initializing variables from %s:", init_checkpoint) tf.logging.debug("\n".join(sorted(restore_vars))) tf.logging.info("Variables in %s but not in graph:", init_checkpoint) tf.logging.info("\n".join(sorted(ckpt_vars - global_vars))) tf.logging.info("Variables in graph but not in %s:", init_checkpoint) tf.logging.info("\n".join(sorted(global_vars - ckpt_vars))) tf.train.init_from_checkpoint(init_checkpoint, {v: v for v in restore_vars}) # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True, include_step_in_filename=False) if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) elif mode == tf.estimator.ModeKeys.EVAL: logits, loss = logits_and_loss(mtf_features) anon_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) tf_loss = tf.cast(tf_loss, tf.float32) tf_logits = tf.cast(lowering.export_to_tf_tensor(anon_logits), tf.float32) def simple_metrics(logits, labels): """Simple metrics for teacher-forced eval.""" weights = tf.cast(tf.not_equal(labels, 0), tf.float32) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) predictions = tf.cast(tf.argmax(logits, axis=-1), labels.dtype) token_correct = tf.cast(tf.equal(predictions, labels), tf.float32) * weights sequence_correct = tf.to_float( tf.equal(tf.reduce_sum(token_correct, -1), tf.reduce_sum(weights, -1))) sequence_weights = tf.to_float( tf.not_equal(tf.reduce_sum(weights, -1), 0)) return { "neg_log_perplexity": tf.metrics.mean(-xent, weights), "token_accuracy": tf.metrics.mean(token_correct, weights), "sequence_accuracy": tf.metrics.mean(sequence_correct, sequence_weights) } labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (simple_metrics, [tf_logits, labels]) with mtf.utils.outside_all_rewrites(): restore_hook = mtf.MtfRestoreHook(lowering) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def increment_loss_scale_normal_steps_func(): return tf.group(loss_scale_normal_steps.assign_add(1))
def _compute_inner_update_onsbet(self, var, grad): update_ops = [] eta = tf.cast(self.eta, var.dtype.base_dtype) betting_domain = tf.cast(self.betting_domain, var.dtype.base_dtype) wealth = self.get_slot(var, INNER_WEALTH) betting_fraction = self.get_slot(var, OUTER_BETTING_FRACTION) inner_betting_fraction = self.get_slot(var, INNER_BETTING_FRACTION) sum_grad_squared = self.get_slot(var, INNER_SUM_GRAD_SQUARED) inner_maximum_gradient = self.get_slot(var, INNER_MAXIMUM_GRADIENT) inner_maximum_gradient_updated = self._assign( inner_maximum_gradient, tf.maximum(inner_maximum_gradient, tf.abs(grad))) update_ops.append(inner_maximum_gradient_updated) clipped_old_betting_fraction = tf.clip_by_value( betting_fraction, -betting_domain, betting_domain) # Process grad to respect truncation to [-betting_domain, betting_domain] truncated_grad = tf.where( tf.greater_equal( grad * (betting_fraction - clipped_old_betting_fraction), 0), grad, tf.zeros(tf.shape(grad))) wealth_delta = -betting_fraction * truncated_grad wealth_updated = self._assign_add(wealth, wealth_delta) update_ops.append(wealth_updated) # This is the gradient with respect to the betting fraction v # use by the ONS algorithm - a kind of "inner inner grad". # Hueristic: We also scale v_grad down by the inner maximum gradient so as # to make it ``unitless''. This is helpful because the learning rate for # ONS is proportional to sum v_grad**2, and so the scale of the learning # rate and of v_grad are unlikely to be properly matched without this. if self.rescale_inner: v_grad = truncated_grad / ( (1.0 - inner_betting_fraction * truncated_grad) * inner_maximum_gradient_updated) else: v_grad = truncated_grad / ( (1.0 - inner_betting_fraction * truncated_grad)) sum_grad_squared_updated = self._assign_add(sum_grad_squared, tf.square(v_grad)) update_ops.append(sum_grad_squared_updated) new_inner_betting_fraction = inner_betting_fraction - eta * v_grad / ( sum_grad_squared_updated) new_inner_betting_fraction = tf.clip_by_value( new_inner_betting_fraction, -betting_domain, betting_domain) inner_betting_fraction_updated = self._assign( inner_betting_fraction, new_inner_betting_fraction) update_ops.append(inner_betting_fraction_updated) if self.output_summaries: mean_inner_betting_fraction_summary = tf.reduce_mean( tf.abs(inner_betting_fraction_updated)) max_inner_betting_fraction_summary = tf.reduce_max( tf.abs(inner_betting_fraction_updated)) inner_maximum_gradient_summary = tf.reduce_max( inner_maximum_gradient_updated) tf.summary.scalar(self._name + "/mean_inner_betting/" + var.name, mean_inner_betting_fraction_summary) tf.summary.scalar(self._name + "/max_inner_betting/" + var.name, max_inner_betting_fraction_summary) tf.summary.scalar( self._name + "/inner_maximum_gradient/" + var.name, inner_maximum_gradient_summary) betting_fraction_updated = self._assign( betting_fraction, inner_betting_fraction_updated * wealth_updated) update_ops.append(betting_fraction_updated) clipped_betting_fraction = tf.clip_by_value(betting_fraction_updated, -betting_domain, betting_domain) return clipped_betting_fraction, tf.group(*update_ops)
def increase_loss_scale_func(): return tf.group(tf.assign(loss_scale_normal_steps, 0), tf.assign(loss_scale, loss_scale * 2))
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) # MTF setup. graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info("device_list = %s" % device_list, ) replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list(ctx.device_assignment.topology.mesh_shape) logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu( mesh_shape.to_integer_list, physical_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) mesh = mtf.Mesh(graph, "bert_mesh", var_placer) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1) batch_size = input_ids.get_shape()[0].value batch_dim = mtf.Dimension("batch", batch_size) seq_length = input_ids.get_shape()[1].value seq_dim = mtf.Dimension("seq", seq_length) max_predictions_per_seq = masked_lm_positions.get_shape()[1].value max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq", max_predictions_per_seq) mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids, [batch_dim, seq_dim]) mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask, [batch_dim, seq_dim]) mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids, [batch_dim, seq_dim]) mtf_masked_lm_positions = mtf.import_tf_tensor( mesh, masked_lm_positions, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_ids = mtf.import_tf_tensor( mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim]) mtf_masked_lm_weights = mtf.import_tf_tensor( mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim]) mtf_next_sentence_labels = mtf.import_tf_tensor( mesh, next_sentence_labels, [batch_dim]) is_training = (mode == tf.estimator.ModeKeys.TRAIN) model = bert_lib.BertModel(config=bert_config, is_training=is_training, input_ids=mtf_input_ids, input_mask=mtf_input_mask, token_type_ids=mtf_segment_ids, layout=layout_rules, mesh_shape=mesh_shape) (masked_lm_loss, masked_lm_example_loss, masked_lm_logits) = model.get_masked_lm_output( mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights) (next_sentence_loss, next_sentence_example_loss, next_sentence_logits ) = model.get_next_sentence_output(mtf_next_sentence_labels) extra_loss = model.get_extra_loss() total_loss = masked_lm_loss + next_sentence_loss total_loss = mtf.anonymize(total_loss) masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss) masked_lm_logits = mtf.anonymize(masked_lm_logits) next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss) next_sentence_logits = mtf.anonymize(next_sentence_logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: _, update_ops = optimization_lib.create_optimizer( total_loss + extra_loss, learning_rate, num_train_steps, num_warmup_steps, optimizer=FLAGS.optimizer, clip_gradients=FLAGS.clip_gradients) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss)) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_logits, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_logits, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_logits = tf.reshape(masked_lm_logits, [-1, masked_lm_logits.shape[-1]]) masked_lm_predictions = tf.argmax(masked_lm_logits, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_logits = tf.reshape( next_sentence_logits, [-1, next_sentence_logits.shape[-1]]) next_sentence_predictions = tf.argmax(next_sentence_logits, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss, } eval_metrics = (metric_fn, [ lowering.export_to_tf_tensor(masked_lm_example_loss), lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids, masked_lm_weights, lowering.export_to_tf_tensor(next_sentence_example_loss), lowering.export_to_tf_tensor(next_sentence_logits), next_sentence_labels ]) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.output_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tf.estimator.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.tpu.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def evaluate(self, input_fn, checkpoint_path=None): if not tf.train.latest_checkpoint(checkpoint_path): raise ValueError("Could not find trained model at %s" % checkpoint_path) with tf.Graph().as_default(): features, labels = self._get_features_and_labels_from_input_fn( input_fn, ModeKeys.EVAL) spec, model = self._get_model_spec(features, labels, ModeKeys.EVAL) # Track the average loss in default eval_metric_ops = spec.eval_metric_ops or {} if model_fn_lib.LOSS_METRIC_KEY not in eval_metric_ops: loss_metric = tf.metrics.mean(spec.loss) eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric # Create the real eval op update_ops, eval_dict = _extract_metric_update_ops(eval_metric_ops) update_ops.extend(model._train_ops) eval_op = tf.group(*update_ops) # Also track the global step if tf.GraphKeys.GLOBAL_STEP in eval_dict: raise ValueError( 'Metric with name `global_step` is not allowed, because ' 'Estimator already defines a default metric with the ' 'same name.') eval_dict[tf.GraphKeys.GLOBAL_STEP] = \ tf.train.get_or_create_global_step() # Prepare the session creator. scaffold = tf.train.Scaffold() session_creator = tf.train.ChiefSessionCreator( scaffold=scaffold, checkpoint_dir=checkpoint_path) # Prepare hooks all_hooks = list(spec.evaluation_hooks) or [] final_ops_hook = tf.train.FinalOpsHook(eval_dict) all_hooks.append(final_ops_hook) # Evaluate over dataset self._bridge.connect() try: with tf.train.MonitoredSession(session_creator=session_creator, hooks=all_hooks) as sess: if not self._restore_datablock(DATA_CHECKPOINT_INIT_VALUE): raise ValueError("Restore data checkpoint error") iter_id = 0 while not sess.should_stop(): self._bridge.start(iter_id) logging.debug('after bridge start.') start_time = time.time() sess.run(eval_op) end_time = time.time() metrics.emit_timer(name="iter_timer", value=end_time - start_time, tags={}) logging.debug('after session run.') self._bridge.commit() logging.debug('after bridge commit.') iter_id += 1 finally: self._bridge.terminate() # Print result logging.info('Metrics for iteration %d: %s', iter_id, _dict_to_str(final_ops_hook.final_ops_values)) return final_ops_hook.final_ops_values
def benchmark_model(self, warmup_runs, bm_runs, num_threads, trace_filename=None): """Benchmark model.""" if self.tensorrt: print('Using tensorrt ', self.tensorrt) self.build_and_save_model() graphdef = self.freeze_model() if num_threads > 0: print('num_threads for benchmarking: {}'.format(num_threads)) sess_config = tf.ConfigProto( intra_op_parallelism_threads=num_threads, inter_op_parallelism_threads=1) else: sess_config = tf.ConfigProto() # rewriter_config_pb2.RewriterConfig.OFF sess_config.graph_options.rewrite_options.dependency_optimization = 2 if self.use_xla: sess_config.graph_options.optimizer_options.global_jit_level = ( tf.OptimizerOptions.ON_2) with tf.Graph().as_default(), tf.Session(config=sess_config) as sess: inputs = tf.placeholder(tf.float32, name='input', shape=self.inputs_shape) output = self.build_model(inputs, is_training=False) img = np.random.uniform(size=self.inputs_shape) sess.run(tf.global_variables_initializer()) if self.tensorrt: fetches = [inputs.name] + [i.name for i in output] goutput = self.convert_tr(graphdef, fetches) inputs, output = goutput[0], goutput[1:] if not self.use_xla: # Don't use tf.group because XLA removes the whole graph for tf.group. output = tf.group(*output) for i in range(warmup_runs): start_time = time.time() sess.run(output, feed_dict={inputs: img}) print('Warm up: {} {:.4f}s'.format(i, time.time() - start_time)) print('Start benchmark runs total={}'.format(bm_runs)) timev = [] for i in range(bm_runs): if trace_filename and i == (bm_runs // 2): run_options = tf.RunOptions() run_options.trace_level = tf.RunOptions.FULL_TRACE run_metadata = tf.RunMetadata() sess.run(output, feed_dict={inputs: img}, options=run_options, run_metadata=run_metadata) logging.info('Dumping trace to %s', trace_filename) trace_dir = os.path.dirname(trace_filename) if not tf.io.gfile.exists(trace_dir): tf.io.gfile.makedirs(trace_dir) with tf.io.gfile.GFile(trace_filename, 'w') as trace_file: from tensorflow.python.client import timeline # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top trace = timeline.Timeline( step_stats=run_metadata.step_stats) trace_file.write( trace.generate_chrome_trace_format( show_memory=True)) start_time = time.time() sess.run(output, feed_dict={inputs: img}) timev.append(time.time() - start_time) timev.sort() timev = timev[2:bm_runs - 2] print( '{} {}runs {}threads: mean {:.4f} std {:.4f} min {:.4f} max {:.4f}' .format(self.model_name, len(timev), num_threads, np.mean(timev), np.std(timev), np.min(timev), np.max(timev)))
def train( self, num_iterations=100, learning_rate=1.0, plot_results=True, # Changed from True optimizer=tf.train.GradientDescentOptimizer): """Trains the model. Args: iterations: number of iterations to run. learning_rate: optimizer learning rate. plot_results: whether to plot the results at the end of training. optimizer: the optimizer to use. Default to GradientDescentOptimizer. Returns: The metrics dictionary evaluated at the last iteration. """ with self._loss.graph.as_default(): opt = optimizer(learning_rate) train_op = opt.minimize(self._loss) local_init_op = tf.group(tf.variables_initializer(opt.variables()), tf.local_variables_initializer()) if self._session is None: self._session = tf.Session() with self._session.as_default(): self._session.run(tf.global_variables_initializer()) self._session.run(tf.tables_initializer()) tf.train.start_queue_runners() with self._session.as_default(): local_init_op.run() iterations = [] metrics = self._metrics or ({}, ) metrics_vals = [ collections.defaultdict(list) for _ in self._metrics ] # Train and append results. for i in range(num_iterations + 1): _, results = self._session.run((train_op, metrics)) if (i % 10 == 0) or i == num_iterations: print("\r iteration %d: " % i + ", ".join([ "%s=%f" % (k, v) for r in results for k, v in r.items() ]), end='') iterations.append(i) for metric_val, result in zip(metrics_vals, results): for k, v in result.items(): metric_val[k].append(v) for k, v in self._embedding_vars.items(): self._embeddings[k] = v.eval() if plot_results: # Plot the metrics. num_subplots = len(metrics) + 1 fig = plt.figure() fig.set_size_inches(num_subplots * 10, 8) for i, metric_vals in enumerate(metrics_vals): ax = fig.add_subplot(1, num_subplots, i + 1) for k, v in metric_vals.items(): ax.plot(iterations, v, label=k) ax.set_xlim([1, num_iterations]) ax.legend() plt.show() return results