示例#1
0
    def __init__(self, lstm, update_rate):
        self.lstm = lstm
        self.update_rate = update_rate

        self.num_emb = self.lstm.num_emb
        self.batch_size = self.lstm.batch_size
        self.emb_dim = self.lstm.emb_dim
        self.hidden_dim = self.lstm.hidden_dim
        self.sequence_length = self.lstm.sequence_length
        self.start_token = tf.identity(self.lstm.start_token)
        self.learning_rate = self.lstm.learning_rate

        self.g_embeddings = tf.identity(self.lstm.g_embeddings)
        # maps h_tm1 to h_t for generator
        self.g_recurrent_unit = self.create_recurrent_unit()
        # maps h_t to o_t (output token logits)
        self.g_output_unit = self.create_output_unit()

        #######################################################################
        # placeholder definition
        self.x = tf.placeholder(tf.int32,
                                shape=[self.batch_size, self.sequence_length])
        self.given_num = tf.placeholder(tf.int32)
        # sequence of indices of generated data generated by generator, not
        # including start token

        # processed for batch
        with tf.device("/cpu:0"):
            inputs = tf.split(axis=1,
                              num_or_size_splits=self.sequence_length,
                              value=tf.nn.embedding_lookup(
                                  self.g_embeddings, self.x))
            self.processed_x = tf.stack([
                tf.squeeze(input_, [1]) for input_ in inputs
            ])  # seq_length x batch_size x emb_dim

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        ta_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                            size=self.sequence_length)
        ta_x = ta_x.unstack(tf.transpose(self.x, perm=[1, 0]))
        #######################################################################

        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        self.h0 = tf.stack([self.h0, self.h0])

        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                             size=self.sequence_length,
                                             dynamic_size=False,
                                             infer_shape=True)

        def _g_recurrence_1(i, x_t, h_tm1, given_num, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            x_tp1 = ta_emb_x.read(i)
            gen_x = gen_x.write(i, ta_x.read(i))
            return i + 1, x_tp1, h_t, given_num, gen_x

        def _g_recurrence_2(i, x_t, h_tm1, given_num, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(
                tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]),
                tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings,
                                           next_token)  # batch x emb_dim
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, given_num, gen_x

        i, x_t, h_tm1, given_num, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, given_num, _4: i < given_num,
            body=_g_recurrence_1,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_token), self.h0,
                       self.given_num, gen_x))

        _, _, _, _, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            body=_g_recurrence_2,
            loop_vars=(i, x_t, h_tm1, given_num, self.gen_x))

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        # batch_size x seq_length
        self.gen_x = tf.transpose(self.gen_x, perm=[1, 0])
示例#2
0
def hessians(ys,
             xs,
             name="hessians",
             colocate_gradients_with_ops=False,
             gate_gradients=False,
             aggregation_method=None):
    """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`.

  `hessians()` adds ops to the graph to output the Hessian matrix of `ys`
  with respect to `xs`.  It returns a list of `Tensor` of length `len(xs)`
  where each tensor is the Hessian of `sum(ys)`. This function currently
  only supports evaluating the Hessian with respect to (a list of) one-
  dimensional tensors.

  The Hessian is a matrix of second-order partial derivatives of a scalar
  tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details).

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    name: Optional name to use for grouping all the gradient ops together.
      defaults to 'hessians'.
    colocate_gradients_with_ops: See `gradients()` documentation for details.
    gate_gradients: See `gradients()` documentation for details.
    aggregation_method: See `gradients()` documentation for details.

  Returns:
    A list of Hessian matrices of `sum(ys)` for each `x` in `xs`.

  Raises:
    LookupError: if one of the operations between `xs` and `ys` does not
      have a registered gradient function.
  """
    xs = _AsList(xs)
    kwargs = {
        'colocate_gradients_with_ops': colocate_gradients_with_ops,
        'gate_gradients': gate_gradients,
        'aggregation_method': aggregation_method
    }
    # Compute first-order derivatives and iterate for each x in xs.
    hessians = []
    _gradients = gradients(ys, xs, **kwargs)
    for i, _gradient, x in zip(range(len(xs)), _gradients, xs):
        # Ensure that x is a vector.
        check_rank = check_ops.assert_rank(
            x,
            1,
            message='Cannot compute Hessian because element %d of `xs` does '
            'not have rank one.' % i)
        with ops.control_dependencies([check_rank]):
            # Declare an iterator and tensor array loop variables for the gradients.
            n = array_ops.size(x)
            loop_vars = [
                array_ops.constant(0, dtypes.int32),
                tensor_array_ops.TensorArray(x.dtype, n)
            ]
            # Iterate over all elements of the gradient and compute second order
            # derivatives.
            _, hessian = control_flow_ops.while_loop(
                lambda j, _: j < n, lambda j, result:
                (j + 1, result.write(j,
                                     gradients(_gradient[j], x)[0])),
                loop_vars)

            hessians.append(hessian.stack())
    return hessians
示例#3
0
    def __init__(self,
                 sequence_length,
                 num_classes,
                 vocab_size,
                 emb_dim,
                 dis_emb_dim,
                 filter_sizes,
                 num_filters,
                 batch_size,
                 hidden_dim,
                 start_token,
                 goal_out_size,
                 goal_size,
                 step_size,
                 D_model,
                 LSTMlayer_num=1,
                 l2_reg_lambda=0.0,
                 learning_rate=0.001):
        self.sequence_length = sequence_length
        self.num_classes = num_classes
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.dis_emb_dim = dis_emb_dim
        self.filter_sizes = filter_sizes
        self.num_filters = num_filters
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.start_token = tf.constant([start_token] * self.batch_size,
                                       dtype=tf.int32)
        self.LSTMlayer_num = LSTMlayer_num
        self.l2_reg_lambda = l2_reg_lambda
        self.learning_rate = learning_rate
        self.num_filters_total = sum(self.num_filters)
        self.grad_clip = 5.0
        self.goal_out_size = goal_out_size
        self.goal_size = goal_size
        self.step_size = step_size
        self.D_model = D_model
        self.FeatureExtractor_unit = self.D_model.FeatureExtractor_unit

        self.scope = self.D_model.feature_scope
        self.worker_params = []
        self.manager_params = []

        self.epis = 0.65
        self.tem = 0.9
        with tf.variable_scope('place_holder'):
            self.x = tf.placeholder(
                tf.int32,
                shape=[self.batch_size, self.sequence_length
                       ])  # sequence of tokens generated by generator
            self.reward = tf.placeholder(
                tf.float32,
                shape=[self.batch_size, self.sequence_length / self.step_size
                       ])  # sequence of tokens generated by generator
            self.given_num = tf.placeholder(tf.int32)
            self.drop_out = tf.placeholder(tf.float32,
                                           name="dropout_keep_prob")
            self.train = tf.placeholder(tf.int32, None, name="train")

        with tf.variable_scope('Worker'):
            self.g_embeddings = tf.Variable(
                tf.random_normal([self.vocab_size, self.emb_dim], stddev=0.1))
            self.worker_params.append(self.g_embeddings)
            self.g_worker_recurrent_unit = self.create_Worker_recurrent_unit(
                self.worker_params)  # maps h_tm1 to h_t for generator
            self.g_worker_output_unit = self.create_Worker_output_unit(
                self.worker_params)  # maps h_t to o_t (output token logits)
            self.W_workerOut_change = tf.Variable(
                tf.random_normal([self.vocab_size, self.goal_size],
                                 stddev=0.1))

            self.g_change = tf.Variable(
                tf.random_normal([self.goal_out_size, self.goal_size],
                                 stddev=0.1))
            self.worker_params.extend([self.W_workerOut_change, self.g_change])

            self.h0_worker = tf.zeros([self.batch_size, self.hidden_dim])
            self.h0_worker = tf.stack([self.h0_worker, self.h0_worker])

        with tf.variable_scope('Manager'):
            self.g_manager_recurrent_unit = self.create_Manager_recurrent_unit(
                self.manager_params)  # maps h_tm1 to h_t for generator
            self.g_manager_output_unit = self.create_Manager_output_unit(
                self.manager_params)  # maps h_t to o_t (output token logits)
            self.h0_manager = tf.zeros([self.batch_size, self.hidden_dim])
            self.h0_manager = tf.stack([self.h0_manager, self.h0_manager])

            self.goal_init = tf.get_variable(
                "goal_init",
                initializer=tf.truncated_normal(
                    [self.batch_size, self.goal_out_size], stddev=0.1))
            self.manager_params.extend([self.goal_init])

        self.padding_array = tf.constant(
            -1, shape=[self.batch_size, self.sequence_length], dtype=tf.int32)

        with tf.name_scope("roll_out"):
            self.gen_for_reward = self.rollout(self.x, self.given_num)

        # processed for batch
        with tf.device("/cpu:0"):
            self.processed_x = tf.transpose(
                tf.nn.embedding_lookup(self.g_embeddings, self.x),
                perm=[1, 0, 2])  # seq_length x batch_size x emb_dim

        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                             size=self.sequence_length,
                                             dynamic_size=False,
                                             infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                             size=1,
                                             dynamic_size=True,
                                             infer_shape=True,
                                             clear_after_read=False)

        goal = tensor_array_ops.TensorArray(dtype=tf.float32,
                                            size=self.sequence_length,
                                            dynamic_size=False,
                                            infer_shape=True,
                                            clear_after_read=False)

        feature_array = tensor_array_ops.TensorArray(
            dtype=tf.float32,
            size=self.sequence_length + 1,
            dynamic_size=False,
            infer_shape=True,
            clear_after_read=False)
        real_goal_array = tensor_array_ops.TensorArray(
            dtype=tf.float32,
            size=self.sequence_length / self.step_size,
            dynamic_size=False,
            infer_shape=True,
            clear_after_read=False)

        gen_real_goal_array = tensor_array_ops.TensorArray(
            dtype=tf.float32,
            size=self.sequence_length,
            dynamic_size=False,
            infer_shape=True,
            clear_after_read=False)

        gen_o_worker_array = tensor_array_ops.TensorArray(
            dtype=tf.float32,
            size=self.sequence_length / self.step_size,
            dynamic_size=False,
            infer_shape=True,
            clear_after_read=False)

        def _g_recurrence(i, x_t, h_tm1, h_tm1_manager, gen_o, gen_x, goal,
                          last_goal, real_goal, step_size, gen_real_goal_array,
                          gen_o_worker_array):
            ## padding sentence by -1
            cur_sen = tf.cond(
                i > 0, lambda: tf.split(
                    tf.concat([
                        tf.transpose(gen_x.stack(), perm=[1, 0]), self.
                        padding_array
                    ], 1), [self.sequence_length, i], 1)[0],
                lambda: self.padding_array)
            with tf.variable_scope(self.scope):
                feature = self.FeatureExtractor_unit(cur_sen, self.drop_out)
            h_t_Worker = self.g_worker_recurrent_unit(
                x_t, h_tm1)  # hidden_memory_tuple
            o_t_Worker = self.g_worker_output_unit(
                h_t_Worker)  # batch x vocab , logits not prob
            o_t_Worker = tf.reshape(
                o_t_Worker, [self.batch_size, self.vocab_size, self.goal_size])

            h_t_manager = self.g_manager_recurrent_unit(feature, h_tm1_manager)
            sub_goal = self.g_manager_output_unit(h_t_manager)
            sub_goal = tf.nn.l2_normalize(sub_goal, 1)
            goal = goal.write(i, sub_goal)

            real_sub_goal = tf.add(last_goal, sub_goal)

            w_g = tf.matmul(real_goal, self.g_change)  #batch x goal_size
            w_g = tf.nn.l2_normalize(w_g, 1)
            gen_real_goal_array = gen_real_goal_array.write(i, real_goal)

            w_g = tf.expand_dims(w_g, 2)  #batch x goal_size x 1

            gen_o_worker_array = gen_o_worker_array.write(i, o_t_Worker)

            x_logits = tf.matmul(o_t_Worker, w_g)
            x_logits = tf.squeeze(x_logits)

            log_prob = tf.log(
                tf.nn.softmax(
                    tf.cond(
                        i > 1, lambda: tf.cond(self.train > 0, lambda: self.
                                               tem, lambda: 1.5), lambda: 1.5)
                    * x_logits))
            next_token = tf.cast(
                tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]),
                tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings,
                                           next_token)  # batch x emb_dim
            with tf.control_dependencies([cur_sen]):
                gen_x = gen_x.write(i, next_token)  # indices, batch_size
            gen_o = gen_o.write(i,
                                tf.reduce_sum(
                                    tf.multiply(
                                        tf.one_hot(next_token, self.vocab_size,
                                                   1.0, 0.0),
                                        tf.nn.softmax(x_logits)),
                                    1))  # [batch_size] , prob
            return i+1,x_tp1,h_t_Worker,h_t_manager,gen_o,gen_x,goal,\
                   tf.cond(((i+1)%step_size)>0,lambda:real_sub_goal,lambda :tf.constant(0.0,shape=[self.batch_size,self.goal_out_size]))\
                    ,tf.cond(((i+1)%step_size)>0,lambda :real_goal,lambda :real_sub_goal),step_size,gen_real_goal_array,gen_o_worker_array

        _, _, _, _, self.gen_o, self.gen_x, _, _, _, _, self.gen_real_goal_array, self.gen_o_worker_array = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11: i <
            self.sequence_length,
            body=_g_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_token),
                       self.h0_worker, self.h0_manager, gen_o, gen_x, goal,
                       tf.zeros([self.batch_size,
                                 self.goal_out_size]), self.goal_init,
                       step_size, gen_real_goal_array, gen_o_worker_array),
            parallel_iterations=1)

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size

        self.gen_x = tf.transpose(self.gen_x,
                                  perm=[1, 0])  # batch_size x seq_length

        self.gen_real_goal_array = self.gen_real_goal_array.stack(
        )  # seq_length x batch_size x goal

        self.gen_real_goal_array = tf.transpose(
            self.gen_real_goal_array,
            perm=[1, 0, 2])  # batch_size x seq_length x goal

        self.gen_o_worker_array = self.gen_o_worker_array.stack(
        )  # seq_length x batch_size* vocab*goal

        self.gen_o_worker_array = tf.transpose(
            self.gen_o_worker_array,
            perm=[1, 0, 2, 3])  # batch_size x seq_length * vocab*goal

        sub_feature = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                   size=self.sequence_length /
                                                   self.step_size,
                                                   dynamic_size=False,
                                                   infer_shape=True,
                                                   clear_after_read=False)

        all_sub_features = tensor_array_ops.TensorArray(
            dtype=tf.float32,
            size=self.sequence_length,
            dynamic_size=False,
            infer_shape=True,
            clear_after_read=False)
        all_sub_goals = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                     size=self.sequence_length,
                                                     dynamic_size=False,
                                                     infer_shape=True,
                                                     clear_after_read=False)

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                     size=self.sequence_length,
                                                     dynamic_size=False,
                                                     infer_shape=True)
        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        def preTrain(i, x_t, g_predictions, h_tm1, input_x, h_tm1_manager,
                     last_goal, real_goal, feature_array, real_goal_array,
                     sub_feature, all_sub_features, all_sub_goals):
            ## padding sentence by -1
            cur_sen = tf.split(
                tf.concat([
                    tf.split(input_x, [i, self.sequence_length - i], 1)[0],
                    self.padding_array
                ], 1), [self.sequence_length, i], 1)[0]  #padding sentence
            with tf.variable_scope(self.scope):
                feature = self.FeatureExtractor_unit(cur_sen, self.drop_out)
            feature_array = feature_array.write(i, feature)

            real_goal_array = tf.cond(
                i > 0, lambda: real_goal_array,
                lambda: real_goal_array.write(0, self.goal_init))
            h_t_manager = self.g_manager_recurrent_unit(feature, h_tm1_manager)
            sub_goal = self.g_manager_output_unit(h_t_manager)
            sub_goal = tf.nn.l2_normalize(sub_goal, 1)

            h_t_Worker = tf.cond(
                i > 0, lambda: self.g_worker_recurrent_unit(x_t, h_tm1),
                lambda: h_tm1)  # hidden_memory_tuple
            o_t_Worker = self.g_worker_output_unit(
                h_t_Worker)  # batch x vocab , logits not prob
            o_t_Worker = tf.reshape(
                o_t_Worker, [self.batch_size, self.vocab_size, self.goal_size])

            real_sub_goal = tf.cond(i > 0, lambda: tf.add(last_goal, sub_goal),
                                    lambda: real_goal)
            all_sub_goals = tf.cond(
                i > 0, lambda: all_sub_goals.write(i - 1, real_goal),
                lambda: all_sub_goals)

            w_g = tf.matmul(real_goal, self.g_change)  # batch x goal_size
            w_g = tf.nn.l2_normalize(w_g, 1)
            w_g = tf.expand_dims(w_g, 2)  # batch x goal_size x 1

            x_logits = tf.matmul(o_t_Worker, w_g)
            x_logits = tf.squeeze(x_logits)

            g_predictions = tf.cond(
                i > 0,
                lambda: g_predictions.write(i - 1, tf.nn.softmax(x_logits)),
                lambda: g_predictions)

            sub_feature = tf.cond(
                ((((i) % step_size) > 0)), lambda: sub_feature, lambda:
                (tf.cond(
                    i > 0, lambda: sub_feature.write(
                        i / step_size - 1,
                        tf.subtract(feature, feature_array.read(i - step_size))
                    ), lambda: sub_feature)))

            all_sub_features = tf.cond(i > 0,lambda: tf.cond((i % step_size) > 0, lambda :all_sub_features.write(i-1,tf.subtract(feature,feature_array.read(i-i%step_size))),\
                                                                                     lambda :all_sub_features.write(i-1,tf.subtract(feature,feature_array.read(i-step_size)))),
                                            lambda : all_sub_features)

            real_goal_array = tf.cond(
                ((i) % step_size) > 0, lambda: real_goal_array,
                lambda: tf.cond(
                    (i) / step_size < self.sequence_length / step_size, lambda:
                    tf.cond(
                        i > 0, lambda: real_goal_array.write(
                            (i) / step_size, real_sub_goal), lambda:
                        real_goal_array), lambda: real_goal_array))
            x_tp1 = tf.cond(i > 0, lambda: ta_emb_x.read(i - 1), lambda: x_t)

            return i+1, x_tp1, g_predictions, h_t_Worker, input_x, h_t_manager,\
                   tf.cond(((i)%step_size)>0,lambda:real_sub_goal,lambda :tf.constant(0.0,shape=[self.batch_size,self.goal_out_size])) ,\
                    tf.cond(((i) % step_size) > 0, lambda: real_goal, lambda: real_sub_goal),\
                   feature_array,real_goal_array,sub_feature,all_sub_features,all_sub_goals

        _, _, self.g_predictions, _, _, _, _, _, self.feature_array, self.real_goal_array, self.sub_feature, self.all_sub_features, self.all_sub_goals = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12: i
            < self.sequence_length + 1,
            body=preTrain,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_token), g_predictions,
                       self.h0_worker, self.x, self.h0_manager,
                       tf.zeros([self.batch_size, self.goal_out_size]),
                       self.goal_init, feature_array, real_goal_array,
                       sub_feature, all_sub_features, all_sub_goals),
            parallel_iterations=1)

        self.sub_feature = self.sub_feature.stack(
        )  # seq_length x batch_size x num_filter
        self.sub_feature = tf.transpose(self.sub_feature, perm=[1, 0, 2])

        self.real_goal_array = self.real_goal_array.stack()
        self.real_goal_array = tf.transpose(self.real_goal_array,
                                            perm=[1, 0, 2])
        print(self.real_goal_array.shape)
        print(self.sub_feature.shape)
        self.pretrain_goal_loss = -tf.reduce_sum(1 - tf.losses.cosine_distance(
            tf.nn.l2_normalize(self.sub_feature, 2),
            tf.nn.l2_normalize(self.real_goal_array, 2), 2)) / (
                self.sequence_length * self.batch_size / self.step_size)

        with tf.name_scope("Manager_PreTrain_update"):
            pretrain_manager_opt = tf.train.AdamOptimizer(self.learning_rate)

            self.pretrain_manager_grad, _ = tf.clip_by_global_norm(
                tf.gradients(self.pretrain_goal_loss, self.manager_params),
                self.grad_clip)
            self.pretrain_manager_updates = pretrain_manager_opt.apply_gradients(
                list(zip(self.pretrain_manager_grad, self.manager_params)))
        # self.real_goal_array = self.real_goal_array.stack()

        self.g_predictions = tf.transpose(
            self.g_predictions.stack(),
            perm=[1, 0, 2])  # batch_size x seq_length x vocab_size
        self.cross_entropy = tf.reduce_sum(self.g_predictions * tf.log(
            tf.clip_by_value(self.g_predictions, 1e-20, 1.0))) / (
                self.batch_size * self.sequence_length * self.vocab_size)

        self.pretrain_worker_loss = -tf.reduce_sum(
            tf.one_hot(tf.to_int32(tf.reshape(
                self.x, [-1])), self.vocab_size, 1.0, 0.0) * tf.log(
                    tf.clip_by_value(
                        tf.reshape(self.g_predictions, [-1, self.vocab_size]),
                        1e-20, 1.0))) / (self.sequence_length *
                                         self.batch_size)

        with tf.name_scope("Worker_PreTrain_update"):
            # training updates
            pretrain_worker_opt = tf.train.AdamOptimizer(self.learning_rate)

            self.pretrain_worker_grad, _ = tf.clip_by_global_norm(
                tf.gradients(self.pretrain_worker_loss, self.worker_params),
                self.grad_clip)
            self.pretrain_worker_updates = pretrain_worker_opt.apply_gradients(
                list(zip(self.pretrain_worker_grad, self.worker_params)))

        self.goal_loss = -tf.reduce_sum(
            tf.multiply(
                self.reward, 1 - tf.losses.cosine_distance(
                    tf.nn.l2_normalize(self.sub_feature, 2),
                    tf.nn.l2_normalize(self.real_goal_array, 2), 2))) / (
                        self.sequence_length * self.batch_size /
                        self.step_size)

        with tf.name_scope("Manager_update"):
            manager_opt = tf.train.AdamOptimizer(self.learning_rate)

            self.manager_grad, _ = tf.clip_by_global_norm(
                tf.gradients(self.goal_loss, self.manager_params),
                self.grad_clip)
            self.manager_updates = manager_opt.apply_gradients(
                list(zip(self.manager_grad, self.manager_params)))

        self.all_sub_features = self.all_sub_features.stack()
        self.all_sub_features = tf.transpose(self.all_sub_features,
                                             perm=[1, 0, 2])

        self.all_sub_goals = self.all_sub_goals.stack()
        self.all_sub_goals = tf.transpose(self.all_sub_goals, perm=[1, 0, 2])
        # self.all_sub_features = tf.nn.l2_normalize(self.all_sub_features, 2)
        self.Worker_Reward = 1 - tf.losses.cosine_distance(
            tf.nn.l2_normalize(self.all_sub_features, 2),
            tf.nn.l2_normalize(self.all_sub_goals, 2), 2)
        # print self.Worker_Reward.shape
        self.worker_loss = -tf.reduce_sum(
            tf.multiply(
                self.Worker_Reward,
                tf.one_hot(tf.to_int32(tf.reshape(
                    self.x, [-1])), self.vocab_size, 1.0, 0.0) * tf.log(
                        tf.clip_by_value(
                            tf.reshape(self.g_predictions,
                                       [-1, self.vocab_size]), 1e-20,
                            1.0)))) / (self.sequence_length * self.batch_size)
        with tf.name_scope("Worker_update"):
            # training updates
            worker_opt = tf.train.AdamOptimizer(self.learning_rate)
            self.worker_grad, _ = tf.clip_by_global_norm(
                tf.gradients(self.worker_loss, self.worker_params),
                self.grad_clip)
            self.worker_updates = worker_opt.apply_gradients(
                list(zip(self.worker_grad, self.worker_params)))
