def safety_critic_loss(tf_agent, safety_critic, time_steps, actions, next_time_steps, safety_rewards, weights=None): """Returns a critic loss with safety.""" next_actions, next_log_pis = tf_agent._actions_and_log_probs( # pylint: disable=protected-access next_time_steps) del next_log_pis target_input = (next_time_steps.observation[0], next_actions[0]) target_q_values, unused_network_state1 = safety_critic( target_input, next_time_steps.step_type[0]) target_q_values = tf.nn.sigmoid(target_q_values) safety_rewards = tf.to_float(safety_rewards) td_targets = tf.stop_gradient(safety_rewards + (1 - safety_rewards) * next_time_steps.discount * target_q_values) td_targets = tf.squeeze(td_targets) pred_input = (time_steps.observation[0], actions[0]) pred_td_targets, unused_network_state1 = safety_critic( pred_input, time_steps.step_type[0]) loss = tf.losses.sigmoid_cross_entropy(td_targets, pred_td_targets) if weights is not None: loss *= tf.to_float(tf.squeeze(weights)) # Take the mean across the batch. loss = tf.reduce_mean(input_tensor=loss) return loss
def ctrl_rewards(states, actions, rewards, next_states, contexts, reward_scales=1.0): """Returns the negative control cost. Args: states: A [batch_size, num_state_dims] Tensor representing a batch of states. actions: A [batch_size, num_action_dims] Tensor representing a batch of actions. rewards: A [batch_size] Tensor representing a batch of rewards. next_states: A [batch_size, num_state_dims] Tensor representing a batch of next states. contexts: A list of [batch_size, num_context_dims] Tensor representing a batch of contexts. reward_scales: multiplicative scale for rewards. A scalar or 1D tensor, must be broadcastable to number of reward dimensions. Returns: A new tf.float32 [batch_size] rewards Tensor, and tf.float32 [batch_size] discounts tensor. """ del states, rewards, contexts # Unused if actions is None: rewards = tf.to_float(tf.zeros(shape=next_states.shape[:1])) else: rewards = -tf.reduce_sum(tf.square(actions), axis=1) rewards *= reward_scales rewards = tf.to_float(rewards) return rewards, tf.ones_like(rewards)
def potential_rewards(states, actions, rewards, next_states, contexts, gamma=1.0, reward_fn=None): """Return the potential-based rewards. Args: states: A [batch_size, num_state_dims] Tensor representing a batch of states. actions: A [batch_size, num_action_dims] Tensor representing a batch of actions. rewards: A [batch_size] Tensor representing a batch of rewards. next_states: A [batch_size, num_state_dims] Tensor representing a batch of next states. contexts: A list of [batch_size, num_context_dims] Tensor representing a batch of contexts. gamma: Reward discount. reward_fn: A reward function. Returns: A new tf.float32 [batch_size] rewards Tensor, and tf.float32 [batch_size] discounts tensor. """ del actions # unused args gamma = tf.to_float(gamma) rewards_tp1, discounts = reward_fn(None, None, rewards, next_states, contexts) rewards, _ = reward_fn(None, None, rewards, states, contexts) return -rewards + gamma * rewards_tp1, discounts
def state_rewards(states, actions, rewards, next_states, contexts, weight_index=None, state_indices=None, weight_vector=1.0, offset_vector=0.0, summarize=False): """Returns the rewards that are linear mapping of next_states. Args: states: A [batch_size, num_state_dims] Tensor representing a batch of states. actions: A [batch_size, num_action_dims] Tensor representing a batch of actions. rewards: A [batch_size] Tensor representing a batch of rewards. next_states: A [batch_size, num_state_dims] Tensor representing a batch of next states. contexts: A list of [batch_size, num_context_dims] Tensor representing a batch of contexts. weight_index: (integer) Index of contexts lists that specify weighting. state_indices: (a list of Numpy integer array) Indices of states dimensions to be mapped. weight_vector: (a number or a list or Numpy array) The weighting vector, broadcastable to `next_states`. offset_vector: (a number or a list of Numpy array) The off vector. summarize: (boolean) enable summary ops. Returns: A new tf.float32 [batch_size] rewards Tensor, and tf.float32 [batch_size] discounts tensor. """ del states, actions, rewards # unused args stats = {} record_tensor(next_states, state_indices, stats) next_states = index_states(next_states, state_indices) weight = tf.constant( weight_vector, dtype=next_states.dtype, shape=next_states[0].shape) weights = tf.expand_dims(weight, 0) offset = tf.constant( offset_vector, dtype=next_states.dtype, shape=next_states[0].shape) offsets = tf.expand_dims(offset, 0) if weight_index is not None: weights *= contexts[weight_index] rewards = tf.to_float(tf.reduce_sum(weights * (next_states+offsets), axis=1)) if summarize: with tf.name_scope('RewardFn/'): summarize_stats(stats) return rewards, tf.ones_like(rewards)
def binary_indicator(states, actions, rewards, next_states, contexts, termination_epsilon=1e-4, offset=0, epsilon=1e-10, state_indices=None, summarize=False): """Returns 0/1 by checking if next_states and contexts overlap. Args: states: A [batch_size, num_state_dims] Tensor representing a batch of states. actions: A [batch_size, num_action_dims] Tensor representing a batch of actions. rewards: A [batch_size] Tensor representing a batch of rewards. next_states: A [batch_size, num_state_dims] Tensor representing a batch of next states. contexts: A list of [batch_size, num_context_dims] Tensor representing a batch of contexts. termination_epsilon: terminate if dist is less than this quantity. offset: Offset the rewards. epsilon: small offset to ensure non-negative/zero distance. Returns: A new tf.float32 [batch_size] rewards Tensor, and tf.float32 [batch_size] discounts tensor. """ del states, actions # unused args next_states = index_states(next_states, state_indices) dist = tf.reduce_sum(tf.squared_difference(next_states, contexts[0]), -1) dist = tf.sqrt(dist + epsilon) discounts = dist > termination_epsilon rewards = tf.logical_not(discounts) rewards = tf.to_float(rewards) + offset return tf.to_float(rewards), tf.ones_like(tf.to_float(discounts)) #tf.to_float(discounts)
def timed_rewards(states, actions, rewards, next_states, contexts, reward_fn=None, dense=False, timer_index=-1): """Return the timed rewards. Args: states: A [batch_size, num_state_dims] Tensor representing a batch of states. actions: A [batch_size, num_action_dims] Tensor representing a batch of actions. rewards: A [batch_size] Tensor representing a batch of rewards. next_states: A [batch_size, num_state_dims] Tensor representing a batch of next states. contexts: A list of [batch_size, num_context_dims] Tensor representing a batch of contexts. reward_fn: A reward function. dense: (boolean) Provide dense rewards or sparse rewards at time = 0. timer_index: (integer) The context list index that specifies timer. Returns: A new tf.float32 [batch_size] rewards Tensor, and tf.float32 [batch_size] discounts tensor. """ assert contexts[timer_index].get_shape().as_list()[1] == 1 timers = contexts[timer_index][:, 0] rewards, discounts = reward_fn(states, actions, rewards, next_states, contexts) terminates = tf.to_float(timers <= 0) # if terminate set 1, else set 0 for _ in range(rewards.shape.ndims - 1): terminates = tf.expand_dims(terminates, axis=-1) if not dense: rewards *= terminates # if terminate, return rewards, else return 0 discounts *= (tf.to_float(1.0) - terminates) return rewards, discounts
def diff_rewards( states, actions, rewards, next_states, contexts, state_indices=None, goal_index=0,): """Returns (next_states - goals) as a batched vector reward.""" del states, rewards, actions # Unused if state_indices is not None: next_states = index_states(next_states, state_indices) rewards = tf.to_float(next_states - contexts[goal_index]) return rewards, tf.ones_like(rewards)
def tanh_similarity(states, actions, rewards, next_states, contexts, mse_scale=1.0, state_scales=1.0, goal_scales=1.0, summarize=False): """Returns the similarity between next_states and contexts using tanh and mse. Args: states: A [batch_size, num_state_dims] Tensor representing a batch of states. actions: A [batch_size, num_action_dims] Tensor representing a batch of actions. rewards: A [batch_size] Tensor representing a batch of rewards. next_states: A [batch_size, num_state_dims] Tensor representing a batch of next states. contexts: A list of [batch_size, num_context_dims] Tensor representing a batch of contexts. mse_scale: A float, to scale mse before tanh. state_scales: multiplicative scale for (next) states. A scalar or 1D tensor, must be broadcastable to number of state dimensions. goal_scales: multiplicative scale for contexts. A scalar or 1D tensor, must be broadcastable to number of goal dimensions. summarize: (boolean) enable summary ops. Returns: A new tf.float32 [batch_size] rewards Tensor, and tf.float32 [batch_size] discounts tensor. """ del states, actions, rewards # Unused mse = tf.reduce_mean(tf.squared_difference(next_states * state_scales, contexts[0] * goal_scales), -1) tanh = tf.tanh(mse_scale * mse) if summarize: with tf.name_scope('RewardFn/'): tf.summary.scalar('mean_mse', tf.reduce_mean(mse)) tf.summary.histogram('mse', mse) tf.summary.scalar('mean_tanh', tf.reduce_mean(tanh)) tf.summary.histogram('tanh', tanh) rewards = tf.to_float(1 - tanh) return rewards, tf.ones_like(rewards)
def diff_distance(states, actions, rewards, next_states, contexts, state_scales=1.0, goal_scales=1.0, reward_scales=1.0, weight_index=None, weight_vector=None, summarize=False, termination_epsilon=1e-4, state_indices=None, goal_indices=None, norm='L2', epsilon=1e-10): """Returns the difference in euclidean distance between states/next_states and contexts. Args: states: A [batch_size, num_state_dims] Tensor representing a batch of states. actions: A [batch_size, num_action_dims] Tensor representing a batch of actions. rewards: A [batch_size] Tensor representing a batch of rewards. next_states: A [batch_size, num_state_dims] Tensor representing a batch of next states. contexts: A list of [batch_size, num_context_dims] Tensor representing a batch of contexts. state_scales: multiplicative scale for (next) states. A scalar or 1D tensor, must be broadcastable to number of state dimensions. goal_scales: multiplicative scale for goals. A scalar or 1D tensor, must be broadcastable to number of goal dimensions. reward_scales: multiplicative scale for rewards. A scalar or 1D tensor, must be broadcastable to number of reward dimensions. weight_index: (integer) The context list index that specifies weight. weight_vector: (a number or a list or Numpy array) The weighting vector, broadcastable to `next_states`. summarize: (boolean) enable summary ops. termination_epsilon: terminate if dist is less than this quantity. state_indices: (a list of integers) list of state indices to select. goal_indices: (a list of integers) list of goal indices to select. vectorize: Return a vectorized form. norm: L1 or L2. epsilon: small offset to ensure non-negative/zero distance. Returns: A new tf.float32 [batch_size] rewards Tensor, and tf.float32 [batch_size] discounts tensor. """ del actions, rewards # Unused stats = {} record_tensor(next_states, state_indices, stats, 'next_states') next_states = index_states(next_states, state_indices) states = index_states(states, state_indices) goals = index_states(contexts[0], goal_indices) next_sq_dists = tf.squared_difference(next_states * state_scales, goals * goal_scales) sq_dists = tf.squared_difference(states * state_scales, goals * goal_scales) record_tensor(sq_dists, None, stats, 'sq_dists') if weight_vector is not None: next_sq_dists *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype) sq_dists *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype) if weight_index is not None: next_sq_dists *= contexts[weight_index] sq_dists *= contexts[weight_index] if norm == 'L1': next_dist = tf.sqrt(next_sq_dists + epsilon) dist = tf.sqrt(sq_dists + epsilon) next_dist = tf.reduce_sum(next_dist, -1) dist = tf.reduce_sum(dist, -1) elif norm == 'L2': next_dist = tf.reduce_sum(next_sq_dists, -1) next_dist = tf.sqrt(next_dist + epsilon) # tf.gradients fails when tf.sqrt(-0.0) dist = tf.reduce_sum(sq_dists, -1) dist = tf.sqrt(dist + epsilon) # tf.gradients fails when tf.sqrt(-0.0) else: raise NotImplementedError(norm) discounts = next_dist > termination_epsilon if summarize: with tf.name_scope('RewardFn/'): tf.summary.scalar('mean_dist', tf.reduce_mean(dist)) tf.summary.histogram('dist', dist) summarize_stats(stats) diff = dist - next_dist diff *= reward_scales return tf.to_float(diff), tf.to_float(discounts)
def cosine_similarity(states, starting_states, actions, rewards, next_states, contexts, state_scales=1.0, goal_scales=1.0, reward_scales=1.0, normalize_states=True, normalize_goals=True, weight_index=None, weight_vector=None, summarize=False, state_indices=None, goal_indices=None, offset=0.0): """Returns the cosine similarity between next_states - states and contexts. Args: states: A [batch_size, num_state_dims] Tensor representing a batch of states. actions: A [batch_size, num_action_dims] Tensor representing a batch of actions. rewards: A [batch_size] Tensor representing a batch of rewards. next_states: A [batch_size, num_state_dims] Tensor representing a batch of next states. contexts: A list of [batch_size, num_context_dims] Tensor representing a batch of contexts. state_scales: multiplicative scale for (next) states. A scalar or 1D tensor, must be broadcastable to number of state dimensions. goal_scales: multiplicative scale for goals. A scalar or 1D tensor, must be broadcastable to number of goal dimensions. reward_scales: multiplicative scale for rewards. A scalar or 1D tensor, must be broadcastable to number of reward dimensions. weight_index: (integer) The context list index that specifies weight. weight_vector: (a number or a list or Numpy array) The weighting vector, broadcastable to `next_states`. summarize: (boolean) enable summary ops. termination_epsilon: terminate if dist is less than this quantity. state_indices: (a list of integers) list of state indices to select. goal_indices: (a list of integers) list of goal indices to select. vectorize: Return a vectorized form. norm: L1 or L2. epsilon: small offset to ensure non-negative/zero distance. Returns: A new tf.float32 [batch_size] rewards Tensor, and tf.float32 [batch_size] discounts tensor. """ del actions, rewards # Unused stats = {} record_tensor(next_states, state_indices, stats, 'next_states') states = index_states(states, state_indices) next_states = index_states(next_states, state_indices) goals = index_states(contexts[0], goal_indices) if weight_vector is not None: goals *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype) if weight_index is not None: weights = tf.abs(index_states(contexts[0], weight_index)) goals *= weights direction_vec = next_states - states if normalize_states: direction_vec = tf.nn.l2_normalize(direction_vec, -1) goal_vec = goals if normalize_goals: goal_vec = tf.nn.l2_normalize(goal_vec, -1) similarity = tf.reduce_sum(goal_vec * direction_vec, -1) discounts = tf.ones_like(similarity) return offset + tf.to_float(similarity), tf.to_float(discounts)
def projection_distance(states, starting_states, actions, rewards, next_states, contexts, alpha = 0, state_scales=1.0, goal_scales=1.0, reward_scales=1.0, weight_index=None, weight_vector=None, summarize=False, termination_epsilon=1e-4, state_indices=None, goal_indices=None, vectorize=False, relative_context=False, diff=False, norm='L2', epsilon=1e-10, bonus_epsilon=0., #5., offset=0.0): """Returns the negative euclidean distance between next_states and contexts. Args: states: A [batch_size, num_state_dims] Tensor representing a batch of states. actions: A [batch_size, num_action_dims] Tensor representing a batch of actions. rewards: A [batch_size] Tensor representing a batch of rewards. next_states: A [batch_size, num_state_dims] Tensor representing a batch of next states. contexts: A list of [batch_size, num_context_dims] Tensor representing a batch of contexts. state_scales: multiplicative scale for (next) states. A scalar or 1D tensor, must be broadcastable to number of state dimensions. goal_scales: multiplicative scale for goals. A scalar or 1D tensor, must be broadcastable to number of goal dimensions. reward_scales: multiplicative scale for rewards. A scalar or 1D tensor, must be broadcastable to number of reward dimensions. weight_index: (integer) The context list index that specifies weight. weight_vector: (a number or a list or Numpy array) The weighting vector, broadcastable to `next_states`. summarize: (boolean) enable summary ops. termination_epsilon: terminate if dist is less than this quantity. state_indices: (a list of integers) list of state indices to select. goal_indices: (a list of integers) list of goal indices to select. vectorize: Return a vectorized form. norm: L1 or L2. epsilon: small offset to ensure non-negative/zero distance. Returns: A new tf.float32 [batch_size] rewards Tensor, and tf.float32 [batch_size] discounts tensor. """ del actions, rewards # Unused stats = {} record_tensor(next_states, state_indices, stats, 'next_states') states = index_states(states, state_indices) starting_states = index_states(starting_states, state_indices) next_states = index_states(next_states, state_indices) goals = index_states(contexts[0], goal_indices) if relative_context: goals = states + goals sq_dists = tf.squared_difference(next_states * state_scales, goals * goal_scales) dist = tf.reduce_sum(sq_dists, -1) #def normalized_dist(states): # dot_product = tf.matmul(states - starting_states, tf.transpose(goals - starting_states)) # return goals - starting_states - dot_product def projection_dist(states): inner = tf.multiply(states - starting_states, goals - starting_states) upper = tf.reduce_sum(inner, -1) sign = tf.sign(upper) result = tf.math.divide(upper, tf.norm(goals - starting_states, ord=2)) term_1 = tf.norm(states - starting_states, 2) return -1*term_1+result dist_s = projection_dist(states) dist_s = tf.sqrt(tf.square(dist_s) + epsilon) dist_ns = projection_dist(next_states) ret = dist_ns, tf.to_float(dist > termination_epsilon) return ret
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key, allow_missing=False): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string allow_missing: a boolean Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) if key not in features: if allow_missing: return None else: raise ValueError("feature not found %s - features %s = " % (key, features)) tf.logging.info("Import feature %s: %s" % (key, features[key])) x = tf.to_int32(features[key]) x = tf.reshape( x, [outer_batch_size, batch_size // outer_batch_size, -1]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if model_type == "lm": _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if mode == tf.estimator.ModeKeys.EVAL: if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) labels = lowering.export_to_tf_tensor(anon_targets) restore_hook = mtf.MtfRestoreHook(lowering) # metric_names becomes locally scoped if we simply assign # ["padded_neg_log_perplexity"] to it conditioned on if it's None. local_metric_names = metric_names or ["token_accuracy"] def metric_fn(labels, outputs): return get_metric_fns(local_metric_names, labels, outputs) eval_metrics = (metric_fn, [labels, outputs]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, # Unfortunately TPUEstimatorSpec requires us to provide a value for # loss when in EVAL mode. Since we are sampling or decoding from the # model, we don't have a loss to report. loss=tf.constant(0.), evaluation_hooks=[restore_hook], eval_metrics=eval_metrics) if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=_import_feature("targets_segmentation", True), position=_import_feature("targets_position", True), ) elif isinstance(transformer_model, transformer.Bitransformer): position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation", True), decoder_sequence_id=_import_feature("targets_segmentation", True), encoder_position=_import_feature("inputs_position", True), decoder_position=_import_feature("targets_position", True), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=checkpoints_to_keep, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) mtf_features = {} for key, x in features.items(): x = tf.to_int32(features[key]) x = tf.reshape(x, [ outer_batch_size, batch_size // outer_batch_size, sequence_length ]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = mtf_features["inputs"] inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) elif mode == tf.estimator.ModeKeys.EVAL: raise NotImplementedError("We don't expect to use mode == eval.") else: assert mode == tf.estimator.ModeKeys.TRAIN num_microbatches = serialize_num_microbatches( batch_dim, length_dim, mesh_shape, layout_rules) def model_fn(mtf_features): """The kind of function we need for mtf.serialize_training_step. Args: mtf_features: a dictionary Returns: a dictionary """ targets = mtf_features["targets"] if model_type == "lm": _, _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if num_microbatches > 1: loss /= float(num_microbatches) del logits return {"loss": loss} if num_microbatches > 1: var_grads, loss_dict = mtf.serialize_training_step( mtf_features, model_fn, batch_dim, num_microbatches) else: loss_dict = model_fn(mtf_features) var_grads = mtf.gradients( [loss_dict["loss"]], [v.outputs[0] for v in graph.trainable_variables]) loss = loss_dict["loss"] if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule update_ops = optimizer(learning_rate=learning_rate).apply_grads( var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if use_tpu: if tpu_summaries: tf.summary.scalar("loss", tf_loss) host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])