def eq_cifar_fn(x, output_dim=10, trainable=True): gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input='Z2', h_output='C4', in_channels=3, out_channels=8, ksize=3) w = tf.get_variable('w1', shape=w_shape) conv1 = gconv2d(input=x, filter=w, strides=[1, 2, 2, 1], padding='SAME', gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info) tf.add_to_collection('conv_output1', conv1) pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input='C4', h_output='C4', in_channels=16, out_channels=32, ksize=5) w = tf.get_variable('w2', shape=w_shape) conv2 = gconv2d(input=conv1, filter=w, strides=[1, 2, 2, 1], padding='SAME', gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info) pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input='C4', h_output='C4', in_channels=8, out_channels=2, ksize=5) w = tf.get_variable('w3', shape=w_shape) conv3 = gconv2d(input=conv2, filter=w, strides=[1, 1, 1, 1], padding='SAME', gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info) conv3 = tf.reshape(conv3, conv3.get_shape().as_list()[:3] + [4] + [out_channels]) conv3 = tf.reduce_mean(conv3, axis=3) pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2) pool3_flat = tf.layers.flatten(pool3) u = pool3_flat u = tf.layers.dense(inputs=pool3_flat, units=output_dim, activation=tf.nn.relu, trainable=trainable) tf.add_to_collection('conv_output2', conv2) return u
def eq_cnn_fn(x, output_dim=10, trainable=True, group='C4', num_filters=2): nchannels = x.shape[3] gconv_indices, gconv_shape_info, w_shape = gconv2d_util( h_input='Z2', h_output='C4', in_channels=nchannels, out_channels=2, ksize=5) w = tf.get_variable('w1', shape=w_shape) conv1 = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding='SAME', gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info) tf.add_to_collection('conv_output1', conv1) pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) # pool1 = layers.Dropout(0.25)(pool1) out_channels = 2 gconv_indices, gconv_shape_info, w_shape = gconv2d_util( h_input='C4', h_output='C4', in_channels=2, out_channels=out_channels, ksize=5) w = tf.get_variable('w2', shape=w_shape) conv2 = gconv2d(input=conv1, filter=w, strides=[1, 1, 1, 1], padding='SAME', gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info) conv2 = tf.reshape(conv2, conv2.get_shape().as_list()[:3] + [4] + [out_channels]) conv2 = tf.reduce_mean(conv2, axis=3) conv2 = tf.reshape(conv2, conv2.get_shape().as_list()[:3] + [out_channels]) pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) pool2_flat = tf.layers.flatten(pool2) u = pool2_flat print(u.shape) u = tf.layers.dense(inputs=pool2_flat, units=output_dim, activation=tf.nn.relu, trainable=trainable) tf.add_to_collection('conv_output2', conv2) return u
def cnn_fn(x, output_dim, trainable=True, group=None, mnist=True, num_filters=64): """t Adapted from https://www.tensorflow.org/tutorials/layers """ if not mnist: input_shape = [32, 32, 3] else: input_shape = [28, 28, 1] input_shape = x.shape[1:4] conv1 = tf.layers.conv2d(inputs=x, filters=num_filters, kernel_size=[5, 5], padding="same", activation=None, trainable=trainable) tf.add_to_collection('conv_output1', conv1) pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) # pool1 = layers.Dropout(0.25)(pool1)# pool1 = tf.Print(pool1, [pool1], "Here's pooling: ") conv2 = tf.layers.conv2d(inputs=pool1, filters=num_filters, kernel_size=[5, 5], padding="same", activation=None, trainable=trainable) tf.add_to_collection('conv_output2', conv2) pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) # pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 8]) pool2_flat = tf.layers.flatten(pool2) u = pool2_flat u = tf.layers.dense(inputs=pool2_flat, units=60, activation=tf.nn.relu, trainable=trainable) u = tf.layers.dense(inputs=u, units=output_dim, activation=None, trainable=trainable) return u
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key, allow_missing=False): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string allow_missing: a boolean Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) if key not in features: if allow_missing: return None else: raise ValueError("feature not found %s - features %s = " % (key, features)) tf.logging.info("Import feature %s: %s" % (key, features[key])) x = tf.to_int32(features[key]) x = tf.reshape( x, [outer_batch_size, batch_size // outer_batch_size, -1]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if model_type == "lm": _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if mode == tf.estimator.ModeKeys.EVAL: if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) labels = lowering.export_to_tf_tensor(anon_targets) restore_hook = mtf.MtfRestoreHook(lowering) # metric_names becomes locally scoped if we simply assign # ["padded_neg_log_perplexity"] to it conditioned on if it's None. local_metric_names = metric_names or ["token_accuracy"] def metric_fn(labels, outputs): return get_metric_fns(local_metric_names, labels, outputs) eval_metrics = (metric_fn, [labels, outputs]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, # Unfortunately TPUEstimatorSpec requires us to provide a value for # loss when in EVAL mode. Since we are sampling or decoding from the # model, we don't have a loss to report. loss=tf.constant(0.), evaluation_hooks=[restore_hook], eval_metrics=eval_metrics) if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=_import_feature("targets_segmentation", True), position=_import_feature("targets_position", True), ) elif isinstance(transformer_model, transformer.Bitransformer): position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation", True), decoder_sequence_id=_import_feature("targets_segmentation", True), encoder_position=_import_feature("inputs_position", True), decoder_position=_import_feature("targets_position", True), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=checkpoints_to_keep, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])
def bn(x, params=None, moments=None, backprop_through_moments=True, use_ema=False, is_training=True, ema_epsilon=.9): """Batch normalization. The usage should be as follows: If x is the support images, moments should be None so that they are computed from the support set examples. On the other hand, if x is the query images, the moments argument should be used in order to pass in the mean and var that were computed from the support set. Args: x: inputs. params: None or a dict containing the values of the offset and scale params. moments: None or a dict containing the values of the mean and var to use for batch normalization. backprop_through_moments: Whether to allow gradients to flow through the given support set moments. Only applies to non-transductive batch norm. use_ema: apply moving averages of batch norm statistics, or update them, depending on whether we are training or testing. Note that passing moments will override this setting, and result in neither updating or using ema statistics. This is important to make sure that episodic learners don't update ema statistics a second time when processing queries. is_training: if use_ema=True, this determines whether to apply the moving averages, or update them. ema_epsilon: if updating moving averages, use this value for the exponential moving averages. Returns: output: The result of applying batch normalization to the input. params: The updated params. moments: The updated moments. """ params_keys, params_vars, moments_keys, moments_vars = [], [], [], [] with tf.variable_scope('batch_norm'): scope_name = tf.get_variable_scope().name if use_ema: ema_shape = [1, 1, 1, x.get_shape().as_list()[-1]] mean_ema = tf.get_variable( 'mean_ema', shape=ema_shape, initializer=tf.initializers.zeros(), trainable=False) var_ema = tf.get_variable( 'var_ema', shape=ema_shape, initializer=tf.initializers.ones(), trainable=False) if moments is not None: if backprop_through_moments: mean = moments[scope_name + '/mean'] var = moments[scope_name + '/var'] else: # This variant does not yield good resutls. mean = tf.stop_gradient(moments[scope_name + '/mean']) var = tf.stop_gradient(moments[scope_name + '/var']) elif use_ema and not is_training: mean = mean_ema var = var_ema else: # If not provided, compute the mean and var of the current batch. replica_ctx = tf.distribute.get_replica_context() if replica_ctx: # from third_party/tensorflow/python/keras/layers/normalization_v2.py axes = list(range(len(x.shape) - 1)) local_sum = tf.reduce_sum(x, axis=axes, keepdims=True) local_squared_sum = tf.reduce_sum( tf.square(x), axis=axes, keepdims=True) batch_size = tf.cast(tf.shape(x)[0], tf.float32) x_sum, x_squared_sum, global_batch_size = ( replica_ctx.all_reduce('sum', [local_sum, local_squared_sum, batch_size])) axes_vals = [(tf.shape(x))[i] for i in range(1, len(axes))] multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32) multiplier = multiplier * global_batch_size mean = x_sum / multiplier x_squared_mean = x_squared_sum / multiplier # var = E(x^2) - E(x)^2 var = x_squared_mean - tf.square(mean) else: mean, var = tf.nn.moments( x, axes=list(range(len(x.shape) - 1)), keep_dims=True) # Only update ema's if training and we computed the moments in the current # call. Note: at test time for episodic learners, ema's may be passed # from the support set to the query set, even if it's not really needed. if use_ema and is_training and moments is None: replica_ctx = tf.distribute.get_replica_context() mean_upd = tf.assign(mean_ema, mean_ema * ema_epsilon + mean * (1.0 - ema_epsilon)) var_upd = tf.assign(var_ema, var_ema * ema_epsilon + var * (1.0 - ema_epsilon)) updates = tf.group([mean_upd, var_upd]) if replica_ctx: tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, tf.cond( tf.equal(replica_ctx.replica_id_in_sync_group, 0), lambda: updates, tf.no_op)) else: tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updates) moments_keys += [scope_name + '/mean'] moments_vars += [mean] moments_keys += [scope_name + '/var'] moments_vars += [var] if params is None: offset = tf.get_variable( 'offset', shape=mean.get_shape().as_list(), initializer=tf.initializers.zeros()) scale = tf.get_variable( 'scale', shape=var.get_shape().as_list(), initializer=tf.initializers.ones()) else: offset = params[scope_name + '/offset'] scale = params[scope_name + '/scale'] params_keys += [scope_name + '/offset'] params_vars += [offset] params_keys += [scope_name + '/scale'] params_vars += [scale] output = tf.nn.batch_normalization(x, mean, var, offset, scale, 0.00001) params = collections.OrderedDict(zip(params_keys, params_vars)) moments = collections.OrderedDict(zip(moments_keys, moments_vars)) return output, params, moments
def cifar_fn(x, output_dim=10, trainable=True, group=None, mnist=True): """t Adapted from https://www.tensorflow.org/tutorials/layers """ if not mnist: input_shape = [32, 32, 3] else: input_shape = [28, 28, 1] conv1 = tf.layers.conv2d(inputs=tf.reshape(x, [-1] + input_shape), filters=16, kernel_size=[3, 3], padding="same", activation=tf.nn.relu, trainable=trainable) tf.add_to_collection('conv_output1', conv1) conv1 = tf.layers.conv2d(inputs=conv1, filters=16, kernel_size=[3, 3], padding="same", activation=tf.nn.relu, trainable=trainable) pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) pool1 = layers.Dropout(0.25)(pool1) # pool1 = tf.Print(pool1, [pool1], "Here's pooling: ") conv2 = tf.layers.conv2d(inputs=pool1, filters=32, kernel_size=[3, 3], padding="same", activation=tf.nn.relu, trainable=trainable) tf.add_to_collection('conv_output2', conv2) conv2 = tf.layers.conv2d(inputs=conv2, filters=32, kernel_size=[3, 3], padding="same", activation=tf.nn.relu, trainable=trainable) pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) conv3 = tf.layers.conv2d(inputs=pool2, filters=64, kernel_size=[5, 5], padding="same", activation=tf.nn.relu, trainable=trainable) conv3 = tf.layers.conv2d(inputs=conv3, filters=64, kernel_size=[7, 7], padding="same", activation=tf.nn.relu, trainable=trainable) tf.add_to_collection('conv_output3', conv3) pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2) # pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 8]) pool2_flat = tf.layers.flatten(pool3) u = pool2_flat # u = tf.layers.dense(inputs=pool2_flat, units=256, activation=tf.nn.relu, trainable=trainable) u = tf.layers.dense(inputs=u, units=output_dim, activation=tf.nn.relu, trainable=trainable) return u
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) mtf_features = {} for key, x in features.items(): x = tf.to_int32(features[key]) x = tf.reshape(x, [ outer_batch_size, batch_size // outer_batch_size, sequence_length ]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = mtf_features["inputs"] inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) elif mode == tf.estimator.ModeKeys.EVAL: raise NotImplementedError("We don't expect to use mode == eval.") else: assert mode == tf.estimator.ModeKeys.TRAIN num_microbatches = serialize_num_microbatches( batch_dim, length_dim, mesh_shape, layout_rules) def model_fn(mtf_features): """The kind of function we need for mtf.serialize_training_step. Args: mtf_features: a dictionary Returns: a dictionary """ targets = mtf_features["targets"] if model_type == "lm": _, _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if num_microbatches > 1: loss /= float(num_microbatches) del logits return {"loss": loss} if num_microbatches > 1: var_grads, loss_dict = mtf.serialize_training_step( mtf_features, model_fn, batch_dim, num_microbatches) else: loss_dict = model_fn(mtf_features) var_grads = mtf.gradients( [loss_dict["loss"]], [v.outputs[0] for v in graph.trainable_variables]) loss = loss_dict["loss"] if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule update_ops = optimizer(learning_rate=learning_rate).apply_grads( var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if use_tpu: if tpu_summaries: tf.summary.scalar("loss", tf_loss) host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])