def _known_len_tf_for_stmt(iter_,
                           extra_test,
                           body,
                           get_state,
                           set_state,
                           init_vars,
                           basic_symbol_names,
                           composite_symbol_names,
                           opts):
  """Overload of for_stmt that iterates over TF entities that admit a length."""
  _disallow_undefs_into_loop(*init_vars)

  n = py_builtins.len_(iter_)
  # TODO(b/117628877): Revisit performance once XLA has the necessary support.
  # Note: using a TensorArray creates an extra copy, but can calculate
  # gradients more efficiently than StridedSlice.
  ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
  iter_ = ta.unstack(iter_)

  def while_body(iterate_index, *loop_vars):
    """Main loop body."""
    iterate = iter_.read(iterate_index)
    new_vars = body(iterate, *loop_vars)

    loop_vars = (iterate_index + 1,)
    if new_vars:
      loop_vars += new_vars

    return loop_vars

  def while_cond(iterate_index, *loop_vars):
    if extra_test is not None:
      return control_flow_ops.cond(iterate_index < n,
                                   lambda: extra_test(*loop_vars),
                                   lambda: False)
    return iterate_index < n

  opts['maximum_iterations'] = n

  results = _tf_while_stmt(
      while_cond,
      while_body,
      get_state,
      set_state,
      (array_ops.zeros_like(n),) + init_vars,
      ('<internal iterate>',) + basic_symbol_names,
      composite_symbol_names,
      opts,
  )

  # Note: the iteration index is not returned by the while loop, however
  # if a symbol with the same name exists outside the loop, it will be captured
  # by the loop variables and ultimately updated correctly.
  if isinstance(results, (tuple, list)):
    assert len(results) >= 1  # Has at least the iterate.
    if len(results) > 1:
      results = results[1:]
  else:
    results = ()

  return results
示例#5
0
def generator(x_real, temperature, vocab_size, batch_size, seq_len,
              gen_emb_dim, mem_slots, head_size, num_heads, hidden_dim,
              start_token):
    start_tokens = tf.constant([start_token] * batch_size, dtype=tf.int32)
    output_size = mem_slots * head_size * num_heads

    g_embeddings = tf.get_variable(
        'g_emb',
        shape=[vocab_size, gen_emb_dim],
        initializer=create_linear_initializer(vocab_size))
    gen_mem = RelationalMemory(mem_slots=mem_slots,
                               head_size=head_size,
                               num_heads=num_heads)
    g_output_unit = create_output_unit(output_size, vocab_size)

    # initial states
    init_states = gen_mem.initial_state(batch_size)

    # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
    gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                         size=seq_len,
                                         dynamic_size=False,
                                         infer_shape=True)
    gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                         size=seq_len,
                                         dynamic_size=False,
                                         infer_shape=True)
    gen_x_onehot_adv = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                    size=seq_len,
                                                    dynamic_size=False,
                                                    infer_shape=True)

    def _gen_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_x_onehot_adv):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)  # hidden_memory_tuple
        o_t = g_output_unit(mem_o_t)  # batch x vocab, logits not prob
        gumbel_t = add_gumbel(o_t)
        next_token = tf.cast(tf.argmax(gumbel_t, axis=1), tf.int32)
        x_onehot_appr = tf.nn.softmax(tf.multiply(
            gumbel_t, temperature))  # one-hot-like, [batch_size x vocab_size]
        # x_tp1 = tf.matmul(x_onehot_appr, g_embeddings)  # approximated embeddings, [batch_size x emb_dim]
        x_tp1 = tf.nn.embedding_lookup(
            g_embeddings, next_token)  # embeddings, [batch_size x emb_dim]
        gen_o = gen_o.write(
            i,
            tf.reduce_sum(
                tf.multiply(tf.one_hot(next_token, vocab_size, 1.0, 0.0),
                            tf.nn.softmax(o_t)), 1))  # [batch_size] , prob
        gen_x = gen_x.write(i, next_token)  # indices, [batch_size]
        gen_x_onehot_adv = gen_x_onehot_adv.write(i, x_onehot_appr)
        return i + 1, x_tp1, h_t, gen_o, gen_x, gen_x_onehot_adv

    _, _, _, gen_o, gen_x, gen_x_onehot_adv = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3, _4, _5: i < seq_len,
        body=_gen_recurrence,
        loop_vars=(tf.constant(0, dtype=tf.int32),
                   tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, gen_o, gen_x, gen_x_onehot_adv))

    gen_x = gen_x.stack()  # seq_len x batch_size
    gen_x = tf.transpose(gen_x, perm=[1, 0])  # batch_size x seq_len

    gen_x_onehot_adv = gen_x_onehot_adv.stack()
    gen_x_onehot_adv = tf.transpose(
        gen_x_onehot_adv, perm=[1, 0, 2])  # batch_size x seq_len x vocab_size

    # ----------- pre-training for generator -----------------
    x_emb = tf.transpose(tf.nn.embedding_lookup(g_embeddings, x_real),
                         perm=[1, 0, 2])  # seq_len x batch_size x emb_dim
    g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                 size=seq_len,
                                                 dynamic_size=False,
                                                 infer_shape=True)

    ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32, size=seq_len)
    ta_emb_x = ta_emb_x.unstack(x_emb)

    def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)
        o_t = g_output_unit(mem_o_t)
        g_predictions = g_predictions.write(
            i, tf.nn.softmax(o_t))  # batch_size x vocab_size
        x_tp1 = ta_emb_x.read(i)
        return i + 1, x_tp1, h_t, g_predictions

    _, _, _, g_predictions = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3: i < seq_len,
        body=_pretrain_recurrence,
        loop_vars=(tf.constant(0, dtype=tf.int32),
                   tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, g_predictions))

    g_predictions = tf.transpose(
        g_predictions.stack(),
        perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

    # pretraining loss
    pretrain_loss = -tf.reduce_sum(
        tf.one_hot(tf.to_int32(tf.reshape(x_real, [-1])), vocab_size, 1.0, 0.0)
        * tf.log(
            tf.clip_by_value(tf.reshape(g_predictions, [-1, vocab_size]),
                             1e-20, 1.0))) / (seq_len * batch_size)

    return gen_x_onehot_adv, gen_x, pretrain_loss
示例#6
0
def scan(fn,
         elems,
         initializer=None,
         parallel_iterations=10,
         back_prop=True,
         swap_memory=False,
         infer_shape=True,
         reverse=False,
         name=None):
    """scan on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `scan` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the tensors
  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
  arguments. The first argument is the accumulated value computed from the
  preceding invocation of fn. If `initializer` is None, `elems` must contain
  at least one element, and its first element is used as the initializer.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
  If reverse=True, it's fn(initializer, values[-1]).shape.

  This method also allows multi-arity `elems` and accumulator.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The second argument of
  `fn` must match the structure of `elems`.

  If no `initializer` is provided, the output structure and dtypes of `fn`
  are assumed to be the same as its input; and in this case, the first
  argument of `fn` must match the structure of `elems`.

  If an `initializer` is provided, then the output of `fn` must have the same
  structure as `initializer`; and the first argument of `fn` must match
  this structure.

  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
   one that works in `python3`, is:
  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.

  Args:
    fn: The callable to be performed.  It accepts two arguments.  The first
      will have the same structure as `initializer` if one is provided,
      otherwise it will have the same structure as `elems`.  The second
      will have the same (possibly nested) structure as `elems`.  Its output
      must have the same structure as `initializer` if one is provided,
      otherwise it must have the same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which
      will be unpacked along their first dimension.  The nested sequence
      of the resulting slices will be the first argument to `fn`.
    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
      initial value for the accumulator, and the expected output type of `fn`.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    reverse: (optional) True scans the tensor last to first (instead of first
      to last).
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
    results of applying `fn` to tensors unpacked from `elems` along the first
    dimension, and the previous accumulator value(s), from first to last (or
    last to first, if `reverse=True`).

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `initializer` do not match.
    ValueError: if the lengths of the output of `fn` and `initializer`
      do not match.

  Examples:
    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    sum = scan(lambda a, x: a + x, elems)
    # sum == [1, 3, 6, 10, 15, 21]
    sum = scan(lambda a, x: a + x, elems, reverse=True)
    # sum == [22, 21, 18, 15, 11, 6]
    ```

    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    initializer = np.array(0)
    sum_one = scan(
        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
    # sum_one == [1, 2, 3, 4, 5, 6]
    ```

    ```python
    elems = np.array([1, 0, 0, 0, 0, 0])
    initializer = (np.array(0), np.array(1))
    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
    ```
  """
    if not callable(fn):
        raise TypeError("fn must be callable.")

    input_is_sequence = nest.is_sequence(elems)
    input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]

    def input_pack(x):
        return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

    if initializer is None:
        output_is_sequence = input_is_sequence
        output_flatten = input_flatten
        output_pack = input_pack
    else:
        output_is_sequence = nest.is_sequence(initializer)
        output_flatten = lambda x: nest.flatten(
            x) if output_is_sequence else [x]

        def output_pack(x):
            return (nest.pack_sequence_as(initializer, x)
                    if output_is_sequence else x[0])

    elems_flat = input_flatten(elems)

    in_graph_mode = not context.executing_eagerly()
    with ops.name_scope(name, "scan", elems_flat):
        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode:
            # Any get_variable calls in fn will cache the first call locally
            # and not issue repeated network I/O requests for each iteration.
            varscope = vs.get_variable_scope()
            varscope_caching_device_was_none = False
            if varscope.caching_device is None:
                # TODO(ebrevdo): Change to using colocate_with here and in other
                # methods.
                varscope.set_caching_device(lambda op: op.device)
                varscope_caching_device_was_none = True

        # Convert elems to tensor array.
        elems_flat = [
            ops.convert_to_tensor(elem, name="elem") for elem in elems_flat
        ]

        # Convert elems to tensor array. n may be known statically.
        n = (tensor_shape.dimension_value(elems_flat[0].shape[0])
             or array_ops.shape(elems_flat[0])[0])

        # TensorArrays are always flat
        elems_ta = [
            tensor_array_ops.TensorArray(dtype=elem.dtype,
                                         size=n,
                                         dynamic_size=False,
                                         infer_shape=True)
            for elem in elems_flat
        ]
        # Unpack elements
        elems_ta = [
            elem_ta.unstack(elem)
            for elem_ta, elem in zip(elems_ta, elems_flat)
        ]

        if initializer is None:
            a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta]
            i = constant_op.constant(1)
        else:
            initializer_flat = output_flatten(initializer)
            a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
            i = constant_op.constant(0)

        # Create a tensor array to store the intermediate values.
        accs_ta = [
            tensor_array_ops.TensorArray(
                dtype=init.dtype,
                size=n,
                element_shape=init.shape if infer_shape else None,
                dynamic_size=False,
                infer_shape=infer_shape) for init in a_flat
        ]

        if initializer is None:
            accs_ta = [
                acc_ta.write(n - 1 if reverse else 0, a)
                for (acc_ta, a) in zip(accs_ta, a_flat)
            ]

        def compute(i, a_flat, tas):
            """The loop body of scan.

      Args:
        i: the loop counter.
        a_flat: the accumulator value(s), flattened.
        tas: the output accumulator TensorArray(s), flattened.

      Returns:
        [i + 1, a_flat, tas]: the updated counter + new accumulator values +
          updated TensorArrays

      Raises:
        TypeError: if initializer and fn() output structure do not match
        ValueType: if initializer and fn() output lengths do not match
      """
            packed_elems = input_pack(
                [elem_ta.read(i) for elem_ta in elems_ta])
            packed_a = output_pack(a_flat)
            a_out = fn(packed_a, packed_elems)
            nest.assert_same_structure(
                elems if initializer is None else initializer, a_out)
            flat_a_out = output_flatten(a_out)
            tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
            if reverse:
                next_i = i - 1
            else:
                next_i = i + 1
            return (next_i, flat_a_out, tas)

        if reverse:
            initial_i = n - 1 - i
            condition = lambda i, _1, _2: i >= 0
        else:
            initial_i = i
            condition = lambda i, _1, _2: i < n
        _, _, r_a = control_flow_ops.while_loop(
            condition,
            compute, (initial_i, a_flat, accs_ta),
            parallel_iterations=parallel_iterations,
            back_prop=back_prop,
            swap_memory=swap_memory,
            maximum_iterations=n)

        results_flat = [r.stack() for r in r_a]

        n_static = tensor_shape.Dimension(
            tensor_shape.dimension_value(
                elems_flat[0].get_shape().with_rank_at_least(1)[0]))
        for elem in elems_flat[1:]:
            n_static.merge_with(
                tensor_shape.Dimension(
                    tensor_shape.dimension_value(
                        elem.get_shape().with_rank_at_least(1)[0])))
        for r in results_flat:
            r.set_shape(
                tensor_shape.TensorShape(n_static).concatenate(
                    r.get_shape()[1:]))

        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode and varscope_caching_device_was_none:
            varscope.set_caching_device(None)

        return output_pack(results_flat)
