def optimize(self, loss, num_async_replicas=1, use_tpu=False): """Return a training op minimizing loss.""" hparams = self.hparams lr = learning_rate.learning_rate_schedule(hparams) if num_async_replicas > 1: log_info("Dividing learning rate by num_async_replicas: %d", num_async_replicas) lr /= tf.sqrt(float(num_async_replicas)) loss = weight_decay_and_noise(loss, hparams, lr) loss = tf.identity(loss, name="total_loss") log_variable_sizes(verbose=hparams.summarize_vars) opt = ConditionalOptimizer(hparams.optimizer, lr, hparams) opt_summaries = ["loss", "learning_rate", "global_gradient_norm"] if hparams.clip_grad_norm: tf.logging.info("Clipping gradients, norm: %0.5f", hparams.clip_grad_norm) if hparams.grad_noise_scale: tf.logging.info("Adding noise to gradients, noise scale: %0.5f", hparams.grad_noise_scale) tf.summary.scalar("training/learning_rate", lr) return tf.contrib.layers.optimize_loss( name="training", loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=lr, clip_gradients=hparams.clip_grad_norm or None, gradient_noise_scale=hparams.grad_noise_scale or None, optimizer=opt, summaries=opt_summaries, colocate_gradients_with_ops=True)
def optimize(self, loss, num_async_replicas=1): """Return a training op minimizing loss.""" log_info("Base learning rate: %f", self.hparams.learning_rate) lr = learning_rate.learning_rate_schedule(self.hparams) if num_async_replicas > 1: log_info("Dividing learning rate by num_async_replicas: %d", num_async_replicas) lr /= math.sqrt(float(num_async_replicas)) train_op = optimize.optimize( loss, lr, self.hparams, use_tpu=common_layers.is_on_tpu()) return train_op
def optimize(self, loss, num_async_replicas=1): """Return a training op minimizing loss.""" log_info("Base learning rate: %f", self.hparams.learning_rate) lr = learning_rate.learning_rate_schedule(self.hparams) if num_async_replicas > 1: log_info("Dividing learning rate by num_async_replicas: %d", num_async_replicas) lr /= math.sqrt(float(num_async_replicas)) train_op = optimize.optimize( loss, lr, self.hparams, use_tpu=common_layers.is_on_tpu()) return train_op
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False, xla_compile=False): del xla_compile hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: mesh_devices = [""] * mesh_shape.size mesh_impl = simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment) else: if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf_optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf_utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: _remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None): hparams = copy.deepcopy(hparams) use_tpu = params and params.get("use_tpu", False) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def define_ppo_epoch(memory, hparams, action_space, batch_size, distributional_size=1, distributional_subscale=0.04, distributional_threshold=0.0, epoch=-1): """PPO epoch.""" observation, reward, done, action, old_pdf, value_sm = memory # This is to avoid propagating gradients through simulated environment. observation = tf.stop_gradient(observation) action = tf.stop_gradient(action) reward = tf.stop_gradient(reward) if hasattr(hparams, "rewards_preprocessing_fun"): reward = hparams.rewards_preprocessing_fun(reward) done = tf.stop_gradient(done) value_sm = tf.stop_gradient(value_sm) old_pdf = tf.stop_gradient(old_pdf) value = value_sm if distributional_size > 1: value = _distributional_to_value( value_sm, distributional_size, distributional_subscale, distributional_threshold) advantage = calculate_generalized_advantage_estimator( reward, value, done, hparams.gae_gamma, hparams.gae_lambda) if distributional_size > 1: # Create discounted reward values range. half = distributional_size // 2 value_range = tf.to_float(tf.range(-half, half)) + 0.5 # Mid-bucket value. value_range *= distributional_subscale # Acquire new discounted rewards by using the above range as end-values. end_values = tf.expand_dims(value_range, 0) discounted_reward = discounted_rewards( reward, done, hparams.gae_gamma, end_values) # Re-normalize the discounted rewards to integers, in [0, dist_size] range. discounted_reward /= distributional_subscale discounted_reward += half discounted_reward = tf.maximum(discounted_reward, 0.0) discounted_reward = tf.minimum(discounted_reward, distributional_size) # Multiply the rewards by 2 for greater fidelity and round to integers. discounted_reward = tf.stop_gradient(tf.round(2 * discounted_reward)) # The probabilities corresponding to the end values from old predictions. discounted_reward_prob = tf.stop_gradient(value_sm[-1]) discounted_reward_prob = tf.nn.softmax(discounted_reward_prob, axis=-1) else: discounted_reward = tf.stop_gradient(advantage + value[:-1]) discounted_reward_prob = discounted_reward # Unused in this case. advantage_mean, advantage_variance = tf.nn.moments(advantage, axes=[0, 1], keep_dims=True) advantage_normalized = tf.stop_gradient( (advantage - advantage_mean)/(tf.sqrt(advantage_variance) + 1e-8)) add_lists_elementwise = lambda l1, l2: [x + y for x, y in zip(l1, l2)] number_of_batches = ((hparams.epoch_length-1) * hparams.optimization_epochs // hparams.optimization_batch_size) epoch_length = hparams.epoch_length if hparams.effective_num_agents is not None: number_of_batches *= batch_size number_of_batches //= hparams.effective_num_agents epoch_length //= hparams.effective_num_agents assert number_of_batches > 0, "Set the paremeters so that number_of_batches>0" lr = learning_rate.learning_rate_schedule(hparams) shuffled_indices = [tf.random.shuffle(tf.range(epoch_length - 1)) for _ in range(hparams.optimization_epochs)] shuffled_indices = tf.concat(shuffled_indices, axis=0) shuffled_indices = shuffled_indices[:number_of_batches * hparams.optimization_batch_size] indices_of_batches = tf.reshape(shuffled_indices, shape=(-1, hparams.optimization_batch_size)) input_tensors = [observation, action, discounted_reward, discounted_reward_prob, advantage_normalized, old_pdf] ppo_step_rets = tf.scan( lambda a, i: add_lists_elementwise( # pylint: disable=g-long-lambda a, define_ppo_step( [tf.gather(t, indices_of_batches[i, :]) for t in input_tensors], hparams, action_space, lr, epoch=epoch, distributional_size=distributional_size, distributional_subscale=distributional_subscale )), tf.range(number_of_batches), [0., 0., 0.], parallel_iterations=1) ppo_summaries = [tf.reduce_mean(ret) / number_of_batches for ret in ppo_step_rets] ppo_summaries.append(lr) summaries_names = [ "policy_loss", "value_loss", "entropy_loss", "learning_rate" ] summaries = [tf.summary.scalar(summary_name, summary) for summary_name, summary in zip(summaries_names, ppo_summaries)] losses_summary = tf.summary.merge(summaries) for summary_name, summary in zip(summaries_names, ppo_summaries): losses_summary = tf.Print(losses_summary, [summary], summary_name + ": ") return losses_summary
def define_ppo_epoch(memory, hparams, action_space, batch_size): """PPO epoch.""" observation, reward, done, action, old_pdf, value = memory # This is to avoid propagating gradients through simulated environment. observation = tf.stop_gradient(observation) action = tf.stop_gradient(action) reward = tf.stop_gradient(reward) if hasattr(hparams, "rewards_preprocessing_fun"): reward = hparams.rewards_preprocessing_fun(reward) done = tf.stop_gradient(done) value = tf.stop_gradient(value) old_pdf = tf.stop_gradient(old_pdf) advantage = calculate_generalized_advantage_estimator( reward, value, done, hparams.gae_gamma, hparams.gae_lambda) discounted_reward = tf.stop_gradient(advantage + value[:-1]) advantage_mean, advantage_variance = tf.nn.moments(advantage, axes=[0, 1], keep_dims=True) advantage_normalized = tf.stop_gradient( (advantage - advantage_mean) / (tf.sqrt(advantage_variance) + 1e-8)) add_lists_elementwise = lambda l1, l2: [x + y for x, y in zip(l1, l2)] number_of_batches = ((hparams.epoch_length - 1) * hparams.optimization_epochs / hparams.optimization_batch_size) if hparams.effective_num_agents is not None: number_of_batches *= batch_size number_of_batches /= hparams.effective_num_agents dataset = tf.data.Dataset.from_tensor_slices( (observation[:-1], action[:-1], discounted_reward, advantage_normalized, old_pdf[:-1])) dataset = dataset.shuffle(buffer_size=hparams.epoch_length - 1, reshuffle_each_iteration=True) dataset = dataset.repeat(-1) dataset = dataset.batch(hparams.optimization_batch_size, drop_remainder=True) iterator = dataset.make_initializable_iterator() lr = learning_rate.learning_rate_schedule(hparams) with tf.control_dependencies([iterator.initializer]): ppo_step_rets = tf.scan( lambda a, i: add_lists_elementwise( # pylint: disable=g-long-lambda a, define_ppo_step(iterator.get_next(), hparams, action_space, lr) ), tf.range(number_of_batches), [0., 0., 0.], parallel_iterations=1) ppo_summaries = [ tf.reduce_mean(ret) / number_of_batches for ret in ppo_step_rets ] ppo_summaries.append(lr) summaries_names = [ "policy_loss", "value_loss", "entropy_loss", "learning_rate" ] summaries = [ tf.summary.scalar(summary_name, summary) for summary_name, summary in zip(summaries_names, ppo_summaries) ] losses_summary = tf.summary.merge(summaries) for summary_name, summary in zip(summaries_names, ppo_summaries): losses_summary = tf.Print(losses_summary, [summary], summary_name + ": ") return losses_summary
def get_learning_rate(): hparams = transformer.transformer_base() return learning_rate_schedule(hparams)
def build_model(self): # build index table index_table = tf.contrib.lookup.index_table_from_file( vocabulary_file=self.config.vocab_list, num_oov_buckets=0, default_value=0) # get data iterator self.data_iterator = self.data.get_data_iterator(index_table, mode=self.mode) # get inputs with tf.variable_scope("inputs"): # get next batch if there is no feeded data next_batch = self.data_iterator.get_next() self.input_queries = tf.placeholder_with_default( next_batch["input_queries"], [None, self.config.max_length], name="input_queries") self.input_replies = tf.placeholder_with_default( next_batch["input_replies"], [None, self.config.max_length], name="input_replies") self.query_lengths = tf.placeholder_with_default( tf.squeeze(next_batch["query_lengths"]), [None], name="query_lengths") self.reply_lengths = tf.placeholder_with_default( tf.squeeze(next_batch["reply_lengths"]), [None], name="reply_lengths") # get hyperparams self.embed_dropout_keep_prob = tf.placeholder( tf.float64, name="embed_dropout_keep_prob") self.lstm_dropout_keep_prob = tf.placeholder( tf.float32, name="lstm_dropout_keep_prob") self.dense_dropout_keep_prob = tf.placeholder( tf.float32, name="dense_dropout_keep_prob") self.num_negative_samples = tf.placeholder( tf.int32, name="num_negative_samples") with tf.variable_scope("properties"): # length properties cur_batch_length = tf.shape(self.input_queries)[0] # get hparms from tensor2tensor.models.transformer hparams = transformer.transformer_small() hparams.batch_size = self.config.batch_size hparams.learning_rate_decay_steps = 10000 hparams.learning_rate_minimum = 3e-5 # learning rate lr = learning_rate.learning_rate_schedule(hparams) self.learning_rate = lr # embedding layer with tf.variable_scope("embedding"): embeddings = tf.Variable(get_embeddings( self.config.vocab_list, self.config.pretrained_embed_dir, self.config.vocab_size, self.config.embed_dim), trainable=True, name="embeddings") embeddings = tf.nn.dropout( embeddings, keep_prob=self.embed_dropout_keep_prob, noise_shape=[tf.shape(embeddings)[0], 1]) queries_embedded = tf.to_float( tf.nn.embedding_lookup(embeddings, self.input_queries, name="queries_embedded")) replies_embedded = tf.to_float( tf.nn.embedding_lookup(embeddings, self.input_replies, name="replies_embedded")) self.queries_embedded = queries_embedded self.replies_embedded = replies_embedded # transformer layer with tf.variable_scope("transformer"): queries_expanded = tf.expand_dims(queries_embedded, axis=2, name="queries_expanded") replies_expanded = tf.expand_dims(replies_embedded, axis=2, name="replies_expanded") hparams = transformer.transformer_small() hparams.set_hparam("batch_size", self.config.batch_size) hparams.set_hparam("hidden_size", self.config.embed_dim) encoder = transformer.TransformerEncoder(hparams, mode=self.mode) self.queries_encoded = encoder({ "inputs": queries_expanded, "targets": queries_expanded })[0] self.replies_encoded = encoder({ "inputs": replies_expanded, "targets": replies_expanded })[0] self.queries_encoded = tf.squeeze( tf.reduce_sum(self.queries_encoded, axis=1, keep_dims=True)) self.replies_encoded = tf.squeeze( tf.reduce_sum(self.replies_encoded, axis=1, keep_dims=True)) with tf.variable_scope("sampling"): positive_mask = tf.eye(cur_batch_length) negative_mask = make_negative_mask( tf.zeros([cur_batch_length, cur_batch_length]), method=self.config.negative_sampling, num_negative_samples=self.num_negative_samples) negative_queries_indices, negative_replies_indices = tf.split( tf.where(tf.not_equal(negative_mask, 0)), [1, 1], 1) self.distances = tf.matmul(self.queries_encoded, self.replies_encoded, transpose_b=True) self.distances_flattened = tf.reshape(self.distances, [-1]) self.positive_distances = tf.gather( self.distances_flattened, tf.where(tf.reshape(positive_mask, [-1]))) self.negative_distances = tf.gather( self.distances_flattened, tf.where(tf.reshape(negative_mask, [-1]))) self.negative_queries_indices = tf.squeeze( negative_queries_indices) self.negative_replies_indices = tf.squeeze( negative_replies_indices) self.positive_inputs = tf.concat([ self.queries_encoded, self.positive_distances, self.replies_encoded ], 1) self.negative_inputs = tf.reshape( tf.concat([ tf.nn.embedding_lookup(self.queries_encoded, self.negative_queries_indices), self.negative_distances, tf.nn.embedding_lookup(self.replies_encoded, self.negative_replies_indices) ], 1), [ tf.shape(negative_queries_indices)[0], self.config.embed_dim * 2 + 1 ]) with tf.variable_scope("prediction"): self.hidden_outputs = tf.layers.dense(tf.concat( [self.positive_inputs, self.negative_inputs], 0), 256, tf.nn.relu, name="hidden_layer") self.logits = tf.layers.dense(self.hidden_outputs, 2, tf.nn.relu, name="output_layer") labels = tf.concat([ tf.ones([tf.shape(self.positive_inputs)[0]], tf.float64), tf.zeros([tf.shape(self.negative_inputs)[0]], tf.float64) ], 0) self.labels = tf.one_hot(tf.to_int32(labels), 2) self.probs = tf.sigmoid(self.logits) self.predictions = tf.argmax(self.probs, 1) with tf.variable_scope("loss"): self.loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.labels, logits=self.logits)) self.train_step = optimize.optimize(self.loss, lr, hparams, use_tpu=False) with tf.variable_scope("score"): correct_predictions = tf.equal(self.predictions, tf.argmax(self.labels, 1)) self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")
def define_ppo_epoch(memory, hparams, action_space, batch_size, distributional_size=1, distributional_subscale=0.04, distributional_threshold=0.0): """PPO epoch.""" observation, reward, done, action, old_pdf, value_sm = memory # This is to avoid propagating gradients through simulated environment. observation = tf.stop_gradient(observation) action = tf.stop_gradient(action) reward = tf.stop_gradient(reward) if hasattr(hparams, "rewards_preprocessing_fun"): reward = hparams.rewards_preprocessing_fun(reward) done = tf.stop_gradient(done) value_sm = tf.stop_gradient(value_sm) old_pdf = tf.stop_gradient(old_pdf) value = value_sm if distributional_size > 1: value = _distributional_to_value(value_sm, distributional_size, distributional_subscale, distributional_threshold) plain_value = value if distributional_threshold > 1: plain_value = _distributional_to_value(value_sm, distributional_size, distributional_subscale, 0.0) advantage = calculate_generalized_advantage_estimator( reward, value, done, hparams.gae_gamma, hparams.gae_lambda) discounted_reward = tf.stop_gradient(advantage + value[:-1]) if distributional_size > 1: end_values = plain_value[-1] discounted_reward = tf.stop_gradient( discounted_rewards(reward, done, hparams.gae_gamma, end_values)) advantage_mean, advantage_variance = tf.nn.moments(advantage, axes=[0, 1], keep_dims=True) advantage_normalized = tf.stop_gradient( (advantage - advantage_mean) / (tf.sqrt(advantage_variance) + 1e-8)) add_lists_elementwise = lambda l1, l2: [x + y for x, y in zip(l1, l2)] number_of_batches = ((hparams.epoch_length - 1) * hparams.optimization_epochs // hparams.optimization_batch_size) epoch_length = hparams.epoch_length if hparams.effective_num_agents is not None: number_of_batches *= batch_size number_of_batches //= hparams.effective_num_agents epoch_length //= hparams.effective_num_agents assert number_of_batches > 0, "Set the paremeters so that number_of_batches>0" lr = learning_rate.learning_rate_schedule(hparams) shuffled_indices = [ tf.random.shuffle(tf.range(epoch_length - 1)) for _ in range(hparams.optimization_epochs) ] shuffled_indices = tf.concat(shuffled_indices, axis=0) shuffled_indices = shuffled_indices[:number_of_batches * hparams.optimization_batch_size] indices_of_batches = tf.reshape(shuffled_indices, shape=(-1, hparams.optimization_batch_size)) input_tensors = [ observation, action, discounted_reward, advantage_normalized, old_pdf ] ppo_step_rets = tf.scan( lambda a, i: add_lists_elementwise( # pylint: disable=g-long-lambda a, define_ppo_step([ tf.gather(t, indices_of_batches[i, :]) for t in input_tensors ], hparams, action_space, lr, distributional_size=distributional_size, distributional_subscale=distributional_subscale)), tf.range(number_of_batches), [0., 0., 0.], parallel_iterations=1) ppo_summaries = [ tf.reduce_mean(ret) / number_of_batches for ret in ppo_step_rets ] ppo_summaries.append(lr) summaries_names = [ "policy_loss", "value_loss", "entropy_loss", "learning_rate" ] summaries = [ tf.summary.scalar(summary_name, summary) for summary_name, summary in zip(summaries_names, ppo_summaries) ] losses_summary = tf.summary.merge(summaries) for summary_name, summary in zip(summaries_names, ppo_summaries): losses_summary = tf.Print(losses_summary, [summary], summary_name + ": ") return losses_summary
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning("Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls( hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def build_model(self): # build index table index_table = tf.contrib.lookup.index_table_from_file( vocabulary_file=self.config.vocab_list, num_oov_buckets=0, default_value=0) # get data iterator self.data_iterator = self.data.get_data_iterator(index_table, mode=self.mode) # get inputs with tf.variable_scope("inputs"): # get next batch if there is no feeded data next_batch = self.data_iterator.get_next() self.input_queries = tf.placeholder_with_default( next_batch["input_queries"], [None, self.config.max_length], name="input_queries") self.input_replies = tf.placeholder_with_default( next_batch["input_replies"], [None, self.config.max_length], name="input_replies") self.query_lengths = tf.placeholder_with_default( tf.squeeze(next_batch["query_lengths"]), [None], name="query_lengths") self.reply_lengths = tf.placeholder_with_default( tf.squeeze(next_batch["reply_lengths"]), [None], name="reply_lengths") # get hyperparams self.embed_dropout_keep_prob = tf.placeholder( tf.float64, name="embed_dropout_keep_prob") self.lstm_dropout_keep_prob = tf.placeholder( tf.float32, name="lstm_dropout_keep_prob") self.dense_dropout_keep_prob = tf.placeholder( tf.float32, name="dense_dropout_keep_prob") self.num_negative_samples = tf.placeholder( tf.int32, name="num_negative_samples") with tf.variable_scope("properties"): # length properties cur_batch_length = tf.shape(self.input_queries)[0] # get hparms from tensor2tensor.models.transformer hparams = transformer.transformer_small() hparams.batch_size = self.config.batch_size # learning rate lr = learning_rate.learning_rate_schedule(hparams) # embedding layer with tf.variable_scope("embedding"): embeddings = tf.Variable(get_embeddings( self.config.vocab_list, self.config.pretrained_embed_dir, self.config.vocab_size, self.config.embed_dim), trainable=True, name="embeddings") embeddings = tf.nn.dropout( embeddings, keep_prob=self.embed_dropout_keep_prob, noise_shape=[tf.shape(embeddings)[0], 1]) queries_embedded = tf.to_float( tf.nn.embedding_lookup(embeddings, self.input_queries, name="queries_embedded")) replies_embedded = tf.to_float( tf.nn.embedding_lookup(embeddings, self.input_replies, name="replies_embedded")) self.queries_embedded = queries_embedded self.replies_embedded = replies_embedded # transformer layer with tf.variable_scope("transformer"): queries_expanded = tf.expand_dims(queries_embedded, axis=2, name="queries_expanded") replies_expanded = tf.expand_dims(replies_embedded, axis=2, name="replies_expanded") hparams = transformer.transformer_small() hparams.set_hparam("batch_size", self.config.batch_size) hparams.set_hparam("hidden_size", self.config.embed_dim) encoder = transformer.TransformerEncoder(hparams, mode=self.mode) self.queries_encoded = encoder({ "inputs": queries_expanded, "targets": queries_expanded })[0] self.replies_encoded = encoder({ "inputs": replies_expanded, "targets": replies_expanded })[0] self.queries_pooled = tf.nn.max_pool( self.queries_encoded, ksize=[1, self.config.max_length, 1, 1], strides=[1, 1, 1, 1], padding='VALID', name="queries_pooled") self.replies_pooled = tf.nn.max_pool( self.replies_encoded, ksize=[1, self.config.max_length, 1, 1], strides=[1, 1, 1, 1], padding='VALID', name="replies_pooled") self.queries_flattened = tf.reshape(self.queries_pooled, [cur_batch_length, -1]) self.replies_flattened = tf.reshape(self.replies_pooled, [cur_batch_length, -1]) # build dense layer with tf.variable_scope("dense_layer"): M = tf.get_variable( "M", shape=[self.config.embed_dim, self.config.embed_dim], initializer=tf.initializers.truncated_normal()) M = tf.nn.dropout(M, self.dense_dropout_keep_prob) self.queries_transformed = tf.matmul(self.queries_flattened, M) with tf.variable_scope("sampling"): self.distances = tf.matmul(self.queries_transformed, self.replies_flattened, transpose_b=True) positive_mask = tf.reshape(tf.eye(cur_batch_length), [-1]) negative_mask = tf.reshape( make_negative_mask( self.distances, method=self.config.negative_sampling, num_negative_samples=self.num_negative_samples), [-1]) with tf.variable_scope("prediction"): distances_flattened = tf.reshape(self.distances, [-1]) self.positive_logits = tf.gather(distances_flattened, tf.where(positive_mask), 1) self.negative_logits = tf.gather(distances_flattened, tf.where(negative_mask), 1) self.logits = tf.concat( [self.positive_logits, self.negative_logits], axis=0) self.labels = tf.concat([ tf.ones_like(self.positive_logits), tf.zeros_like(self.negative_logits) ], axis=0) self.positive_probs = tf.sigmoid(self.positive_logits) self.probs = tf.sigmoid(self.logits) self.predictions = tf.cast(self.probs > 0.5, dtype=tf.int32) with tf.variable_scope("loss"): self.loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=self.labels, logits=self.logits)) self.train_step = optimize.optimize(self.loss, lr, hparams, use_tpu=False) with tf.variable_scope("score"): correct_predictions = tf.equal(self.predictions, tf.to_int32(self.labels)) self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")