def _infer_latents(self, inputs, observed): hparams = self._hparams batch_size = util.batch_size_from_nested_tensors(observed) enc_observed = snt.BatchApply(self._obs_encoder, n_dims=2)(observed) e_outs, _ = tf.nn.dynamic_rnn( self._e_core, util.concat_features((inputs, enc_observed)), initial_state=self._e_core.initial_state(batch_size)) f_outs, _ = util.reverse_dynamic_rnn( self._f_core, e_outs, initial_state=self._f_core.initial_state(batch_size)) q_zs = self._q_z.dist( snt.BatchApply(self._q_z, n_dims=2)(f_outs), name="q_zs") latents = q_zs.sample() p_zs = tf.contrib.distributions.MultivariateNormalDiag( loc=tf.zeros_like(latents), scale_diag=tf.ones_like(latents), name="p_zs") divs = util.calc_kl(hparams, latents, q_zs, p_zs) (_unused_d_outs, d_states), _ = tf.nn.dynamic_rnn( util.state_recording_rnn(self._d_core), util.concat_features((inputs, latents)), initial_state=self._d_core.initial_state(batch_size)) return (d_states, latents), divs
def _build(self, x1, x2, r): if self._rep == 'identity': k, q = (x1, x2) elif self._rep == 'mlp': # Pass through MLP initializer = tf.initializers.glorot_uniform( dtype=self._float_dtype) regularizer = tf.contrib.layers.l2_regularizer( self._l2_penalty_weight) module = snt.nets.MLP( self._output_sizes, activation=self._nonlinearity, use_bias=True, initializers={"w": initializer}, ) k = snt.BatchApply(module, n_dims=1)(x1) q = snt.BatchApply(module, n_dims=1)(x2) else: raise NameError("'rep' not among ['identity','mlp']") if self._att_type == 'dot_product': rep = dot_product_attention(q, k, r, self._normalise) else: raise NameError(("'att_type' not among ['dot_product']")) return rep
def _build(self, z_tm1, prior_rnn_hidden_state): """Applies the op. :param z_tm1: :param prior_rnn_hidden_state: :return: """ what_tm1, where_tm1, presence_tm1 = z_tm1[:3] prior_rnn_inpt = tf.concat((what_tm1, where_tm1), -1) rnn = snt.BatchApply(self._cell) outputs, prior_rnn_hidden_state = rnn(prior_rnn_inpt, prior_rnn_hidden_state) n_outputs = 2 * (4 + self._n_what) + 1 stats = snt.BatchApply(snt.Linear(n_outputs))(outputs) prop_prob_logit, stats = tf.split(stats, [1, n_outputs - 1], -1) prop_prob_logit += self._prop_logit_bias prop_prob_logit = presence_tm1 * prop_prob_logit + (presence_tm1 - 1.) * 88. locs, scales = tf.split(stats, 2, -1) prior_where_loc, prior_what_loc = tf.split(locs, [4, self._n_what], -1) prior_where_scale, prior_what_scale = tf.split(scales, [4, self._n_what], -1) prior_where_scale, prior_what_scale = (tf.nn.softplus(i) + 1e-2 for i in (prior_where_scale, prior_what_scale)) if self._where_loc_bias is not None: bias = np.asarray(self._where_loc_bias).reshape((1, 4)) prior_where_loc += bias prior_stats = (prior_where_loc, prior_where_scale, prior_what_loc, prior_what_scale, prop_prob_logit) return prior_stats, prior_rnn_hidden_state
def __call__(self, observations): # Transform observations to intermediate representations (typically a # convolutional network). torso_output = self._torso(observations) # Now that dimension of intermediate representation is known initialize # embedding of sample quantile thresholds (only done once). self._create_embedding(torso_output.shape[-1]) # Sample quantile thresholds. batch_size = tf.shape(observations)[0] tau_shape = tf.stack([batch_size, self._num_quantile_samples]) tau = tf.random.uniform(tau_shape) indices = tf.range(1, self._latent_dim+1, dtype=tf.float32) # Embed sampled quantile thresholds in intermediate representation space. tau_tiled = tf.tile(tau[:, :, None], (1, 1, self._latent_dim)) indices_tiled = tf.tile(indices[None, None, :], tf.concat([tau_shape, [1]], 0)) tau_embedding = tf.cos(tau_tiled * indices_tiled * np.pi) tau_embedding = snt.BatchApply(self._embedding)(tau_embedding) tau_embedding = tf.nn.relu(tau_embedding) # Merge intermediate representations with embeddings, and apply head # network (typically an MLP). torso_output = tf.tile(torso_output[:, None, :], (1, self._num_quantile_samples, 1)) q_value_quantiles = snt.BatchApply(self._head)(tau_embedding * torso_output) q_dist = tf.transpose(q_value_quantiles, (0, 2, 1)) q_values = tf.reduce_mean(q_value_quantiles, axis=1) q_values = tf.stop_gradient(q_values) return q_values, q_dist, tau
def unroll(self, actions, env_outputs, core_state): _, _, done, _ = env_outputs torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs)) core_output_list, core_state = self._recurrent(torso_outputs, actions, done, core_state) return snt.BatchApply(self._head)( tf.stack(core_output_list)), core_state
def unroll(self, actions, env_outputs, core_state): _, _, done, _ = env_outputs torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs)) # Note, in this implementation we can't use CuDNN RNN to speed things up due # to the state reset. This can be XLA-compiled (LSTMBlockCell needs to be # changed to implement snt.LSTMCell). initial_core_state_first = self._core_first.zero_state(tf.shape(actions)[1], tf.float32) initial_core_state_second = self._core_second.zero_state(tf.shape(actions)[1], tf.float32) core_state_first = tf.contrib.rnn.LSTMStateTuple(c=core_state[:, 0, :64], h=core_state[:, 1, :64]) core_state_second = tf.contrib.rnn.LSTMStateTuple(c=core_state[:, 0, 64:], h=core_state[:, 1, 64:]) core_output_list = [] for input_, d, act in zip(tf.unstack(torso_outputs), tf.unstack(done), tf.unstack(actions)): # First layer core_state_first = nest.map_structure(functools.partial(tf.where, d), initial_core_state_first, core_state_first) core_output_first, core_state_first = self._core_first(input_, core_state_first) # Second layer core_state_second = nest.map_structure(functools.partial(tf.where, d), initial_core_state_second, core_state_second) one_hot_last_action = tf.one_hot(act, self._num_actions) input_second = tf.concat([input_[:,:-1], core_output_first, one_hot_last_action], axis=1) core_output, core_state_second = self._core_second(input_second, core_state_second) core_output_list.append(core_output) core_state = tf.transpose(tf.concat([core_state_first, core_state_second], axis=-1), [1, 0, 2]) return snt.BatchApply(self._head)(tf.stack(core_output_list)), core_state
def _build(self, z_tm1, prior_rnn_hidden_state): """Applies the op. :param z_tm1: :param prior_rnn_hidden_state: :return: """ #latent variables from the step at time = t - 1 what_tm1, where_tm1, presence_tm1 = z_tm1[:3] #making input for the RNN by concat of latent where and what prior_rnn_inpt = tf.concat((what_tm1, where_tm1), -1) rnn = snt.BatchApply(self._cell) #running RNN and getting the weights and hidden states that we will pass through the # linear NN unit in order to get the values for parameters for propogation prior distribution outputs, prior_rnn_hidden_state = rnn(prior_rnn_inpt, prior_rnn_hidden_state) #specifying the number of output weights for Linear NN Unit n_outputs = 2 * (4 + self._n_what) + 1 #getting the parameters that we will use in order to #specify the parameters of propogation prior distributions for latent variables 'where', 'what' and 'presence' stats = snt.BatchApply(snt.Linear(n_outputs))(outputs) #splitting the outputs from Linear NN Unit into num_images * 1 vector for prop_prob_logit, #which are the parameters for Bernoulli prior distribution for latent 'presence' #and num_images * n_outputs - 1 vector for stats that will be used for # 'what' and 'where' latent variables prop_prob_logit, stats = tf.split(stats, [1, n_outputs - 1], -1) #updating parameters for Bernoulli prior distribution for latent 'presence' #by adding bias(some specidied hyperparameter) prop_prob_logit += self._prop_logit_bias #updating parameters for Bernoulli prior distribution for latent 'presence' #by applying sigma function prop_prob_logit = presence_tm1 * prop_prob_logit + (presence_tm1 - 1.) * 88. #splitting stats in order to get parameters (mean or locs, st deviation or scale) #for factorized Gaussian distribution for # latent variables 'where' and 'what' locs, scales = tf.split(stats, 2, -1) #splitting mean or loc parameter into # mean or loc for 'what' and 'where' latent variables separately prior_where_loc, prior_what_loc = tf.split(locs, [4, self._n_what], -1) #splitting scale or standard deviation parameter into #scale or standard deviation for 'what' and 'where' latent variables separately prior_where_scale, prior_what_scale = tf.split(scales, [4, self._n_what], -1) #making sure that standard deviation is positive and not equal to 0 prior_where_scale, prior_what_scale = (tf.nn.softplus(i) + 1e-2 for i in (prior_where_scale, prior_what_scale)) # adding bias for 'where' latent variable mean or loc parameter #for Gaussian distribution if there must exist one if self._where_loc_bias is not None: bias = np.asarray(self._where_loc_bias).reshape((1, 4)) prior_where_loc += bias #putting all parameters for propagation prior distribution together prior_stats = (prior_where_loc, prior_where_scale, prior_what_loc, prior_what_scale, prop_prob_logit) return prior_stats, prior_rnn_hidden_state
def unroll(self, actions, env_outputs, level_name): _, _, done, _ = env_outputs # TODO: Cleanup - remove BatchApply since it clutters the code in head. torso_outputs = snt.BatchApply(self._torso)( (actions, env_outputs, level_name)) return snt.BatchApply(self._head)((torso_outputs, level_name))
def unroll(self, actions, env_outputs, core_state): """Manual implementation of the network unroll.""" _, _, done, _ = env_outputs torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs)) tf.logging.info(torso_outputs) conv_outputs, actions_and_rewards, goals = torso_outputs # Note, in this implementation we can't use CuDNN RNN to speed things up due # to the state reset. This can be XLA-compiled (LSTMBlockCell needs to be # changed to implement snt.LSTMCell). initial_core_state = self.initial_state(tf.shape(actions)[1]) policy_input_list = [] heading_output_list = [] xy_output_list = [] target_xy_output_list = [] for torso_output_, action_and_reward_, goal_, done_ in zip( tf.unstack(conv_outputs), tf.unstack(actions_and_rewards), tf.unstack(goals), tf.unstack(done)): # If the episode ended, the core state should be reset before the next. core_state = nest.map_structure(functools.partial(tf.where, done_), initial_core_state, core_state) core_output, core_state = self._core( (torso_output_, action_and_reward_, goal_), core_state) policy_input_list.append(core_output[0]) heading_output_list.append(core_output[1]) xy_output_list.append(core_output[2]) target_xy_output_list.append(core_output[3]) head_output = snt.BatchApply(self._head)( tf.stack(policy_input_list), tf.stack(heading_output_list), tf.stack(xy_output_list), tf.stack(target_xy_output_list)) return head_output, core_state
def _build(self, output_size, inputs): # loop net -> [Time, Batch, hidden_size] net = build_common_network(inputs) # rnn output (-1, 1) # linear net net = snt.BatchApply(Linear('input_layer', 64))(net) net = swich(net) net = snt.BatchApply(Linear('output_layer', output_size))(net) return tf.nn.softmax(net) # [Time, Batch, output_size]
def _build(self, input_sequence, input_length): """Builds the deep LSTM model sub-graph. Args: input_sequence: A 3D Tensor with padded input sequence data. input_length. Actual length of each sequence in padded input data. Returns: Tuple of the Tensor of output logits for the batch, with dimensions `[truncation_length, batch_size, output_size]`, and the final state of the unrolled core,. """ batch_input_module = snt.BatchApply(self._input_module) output_sequence = batch_input_module(input_sequence) if not self._bidirectional: output_sequence, final_state = tf.nn.dynamic_rnn( cell=self._core, inputs=output_sequence, sequence_length=input_length, dtype=tf.float32) else: outputs = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( cells_fw=self._unpack_cell(self._core["fw"]), cells_bw=self._unpack_cell(self._core["bw"]), inputs=output_sequence, sequence_length=input_length, dtype=tf.float32) output_sequence, final_state_fw, final_state_bw = outputs final_state = (final_state_fw, final_state_bw) if self._rnn_output: output_sequence_logits, _ = tf.nn.dynamic_rnn( cell=self._output_module, inputs=output_sequence, sequence_length=input_length, dtype=tf.float32) elif self._cnn_output: if not self._bidirectional: for i in xrange(self._look_ahead - 1): output_sequence = tf.pad(output_sequence, [[0, 0], [0, 1], [0, 0]], "SYMMETRIC") else: for i in xrange(int((self._look_ahead - 1) / 2)): output_sequence = tf.pad(output_sequence, [[0, 0], [1, 1], [0, 0]], "SYMMETRIC") batch_output_module = snt.BatchApply(self._output_module["linear"]) output_sequence = batch_output_module(output_sequence) output_sequence_logits = tf.squeeze(self._output_module["cnn"]( tf.expand_dims(output_sequence, -1))) else: batch_output_module = snt.BatchApply(self._output_module) output_sequence_logits = batch_output_module(output_sequence) return output_sequence_logits, final_state
def _build(self, input_dict, num_outputs, options): if options.get("use_lstm"): cell_size = options.get("lstm_cell_size") self.state_init = ( np.zeros([cell_size], np.float32), ) self.state_in = ( tf.placeholder(tf.float32, [None, cell_size], name="state_in"), ) else: self.state_init = () self.state_in = () if self._imitation: obs_embed = ssbm_spaces.slippi_conv_list[0].embed(input_dict["obs"]) prev_actions = input_dict["prev_actions"] else: obs_embed = input_dict["obs"] prev_actions = tf.unstack(input_dict["prev_actions"], axis=-1) action_config = nest.flatten(ssbm_actions.repeated_simple_controller_config) prev_actions_embed = tf.concat([ conv.embed(action) for conv, action in zip(action_config, prev_actions)], -1) prev_rewards_embed = tf.expand_dims(input_dict["prev_rewards"], -1) inputs = tf.concat([obs_embed, prev_actions_embed, prev_rewards_embed], -1) trunk = snt.nets.MLP( output_sizes=options["fcnet_hiddens"], activation=getattr(tf.nn, options["fcnet_activation"]), activate_final=True) if not self._imitation: inputs = lstm.add_time_dimension(inputs, self.seq_lens) trunk_outputs = snt.BatchApply(trunk)(inputs) if options.get("use_lstm"): gru = snt.GRU(cell_size) core_outputs, state_out = tf.nn.dynamic_rnn( gru, trunk_outputs, initial_state=self.state_in[0], sequence_length=None if self._imitation else self.seq_lens, time_major=self._imitation) self.state_out = [state_out] else: core_outputs = trunk_outputs self.state_out = [] self._logit_head = snt.Linear(num_outputs, name="logit_head") logits = snt.BatchApply(self._logit_head)(core_outputs) self.values = tf.squeeze(snt.BatchApply(self._value_head)(core_outputs), -1) return logits, core_outputs
def _build(self, inputs): # loop net -> [Time, Batch, hidden_size] net = build_common_network(inputs) # range (-1, 1) # linear net net = snt.BatchApply(Linear('input_layer', 64))(net) net = swich(net) net = snt.BatchApply(Linear('output_layer', 1))(net) net = tf.squeeze(net, axis=2) # net = tf.nn.tanh(net) return tf.reduce_mean(net, axis=1) # [Time]
def unroll( self, inputs: observation_action_reward.OAR, state: snt.LSTMState, sequence_length: int, ) -> Tuple[QValues, snt.LSTMState]: """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" embeddings = snt.BatchApply(self._embed)(inputs) # [T, B, D+A+1] embeddings, new_state = snt.static_unroll(self._core, embeddings, state, sequence_length) action_values = snt.BatchApply(self._head)(embeddings) return action_values, new_state
def unroll(self, actions, env_outputs): _, _, done, _ = env_outputs shape = tf.shape(actions) # [T, B, d] torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs)) # Note, in this implementation we can't use CuDNN RNN to speed things up due # to the state reset. This can be XLA-compiled (LSTMBlockCell needs to be # changed to implement snt.LSTMCell). core_output_list = [] for input_, d in zip(tf.unstack(torso_outputs), tf.unstack(done)): # If the episode ended, the core state should be reset before the next. core_output_list.append(input_) return snt.BatchApply(self._head)( (tf.stack(core_output_list), actions))
def __init__(self): # placeholder self.inputs = tf.placeholder(tf.float32, INPUT_SHAPE, 'input') self.actions = tf.placeholder(tf.int32, [None, UNIVERSE_SIZE], 'action') self.rewards = tf.placeholder(tf.float32, [None], 'reward') self.targets = tf.placeholder(tf.float32, [None], 'targets') self.value_net = Forward("value") self.target_net = Forward("target") # Q value eval self.value_eval = snt.BatchApply(self.value_net, 2)(self.inputs) self.step_value_eval = tf.squeeze(self.value_eval, axis=0) # Q_ target eval next_value = tf.stop_gradient(self.value_eval) action_next = tf.one_hot(tf.argmax(next_value, axis=2), ACTION_SIZE, axis=2) target_eval = snt.BatchApply(self.target_net, 2)(self.inputs) target_eval = tf.reduce_sum(target_eval * action_next, axis=2) self.target_eval = tf.reduce_mean(target_eval, axis=1) # loss function action_choice = tf.one_hot(self.actions, ACTION_SIZE, axis=2) action_eval = tf.reduce_sum(self.value_eval * action_choice, axis=2) # warning action_eval = tf.reduce_mean(action_eval, axis=1) loss = tf.squared_difference(self.targets, action_eval) self._loss = tf.reduce_sum(loss) # train op trainable_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, 'value') grads, _ = tf.clip_by_global_norm( tf.gradients(self._loss, trainable_variables), MAX_GRAD_NORM) optimizer = tf.contrib.opt.NadamOptimizer() self._train_op = optimizer.apply_gradients( zip(grads, trainable_variables)) # update target net params eval_net_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "value") target_net_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "target") self._target_params_swap = \ [tf.assign(n, q) for n, q in zip(target_net_params, eval_net_params)] # cache for experience replay self.cache = deque(maxlen=MEMORY_SIZE)
def unroll(self, actions, env_outputs, core_state, dynamic_unroll=False, sequence_length=None): if len(env_outputs) == 4: rewards, _, done, observations = env_outputs else: rewards, done, observations = env_outputs if isinstance(actions, list): actions = tf.nest.pack_sequence_as(self._action_specs, actions) if isinstance(observations, list): observations = tf.nest.pack_sequence_as(self._observation_specs, observations) representation, craft_representation = snt.BatchApply(self._torso)( (actions, rewards, done, observations)) if dynamic_unroll: initial_core_state = self._core.zero_state( tf.shape(done)[1], tf.float32) core_output, core_state = tf.nn.dynamic_rnn( self._core, representation, sequence_length=sequence_length, initial_state=initial_core_state, time_major=True, scope='rnn') else: # to be compatible with dynamic_rnn unroll, we must enter the same variable scope here with tf.variable_scope('rnn'): initial_core_state = self._core.zero_state( tf.shape(done)[1], tf.float32) core_output_list = [] for input_, d in zip(tf.unstack(representation), tf.unstack(done)): # If the episode ended, the core state should be reset before the next. core_state = tf.nest.map_structure( functools.partial(tf.where, d), initial_core_state, core_state) core_output, core_state = self._core(input_, core_state) core_output_list.append(core_output) core_output = tf.stack(core_output_list) return snt.BatchApply(self._head)(core_output, craft_representation), core_state
def infer_z_posterior(self, x, y): """x should be rank 4 and y should be rank 2 or 3""" x_shape = util.int_shape(x) y_shape = util.int_shape(y) y_channel = tf.tile(tf.expand_dims(tf.expand_dims(y, 2), 2), [1, 1, x_shape[1], x_shape[2], 1]) z_encoder_input = tf.concat([ tf.tile(tf.expand_dims(x, 0), [y_shape[0], 1, 1, 1, 1]), y_channel ], axis=4) z_repr = snt.BatchApply(self._z_net)(z_encoder_input) z_repr = snt.BatchFlatten(preserve_dims=2)(z_repr) batch_loc = snt.BatchApply(self._loc) batch_scale = snt.BatchApply(self._scale) return self._z_posterior_fn(batch_loc(z_repr), batch_scale(z_repr))
def decoder(self, reprs, output_sizes=[40, 40, 2]): with tf.compat.v1.variable_scope("decoder"): # _repr_as_inputs = True: Representation is used as inputs to the predictor # _repr_as_inputs = False: Representation is used to generate weights of the predictor if self._repr_as_inputs: return reprs regularizer = tf.contrib.layers.l2_regularizer( self._l2_penalty_weight) initializer = tf.initializers.glorot_uniform( dtype=self._float_dtype) num_layers = len(output_sizes) output_sizes = [self.embedding_dim] + output_sizes # count number of parameters in the predictor num_params = 0 for i in range(num_layers): num_params += (output_sizes[i] + 1) * output_sizes[i + 1] # decode the representation into the weights of the predictor module = snt.nets.MLP( [self._nn_size] * self._nn_layers + [2 * num_params], activation=self._nonlinearity, use_bias=True, regularizers={"w": regularizer}, initializers={"w": initializer}, ) outputs = snt.BatchApply(module, n_dims=1)(reprs) return outputs
def icm_unroll(self, actions, env_outputs): _, _, done, _ = env_outputs conv_out = snt.BatchApply(self._curiosty)((actions, env_outputs)) enc_dim = conv_out.get_shape()[-1] one_hot_last_action = tf.one_hot(actions, self._num_actions) last_action = tf.reshape(one_hot_last_action, [-1, FLAGS.batch_size, self._num_actions]) # ICM module inverse model concat_conv_out = tf.reshape( tf.concat([conv_out[:-1], conv_out[1:]], axis=-1), [-1, enc_dim * 2]) fc_conv_out = snt.Linear(256, name='icm_inverse')(concat_conv_out) icm_inverse = tf.reshape( snt.Linear(self._num_actions, name='icm_logits')(fc_conv_out), [-1, FLAGS.batch_size, self._num_actions]) # ICM module forward model (eta=0.01) concat_conv_out = tf.reshape( tf.concat([conv_out[:-1], last_action[1:]], axis=-1), [-1, enc_dim + self._num_actions]) icm_forward = snt.Linear(256, name='icm_forward')(concat_conv_out) icm_forward = tf.reshape( tf.reduce_sum( (snt.Linear(enc_dim, name='icm_output')(icm_forward) - tf.reshape(conv_out[1:], [-1, enc_dim]))**2, axis=-1) * 0.005, [-1, FLAGS.batch_size]) return AgentOutputICM(icm_inverse, icm_forward)
def _read_inputs(self, inputs): """Applies transformations to `inputs` to get control for this module.""" def _linear(first_dim, second_dim, name, activation=None): """Returns a linear transformation of `inputs`, followed by a reshape.""" linear = snt.Linear(first_dim * second_dim, name=name)(inputs) if activation is not None: linear = activation(linear, name=name + '_activation') return tf.reshape(linear, [-1, first_dim, second_dim]) # v_t^i - The vectors to write to memory, for each write head `i`. # write_vectors = _linear(self._num_writes, self._word_size, 'write_vectors') # e_t^i - Amount to erase the memory by before writing, for each write head. #erase_vectors = _linear(self._num_writes, self._word_size, 'erase_vectors', # tf.sigmoid) # f_t^j - Amount that the memory at the locations read from at the previous # time step can be declared unused, for each read head `j`. #free_gate = tf.sigmoid( # snt.Linear(self._num_reads, name='free_gate')(inputs)) # g_t^{a, i} - Interpolation between writing to unallocated memory and # content-based lookup, for each write head `i`. Note: `a` is simply used to # identify this gate with allocation vs writing (as defined below). # allocation_gate = tf.sigmoid( # snt.Linear(self._num_writes, name='allocation_gate')(inputs)) # g_t^{w, i} - Overall gating of write amount for each write head. #write_gate = tf.sigmoid( # snt.Linear(self._num_writes, name='write_gate')(inputs)) # \pi_t^j - Mixing between "backwards" and "forwards" positions (for # each write head), and content-based lookup, for each read head. num_read_modes = 1 #+ 2 * self._num_writes read_mode = snt.BatchApply(tf.nn.softmax)(_linear(self._num_reads, num_read_modes, name='read_mode')) # Parameters for the (read / write) "weights by content matching" modules. #write_keys = _linear(self._num_writes, self._word_size, 'write_keys') #write_strengths = snt.Linear(self._num_writes, name='write_strengths')( # inputs) read_keys = _linear(self._num_reads, self._word_size, 'read_keys') read_strengths = snt.Linear(self._num_reads, name='read_strengths')(inputs) result = { 'read_content_keys': read_keys, 'read_content_strengths': read_strengths, #'write_content_keys': write_keys, #'write_content_strengths': write_strengths, #'write_vectors': write_vectors, #'erase_vectors': erase_vectors, #'free_gate': free_gate, #'allocation_gate': allocation_gate, #'write_gate': write_gate, 'read_mode': read_mode, } return result
def _step(self, trajectory: sequence.Trajectory): """Do a batch of SGD on the actor + critic loss.""" observations, actions, rewards, discounts = trajectory # Add dummy batch dimensions. rewards = tf.expand_dims(rewards, axis=-1) # [T, 1] discounts = tf.expand_dims(discounts, axis=-1) # [T, 1] observations = tf.expand_dims(observations, axis=1) # [T+1, 1, ...] # Extract final observation for bootstrapping. observations, final_observation = observations[:-1], observations[-1] with tf.GradientTape() as tape: # Build actor and critic losses. policies, values = snt.BatchApply(self._network)(observations) _, bootstrap_value = self._network(final_observation) critic_loss, (advantages, _) = trfl.td_lambda( state_values=values, rewards=rewards, pcontinues=self._discount * discounts, bootstrap_value=bootstrap_value, lambda_=self._td_lambda) advantages = tf.squeeze(advantages, axis=-1) # [T] actor_loss = -policies.log_prob(actions) * tf.stop_gradient( advantages) loss = tf.reduce_sum(actor_loss) + critic_loss gradients = tape.gradient(loss, self._network.trainable_variables) self._optimizer.apply(gradients, self._network.trainable_variables)
def _build(self, x, presence=None): n_dims = int(x.shape[-1]) y = self._self_attention(x, presence) if self._dropout_rate > 0.: x = tf.nn.dropout(x, rate=self._dropout_rate) y += x if presence is not None: y *= tf.expand_dims(tf.to_float(presence), -1) if self._layer_norm: y = snt.LayerNorm(axis=-1)(y) h = snt.BatchApply(snt.nets.MLP([2*n_dims, n_dims]))(y) if self._dropout_rate > 0.: h = tf.nn.dropout(h, rate=self._dropout_rate) h += y if self._layer_norm: h = snt.LayerNorm(axis=-1)(h) return h
def delta_matmul(w, delta): d = tf.transpose(delta, [0, 2, 1]) # [bs x delta_channels x n_units) d = snt.BatchApply( lambda x: tf.matmul(x, w, transpose_b=True))(d) d = tf.transpose(d, [0, 2, 1]) return d
def _fn(batch): """Compute the loss from the given batch.""" # Shape is [bs, seq len, features] inp = batch["text"] mask = batch["mask"] embed = snt.Embed(vocab_size=256, embed_dim=embed_dim) embedded_chars = embed(inp) rnn = core_fn() bs = inp.shape.as_list()[0] state = rnn.initial_state(bs, trainable=True) outputs, state = tf.nn.dynamic_rnn(rnn, embedded_chars, initial_state=state) pred_logits = snt.BatchApply(snt.Linear(256))(outputs[:, :-1]) actual_tokens = inp[:, 1:] flat_s = [ pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2] ] f_pred_logits = tf.reshape(pred_logits, flat_s) f_actual_tokens = tf.reshape(actual_tokens, [flat_s[0]]) f_mask = tf.reshape(mask[:, 1:], [flat_s[0]]) loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=f_actual_tokens, logits=f_pred_logits) total_loss = tf.reduce_sum(f_mask * loss_vec) mean_loss = total_loss / tf.reduce_sum(f_mask) return mean_loss
def _build(batch): """Build the sonnet module.""" rnn = core_fn() initial_state = rnn.initial_state(batch["input"].shape[0]) outputs, _ = tf.nn.dynamic_rnn(rnn, batch["input"], initial_state=initial_state, dtype=tf.float32, time_major=False) pred_logits = snt.BatchApply(snt.Linear( batch["output"].shape[2]))(outputs) flat_shape = [ pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2] ] flat_pred_logits = tf.reshape(pred_logits, flat_shape) flat_actual_tokens = tf.reshape(batch["output"], flat_shape) flat_mask = tf.reshape(batch["loss_mask"], [flat_shape[0]]) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=flat_actual_tokens, logits=flat_pred_logits) total_loss = tf.reduce_sum(flat_mask * loss_vec) mean_loss = total_loss / tf.reduce_sum(flat_mask) return mean_loss
def unroll(self, actions, env_outputs, core_state): _, _, done, _ = env_outputs torso_outputs = snt.BatchApply(self._torso)((actions, env_outputs)) # Note, in this implementation we can't use CuDNN RNN to speed things up due # to the state reset. This can be XLA-compiled (LSTMBlockCell needs to be # changed to implement snt.LSTMCell). initial_core_state = self._core.zero_state(tf.shape(actions)[1], tf.float32) core_output_list = [] for input_, d in zip(tf.unstack(torso_outputs), tf.unstack(done)): # If the episode ended, the core state should be reset before the next. core_state = nest.map_structure(functools.partial(tf.where, d), initial_core_state, core_state) core_output, core_state = self._core(input_, core_state) core_output_list.append(core_output) return snt.BatchApply(self._head)(tf.stack(core_output_list)), core_state
def _decode_images(self, input_data, n_frames_input, is_training, params, encoded, predicted, skips, input_images, predict_images): decoded = dict() # Decoding phase with tf.name_scope("image_decoder"): if self._has_image_input: pred_enc = tf.expand_dims(tf.expand_dims(predicted["seq_future"], axis=-1), axis=-1) decoded_frames = self._build_image_decoder( pred_enc, skips["predict"], is_training, decoder_phase="future", last_input_frame=input_images[-1], use_recursive_image=self._use_recursive_image) else: pred_coord = snt.BatchApply(self.conv_decoder)(predicted["seq_future"], is_training) decoded_frames = tf.py_func(self._render_fcn, [pred_coord], tf.float32) render_shape = shape(pred_coord)[:2] + self._render_shape decoded_frames = tf.reshape(decoded_frames, render_shape) decoded["pred_frames"] = decoded_frames decoded["pred_coords"] = pred_coord if not self._has_image_input else None return decoded
def relation_network(self, inputs): with tf.variable_scope("relation_network"): regularizer = tf.contrib.layers.l2_regularizer( self._l2_penalty_weight) initializer = tf.initializers.glorot_uniform( dtype=self._float_dtype) relation_network_module = snt.nets.MLP( [2 * self._num_latents] * 3, use_bias=False, regularizers={"w": regularizer}, initializers={"w": initializer}, ) total_num_examples = self.num_examples_per_class * self.num_classes inputs = tf.reshape(inputs, [total_num_examples, self._num_latents]) left = tf.tile(tf.expand_dims(inputs, 1), [1, total_num_examples, 1]) right = tf.tile(tf.expand_dims(inputs, 0), [total_num_examples, 1, 1]) concat_codes = tf.concat([left, right], axis=-1) outputs = snt.BatchApply(relation_network_module)(concat_codes) outputs = tf.reduce_mean(outputs, axis=1) # 2 * latents, because we are returning means and variances of a Gaussian outputs = tf.reshape(outputs, [ self.num_classes, self.num_examples_per_class, 2 * self._num_latents ]) return outputs
def testValues(self): batch_size = 5 num_heads = 3 memory_size = 7 activations_data = np.random.randn(batch_size, num_heads, memory_size) weights_data = np.ones((batch_size, num_heads)) activations = tf.placeholder(tf.float32, [batch_size, num_heads, memory_size]) weights = tf.placeholder(tf.float32, [batch_size, num_heads]) # Run weighted softmax with identity placed on weights. Output should be # equal to a standalone softmax. observed = addressing.weighted_softmax(activations, weights, tf.identity) expected = snt.BatchApply(module_or_op=tf.nn.softmax, name='BatchSoftmax')(activations) with self.test_session() as sess: observed = sess.run(observed, feed_dict={ activations: activations_data, weights: weights_data }) expected = sess.run(expected, feed_dict={activations: activations_data}) self.assertAllClose(observed, expected)