示例#7
0
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
                    test_util.TensorFlowTestCase):

  # pylint: disable=g-long-lambda,protected-access
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec,
       [dtypes.float32], [[]]),
      ("TensorArray", lambda: tensor_array_ops.TensorArray(
          dtype=dtypes.float32, element_shape=(3,), size=0),
       tensor_array_ops.TensorArraySpec, [dtypes.variant], [[]]),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensorSpec, [dtypes.variant], [None]),
      ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
       ragged_tensor.RaggedTensorSpec, [dtypes.variant], [None]),
      ("Nested_0",
       lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
       tuple, [dtypes.float32, dtypes.int32], [[], [3]]),
      ("Nested_1", lambda: {
          "a": constant_op.constant(37.0),
          "b": constant_op.constant([1, 2, 3])
      }, dict, [dtypes.float32, dtypes.int32], [[], [3]]),
      ("Nested_2", lambda: {
          "a":
              constant_op.constant(37.0),
          "b": (sparse_tensor.SparseTensor(
              indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
      }, dict, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None
                                                                 ]),
  )
  def testFlatStructure(self, value_fn, expected_structure, expected_types,
                        expected_shapes):
    value = value_fn()
    s = structure.type_spec_from_value(value)
    self.assertIsInstance(s, expected_structure)
    flat_types = structure.get_flat_tensor_types(s)
    self.assertEqual(expected_types, flat_types)
    flat_shapes = structure.get_flat_tensor_shapes(s)
    self.assertLen(flat_shapes, len(expected_shapes))
    for expected, actual in zip(expected_shapes, flat_shapes):
      if expected is None:
        self.assertEqual(actual.ndims, None)
      else:
        self.assertEqual(actual.as_list(), expected)

  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0), lambda: [
          constant_op.constant(38.0),
          array_ops.placeholder(dtypes.float32),
          variables.Variable(100.0), 42.0,
          np.array(42.0, dtype=np.float32)
      ], lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
      ("TensorArray", lambda: tensor_array_ops.TensorArray(
          dtype=dtypes.float32, element_shape=(3,), size=0), lambda: [
              tensor_array_ops.TensorArray(
                  dtype=dtypes.float32, element_shape=(3,), size=0),
              tensor_array_ops.TensorArray(
                  dtype=dtypes.float32, element_shape=(3,), size=10)
          ], lambda: [
              tensor_array_ops.TensorArray(
                  dtype=dtypes.int32, element_shape=(3,), size=0),
              tensor_array_ops.TensorArray(
                  dtype=dtypes.float32, element_shape=(), size=0)
          ]),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       lambda: [
           sparse_tensor.SparseTensor(
               indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
           sparse_tensor.SparseTensorValue(
               indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
           array_ops.sparse_placeholder(dtype=dtypes.int32),
           array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None])
       ], lambda: [
           constant_op.constant(37, shape=[4, 5]),
           sparse_tensor.SparseTensor(
               indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
           array_ops.sparse_placeholder(
               dtype=dtypes.int32, shape=[None, None, None]),
           sparse_tensor.SparseTensor(
               indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
       ]),
      ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
       lambda: [
           ragged_factory_ops.constant([[1, 2], [3, 4], []]),
           ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
       ], lambda: [
           ragged_factory_ops.constant(1),
           ragged_factory_ops.constant([1, 2]),
           ragged_factory_ops.constant([[1], [2]]),
           ragged_factory_ops.constant([["a", "b"]]),
       ]),
      ("Nested", lambda: {
          "a": constant_op.constant(37.0),
          "b": constant_op.constant([1, 2, 3])
      }, lambda: [{
          "a": constant_op.constant(15.0),
          "b": constant_op.constant([4, 5, 6])
      }], lambda: [{
          "a": constant_op.constant(15.0),
          "b": constant_op.constant([4, 5, 6, 7])
      }, {
          "a": constant_op.constant(15),
          "b": constant_op.constant([4, 5, 6])
      }, {
          "a":
              constant_op.constant(15),
          "b":
              sparse_tensor.SparseTensor(
                  indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
      }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
  )
  @test_util.run_deprecated_v1
  def testIsCompatibleWithStructure(self, original_value_fn,
                                    compatible_values_fn,
                                    incompatible_values_fn):
    original_value = original_value_fn()
    compatible_values = compatible_values_fn()
    incompatible_values = incompatible_values_fn()
    s = structure.type_spec_from_value(original_value)
    for compatible_value in compatible_values:
      self.assertTrue(
          structure.are_compatible(
              s, structure.type_spec_from_value(compatible_value)))
    for incompatible_value in incompatible_values:
      self.assertFalse(
          structure.are_compatible(
              s, structure.type_spec_from_value(incompatible_value)))

  @parameterized.named_parameters(
      ("Tensor",
       lambda: constant_op.constant(37.0),
       lambda: constant_op.constant(42.0),
       lambda: constant_op.constant([5])),
      ("TensorArray",
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.float32, element_shape=(3,), size=0),
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.float32, element_shape=(3,), size=0),
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.int32, element_shape=(), size=0)),
      ("SparseTensor",
       lambda: sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[3]], values=[-1], dense_shape=[5]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])),
      ("RaggedTensor",
       lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
       lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]),
       lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]],
                                           ragged_rank=1),
       lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
       lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])),
      ("Nested",
       lambda: {
           "a": constant_op.constant(37.0),
           "b": constant_op.constant([1, 2, 3])},
       lambda: {
           "a": constant_op.constant(42.0),
           "b": constant_op.constant([4, 5, 6])},
       lambda: {
           "a": constant_op.constant([1, 2, 3]),
           "b": constant_op.constant(37.0)
       }),
  )  # pyformat: disable
  def testStructureFromValueEquality(self, value1_fn, value2_fn,
                                     *not_equal_value_fns):
    # pylint: disable=g-generic-assert
    s1 = structure.type_spec_from_value(value1_fn())
    s2 = structure.type_spec_from_value(value2_fn())
    self.assertEqual(s1, s1)  # check __eq__ operator.
    self.assertEqual(s1, s2)  # check __eq__ operator.
    self.assertFalse(s1 != s1)  # check __ne__ operator.
    self.assertFalse(s1 != s2)  # check __ne__ operator.
    for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)):
      self.assertEqual(hash(c1), hash(c1))
      self.assertEqual(hash(c1), hash(c2))
    for value_fn in not_equal_value_fns:
      s3 = structure.type_spec_from_value(value_fn())
      self.assertNotEqual(s1, s3)  # check __ne__ operator.
      self.assertNotEqual(s2, s3)  # check __ne__ operator.
      self.assertFalse(s1 == s3)  # check __eq_ operator.
      self.assertFalse(s2 == s3)  # check __eq_ operator.

  @parameterized.named_parameters(
      ("RaggedTensor_RaggedRank",
       ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
       ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)),
      ("RaggedTensor_Shape",
       ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32, 1),
       ragged_tensor.RaggedTensorSpec([5, None], dtypes.int32, 1)),
      ("RaggedTensor_DType",
       ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
       ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)),
  )
  def testRaggedStructureInequality(self, s1, s2):
    # pylint: disable=g-generic-assert
    self.assertNotEqual(s1, s2)  # check __ne__ operator.
    self.assertFalse(s1 == s2)  # check __eq__ operator.

  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])),
      ("TensorArray", lambda: tensor_array_ops.TensorArray(
          dtype=dtypes.float32, element_shape=(3,), size=0),
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.float32, element_shape=(3,), size=0),
       lambda: tensor_array_ops.TensorArray(
           dtype=dtypes.int32, element_shape=(), size=0)),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda:
       sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5])),
      ("Nested", lambda: {
          "a": constant_op.constant(37.0),
          "b": constant_op.constant([1, 2, 3])
      }, lambda: {
          "a": constant_op.constant(42.0),
          "b": constant_op.constant([4, 5, 6])
      }, lambda: {
          "a": constant_op.constant([1, 2, 3]),
          "b": constant_op.constant(37.0)
      }),
  )
  def testHash(self, value1_fn, value2_fn, value3_fn):
    s1 = structure.type_spec_from_value(value1_fn())
    s2 = structure.type_spec_from_value(value2_fn())
    s3 = structure.type_spec_from_value(value3_fn())
    for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2), nest.flatten(s3)):
      self.assertEqual(hash(c1), hash(c1))
      self.assertEqual(hash(c1), hash(c2))
      self.assertNotEqual(hash(c1), hash(c3))
      self.assertNotEqual(hash(c2), hash(c3))

  @parameterized.named_parameters(
      (
          "Tensor",
          lambda: constant_op.constant(37.0),
      ),
      (
          "SparseTensor",
          lambda: sparse_tensor.SparseTensor(
              indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
      ),
      ("TensorArray", lambda: tensor_array_ops.TensorArray(
          dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
      (
          "RaggedTensor",
          lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
      ),
      (
          "Nested_0",
          lambda: {
              "a": constant_op.constant(37.0),
              "b": constant_op.constant([1, 2, 3])
          },
      ),
      (
          "Nested_1",
          lambda: {
              "a":
                  constant_op.constant(37.0),
              "b": (sparse_tensor.SparseTensor(
                  indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
                    sparse_tensor.SparseTensor(
                        indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
          },
      ),
  )
  def testRoundTripConversion(self, value_fn):
    value = value_fn()
    s = structure.type_spec_from_value(value)

    def maybe_stack_ta(v):
      if isinstance(v, tensor_array_ops.TensorArray):
        return v.stack()
      else:
        return v

    before = self.evaluate(maybe_stack_ta(value))
    after = self.evaluate(
        maybe_stack_ta(
            structure.from_tensor_list(s, structure.to_tensor_list(s, value))))

    flat_before = nest.flatten(before)
    flat_after = nest.flatten(after)
    for b, a in zip(flat_before, flat_after):
      if isinstance(b, sparse_tensor.SparseTensorValue):
        self.assertAllEqual(b.indices, a.indices)
        self.assertAllEqual(b.values, a.values)
        self.assertAllEqual(b.dense_shape, a.dense_shape)
      elif isinstance(
          b,
          (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)):
        self.assertAllEqual(b, a)
      else:
        self.assertAllEqual(b, a)

  # pylint: enable=g-long-lambda

  def preserveStaticShape(self):
    rt = ragged_factory_ops.constant([[1, 2], [], [3]])
    rt_s = structure.type_spec_from_value(rt)
    rt_after = structure.from_tensor_list(rt_s,
                                          structure.to_tensor_list(rt_s, rt))
    self.assertEqual(rt_after.row_splits.shape.as_list(),
                     rt.row_splits.shape.as_list())
    self.assertEqual(rt_after.values.shape.as_list(), [None])

    st = sparse_tensor.SparseTensor(
        indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
    st_s = structure.type_spec_from_value(st)
    st_after = structure.from_tensor_list(st_s,
                                          structure.to_tensor_list(st_s, st))
    self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
    self.assertEqual(st_after.values.shape.as_list(), [None])
    self.assertEqual(st_after.dense_shape.shape.as_list(),
                     st.dense_shape.shape.as_list())

  def testPreserveTensorArrayShape(self):
    ta = tensor_array_ops.TensorArray(
        dtype=dtypes.int32, size=1, element_shape=(3,))
    ta_s = structure.type_spec_from_value(ta)
    ta_after = structure.from_tensor_list(ta_s,
                                          structure.to_tensor_list(ta_s, ta))
    self.assertEqual(ta_after.element_shape.as_list(), [3])

  def testPreserveInferredTensorArrayShape(self):
    ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1)
    # Shape is inferred from the write.
    ta = ta.write(0, [1, 2, 3])
    ta_s = structure.type_spec_from_value(ta)
    ta_after = structure.from_tensor_list(ta_s,
                                          structure.to_tensor_list(ta_s, ta))
    self.assertEqual(ta_after.element_shape.as_list(), [3])

  def testIncompatibleStructure(self):
    # Define three mutually incompatible values/structures, and assert that:
    # 1. Using one structure to flatten a value with an incompatible structure
    #    fails.
    # 2. Using one structure to restructure a flattened value with an
    #    incompatible structure fails.
    value_tensor = constant_op.constant(42.0)
    s_tensor = structure.type_spec_from_value(value_tensor)
    flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)

    value_sparse_tensor = sparse_tensor.SparseTensor(
        indices=[[0, 0]], values=[1], dense_shape=[1, 1])
    s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
    flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
                                                  value_sparse_tensor)

    value_nest = {
        "a": constant_op.constant(37.0),
        "b": constant_op.constant([1, 2, 3])
    }
    s_nest = structure.type_spec_from_value(value_nest)
    flat_nest = structure.to_tensor_list(s_nest, value_nest)

    with self.assertRaisesRegexp(
        ValueError, r"SparseTensor.* is not convertible to a tensor with "
        r"dtype.*float32.* and shape \(\)"):
      structure.to_tensor_list(s_tensor, value_sparse_tensor)
    with self.assertRaisesRegexp(
        ValueError, "The two structures don't have the same nested structure."):
      structure.to_tensor_list(s_tensor, value_nest)

    with self.assertRaisesRegexp(
        TypeError, "Neither a SparseTensor nor SparseTensorValue"):
      structure.to_tensor_list(s_sparse_tensor, value_tensor)

    with self.assertRaisesRegexp(
        ValueError, "The two structures don't have the same nested structure."):
      structure.to_tensor_list(s_sparse_tensor, value_nest)

    with self.assertRaisesRegexp(
        ValueError, "The two structures don't have the same nested structure."):
      structure.to_tensor_list(s_nest, value_tensor)

    with self.assertRaisesRegexp(
        ValueError, "The two structures don't have the same nested structure."):
      structure.to_tensor_list(s_nest, value_sparse_tensor)

    with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
      structure.from_tensor_list(s_tensor, flat_sparse_tensor)

    with self.assertRaisesRegexp(ValueError, "Expected 1 tensors but got 2."):
      structure.from_tensor_list(s_tensor, flat_nest)

    with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
      structure.from_tensor_list(s_sparse_tensor, flat_tensor)

    with self.assertRaisesRegexp(ValueError, "Expected 1 tensors but got 2."):
      structure.from_tensor_list(s_sparse_tensor, flat_nest)

    with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 1."):
      structure.from_tensor_list(s_nest, flat_tensor)

    with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 1."):
      structure.from_tensor_list(s_nest, flat_sparse_tensor)

  def testIncompatibleNestedStructure(self):
    # Define three mutually incompatible nested values/structures, and assert
    # that:
    # 1. Using one structure to flatten a value with an incompatible structure
    #    fails.
    # 2. Using one structure to restructure a flattened value with an
    #    incompatible structure fails.

    value_0 = {
        "a": constant_op.constant(37.0),
        "b": constant_op.constant([1, 2, 3])
    }
    s_0 = structure.type_spec_from_value(value_0)
    flat_s_0 = structure.to_tensor_list(s_0, value_0)

    # `value_1` has compatible nested structure with `value_0`, but different
    # classes.
    value_1 = {
        "a":
            constant_op.constant(37.0),
        "b":
            sparse_tensor.SparseTensor(
                indices=[[0, 0]], values=[1], dense_shape=[1, 1])
    }
    s_1 = structure.type_spec_from_value(value_1)
    flat_s_1 = structure.to_tensor_list(s_1, value_1)

    # `value_2` has incompatible nested structure with `value_0` and `value_1`.
    value_2 = {
        "a":
            constant_op.constant(37.0),
        "b": (sparse_tensor.SparseTensor(
            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
              sparse_tensor.SparseTensor(
                  indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
    }
    s_2 = structure.type_spec_from_value(value_2)
    flat_s_2 = structure.to_tensor_list(s_2, value_2)

    with self.assertRaisesRegexp(
        ValueError, r"SparseTensor.* is not convertible to a tensor with "
        r"dtype.*int32.* and shape \(3,\)"):
      structure.to_tensor_list(s_0, value_1)

    with self.assertRaisesRegexp(
        ValueError, "The two structures don't have the same nested structure."):
      structure.to_tensor_list(s_0, value_2)

    with self.assertRaisesRegexp(
        TypeError, "Neither a SparseTensor nor SparseTensorValue"):
      structure.to_tensor_list(s_1, value_0)

    with self.assertRaisesRegexp(
        ValueError, "The two structures don't have the same nested structure."):
      structure.to_tensor_list(s_1, value_2)

    # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
    # needs to account for "a" coming before or after "b". It might be worth
    # adding a deterministic repr for these error messages (among other
    # improvements).
    with self.assertRaisesRegexp(
        ValueError, "The two structures don't have the same nested structure."):
      structure.to_tensor_list(s_2, value_0)

    with self.assertRaisesRegexp(
        ValueError, "The two structures don't have the same nested structure."):
      structure.to_tensor_list(s_2, value_1)

    with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
      structure.from_tensor_list(s_0, flat_s_1)

    with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 3."):
      structure.from_tensor_list(s_0, flat_s_2)

    with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
      structure.from_tensor_list(s_1, flat_s_0)

    with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 3."):
      structure.from_tensor_list(s_1, flat_s_2)

    with self.assertRaisesRegexp(ValueError, "Expected 3 tensors but got 2."):
      structure.from_tensor_list(s_2, flat_s_0)

    with self.assertRaisesRegexp(ValueError, "Expected 3 tensors but got 2."):
      structure.from_tensor_list(s_2, flat_s_1)

  @parameterized.named_parameters(
      ("Tensor", dtypes.float32, tensor_shape.TensorShape(
          []), ops.Tensor, tensor_spec.TensorSpec([], dtypes.float32)),
      ("SparseTensor", dtypes.int32, tensor_shape.TensorShape(
          [2, 2]), sparse_tensor.SparseTensor,
       sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)),
      ("TensorArray_0", dtypes.int32,
       tensor_shape.TensorShape([None, True, 2, 2
                                ]), tensor_array_ops.TensorArray,
       tensor_array_ops.TensorArraySpec(
           [2, 2], dtypes.int32, dynamic_size=None, infer_shape=True)),
      ("TensorArray_1", dtypes.int32,
       tensor_shape.TensorShape([True, None, 2, 2
                                ]), tensor_array_ops.TensorArray,
       tensor_array_ops.TensorArraySpec(
           [2, 2], dtypes.int32, dynamic_size=True, infer_shape=None)),
      ("TensorArray_2", dtypes.int32,
       tensor_shape.TensorShape([True, False, 2, 2
                                ]), tensor_array_ops.TensorArray,
       tensor_array_ops.TensorArraySpec(
           [2, 2], dtypes.int32, dynamic_size=True, infer_shape=False)),
      ("RaggedTensor", dtypes.int32, tensor_shape.TensorShape([2, None]),
       ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1),
       ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)),
      ("Nested", {
          "a": dtypes.float32,
          "b": (dtypes.int32, dtypes.string)
      }, {
          "a": tensor_shape.TensorShape([]),
          "b": (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([]))
      }, {
          "a": ops.Tensor,
          "b": (sparse_tensor.SparseTensor, ops.Tensor)
      }, {
          "a":
              tensor_spec.TensorSpec([], dtypes.float32),
          "b": (sparse_tensor.SparseTensorSpec(
              [2, 2], dtypes.int32), tensor_spec.TensorSpec([], dtypes.string))
      }),
  )
  def testConvertLegacyStructure(self, output_types, output_shapes,
                                 output_classes, expected_structure):
    actual_structure = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)
    self.assertEqual(actual_structure, expected_structure)

  def testNestedNestedStructure(self):
    s = (tensor_spec.TensorSpec([], dtypes.int64),
         (tensor_spec.TensorSpec([], dtypes.float32),
          tensor_spec.TensorSpec([], dtypes.string)))

    int64_t = constant_op.constant(37, dtype=dtypes.int64)
    float32_t = constant_op.constant(42.0)
    string_t = constant_op.constant("Foo")

    nested_tensors = (int64_t, (float32_t, string_t))

    tensor_list = structure.to_tensor_list(s, nested_tensors)
    for expected, actual in zip([int64_t, float32_t, string_t], tensor_list):
      self.assertIs(expected, actual)

    (actual_int64_t,
     (actual_float32_t,
      actual_string_t)) = structure.from_tensor_list(s, tensor_list)
    self.assertIs(int64_t, actual_int64_t)
    self.assertIs(float32_t, actual_float32_t)
    self.assertIs(string_t, actual_string_t)

    (actual_int64_t, (actual_float32_t, actual_string_t)) = (
        structure.from_compatible_tensor_list(s, tensor_list))
    self.assertIs(int64_t, actual_int64_t)
    self.assertIs(float32_t, actual_float32_t)
    self.assertIs(string_t, actual_string_t)

  @parameterized.named_parameters(
      ("Tensor", tensor_spec.TensorSpec([], dtypes.float32), 32,
       tensor_spec.TensorSpec([32], dtypes.float32)),
      ("TensorUnknown", tensor_spec.TensorSpec([], dtypes.float32), None,
       tensor_spec.TensorSpec([None], dtypes.float32)),
      ("SparseTensor", sparse_tensor.SparseTensorSpec([None], dtypes.float32),
       32, sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)),
      ("SparseTensorUnknown",
       sparse_tensor.SparseTensorSpec([4], dtypes.float32), None,
       sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32)),
      ("RaggedTensor",
       ragged_tensor.RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
       ragged_tensor.RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
      ("RaggedTensorUnknown",
       ragged_tensor.RaggedTensorSpec([4, None], dtypes.float32, 1), None,
       ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)),
      ("Nested", {
          "a":
              tensor_spec.TensorSpec([], dtypes.float32),
          "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                tensor_spec.TensorSpec([], dtypes.string))
      }, 128, {
          "a":
              tensor_spec.TensorSpec([128], dtypes.float32),
          "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
                tensor_spec.TensorSpec([128], dtypes.string))
      }),
  )
  def testBatch(self, element_structure, batch_size,
                expected_batched_structure):
    batched_structure = nest.map_structure(
        lambda component_spec: component_spec._batch(batch_size),
        element_structure)
    self.assertEqual(batched_structure, expected_batched_structure)

  @parameterized.named_parameters(
      ("Tensor", tensor_spec.TensorSpec(
          [32], dtypes.float32), tensor_spec.TensorSpec([], dtypes.float32)),
      ("TensorUnknown", tensor_spec.TensorSpec(
          [None], dtypes.float32), tensor_spec.TensorSpec([], dtypes.float32)),
      ("SparseTensor", sparse_tensor.SparseTensorSpec([32, None],
                                                      dtypes.float32),
       sparse_tensor.SparseTensorSpec([None], dtypes.float32)),
      ("SparseTensorUnknown",
       sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32),
       sparse_tensor.SparseTensorSpec([4], dtypes.float32)),
      ("RaggedTensor",
       ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2),
       ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
      ("RaggedTensorUnknown",
       ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2),
       ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
      ("Nested", {
          "a":
              tensor_spec.TensorSpec([128], dtypes.float32),
          "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
                tensor_spec.TensorSpec([None], dtypes.string))
      }, {
          "a":
              tensor_spec.TensorSpec([], dtypes.float32),
          "b": (sparse_tensor.SparseTensorSpec(
              [2, 2], dtypes.int32), tensor_spec.TensorSpec([], dtypes.string))
      }),
  )
  def testUnbatch(self, element_structure, expected_unbatched_structure):
    unbatched_structure = nest.map_structure(
        lambda component_spec: component_spec._unbatch(), element_structure)
    self.assertEqual(unbatched_structure, expected_unbatched_structure)

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
       lambda: constant_op.constant([1.0, 2.0])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[0]], values=[13], dense_shape=[2])),
      ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
       lambda: ragged_factory_ops.constant([[1]])),
      ("Nest", lambda:
       (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
        sparse_tensor.SparseTensor(
            indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
       lambda: (constant_op.constant([1.0, 2.0]),
                sparse_tensor.SparseTensor(
                    indices=[[0]], values=[13], dense_shape=[2]))),
  )
  def testToBatchedTensorList(self, value_fn, element_0_fn):
    batched_value = value_fn()
    s = structure.type_spec_from_value(batched_value)
    batched_tensor_list = structure.to_batched_tensor_list(s, batched_value)

    # The batch dimension is 2 for all of the test cases.
    # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
    # tensors in which we store sparse tensors.
    for t in batched_tensor_list:
      if t.dtype != dtypes.variant:
        self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

    # Test that the 0th element from the unbatched tensor is equal to the
    # expected value.
    expected_element_0 = self.evaluate(element_0_fn())
    unbatched_s = nest.map_structure(
        lambda component_spec: component_spec._unbatch(), s)
    actual_element_0 = structure.from_tensor_list(
        unbatched_s, [t[0] for t in batched_tensor_list])

    for expected, actual in zip(
        nest.flatten(expected_element_0), nest.flatten(actual_element_0)):
      self.assertValuesEqual(expected, actual)

  # pylint: enable=g-long-lambda

  def testDatasetSpecConstructor(self):
    rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
    st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
    t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
    element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
    ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
    self.assertEqual(ds_struct._element_spec, element_spec)
    # Note: shape was automatically converted from a list to a TensorShape.
    self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))

  def testCustomMapping(self):
    elem = CustomMap(foo=constant_op.constant(37.))
    spec = structure.type_spec_from_value(elem)
    self.assertIsInstance(spec, CustomMap)
    self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32))

  def testObjectProxy(self):
    nt_type = collections.namedtuple("A", ["x", "y"])
    proxied = wrapt.ObjectProxy(nt_type(1, 2))
    proxied_spec = structure.type_spec_from_value(proxied)
    self.assertEqual(structure.type_spec_from_value(nt_type(1, 2)),
                     proxied_spec)
示例#8
0
def raw_rnn(cell, loop_fn, parallel_iterations=None, swap_memory=False, scope=None):
    """
    raw_rnn adapted from the original tensorflow implementation
    (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
    to emit arbitrarily nested states for each time step (concatenated along the time axis)
    in addition to the outputs at each timestep and the final state

    returns (
        states for all timesteps,
        outputs for all timesteps,
        final cell state,
    )
    """
    assert_like_rnncell("Raw rnn cell",cell)

    if not callable(loop_fn):
        raise TypeError("loop_fn must be a callable")

    parallel_iterations = parallel_iterations or 32

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "rnn") as varscope:
        if is_in_graph_mode.IS_IN_GRAPH_MODE():
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        time = constant_op.constant(0, dtype=dtypes.int32)
        (elements_finished, next_input, initial_state, emit_structure,
         init_loop_state) = loop_fn(time, None, None, None)
        flat_input = nest.flatten(next_input)

        # Need a surrogate loop state for the while_loop if none is available.
        loop_state = (init_loop_state if init_loop_state is not None
                      else constant_op.constant(0, dtype=dtypes.int32))

        input_shape = [input_.get_shape() for input_ in flat_input]
        static_batch_size = input_shape[0][0]

        for input_shape_i in input_shape:
            # Static verification that batch sizes all match
            static_batch_size.merge_with(input_shape_i[0])

        batch_size = static_batch_size.value
        const_batch_size = batch_size
        if batch_size is None:
            batch_size = array_ops.shape(flat_input[0])[0]

        nest.assert_same_structure(initial_state, cell.state_size)
        state = initial_state
        flat_state = nest.flatten(state)
        flat_state = [ops.convert_to_tensor(s) for s in flat_state]
        state = nest.pack_sequence_as(structure=state,
                                      flat_sequence=flat_state)

        if emit_structure is not None:
            flat_emit_structure = nest.flatten(emit_structure)
            flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                              array_ops.shape(emit) for emit in flat_emit_structure]
            flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
        else:
            emit_structure = cell.output_size
            flat_emit_size = nest.flatten(emit_structure)
            flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

        flat_state_size = [s.shape if s.shape.is_fully_defined() else
                           array_ops.shape(s) for s in flat_state]
        flat_state_dtypes = [s.dtype for s in flat_state]

        flat_emit_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([const_batch_size])
                               .concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_output_%d" % i
            )
            for i, (dtype_i, size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
        ]
        emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta)
        flat_zero_emit = [
            array_ops.zeros(_concat(batch_size, size_i), dtype_i)
            for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)]

        zero_emit = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_zero_emit)

        flat_state_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([const_batch_size])
                               .concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_state_%d" % i
            )
            for i, (dtype_i, size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
        ]
        state_ta = nest.pack_sequence_as(structure=state, flat_sequence=flat_state_ta)

        def condition(unused_time, elements_finished, *_):
            return math_ops.logical_not(math_ops.reduce_all(elements_finished))

        def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state):
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""

                def copy_fn(cur_i, cand_i):
                    # TensorArray and scalar get passed through.
                    if isinstance(cur_i, tensor_array_ops.TensorArray):
                        return cand_i
                    if cur_i.shape.ndims == 0:
                        return cand_i
                    # Otherwise propagate the old or the new value.
                    with ops.colocate_with(cand_i):
                        return array_ops.where(elements_finished, cur_i, cand_i)

                return nest.map_structure(copy_fn, current, candidate)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
            state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state)

            elements_finished = math_ops.logical_or(elements_finished, next_finished)

            return (next_time, elements_finished, next_input, state_ta,
                    emit_ta, next_state, loop_state)

        returned = control_flow_ops.while_loop(
            condition, body, loop_vars=[
                time, elements_finished, next_input, state_ta,
                emit_ta, state, loop_state],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory
        )

        (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]

        flat_states = nest.flatten(state_ta)
        flat_states = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states]
        states = nest.pack_sequence_as(structure=state_ta, flat_sequence=flat_states)

        flat_outputs = nest.flatten(emit_ta)
        flat_outputs = [array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs]
        outputs = nest.pack_sequence_as(structure=emit_ta, flat_sequence=flat_outputs)

        return (states, outputs, final_state)
示例#9
0
 def _create_ta(s, d):
     return tensor_array_ops.TensorArray(dtype=d,
                                         size=0,
                                         dynamic_size=True,
                                         element_shape=_shape(
                                             decoder.batch_size, s))
示例#10
0
def _unstack_ta(inp):
    return tensor_array_ops.TensorArray(
        dtype=inp.dtype,
        size=array_ops.shape(inp)[0],
        element_shape=inp.get_shape()[1:]).unstack(inp)
 def _create_ta(name, dtype):
     return tensor_array_ops.TensorArray(dtype=dtype,
                                         size=time_steps,
                                         tensor_array_name=base_name + name)
示例#12
0
def map_fn(fn,
           elems,
           dtype=None,
           parallel_iterations=None,
           back_prop=True,
           swap_memory=False,
           infer_shape=True,
           name=None,
           fn_output_signature=None):
  """Transforms `elems` by applying `fn` to each element unstacked on axis 0.

  `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements;
  calls `fn` to transform each element; and then stacks the transformed
  values back together.

  #### Mapping functions with single-Tensor inputs and outputs

  If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`,
  then `map_fn(fn, elems)` is equivalent to
  `tf.stack([fn(elem) for elem in tf.unstack(elems)])`.  E.g.:

  >>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[3, 4, 5],
           [5, 6, 7],
           [2, 3, 4]], dtype=int32)>

  `map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`.

  #### Mapping functions with multi-arity inputs and outputs

  `map_fn` also supports functions with multi-arity inputs and outputs:

  * If `elems` is a tuple (or nested structure) of tensors, then those tensors
    must all have the same outer-dimension size (`num_elems`); and `fn` is
    used to transform each tuple (or structure) of corresponding slices from
    `elems`.  E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to
    transform each tuple of slices `(t1[i], t2[i], t3[i])`
    (where `0 <= i < num_elems`).

  * If `fn` returns a tuple (or nested structure) of tensors, then the
    result is formed by stacking corresponding elements from those structures.

  #### Specifying `fn`'s output signature

  If `fn`'s input and output signatures are different, then the output
  signature must be specified using `fn_output_signature`.  (The input and
  output signatures are differ if their structures, dtypes, or tensor types do
  not match).  E.g.:

  >>> tf.map_fn(fn=tf.strings.length,  # input & output have different dtypes
  ...           elems=tf.constant(["hello", "moon"]),
  ...           fn_output_signature=tf.int32)
  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
  >>> tf.map_fn(fn=tf.strings.join,  # input & output have different structures
  ...           elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
  ...           fn_output_signature=tf.string)
  <tf.Tensor: shape=(2,), dtype=string,
   numpy=array([b'TheDog', b'ACat'], dtype=object)>

  `fn_output_signature` can be specified using any of the following:

    * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
    * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
    * A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`)
    * A (possibly nested) tuple, list, or dict containing the above types.

  #### RaggedTensors

  `map_fn` supports `tf.RaggedTensor` inputs and outputs.  In particular:

    * If `elems` is a `RaggedTensor`, then `fn` will be called with each
      row of that ragged tensor.

      * If `elems` has only one ragged dimension, then the values passed to
        `fn` will be `tf.Tensor`s.
      * If `elems` has multiple ragged dimensions, then the values passed to
        `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension.

    * If the result of `map_fn` should be a `RaggedTensor`, then use a
      `tf.RaggedTensorSpec` to specify `fn_output_signature`.

      * If `fn` returns `tf.Tensor`s with varying sizes, then use a
        `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a
        single ragged tensor (which will have ragged_rank=1).
      * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec`
        with the same `ragged_rank`.

  >>> # Example: RaggedTensor input
  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>

  >>> # Example: RaggedTensor output
  >>> elems = tf.constant([3, 5, 0, 2])
  >>> tf.map_fn(tf.range, elems,
  ...           fn_output_signature=tf.RaggedTensorSpec(shape=[None],
  ...                                                   dtype=tf.int32))
  <tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>

  Note: `map_fn` should only be used if you need to map a function over the
  *rows* of a `RaggedTensor`.  If you wish to map a function over the
  individual values, then you should use:

    * `tf.ragged.map_flat_values(fn, rt)`
      (if fn is expressible as TensorFlow ops)
    * `rt.with_flat_values(map_fn(fn, rt.flat_values))`
      (otherwise)

  E.g.:

  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> tf.ragged.map_flat_values(lambda x: x + 2, rt)
  <tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>

  #### SparseTensors

  `map_fn` supports `tf.SparseTensor` inputs and outputs.  In particular:

    * If `elems` is a `SparseTensor`, then `fn` will be called with each row
      of that sparse tensor. In particular, the value passed to `fn` will be a
      `tf.SparseTensor` with one fewer dimension than `elems`.

    * If the result of `map_fn` should be a `SparseTensor`, then use a
      `tf.SparseTensorSpec` to specify `fn_output_signature`.  The individual
      `SparseTensor`s returned by `fn` will be stacked into a single
      `SparseTensor` with one more dimension.

  >>> # Example: SparseTensor input
  >>> st = tf.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
  >>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>

  >>> # Example: SparseTensor output
  >>> tf.sparse.to_dense(
  ...     tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
  ...               fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
  <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
    array([[[1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 0.]],
           [[1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 1.]]], dtype=float32)>

  Note: `map_fn` should only be used if you need to map a function over the
  *rows* of a `SparseTensor`.  If you wish to map a function over the nonzero
  values, then you should use:

    * `tf.SparseTensor(st.indices, fn(st.values), st.dense_shape)`
      (if the function is expressible as TensorFlow ops)
    * `tf.SparseTensor(st.indices, tf.map_fn(fn, st.values), st.dense_shape)`
      (otherwise).

  #### `map_fn` vs. vectorized operations

  `map_fn` will apply the operations used by `fn` to each element of `elems`,
  resulting in `O(elems.shape[0])` total operations.  This is somewhat
  mitigated by the fact that `map_fn` can process elements in parallel.
  However, a transform expressed using `map_fn` is still typically less
  efficient than an equivalent transform expressed using vectorized operations.

  `map_fn` should typically only be used if one of the following is true:

    * It is difficult or expensive to express the desired transform with
      vectorized operations.
    * `fn` creates large intermediate values, so an equivalent vectorized
      transform would take too much memory.
    * Processing elements in parallel is more efficient than an equivalent
      vectorized transform.
    * Efficiency of the transform is not critical, and using `map_fn` is
      more readable.

  E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)`
  across `elems` could be rewritten more efficiently using vectorized ops:

  >>> elems = tf.constant([3, 5, 2])
  >>> tf.range(3) + tf.expand_dims(elems, 1)
  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[3, 4, 5],
           [5, 6, 7],
           [2, 3, 4]], dtype=int32)>

  In some cases, `tf.vectorized_map` can be used to automatically convert a
  function to a vectorized eqivalent.

  #### Eager execution

  When executing eagerly, `map_fn` does not execute in parallel even if
  `parallel_iterations` is set to a value > 1. You can still get the
  performance benefits of running a function in parallel by using the
  `tf.function` decorator:

  >>> fn=lambda t: tf.range(t, t + 3)
  >>> @tf.function
  ... def func(elems):
  ...   return tf.map_fn(fn, elems, parallel_iterations=3)
  >>> func(tf.constant([3, 5, 2]))
  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[3, 4, 5],
           [5, 6, 7],
           [2, 3, 4]], dtype=int32)>


  Note that if you use the `tf.function` decorator, any non-TensorFlow Python
  code that you may have written in your function won't get executed. See
  `tf.function` for more  details. The recommendation would be to debug without
  `tf.function` but switch to it to get performance benefits of running `map_fn`
  in parallel.

  Args:
    fn: The callable to be performed.  It accepts one argument, which will have
      the same (possibly nested) structure as `elems`.  Its output must have the
      same structure as `fn_output_signature` if one is provided; otherwise it
      must have the same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which will
      be unstacked along their first dimension.  `fn` will be applied to the
      nested sequence of the resulting slices.  `elems` may include ragged and
      sparse tensors.
    dtype: Deprecated: Equivalent to `fn_output_signature`.
    parallel_iterations: (optional) The number of iterations allowed to run in
      parallel. When graph building, the default value is 10. While executing
      eagerly, the default value is set to 1.
    back_prop: (optional) False disables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.
    fn_output_signature: The output signature of `fn`. Must be specified if
      `fn`'s input and output signatures are different (i.e., if their
      structures, dtypes, or tensor types do not match).
      `fn_output_signature` can be specified using any of the following:

      * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
      * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
      * A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`)
      * A (possibly nested) tuple, list, or dict containing the above types.

  Returns:
    A tensor or (possibly nested) sequence of tensors.  Each tensor stacks the
    results of applying `fn` to tensors unstacked from `elems` along the first
    dimension, from first to last.  The result may include ragged and sparse
    tensors.

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `fn_output_signature` do not match.
    ValueError: if the lengths of the output of `fn` and `fn_output_signature`
      do not match.

  Examples:

    >>> elems = np.array([1, 2, 3, 4, 5, 6])
    >>> tf.map_fn(lambda x: x * x, elems)
    <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1,  4,  9, 16, 25, 36])>

    >>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    >>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64)
    <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1,  2, -3])>

    >>> elems = np.array([1, 2, 3])
    >>> tf.map_fn(lambda x: (x, -x), elems,
    ...          fn_output_signature=(tf.int64, tf.int64))
    (<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>,
     <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>)
  """
  # This function uses a `while_loop` to call `fn` on each value of the input
  # tensor(s) (unstacked on dimension 0).  The following sequence of variables
  # are used to transform the input tensor(s) (`elems`) into the output
  # tensor(s) (`result`):
  #
  #   - Preparing and unstacking input values for the while_loop:
  #     - elems: The input tensor(s) to map_fn. May include composite tensors.
  #     - elems_flat: Flattened list of tensors from elems (using nest.flatten)
  #                   May include composite tensors.
  #     - elems_batchable: Concatenation of "batchable tensor lists" for each
  #                        tensor in elems_flat.  This "boxes" composite tensors
  #                        into sliceable tf.Tensor objects.  For more info see:
  #                        TensorSpec._to_batched_tensor_list
  #     - elems_batchable_ta: List of TensorArrays used to unstack each Tensor
  #                           in elems_batchable into elems_value_batchable.
  #
  #   - Calling `fn` on each unstacked value in the body of the while_loop:
  #     - elems_value_batchable: Single unstacked value from elems_batchable.
  #     - elems_value_flat: Single unstacked value from elems_flat,
  #                         constructed from elems_value_batchable (using
  #                         TensorSpec._from_tensor_list).
  #     - elems_value: Single unstacked value from elems (the input to fn).
  #     - result_value: Result of calling `fn(elems_value)`.  May contain
  #                     composite tensors.
  #     - result_value_flat: Flattened list of tensors from result_value.
  #                          May contain composite tensors.
  #     - result_value_batchable: Concatenation of batchable tensor lists for
  #                               each tensor in result_value_flat
  #                               (using TensorSpec._to_tensor_list).
  #
  #   - Collecting and stacking output values from the while_loop:
  #     - result_batchable_ta: List of TensorArrays used to stack each tensor
  #                            ta result_value_batchable into result_batchable.
  #     - result_batchable: Stacked tensors from result_batchable_ta.
  #     - result_flat: Flat list of tensors for the result, constructed from
  #                    results bactchable (using TensorSpec._from_tensor_list).
  #     - result: Structured result value packed from results flat
  #               (using nest.pack_sequence_as).

  if fn_output_signature is None:
    fn_output_signature = dtype

  if not callable(fn):
    raise TypeError("fn must be callable.")

  in_graph_mode = not context.executing_eagerly()
  # Set the default number of parallel_iterations depending on graph/eager mode.
  if in_graph_mode and not parallel_iterations:
    parallel_iterations = 10
  elif not in_graph_mode and not parallel_iterations:
    parallel_iterations = 1
  elif not in_graph_mode and parallel_iterations > 1:
    logging.log_first_n(
        logging.WARN, "Setting parallel_iterations > 1 has no "
        "effect when executing eagerly. Consider calling map_fn"
        " with tf.function to execute fn in "
        "parallel.", 1)
    parallel_iterations = 1

  # Flatten the input tensors, and get the TypeSpec for each one.
  elems_flat = nest.flatten(elems)
  elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
  elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)

  # Flatten fn's output signature.
  if fn_output_signature is None:
    # If fn_output_signature was not specified, then assume that it matches the
    # input signature.
    result_flat_signature = [
        _most_general_compatible_type(s)._unbatch()  # pylint: disable=protected-access
        for s in elems_flat_signature
    ]
    result_unflatten = elems_unflatten
  else:
    result_flat_signature = [
        _dtype_to_spec(d) for d in nest.flatten(fn_output_signature)
    ]
    result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x)

  with ops.name_scope(name, "map", elems_flat):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    elems_flat = [
        ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat
    ]

    # Check that inputs are not scalars.
    elems_static_shape = elems_flat[0].shape
    if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
      if len(elems_flat) == 1:
        raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
      else:
        raise ValueError(
            "elements in elems must be 1+ dimensional Tensors, not scalars"
        )

    # Box any composite tensors into tensor lists.
    elems_batchable = _elems_flat_to_batchable(elems_flat)

    # Find the number of iterations, n.  (may be known statically.)
    n_static = tensor_shape.Dimension(
        tensor_shape.dimension_value(
            elems_batchable[0].get_shape().with_rank_at_least(1)[0]))
    for tensor in elems_batchable[1:]:
      n_static.merge_with(
          tensor_shape.Dimension(
              tensor_shape.dimension_value(
                  tensor.get_shape().with_rank_at_least(1)[0])))
    n = n_static.value or array_ops.shape(elems_batchable[0])[0]

    # Convert elems to tensor array.
    # TODO(edloper): Should we set infer_shape=False for composite tensors?
    elems_batchable_ta = [
        tensor_array_ops.TensorArray(
            dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True)
        for t in elems_batchable
    ]
    # Unpack elements
    elems_batchable_ta = [
        ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable)
    ]

    i = constant_op.constant(0)

    # Prepare result tensor array.
    # TODO(edloper): Should we set infer_shape=False for composite tensors?
    result_batchable_dtype = _result_flat_signature_to_batchable_dtype(
        result_flat_signature)
    result_batchable_ta = [
        tensor_array_ops.TensorArray(
            dtype=dt, size=n, dynamic_size=False, infer_shape=infer_shape)
        for dt in result_batchable_dtype
    ]

    def compute(i, tas):
      """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if fn_output_signature and result_value structure don't match
        ValueType: if fn_output_signature and result_value lengths don't match
      """
      elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta]
      elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable,
                                                        elems_flat_signature)
      elems_value = elems_unflatten(elems_value_flat)
      result_value = fn(elems_value)
      nest.assert_same_structure(fn_output_signature or elems, result_value)
      result_value_flat = nest.flatten(result_value)
      result_value_batchable = _result_value_flat_to_batchable(
          result_value_flat, result_flat_signature)
      tas = [
          ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable)
      ]
      return (i + 1, tas)

    _, r_a = control_flow_ops.while_loop(
        lambda i, _: i < n,
        compute, (i, result_batchable_ta),
        parallel_iterations=parallel_iterations,
        back_prop=back_prop,
        swap_memory=swap_memory,
        maximum_iterations=n)
    result_batchable = [r.stack() for r in r_a]

    # Update each output tensor w/ static shape info about the outer dimension.
    for r in result_batchable:
      r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
          r.get_shape()[1:]))

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:
      varscope.set_caching_device(None)

    result_flat = _result_batchable_to_flat(result_batchable,
                                            result_flat_signature)
    result = result_unflatten(result_flat)
    return result
示例#13
0
 def zero_state(self, batch_size, dtype):
   return (array_ops.zeros([], dtype=dtypes.int32),
           tensor_array_ops.TensorArray(
               dtype=dtype, size=0, dynamic_size=True))
示例#14
0
    def __init__(self,
                 num_emb,
                 batch_size,
                 emb_dim,
                 hidden_dim,
                 sequence_length,
                 start_token,
                 learning_rate=0.01,
                 reward_gamma=0.95):

        self.num_emb = num_emb
        self.batch_size = batch_size
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = tf.constant([start_token] * self.batch_size,
                                       dtype=tf.int32)
        self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
        self.reward_gamma = reward_gamma
        self.g_params = []
        self.d_params = []
        self.temperature = 1.0
        self.grad_clip = 5.0
        self.expected_reward = tf.Variable(tf.zeros([self.sequence_length]))

        with tf.variable_scope('generator'):
            self.g_embeddings = tf.Variable(
                self.init_matrix([self.num_emb, self.emb_dim]))
            self.g_params.append(self.g_embeddings)
            self.g_recurrent_unit = self.create_recurrent_unit(
                self.g_params)  # maps h_tm1 to h_t for generator
            self.g_output_unit = self.create_output_unit(
                self.g_params)  # maps h_t to o_t (output token logits)

        # placeholder definition
        self.x = tf.placeholder(tf.int32,
                                shape=[self.batch_size, self.sequence_length])
        # sequence of indices of true data, not including start token

        self.rewards = tf.placeholder(
            tf.float32, shape=[self.batch_size, self.sequence_length])
        # get from rollout policy and discriminator

        # processed for batch
        with tf.device("/cpu:0"):
            inputs = tf.split(axis=1,
                              num_or_size_splits=self.sequence_length,
                              value=tf.nn.embedding_lookup(
                                  self.g_embeddings, self.x))
            self.processed_x = tf.stack([
                tf.squeeze(input_, [1]) for input_ in inputs
            ])  # seq_length x batch_size x emb_dim

        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        self.h0 = tf.stack([self.h0, self.h0])

        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                             size=self.sequence_length,
                                             dynamic_size=False,
                                             infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                             size=self.sequence_length,
                                             dynamic_size=False,
                                             infer_shape=True)

        def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(
                tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]),
                tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings,
                                           next_token)  # batch x emb_dim
            gen_o = gen_o.write(
                i,
                tf.reduce_sum(
                    tf.multiply(tf.one_hot(next_token, self.num_emb, 1.0, 0.0),
                                tf.nn.softmax(o_t)), 1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, gen_o, gen_x

        _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            body=_g_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_token), self.h0,
                       gen_o, gen_x))

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        # batch_size x seq_length
        self.gen_x = tf.transpose(self.gen_x, perm=[1, 0])

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                     size=self.sequence_length,
                                                     dynamic_size=False,
                                                     infer_shape=True)

        g_logits = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.sequence_length,
                                                dynamic_size=False,
                                                infer_shape=True)

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions, g_logits):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = self.g_output_unit(h_t)
            g_predictions = g_predictions.write(
                i, tf.nn.softmax(o_t))  # batch x vocab_size
            g_logits = g_logits.write(i, o_t)  # batch x vocab_size
            x_tp1 = ta_emb_x.read(i)
            return i + 1, x_tp1, h_t, g_predictions, g_logits

        _, _, _, self.g_predictions, self.g_logits = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_token), self.h0,
                       g_predictions, g_logits))

        self.g_predictions = tf.transpose(
            self.g_predictions.stack(),
            perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

        self.g_logits = tf.transpose(
            self.g_logits.stack(),
            perm=[1, 0, 2])  # batch_size x seq_length x vocab_size
        # pretraining loss
        self.pretrain_loss = -tf.reduce_sum(
            tf.one_hot(tf.to_int32(tf.reshape(
                self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log(
                    tf.clip_by_value(
                        tf.reshape(self.g_predictions, [-1, self.num_emb]),
                        1e-20, 1.0))) / (self.sequence_length *
                                         self.batch_size)

        # training updates
        pretrain_opt = self.g_optimizer(self.learning_rate)

        self.pretrain_grad, _ = tf.clip_by_global_norm(
            tf.gradients(self.pretrain_loss, self.g_params), self.grad_clip)
        self.pretrain_updates = pretrain_opt.apply_gradients(
            zip(self.pretrain_grad, self.g_params))

        #######################################################################
        #  Unsupervised Training
        #######################################################################
        self.g_loss = -tf.reduce_sum(
            tf.reduce_sum(
                tf.one_hot(tf.to_int32(tf.reshape(
                    self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log(
                        tf.clip_by_value(
                            tf.reshape(self.g_predictions, [-1, self.num_emb]),
                            1e-20, 1.0)), 1) * tf.reshape(self.rewards, [-1]))

        g_opt = self.g_optimizer(self.learning_rate)

        self.g_grad, _ = tf.clip_by_global_norm(
            tf.gradients(self.g_loss, self.g_params), self.grad_clip)
        self.g_updates = g_opt.apply_gradients(zip(self.g_grad, self.g_params))
示例#15
0
 def create_ta(elem):
     return tensor_array_ops.TensorArray(dtype=elem.dtype,
                                         size=n,
                                         dynamic_size=False,
                                         infer_shape=True).unstack(elem)
示例#16
0
 def fn():
     ta = tensor_array_ops.TensorArray(dtype=dtypes.float32,
                                       size=0,
                                       infer_shape=False)
     return ta.stack()
示例#17
0
def map_fn(fn,
           elems,
           dtype=None,
           parallel_iterations=None,
           back_prop=True,
           swap_memory=False,
           infer_shape=True,
           name=None):
    """map on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `map_fn` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the
  tensors unpacked from `elems`. `dtype` is the data type of the return
  value of `fn`. Users must provide `dtype` if it is different from
  the data type of `elems`.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Furthermore, `fn` may emit a different structure than its input.  For example,
  `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`.  In this case,
  the `dtype` parameter is not optional: `dtype` must be a type or (possibly
  nested) tuple of types matching the output of `fn`.

  To apply a functional operation to the nonzero elements of a SparseTensor
  one of the following methods is recommended. First, if the function is
  expressible as TensorFlow ops, use

  ```python
    result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
  ```

  If, however, the function is not expressible as a TensorFlow op, then use

  ```python
  result = SparseTensor(
    input.indices, map_fn(fn, input.values), input.dense_shape)
  ```

  instead.

  When executing eagerly, map_fn does not execute in parallel even if
  `parallel_iterations` is set to a value > 1. You can still get the
  performance benefits of running a function in parallel by using the
  `tf.contrib.eager.defun` decorator,

  ```python
  # Assume the function being used in map_fn is fn.
  # To ensure map_fn calls fn in parallel, use the defun decorator.
  @tf.contrib.eager.defun
  def func(tensor):
    return tf.map_fn(fn, tensor)
  ```

  Note that if you use the defun decorator, any non-TensorFlow Python code
  that you may have written in your function won't get executed. See
  `tf.contrib.eager.defun` for more details. The recommendation would be to
  debug without defun but switch to defun to get performance benefits of
  running map_fn in parallel.

  Args:
    fn: The callable to be performed.  It accepts one argument, which will
      have the same (possibly nested) structure as `elems`.  Its output
      must have the same structure as `dtype` if one is provided, otherwise
      it must have the same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which
      will be unpacked along their first dimension.  The nested sequence
      of the resulting slices will be applied to `fn`.
    dtype: (optional) The output type(s) of `fn`.  If `fn` returns a structure
      of Tensors differing from the structure of `elems`, then `dtype` is not
      optional and must have the same structure as the output of `fn`.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel. When graph building, the default value is 10. While executing
      eagerly, the default value is set to 1.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
    results of applying `fn` to tensors unpacked from `elems` along the first
    dimension, from first to last.

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `dtype` do not match, or if elems is a SparseTensor.
    ValueError: if the lengths of the output of `fn` and `dtype` do not match.

  Examples:
    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    squares = map_fn(lambda x: x * x, elems)
    # squares == [1, 4, 9, 16, 25, 36]
    ```

    ```python
    elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
    # alternate == [-1, 2, -3]
    ```

    ```python
    elems = np.array([1, 2, 3])
    alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
    # alternates[0] == [1, 2, 3]
    # alternates[1] == [-1, -2, -3]
    ```
  """
    if not callable(fn):
        raise TypeError("fn must be callable.")

    if isinstance(elems, sparse_tensor.SparseTensor):
        raise TypeError(
            "To perform a map on the values of a sparse tensor use either "
            " SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
            " SparseTensor(input.indices, map_fn(fn, input.values), "
            "input.dense_shape)")

    in_graph_mode = not context.executing_eagerly()
    # Set the default number of parallel_iterations depending on graph/eager mode.
    if in_graph_mode and not parallel_iterations:
        parallel_iterations = 10
    elif not in_graph_mode and not parallel_iterations:
        parallel_iterations = 1

    if not in_graph_mode and parallel_iterations > 1:
        logging.log_first_n(
            logging.WARN, "Setting parallel_iterations > 1 has no "
            "effect when executing eagerly. Consider calling map_fn"
            " with tf.contrib.eager.defun to execute fn in "
            "parallel.", 1)
        parallel_iterations = 1

    input_is_sequence = nest.is_sequence(elems)
    input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]

    def input_pack(x):
        return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

    if dtype is None:
        output_is_sequence = input_is_sequence
        output_flatten = input_flatten
        output_pack = input_pack
    else:
        output_is_sequence = nest.is_sequence(dtype)
        output_flatten = lambda x: nest.flatten(
            x) if output_is_sequence else [x]

        def output_pack(x):
            return (nest.pack_sequence_as(dtype, x)
                    if output_is_sequence else x[0])

    elems_flat = input_flatten(elems)

    with ops.name_scope(name, "map", elems_flat):
        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode:
            # Any get_variable calls in fn will cache the first call locally
            # and not issue repeated network I/O requests for each iteration.
            varscope = vs.get_variable_scope()
            varscope_caching_device_was_none = False
            if varscope.caching_device is None:
                # TODO(ebrevdo): Change to using colocate_with here and in other
                # methods.
                varscope.set_caching_device(lambda op: op.device)
                varscope_caching_device_was_none = True

        elems_flat = [
            ops.convert_to_tensor(elem, name="elem") for elem in elems_flat
        ]

        dtype = dtype or input_pack([elem.dtype for elem in elems_flat])
        dtype_flat = output_flatten(dtype)

        # Convert elems to tensor array. n may be known statically.
        static_shape = elems_flat[0].shape
        if static_shape.ndims is not None and static_shape.ndims < 1:
            if len(elems_flat) == 1:
                raise ValueError(
                    "elems must be a 1+ dimensional Tensor, not a scalar")
            else:
                raise ValueError(
                    "elements in elems must be 1+ dimensional Tensors, not scalars"
                )
        n = (tensor_shape.dimension_value(static_shape[0])
             or array_ops.shape(elems_flat[0])[0])

        # TensorArrays are always flat
        elems_ta = [
            tensor_array_ops.TensorArray(dtype=elem.dtype,
                                         size=n,
                                         dynamic_size=False,
                                         infer_shape=True)
            for elem in elems_flat
        ]
        # Unpack elements
        elems_ta = [
            elem_ta.unstack(elem)
            for elem_ta, elem in zip(elems_ta, elems_flat)
        ]

        i = constant_op.constant(0)

        accs_ta = [
            tensor_array_ops.TensorArray(dtype=dt,
                                         size=n,
                                         dynamic_size=False,
                                         infer_shape=infer_shape)
            for dt in dtype_flat
        ]

        def compute(i, tas):
            """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      """
            packed_values = input_pack(
                [elem_ta.read(i) for elem_ta in elems_ta])
            packed_fn_values = fn(packed_values)
            nest.assert_same_structure(dtype or elems, packed_fn_values)
            flat_fn_values = output_flatten(packed_fn_values)
            tas = [
                ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)
            ]
            return (i + 1, tas)

        _, r_a = control_flow_ops.while_loop(
            lambda i, _: i < n,
            compute, (i, accs_ta),
            parallel_iterations=parallel_iterations,
            back_prop=back_prop,
            swap_memory=swap_memory,
            maximum_iterations=n)
        results_flat = [r.stack() for r in r_a]

        n_static = tensor_shape.Dimension(
            tensor_shape.dimension_value(
                elems_flat[0].get_shape().with_rank_at_least(1)[0]))
        for elem in elems_flat[1:]:
            n_static.merge_with(
                tensor_shape.Dimension(
                    tensor_shape.dimension_value(
                        elem.get_shape().with_rank_at_least(1)[0])))
        for r in results_flat:
            r.set_shape(
                tensor_shape.TensorShape(n_static).concatenate(
                    r.get_shape()[1:]))

        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode and varscope_caching_device_was_none:
            varscope.set_caching_device(None)

        return output_pack(results_flat)
示例#18
0
 def fn():
     ta = tensor_array_ops.TensorArray(dtype=dtypes.float32,
                                       size=0,
                                       infer_shape=True)
     ta = ta.unstack(array_ops.zeros([0, 3, 5]))
     return ta.concat()
示例#19
0
 def call(self, inputs):
   samples = tensor_array_ops.TensorArray(
       dtype=dtypes.float32, size=array_ops.shape(inputs)[0])
   for idx, sample in enumerate(inputs):
     samples = samples.write(idx, math_ops.square(sample))
   return samples.stack()
示例#20
0
 def fn():
     ta = tensor_array_ops.TensorArray(dtype=dtypes.float32,
                                       tensor_array_name="foo",
                                       size=3)
     return ta.write(-1, constant_op.constant(7)).flow
示例#21
0
    def __init__(self, num_emb, emb_dim, hidden_dim,
                 sequence_length, start_token,
                 learning_rate=0.01, reward_gamma=0.9):
        self.num_emb = num_emb
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = tf.constant(start_token, dtype=tf.int32)
        self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
        self.reward_gamma = reward_gamma
        self.g_params = []
        self.d_params = []

        self.expected_reward = tf.Variable(tf.zeros([self.sequence_length]))

        with tf.variable_scope('generator'):
            self.g_embeddings = tf.Variable(self.init_matrix([self.num_emb, self.emb_dim]))
            self.g_params.append(self.g_embeddings)
            self.g_recurrent_unit = self.create_recurrent_unit(self.g_params)  # maps h_tm1 to h_t for generator
            self.g_output_unit = self.create_output_unit(self.g_params, self.g_embeddings)  # maps h_t to o_t (output token logits)

        with tf.variable_scope('discriminator'):
            self.d_embeddings = tf.Variable(self.init_matrix([self.num_emb, self.emb_dim]))
            self.d_params.append(self.d_embeddings)
            self.d_recurrent_unit = self.create_recurrent_unit(self.d_params)  # maps h_tm1 to h_t for discriminator
            self.d_classifier_unit = self.create_classifier_unit(self.d_params)  # maps h_t to class prediction logits
            self.d_h0 = tf.Variable(self.init_vector([self.hidden_dim]))
            self.d_params.append(self.d_h0)

        self.h0 = tf.placeholder(tf.float32, shape=[self.hidden_dim])  # initial random vector for generator
        self.x = tf.placeholder(tf.int32, shape=[self.sequence_length])  # sequence of indices of true data, not including start token
        self.samples = tf.placeholder(tf.float32, shape=[self.sequence_length])  # random samples from [0, 1]

        # generator on initial randomness
        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)
        samples = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length)
        samples = samples.unpack(self.samples)
        def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = self.g_output_unit(h_t)
            sample = samples.read(i)
            o_cumsum = _cumsum(o_t, self.num_emb)  # prepare for sampling
            next_token = tf.to_int32(tf.reduce_min(tf.where(sample < o_cumsum)))  # sample
            x_tp1 = tf.gather(self.g_embeddings, next_token)
            gen_o = gen_o.write(i, tf.gather(o_t, next_token))  # we only need the sampled token's probability
            gen_x = gen_x.write(i, next_token)  # indices, not embeddings
            return i + 1, x_tp1, h_t, gen_o, gen_x

        _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            body=_g_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.gather(self.g_embeddings, self.start_token),
                       self.h0, gen_o, gen_x))

        # discriminator on generated and real data
        d_gen_predictions = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length,
            dynamic_size=False, infer_shape=True)
        d_real_predictions = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length,
            dynamic_size=False, infer_shape=True)

        self.gen_x = self.gen_x.pack()
        emb_gen_x = tf.gather(self.d_embeddings, self.gen_x)
        ta_emb_gen_x = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length)
        ta_emb_gen_x = ta_emb_gen_x.unpack(emb_gen_x)

        emb_real_x = tf.gather(self.d_embeddings, self.x)
        ta_emb_real_x = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length)
        ta_emb_real_x = ta_emb_real_x.unpack(emb_real_x)

        def _d_recurrence(i, inputs, h_tm1, pred):
            x_t = inputs.read(i)
            h_t = self.d_recurrent_unit(x_t, h_tm1)
            y_t = self.d_classifier_unit(h_t)
            pred = pred.write(i, y_t)
            return i + 1, inputs, h_t, pred

        _, _, _, self.d_gen_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_d_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       ta_emb_gen_x,
                       self.d_h0,
                       d_gen_predictions))
        self.d_gen_predictions = tf.reshape(
                self.d_gen_predictions.pack(),
                [self.sequence_length])

        _, _, _, self.d_real_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_d_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       ta_emb_real_x,
                       self.d_h0,
                       d_real_predictions))
        self.d_real_predictions = tf.reshape(
                self.d_real_predictions.pack(),
                [self.sequence_length])

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length,
            dynamic_size=False, infer_shape=True)

        emb_x = tf.gather(self.g_embeddings, self.x)
        ta_emb_x = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length)
        ta_emb_x = ta_emb_x.unpack(emb_x)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = self.g_output_unit(h_t)
            g_predictions = g_predictions.write(i, o_t)
            x_tp1 = ta_emb_x.read(i)
            return i + 1, x_tp1, h_t, g_predictions

        _, _, _, self.g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.gather(self.g_embeddings, self.start_token),
                       self.h0, g_predictions))

        self.g_predictions = tf.reshape(
                self.g_predictions.pack(),
                [self.sequence_length, self.num_emb])

        # calculate discriminator loss
        self.d_gen_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                self.d_gen_predictions, tf.zeros([self.sequence_length])))
        self.d_real_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                self.d_real_predictions, tf.ones([self.sequence_length])))

        # calculate generator rewards and loss
        decays = tf.exp(tf.log(self.reward_gamma) * tf.to_float(tf.range(self.sequence_length)))
        rewards = _backwards_cumsum(decays * tf.sigmoid(self.d_gen_predictions),
                                    self.sequence_length)
        normalized_rewards = \
            rewards / _backwards_cumsum(decays, self.sequence_length) - self.expected_reward

        self.reward_loss = tf.reduce_mean(normalized_rewards ** 2)
        self.g_loss = \
            -tf.reduce_mean(tf.log(self.gen_o.pack()) * normalized_rewards)

        # pretraining loss
        self.pretrain_loss = \
            (-tf.reduce_sum(
                tf.one_hot(tf.to_int64(self.x),
                           self.num_emb, 1.0, 0.0) * tf.log(self.g_predictions))
             / self.sequence_length)

        # training updates
        d_opt = self.d_optimizer(self.learning_rate)
        g_opt = self.g_optimizer(self.learning_rate)
        pretrain_opt = self.g_optimizer(self.learning_rate)
        reward_opt = tf.train.GradientDescentOptimizer(self.learning_rate)

        self.d_gen_grad = tf.gradients(self.d_gen_loss, self.d_params)
        self.d_real_grad = tf.gradients(self.d_real_loss, self.d_params)
        self.d_gen_updates = d_opt.apply_gradients(zip(self.d_gen_grad, self.d_params))
        self.d_real_updates = d_opt.apply_gradients(zip(self.d_real_grad, self.d_params))

        self.reward_grad = tf.gradients(self.reward_loss, [self.expected_reward])
        self.reward_updates = reward_opt.apply_gradients(zip(self.reward_grad, [self.expected_reward]))

        self.g_grad = tf.gradients(self.g_loss, self.g_params)
        self.g_updates = g_opt.apply_gradients(zip(self.g_grad, self.g_params))

        self.pretrain_grad = tf.gradients(self.pretrain_loss, self.g_params)
        self.pretrain_updates = pretrain_opt.apply_gradients(zip(self.pretrain_grad, self.g_params))
示例#22
0
 def fn():
     ta = tensor_array_ops.TensorArray(dtype=dtypes.float32,
                                       tensor_array_name="foo",
                                       size=3,
                                       infer_shape=False)
     return ta.split(1.0, [1]).flow
示例#23
0
    def __init__(self, lstm, update_rate, word_embedding_matrix):
        self.lstm = lstm
        self.update_rate = update_rate

        self.num_emb = self.lstm.num_emb
        self.batch_size = self.lstm.batch_size
        self.emb_dim = self.lstm.emb_dim
        self.hidden_dim = self.lstm.hidden_dim
        self.sequence_length = self.lstm.sequence_length
        self.start_token = tf.identity(self.lstm.start_token)
        self.learning_rate = self.lstm.learning_rate
        self.type_size = self.lstm.type_size

        self.g_embeddings = word_embedding_matrix
        self.g_recurrent_unit = self.create_recurrent_unit(
        )  # maps h_tm1 to h_t for generator
        self.g_output_unit = self.create_output_unit(
        )  # maps h_t to o_t (output token logits)

        #####################################################################################################
        # placeholder definition
        self.x = tf.placeholder(tf.int32,
                                shape=[
                                    self.batch_size, self.sequence_length
                                ])  # sequence of tokens generated by generator
        self.given_num = tf.placeholder(tf.int32)
        self.type_index = tf.placeholder(dtype=tf.int32,
                                         shape=[self.batch_size])

        # x 에 type vector 추가
        x_type_index = tf.reshape(
            tf.concat([self.type_index] * self.sequence_length, axis=0),
            [self.batch_size, self.sequence_length])
        self.x_type_onehot = tf.one_hot(x_type_index, self.type_size)

        # start_token 에 type vector 추가
        self.type_onehot = tf.one_hot(self.type_index, self.type_size)

        # processed for batch
        with tf.device("/cpu:0"):
            embedding_input = tf.nn.embedding_lookup(self.g_embeddings, self.x)
            embedding_input = tf.concat([embedding_input, self.x_type_onehot],
                                        axis=2)
            self.processed_x = tf.transpose(
                embedding_input, perm=[1, 0,
                                       2])  # seq_length x batch_size x emb_dim

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        ta_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                            size=self.sequence_length)
        ta_x = ta_x.unstack(tf.transpose(self.x, perm=[1, 0]))
        #####################################################################################################

        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        self.h0 = tf.stack([self.h0, self.h0])

        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                             size=self.sequence_length,
                                             dynamic_size=False,
                                             infer_shape=True)

        # When current index i < given_num, use the provided tokens as the input at each time step
        def _g_recurrence_1(i, x_t, h_tm1, given_num, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            x_tp1 = ta_emb_x.read(i)
            gen_x = gen_x.write(i, ta_x.read(i))
            return i + 1, x_tp1, h_t, given_num, gen_x

        # When current index i >= given_num, start roll-out, use the output as time step t as the input at time step t+1
        def _g_recurrence_2(i, x_t, h_tm1, given_num, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(
                tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]),
                tf.int32)
            x_tp1 = tf.concat([
                tf.nn.embedding_lookup(self.g_embeddings, next_token),
                self.type_onehot
            ],
                              axis=1)  # batch x emb_dim
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, given_num, gen_x

        i, x_t, h_tm1, given_num, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, given_num, _4: i < given_num,
            body=_g_recurrence_1,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.concat([
                           tf.nn.embedding_lookup(self.g_embeddings,
                                                  self.start_token),
                           self.type_onehot
                       ],
                                 axis=1), self.h0, self.given_num, gen_x))

        _, _, _, _, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            body=_g_recurrence_2,
            loop_vars=(i, x_t, h_tm1, given_num, self.gen_x))

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        self.gen_x = tf.transpose(self.gen_x,
                                  perm=[1, 0])  # batch_size x seq_length
示例#24
0
 def fn():
     ta = tensor_array_ops.TensorArray(dtype=dtypes.float32,
                                       tensor_array_name="foo",
                                       size=3)
     with ops.control_dependencies([ta.close()]):
         return 1.0
示例#25
0
def dynamic_rnn_decoder(cell,
                        decoder_fn,
                        inputs=None,
                        sequence_length=None,
                        parallel_iterations=None,
                        swap_memory=False,
                        time_major=False,
                        scope=None,
                        name=None):
    """ Dynamic RNN decoder for a sequence-to-sequence model specified by
  RNNCell and decoder function.

  The `dynamic_rnn_decoder` is similar to the `tf.python.ops.rnn.dynamic_rnn`
  as the decoder does not make any assumptions of sequence length and batch
  size of the input.

  The `dynamic_rnn_decoder` has two modes: training or inference and expects
  the user to create seperate functions for each.

  Under both training and inference, both `cell` and `decoder_fn` are expected,
  where `cell` performs computation at every timestep using `raw_rnn`, and
  `decoder_fn` allows modeling of early stopping, output, state, and next
  input and context.

  When training the user is expected to supply `inputs`. At every time step a
  slice of the supplied input is fed to the `decoder_fn`, which modifies and
  returns the input for the next time step.

  `sequence_length` is needed at training time, i.e., when `inputs` is not
  None, for dynamic unrolling. At test time, when `inputs` is None,
  `sequence_length` is not needed.

  Under inference `inputs` is expected to be `None` and the input is inferred
  solely from the `decoder_fn`.

  Args:
    cell: An instance of RNNCell.
    decoder_fn: A function that takes time, cell state, cell input,
      cell output and context state. It returns a early stopping vector,
      cell state, next input, cell output and context state.
      Examples of decoder_fn can be found in the decoder_fn.py folder.
    inputs: The inputs for decoding (embedded format).

      If `time_major == False` (default), this must be a `Tensor` of shape:
        `[batch_size, max_time, ...]`.

      If `time_major == True`, this must be a `Tensor` of shape:
        `[max_time, batch_size, ...]`.

      The input to `cell` at each time step will be a `Tensor` with dimensions
        `[batch_size, ...]`.

    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
      if `inputs` is not None and `sequence_length` is None it is inferred
      from the `inputs` as the maximal possible sequence length.
    parallel_iterations: (Default: 32).  The number of iterations to run in
      parallel.  Those operations which do not have any temporal dependency
      and can be run in parallel, will be.  This parameter trades off
      time for space.  Values >> 1 use more memory but take less time,
      while smaller values use less memory but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs
      which would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
    time_major: The shape format of the `inputs` and `outputs` Tensors.
      If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
      If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
      Using `time_major = True` is a bit more efficient because it avoids
      transposes at the beginning and end of the RNN calculation.  However,
      most TensorFlow data is batch-major, so by default this function
      accepts input and emits output in batch-major form.
    scope: VariableScope for the `raw_rnn`;
      defaults to None.
    name: NameScope for the decoder;
      defaults to "dynamic_rnn_decoder"

  Returns:
    A tuple (outputs, final_state, final_context_state) where:

      outputs: the RNN output 'Tensor'.

        If time_major == False (default), this will be a `Tensor` shaped:
          `[batch_size, max_time, cell.output_size]`.

        If time_major == True, this will be a `Tensor` shaped:
          `[max_time, batch_size, cell.output_size]`.

      final_state: The final state and will be shaped
        `[batch_size, cell.state_size]`.

      final_context_state: The context state returned by the final call
        to decoder_fn. This is useful if the context state maintains internal
        data which is required after the graph is run.
        For example, one way to diversify the inference output is to use
        a stochastic decoder_fn, in which case one would want to store the
        decoded outputs, not just the RNN outputs. This can be done by
        maintaining a TensorArray in context_state and storing the decoded
        output of each iteration therein.

  Raises:
    ValueError: if inputs is not None and has less than three dimensions.
  """
    with ops.name_scope(name, "dynamic_rnn_decoder", [
            cell, decoder_fn, inputs, sequence_length, parallel_iterations,
            swap_memory, time_major, scope
    ]):
        if inputs is not None:
            # Convert to tensor
            inputs = ops.convert_to_tensor(inputs)

            # Test input dimensions
            if inputs.get_shape().ndims is not None and (
                    inputs.get_shape().ndims < 2):
                raise ValueError("Inputs must have at least two dimensions")
            # Setup of RNN (dimensions, sizes, length, initial state, dtype)
            if not time_major:
                # [batch, seq, features] -> [seq, batch, features]
                inputs = array_ops.transpose(inputs, perm=[1, 0, 2])

            dtype = inputs.dtype
            # Get data input information
            input_depth = int(inputs.get_shape()[2])
            batch_depth = inputs.get_shape()[1].value
            max_time = inputs.get_shape()[0].value
            if max_time is None:
                max_time = array_ops.shape(inputs)[0]
            # Setup decoder inputs as TensorArray
            inputs_ta = tensor_array_ops.TensorArray(dtype, size=max_time)
            inputs_ta = inputs_ta.unstack(inputs)

        def loop_fn(time, cell_output, cell_state, loop_state):
            if cell_state is None:  # first call, before while loop (in raw_rnn)
                if cell_output is not None:
                    raise ValueError(
                        "Expected cell_output to be None when cell_state "
                        "is None, but saw: %s" % cell_output)
                if loop_state is not None:
                    raise ValueError(
                        "Expected loop_state to be None when cell_state "
                        "is None, but saw: %s" % loop_state)
                context_state = None
            else:  # subsequent calls, inside while loop, after cell excution
                if isinstance(loop_state, tuple):
                    (done, context_state) = loop_state
                else:
                    done = loop_state
                    context_state = None

            # call decoder function
            if inputs is not None:  # training
                # get next_cell_input
                if cell_state is None:
                    next_cell_input = inputs_ta.read(0)
                else:
                    if batch_depth is not None:
                        batch_size = batch_depth
                    else:
                        batch_size = array_ops.shape(done)[0]
                    next_cell_input = control_flow_ops.cond(
                        math_ops.equal(time, max_time),
                        lambda: array_ops.zeros([batch_size, input_depth],
                                                dtype=dtype),
                        lambda: inputs_ta.read(time))
                (next_done, next_cell_state, next_cell_input, emit_output,
                 next_context_state) = decoder_fn(time, cell_state,
                                                  next_cell_input, cell_output,
                                                  context_state)
            else:  # inference
                # next_cell_input is obtained through decoder_fn
                (next_done, next_cell_state, next_cell_input, emit_output,
                 next_context_state) = decoder_fn(time, cell_state, None,
                                                  cell_output, context_state)

            # check if we are done
            if next_done is None:  # training
                next_done = time >= sequence_length

            # build next_loop_state
            if next_context_state is None:
                next_loop_state = next_done
            else:
                next_loop_state = (next_done, next_context_state)

            return (next_done, next_cell_input, next_cell_state, emit_output,
                    next_loop_state)

        # Run raw_rnn function
        outputs_ta, final_state, final_loop_state = rnn.raw_rnn(
            cell,
            loop_fn,
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory,
            scope=scope)
        outputs = outputs_ta.stack()

        # Get final context_state, if generated by user
        if isinstance(final_loop_state, tuple):
            final_context_state = final_loop_state[1]
        else:
            final_context_state = None

        if not time_major:
            # [seq, batch, features] -> [batch, seq, features]
            outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
        return outputs, final_state, final_context_state
示例#26
0
 def fn():
     ta = tensor_array_ops.TensorArray(dtype=dtypes.float32,
                                       tensor_array_name="foo",
                                       size=3)
     return ta.size()
示例#27
0
    def rollout(self, input_x, given_num):
        with tf.device("/cpu:0"):
            processed_x = tf.transpose(
                tf.nn.embedding_lookup(self.g_embeddings, input_x),
                perm=[1, 0, 2])  # seq_length x batch_size x emb_dim
        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(processed_x)

        #Next is rollout
        gen_for_reward = tensor_array_ops.TensorArray(dtype=tf.int32,
                                                      size=1,
                                                      dynamic_size=True,
                                                      infer_shape=True,
                                                      clear_after_read=False)
        ta_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                            size=self.sequence_length)
        ta_x = ta_x.unstack(tf.transpose(input_x, perm=[1, 0]))

        # When current index i < given_num, use the provided tokens as the input at each time step
        def _g_recurrence_1(i, x_t, input_x, gen_x, h_tm1, h_tm1_manager,
                            last_goal, real_goal, give_num):

            cur_sen = tf.split(
                tf.concat([
                    tf.split(input_x, [i, self.sequence_length - i], 1)[0],
                    self.padding_array
                ], 1), [self.sequence_length, i], 1)[0]
            with tf.variable_scope(self.scope):
                feature = self.FeatureExtractor_unit(cur_sen, self.drop_out)

            h_t_manager = self.g_manager_recurrent_unit(feature, h_tm1_manager)
            sub_goal = self.g_manager_output_unit(h_t_manager)
            sub_goal = tf.nn.l2_normalize(sub_goal, 1)

            h_t_Worker = tf.cond(
                i > 0, lambda: self.g_worker_recurrent_unit(x_t, h_tm1),
                lambda: h_tm1)  # hidden_memory_tuple

            real_sub_goal = tf.cond(i > 0, lambda: tf.add(last_goal, sub_goal),
                                    lambda: real_goal)
            # real_goal_array = real_goal_array.write(i, real_sub_goal)

            x_tp1 = tf.cond(i > 0, lambda: ta_emb_x.read(i - 1), lambda: x_t)

            # hidden_memory_tuple
            with tf.control_dependencies([cur_sen]):
                gen_x = tf.cond(i > 0,
                                lambda: gen_x.write(i - 1, ta_x.read(i - 1)),
                                lambda: gen_x)
            return i + 1, x_tp1,input_x,gen_x,h_t_Worker, h_t_manager, \
                   tf.cond(((i) % self.step_size) > 0, lambda: real_sub_goal,
                           lambda: tf.constant(0.0, shape=[self.batch_size, self.goal_out_size])), \
                   tf.cond(((i) % self.step_size) > 0, lambda: real_goal, lambda: real_sub_goal), give_num

        # When current index i >= given_num, start roll-out, use the output as time step t as the input at time step t+1
        def _g_recurrence_2(i, x_t, gen_x, h_tm1, h_tm1_manager, last_goal,
                            real_goal):
            # with tf.device('/cpu:0'):
            cur_sen = tf.cond(
                i > 0, lambda: tf.split(
                    tf.concat([
                        tf.transpose(gen_x.stack(), perm=[1, 0]), self.
                        padding_array
                    ], 1), [self.sequence_length, i - 1], 1)[0],
                lambda: self.padding_array)
            with tf.variable_scope(self.scope):
                feature = self.FeatureExtractor_unit(cur_sen, self.drop_out)
            h_t_Worker = self.g_worker_recurrent_unit(
                x_t, h_tm1)  # hidden_memory_tuple
            o_t_Worker = self.g_worker_output_unit(
                h_t_Worker)  # batch x vocab , logits not prob

            o_t_Worker = tf.reshape(
                o_t_Worker, [self.batch_size, self.vocab_size, self.goal_size])
            o_t_Worker = tf.nn.softmax(o_t_Worker)
            # o_t_Worker = tf.expand_dims(o_t_Worker,2)   # batch x vocab x 1
            # o_t_Worker = tf.multiply(o_t_Worker,tf.nn.softmax(self.W_workerOut_change) ) #batch x vocab x goal_size

            h_t_manager = self.g_manager_recurrent_unit(feature, h_tm1_manager)
            sub_goal = self.g_manager_output_unit(h_t_manager)
            sub_goal = tf.nn.l2_normalize(sub_goal, 1)

            real_sub_goal = tf.add(last_goal, sub_goal)
            w_g = tf.matmul(real_goal, self.g_change)  #batch x goal_size
            w_g = tf.nn.l2_normalize(w_g, 1)
            w_g = tf.expand_dims(w_g, 2)  #batch x goal_size x 1

            x_logits = tf.matmul(o_t_Worker, w_g)
            x_logits = tf.squeeze(x_logits)

            log_prob = tf.log(tf.nn.softmax(x_logits))
            next_token = tf.cast(
                tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]),
                tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings,
                                           next_token)  # batch x emb_dim
            with tf.control_dependencies([cur_sen]):
                gen_x = gen_x.write(i - 1, next_token)  # indices, batch_size
            return i + 1, x_tp1, gen_x,h_t_Worker,h_t_manager,\
                    tf.cond(((i) % self.step_size) > 0, lambda: real_sub_goal,
                                                lambda: tf.constant(0.0, shape=[self.batch_size, self.goal_out_size])), \
                    tf.cond(((i) % self.step_size) > 0, lambda: real_goal, lambda: real_sub_goal)

        i, x_t, _, gen_for_reward, h_worker, h_manager, self.last_goal_for_reward, self.real_goal_for_reward, given_num = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5, _6, _7, given_num: i < given_num
            + 1,
            body=_g_recurrence_1,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_token), self.x,
                       gen_for_reward, self.h0_worker, self.h0_manager,
                       tf.zeros([self.batch_size, self.goal_out_size
                                 ]), self.goal_init, given_num),
            parallel_iterations=1)  ##input groud-truth

        _, _, gen_for_reward, _, _, _, _ = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5, _6: i < self.sequence_length +
            1,
            body=_g_recurrence_2,
            loop_vars=(i, x_t, gen_for_reward, h_worker, h_manager,
                       self.last_goal_for_reward, self.real_goal_for_reward),
            parallel_iterations=1)  ## rollout by original policy

        gen_for_reward = gen_for_reward.stack()  # seq_length x batch_size

        gen_for_reward = tf.transpose(gen_for_reward,
                                      perm=[1, 0])  # batch_size x seq_length

        return gen_for_reward
示例#28
0
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
          swap_memory=False, name=None):
  """foldl on the list of tensors unpacked from `elems` on dimension 0.

  This foldl operator repeatedly applies the callable `fn` to a sequence
  of elements from first to last. The elements are made of the tensors
  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
  arguments. The first argument is the accumulated value computed from the
  preceding invocation of fn. If `initializer` is None, `elems` must contain
  at least one element, and its first element is used as the initializer.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is fn(initializer, values[0]).shape`.

  Args:
    fn: The callable to be performed.
    elems: A tensor to be unpacked on dimension 0.
    initializer: (optional) The initial value for the accumulator.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A tensor resulting from applying `fn` consecutively to the list of tensors
    unpacked from `elems`, from first to last.

  Raises:
    TypeError: if `fn` is not callable.

  Example:
    ```python
    elems = [1, 2, 3, 4, 5, 6]
    sum = foldl(lambda a, x: a + x, elems)
    # sum == 21
    ```
  """
  if not callable(fn):
    raise TypeError("fn must be callable.")

  with ops.name_scope(name, "foldl", [elems]):
    # Any get_variable calls in fn will cache the first call locally
    # and not issue repeated network I/O requests for each iteration.
    varscope = vs.get_variable_scope()
    varscope_caching_device_was_none = False
    if varscope.caching_device is None:
      # TODO(ebrevdo): Change to using colocate_with here and in other methods.
      varscope.set_caching_device(lambda op: op.device)
      varscope_caching_device_was_none = True

    # Convert elems to tensor array.
    elems = ops.convert_to_tensor(elems, name="elems")
    n = array_ops.shape(elems)[0]
    elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
                                            dynamic_size=False,
                                            infer_shape=True)
    elems_ta = elems_ta.unstack(elems)

    if initializer is None:
      a = elems_ta.read(0)
      i = constant_op.constant(1)
    else:
      a = ops.convert_to_tensor(initializer)
      i = constant_op.constant(0)

    def compute(i, a):
      a = fn(a, elems_ta.read(i))
      return [i + 1, a]
    _, r_a = control_flow_ops.while_loop(
        lambda i, a: i < n, compute, [i, a],
        parallel_iterations=parallel_iterations,
        back_prop=back_prop,
        swap_memory=swap_memory)

    if varscope_caching_device_was_none:
      varscope.set_caching_device(None)
    return r_a
示例#29
0
 def empty():
     return tensor_array_ops.TensorArray(size=0,
                                         element_shape=[],
                                         dtype=dtypes.int64,
                                         dynamic_size=True)
示例#30
0
    def __init__(self,
                 num_vocabulary,
                 batch_size,
                 emb_dim,
                 hidden_dim,
                 sequence_length,
                 start_token,
                 discriminator=None,
                 g_embeddings=None,
                 learning_rate=0.01,
                 reward_gamma=0.95):
        self.num_vocabulary = num_vocabulary
        self.batch_size = batch_size
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = tf.constant([start_token] * self.batch_size,
                                       dtype=tf.int32)
        self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
        self.reward_gamma = reward_gamma
        self.g_params = []
        self.d_params = []
        self.discriminator = discriminator
        self.temperature = 1.0
        self.grad_clip = 5.0
        self.expected_reward = tf.Variable(tf.zeros([self.sequence_length]))

        with tf.compat.v1.variable_scope('generator'):
            # self.g_embeddings = tf.Variable(self.init_matrix([self.num_vocabulary, self.emb_dim]))
            self.g_embeddings = g_embeddings
            self.g_params.append(self.g_embeddings)
            self.g_recurrent_unit = self.create_recurrent_unit(
                self.g_params)  # maps h_tm1 to h_t for generator
            self.g_output_unit = self.create_output_unit(
                self.g_params)  # maps h_t to o_t (output token logits)

        # placeholder definition

        self.x = tf.compat.v1.placeholder(
            tf.int32, shape=[self.batch_size, self.sequence_length
                             ])  # sequence of tokens generated by generator
        self.y = tf.compat.v1.placeholder(tf.int32,
                                          shape=[
                                              self.batch_size,
                                              self.sequence_length
                                          ])  # sequence of tokens of real data

        # processed for batch
        with tf.device("/cpu:0"):
            self.processed_x = tf.transpose(
                a=tf.nn.embedding_lookup(params=self.g_embeddings, ids=self.x),
                perm=[1, 0, 2])  # seq_length x batch_size x emb_dim

        # Initial states
        self.h_0 = tf.compat.v1.placeholder(tf.float32,
                                            shape=[batch_size, emb_dim])
        self.c_0 = tf.compat.v1.placeholder(tf.float32,
                                            shape=[batch_size, emb_dim])
        self.h0 = tf.stack([self.h_0, self.c_0])

        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
                                             size=self.sequence_length,
                                             dynamic_size=False,
                                             infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
                                             size=self.sequence_length,
                                             dynamic_size=False,
                                             infer_shape=True)
        gen_ot = tensor_array_ops.TensorArray(dtype=tf.float32,
                                              size=self.sequence_length,
                                              dynamic_size=False,
                                              infer_shape=True)

        def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_ot):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            next_token = tf.cast(tf.argmax(input=o_t, axis=1), tf.int32)
            x_tp1 = tf.matmul(tf.nn.softmax(tf.multiply(o_t, 1e3)),
                              self.g_embeddings)
            gen_o = gen_o.write(i,
                                tf.reduce_sum(input_tensor=tf.multiply(
                                    tf.one_hot(next_token, self.num_vocabulary,
                                               1.0, 0.0), tf.nn.softmax(o_t)),
                                              axis=1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            gen_ot = gen_ot.write(i, x_tp1)
            return i + 1, x_tp1, h_t, gen_o, gen_x, gen_ot

        _, _, _, self.gen_o, self.gen_x, self.gen_ot = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5: i < self.sequence_length,
            body=_g_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(params=self.g_embeddings,
                                              ids=self.start_token), self.h0,
                       gen_o, gen_x, gen_ot))

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        self.gen_x = tf.transpose(a=self.gen_x,
                                  perm=[1, 0])  # batch_size x seq_length

        self.gen_ot = self.gen_ot.stack()
        self.gen_ot = tf.slice(self.gen_ot,
                               begin=[0, 0, 0],
                               size=[sequence_length, batch_size, emb_dim])
        self.gen_ot = tf.transpose(
            a=self.gen_ot, perm=[1, 0,
                                 2])  # batch_size x seq_length x g_emb_dim

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                     size=self.sequence_length,
                                                     dynamic_size=False,
                                                     infer_shape=True)

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = self.g_output_unit(h_t)
            g_predictions = g_predictions.write(
                i, tf.nn.softmax(o_t))  # batch x vocab_size
            x_tp1 = ta_emb_x.read(i)
            return i + 1, x_tp1, h_t, g_predictions

        _, _, _, self.g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(params=self.g_embeddings,
                                              ids=self.start_token), self.h0,
                       g_predictions))

        self.g_predictions = tf.transpose(
            a=self.g_predictions.stack(),
            perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

        # pretraining loss
        self.pretrain_loss = -tf.reduce_sum(input_tensor=tf.one_hot(
            tf.cast(tf.reshape(self.x, [-1]), dtype=tf.int32),
            self.num_vocabulary, 1.0, 0.0) * tf.math.log(
                tf.clip_by_value(
                    tf.reshape(self.g_predictions, [-1, self.num_vocabulary]),
                    1e-20, 1.0))) / (self.sequence_length * self.batch_size)

        # training updates
        pretrain_opt = self.g_optimizer(self.learning_rate)

        self.pretrain_grad, _ = tf.clip_by_global_norm(
            tf.gradients(ys=self.pretrain_loss, xs=self.g_params),
            self.grad_clip)
        self.pretrain_updates = pretrain_opt.apply_gradients(
            zip(self.pretrain_grad, self.g_params))

        #######################################################################################################
        #  Unsupervised Training
        #######################################################################################################

        def get_feature(input_x, name=''):
            return self.discriminator.feature(input_x=input_x, name=name)

        def compute_pairwise_distances(x, y):
            """Computes the squared pairwise Euclidean distances between x and y.
            Args:
              x: a tensor of shape [num_x_samples, num_features]
              y: a tensor of shape [num_y_samples, num_features]
            Returns:
              a distance matrix of dimensions [num_x_samples, num_y_samples].
            Raises:
              ValueError: if the inputs do no matched the specified dimensions.
            """

            if not len(x.get_shape()) == len(y.get_shape()) == 2:
                raise ValueError('Both inputs should be matrices.')

            if x.get_shape().as_list()[1] != y.get_shape().as_list()[1]:
                raise ValueError('The number of features should be the same.')

            norm = lambda x: tf.reduce_sum(input_tensor=tf.square(x), axis=1)

            # By making the `inner' dimensions of the two matrices equal to 1 using
            # broadcasting then we are essentially substracting every pair of rows
            # of x and y.
            # x will be num_samples x num_features x 1,
            # and y will be 1 x num_features x num_samples (after broadcasting).
            # After the substraction we will get a
            # num_x_samples x num_features x num_y_samples matrix.
            # The resulting dist will be of shape num_y_samples x num_x_samples.
            # and thus we need to transpose it again.
            return tf.transpose(
                a=norm(tf.expand_dims(x, 2) - tf.transpose(a=y)))

        def gaussian_kernel_matrix(x, y, sigmas=None):
            r"""Computes a Guassian Radial Basis Kernel between the samples of x and y.
            We create a sum of multiple gaussian kernels each having a width sigma_i.
            Args:
              x: a tensor of shape [num_samples, num_features]
              y: a tensor of shape [num_samples, num_features]
              sigmas: a tensor of floats which denote the widths of each of the
                gaussians in the kernel.
            Returns:
              A tensor of shape [num_samples{x}, num_samples{y}] with the RBF kernel.
            """
            if sigmas is None:
                sigmas = [
                    1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25,
                    30, 35, 100, 1e3, 1e4, 1e5, 1e6
                ]
            beta = 1. / (2. * (tf.expand_dims(sigmas, 1)))

            dist = compute_pairwise_distances(x, y)

            s = tf.matmul(beta, tf.reshape(dist, (1, -1)))

            return tf.reshape(tf.reduce_sum(input_tensor=tf.exp(-s), axis=0),
                              tf.shape(input=dist))

        def calc_mmd(x, y):
            cost = tf.reduce_mean(input_tensor=gaussian_kernel_matrix(x, x))
            cost += tf.reduce_mean(input_tensor=gaussian_kernel_matrix(y, y))
            cost -= 2 * tf.reduce_mean(
                input_tensor=gaussian_kernel_matrix(x, y))

            # We do not allow the loss to become negative.
            cost = tf.compat.v1.where(cost > 0, cost, 0, name='value')

            return cost

        x_feature = get_feature(input_x=self.gen_ot, name='gx')
        y_feature = get_feature(input_x=self.y, name='gy')
        self.mmd = calc_mmd(x_feature, y_feature)
        g_opt = self.g_optimizer(self.learning_rate)

        self.g_grad, _ = tf.clip_by_global_norm(
            tf.gradients(ys=self.mmd, xs=self.g_params), self.grad_clip)
        self.g_updates = g_opt.apply_gradients(zip(self.g_grad, self.g_params))