def lstm_model(inputs, reuse=False, num_units=(64, 32), scope="l"): debug = False if debug: import time t = time.time() observation_n = [] for i in range(len(inputs)): x = inputs[i] if not reuse and i == 0: reuse = False else: reuse = True with tf.variable_scope(scope, reuse=reuse): x = tf.transpose(x, (2, 0, 1)) # (time_steps, batch_size, state_size) cells = [ rnn.LSTMCell(lstm_size, forget_bias=1, state_is_tuple=True) for lstm_size in num_units ] cell = rnn.MultiRNNCell(cells, state_is_tuple=True) with tf.variable_scope("Multi_Layer_RNN"): cell_outputs, states = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32) outputs = cell_outputs[-1:, :, :] outputs = tf.squeeze(outputs, 0) observation_n.append(outputs) if debug: print("lstm time: ", time.time() - t) return observation_n
def baseline_encoder(inputs, input_lengths, hidden_cells=128, layers=1, code_dim=128): """ The baseline encoder for our variational model is is a unidirectional LSTM whose final output parameterises the mean and std of a normal over our latent space. Outputs an edward Normal. """ with tf.variable_scope('encoder'): # turn the inputs from [batch, max_len] into [batch, max_len, 16] with tf.variable_scope('inputs'): binary_input = _integer_to_binary(inputs, 16) # project the inputs projected_inputs = project_sequence(binary_input, 128) # run an RNN over the lot cells = [rnn_cell.LSTMCell(hidden_cells) for _ in range(layers)] if layers > 1: cell = rnn_cell.MultiRNNCell(cells) else: cell = cells[0] outputs, _ = tf.nn.dynamic_rnn(cell, projected_inputs, input_lengths) final_output = outputs[:, -1, :] loc = tf.layers.dense(outputs, code_dim, name='code_mean') scale = tf.layers.dense( outputs, code_dim, activation=tf.nn.softplus, 'code_std') return ed.models.Normal(loc=loc, scale=scale)
def single_cell_fn(unit_type, num_units, dropout, mode, forget_bias=1.0): """Create an instance of a single RNN cell.""" dropout = dropout if mode is True else 0.0 if unit_type == "lstm": c = rnn_cell.LSTMCell(num_units, forget_bias=forget_bias, state_is_tuple=False) elif unit_type == "gru": c = rnn_cell.GRUCell(num_units) else: raise ValueError("Unknown unit type %s!" % unit_type) if dropout > 0.0: c = rnn_cell.DropoutWrapper(cell=c, input_keep_prob=(1.0 - dropout)) return c
def lstm_model(common_obs, inputs, reuse=tf.AUTO_REUSE, num_units=(64, 32), scope="l"): # inputs.shape: [batch_size, time_steps, agents_number, shape] observation_n = [] for i in range(len(inputs)): x = tf.transpose(inputs[i], (1, 0, 2)) x = tf.concat((common_obs, x), 2) with tf.variable_scope(scope, reuse=reuse): cells = [rnn.LSTMCell(lstm_size, forget_bias=1, state_is_tuple=True) for lstm_size in num_units] cell = rnn.MultiRNNCell(cells, state_is_tuple=True) with tf.variable_scope("Multi_Layer_RNN"): cell_outputs, states = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32) outputs = cell_outputs[-1:, :, :] outputs = tf.squeeze(outputs, 0) observation_n.append(outputs) return observation_n
def lstm(x, batch_size): output_size = 10 lstm_size = 12 # hidden state and output size x = tf.transpose(x, (1, 0, 2)) ## (time_steps, batch_size, state_size) lstm = rnn.LSTMCell(lstm_size, forget_bias=1, state_is_tuple=True) outputs, states = tf.nn.dynamic_rnn(lstm, x, dtype=tf.float32, time_major=True) print("hhhhhhhhhhhhhhhhhhhh") print(outputs.shape) out = tf.convert_to_tensor(outputs[-1:, :, :]) out = tf.squeeze(out, 0) return tf.layers.dense(out, output_size, activation=tf.nn.relu, use_bias=True)
def model(input_placeholder, weights, biases): # reshape to [1, n_input] input_placeholder = tf.reshape(input_placeholder, [-1, _ ]) # Generate a n_input-element sequence of inputs # (eg. [had] [a] [general] -> [20] [6] [33]) input_placeholder = tf.split(input_placeholder, n_input,1) # 1-layer LSTM with n_hidden units but with lower accuracy. # Average Accuracy= 90.60% 50k iter cell = rnn_cell.LSTMCell (n_hidden, reuse=tf.AUTO_REUSE) # 2-layer LSTM, each layer has n_hidden units. # Average Accuracy= 95.20% at 50k iter # cell = rnn.MultiRNNCell([rnn_cell.LSTMCell(n_hidden), rnn_cell.LSTMCell(n_hidden)]) # generate prediction outputs, states = rnn.static_rnn(cell, input_placeholder, dtype=tf.float32) # there are n_input outputs but # we only want the last output return _
def lstm_cell(self, dropout_keep_prob): cell = rnn.LSTMCell(self.num_lstm_units) return rnn.DropoutWrapper(cell, ouput_keep_prob=dropout_keep_prob)
def __init__(self, args, infer=False): """ Initialisation function for the class Model. Params: args: Contains arguments required for the Model creation """ # If sampling new trajectories, then infer mode if infer: # Infer one position at a time args.batch_size = 1 args.obs_length = 1 args.pred_length = 1 # Store the arguments self.args = args # placeholders for the input data and the target data # A sequence contains an ordered set of consecutive frames # Each frame can contain a maximum of 'args.maxNumPeds' number of peds # For each ped we have their (pedID, x, y) positions as input self.input_data = tf.placeholder(tf.float32, [args.obs_length, args.maxNumPeds, 3], name="input_data") # target data would be the same format as input_data except with one time-step ahead self.target_data = tf.placeholder( tf.float32, [args.obs_length, args.maxNumPeds, 3], name="target_data") # Learning rate self.lr = tf.placeholder(tf.float32, shape=None, name="learning_rate") self.final_lr = tf.placeholder(tf.float32, shape=None, name="final_learning_rate") self.training_epoch = tf.placeholder(tf.float32, shape=None, name="training_epoch") # keep prob self.keep_prob = tf.placeholder(tf.float32, name='keep_prob') cells = [] for _ in range(args.num_layers): # Initialize a BasicLSTMCell recurrent unit # args.rnn_size contains the dimension of the hidden state of the LSTM # cell = rnn_cell.BasicLSTMCell(args.rnn_size, name='basic_lstm_cell', state_is_tuple=False) # Construct the basicLSTMCell recurrent unit with a dimension given by args.rnn_size if args.model == "lstm": with tf.name_scope("LSTM_cell"): cell = rnn_cell.LSTMCell(args.rnn_size, state_is_tuple=False) elif args.model == "gru": with tf.name_scope("GRU_cell"): cell = rnn_cell.GRUCell(args.rnn_size, state_is_tuple=False) if not infer and args.keep_prob < 1: cell = rnn_cell.DropoutWrapper(cell, output_keep_prob=self.keep_prob) cells.append(cell) # Multi-layer RNN construction, if more than one layer # cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=False) cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=False) # Store the recurrent unit self.cell = cell # Output size is the set of parameters (mu, sigma, corr) self.output_size = 5 # 2 mu, 2 sigma and 1 corr with tf.name_scope("learning_rate"): self.final_lr = self.lr * (self.args.decay_rate** self.training_epoch) self.define_embedding_and_output_layers(args) # Define LSTM states for each pedestrian with tf.variable_scope("LSTM_states"): self.LSTM_states = tf.zeros( [args.maxNumPeds, self.cell.state_size], name="LSTM_states") self.initial_states = tf.split(self.LSTM_states, args.maxNumPeds, 0) # https://stackoverflow.com/a/41384913/2049763 # Define hidden output states for each pedestrian with tf.variable_scope("Hidden_states"): self.output_states = tf.split( tf.zeros([args.maxNumPeds, self.cell.output_size]), args.maxNumPeds, 0) # List of tensors each of shape args.maxNumPeds x 3 corresponding to each frame in the sequence with tf.name_scope("frame_data_tensors"): frame_data = [ tf.squeeze(input_, [0]) for input_ in tf.split(self.input_data, args.obs_length, 0) ] with tf.name_scope("frame_target_data_tensors"): frame_target_data = [ tf.squeeze(target_, [0]) for target_ in tf.split(self.target_data, args.obs_length, 0) ] # Cost with tf.name_scope("Cost_related_stuff"): self.cost = tf.constant(0.0, name="cost") self.counter = tf.constant(0.0, name="counter") self.increment = tf.constant(1.0, name="increment") # Containers to store output distribution parameters with tf.name_scope("Distribution_parameters_stuff"): self.initial_output = tf.split( tf.zeros([args.maxNumPeds, self.output_size]), args.maxNumPeds, 0) # Tensor to represent non-existent ped with tf.name_scope("Non_existent_ped_stuff"): nonexistent_ped = tf.constant(0.0, name="zero_ped") self.final_result = [] # Iterate over each frame in the sequence for seq, frame in enumerate(frame_data): # print("Frame number", seq) final_result_ped = [] current_frame_data = frame # MNP x 3 tensor for ped in range(args.maxNumPeds): # pedID of the current pedestrian pedID = current_frame_data[ped, 0] # print("Pedestrian Number", ped) with tf.name_scope("extract_input_ped"): # Extract x and y positions of the current ped self.spatial_input = tf.slice( current_frame_data, [ped, 1], [1, 2]) # Tensor of shape (1,2) with tf.name_scope("embeddings_operations"): # Embed the spatial input embedded_spatial_input = tf.nn.relu( tf.nn.xw_plus_b(self.spatial_input, self.embedding_w, self.embedding_b)) # One step of LSTM with tf.variable_scope("LSTM") as scope: if seq > 0 or ped > 0: scope.reuse_variables() self.output_states[ped], self.initial_states[ ped] = self.cell(embedded_spatial_input, self.initial_states[ped]) # Apply the linear layer. Output would be a tensor of shape 1 x output_size with tf.name_scope("output_linear_layer"): self.initial_output[ped] = tf.nn.xw_plus_b( self.output_states[ped], self.output_w, self.output_b) with tf.name_scope("extract_target_ped"): # Extract x and y coordinates of the target data # x_data and y_data would be tensors of shape 1 x 1 [x_data, y_data] = tf.split( tf.slice(frame_target_data[seq], [ped, 1], [1, 2]), 2, 1) target_pedID = frame_target_data[seq][ped, 0] with tf.name_scope("get_coef"): # Extract coef from output of the linear output layer [o_mux, o_muy, o_sx, o_sy, o_corr] = self.get_coef(self.initial_output[ped]) final_result_ped.append([o_mux, o_muy, o_sx, o_sy, o_corr]) # Calculate loss for the current ped with tf.name_scope("calculate_loss"): lossfunc = self.get_lossfunc(o_mux, o_muy, o_sx, o_sy, o_corr, x_data, y_data) # If it is a non-existent ped, it should not contribute to cost # If the ped doesn't exist in the next frame, he/she should not contribute to cost as well with tf.name_scope("increment_cost"): self.cost = tf.where( tf.logical_or(tf.equal(pedID, nonexistent_ped), tf.equal(target_pedID, nonexistent_ped)), self.cost, tf.add(self.cost, lossfunc)) self.counter = tf.where( tf.logical_or(tf.equal(pedID, nonexistent_ped), tf.equal(target_pedID, nonexistent_ped)), self.counter, tf.add(self.counter, self.increment)) self.final_result.append(tf.stack(final_result_ped)) # Compute the cost with tf.name_scope("mean_cost"): # Mean of the cost self.cost = tf.div(self.cost, self.counter) # Get trainable_variables tvars = tf.trainable_variables() # L2 loss l2 = args.lambda_param * sum(tf.nn.l2_loss(tvar) for tvar in tvars) self.cost = self.cost + l2 # Get the final LSTM states self.final_states = tf.concat(self.initial_states, 0) # Get the final distribution parameters self.final_output = self.initial_output # initialize the optimizer with the given learning rate if args.optimizer == "RMSprop": optimizer = tf.train.RMSPropOptimizer(learning_rate=self.final_lr, momentum=0.9) elif args.optimizer == "AdamOpt": # NOTE: Using RMSprop as suggested by Social LSTM instead of Adam as Graves(2013) does optimizer = tf.train.AdamOptimizer(self.final_lr) # How to apply gradient clipping in TensorFlow? https://stackoverflow.com/a/43486487/2049763 # # https://stackoverflow.com/a/40540396/2049763 # TODO: (resolve) We are clipping the gradients as is usually done in LSTM # implementations. Social LSTM paper doesn't mention about this at all # Calculate gradients of the cost w.r.t all the trainable variables self.gradients = tf.gradients(self.cost, tvars) # self.gradients = optimizer.compute_gradients(self.cost, var_list=tvars) # Clip the gradients if they are larger than the value given in args self.clipped_gradients, _ = tf.clip_by_global_norm( self.gradients, args.grad_clip) # Train operator self.train_op = optimizer.apply_gradients( zip(self.clipped_gradients, tvars)) self.grad_placeholders = [] for var in tvars: self.grad_placeholders.append(tf.placeholder(var.dtype, var.shape)) # Train operator self.train_op_2 = optimizer.apply_gradients( zip(self.grad_placeholders, tvars))
def build_model(self, max_length, vocabulary_size, embedding_size, num_hidden, num_layers, num_classes, learn_rate): self.__graph = tf.Graph() with self.__graph.as_default(): # to track progress self.__global_step = tf.Variable(0, trainable=False) # input parameters self.__dropout = tf.placeholder(tf.float32) self.__data = tf.placeholder(tf.int32, shape=[self.__batch_size, max_length], name="data") self.__labels = tf.placeholder( tf.float32, shape=[self.__batch_size, num_classes], name="labels") self.__seq_length = tf.placeholder(tf.int32, shape=[self.__batch_size], name="seqlength") # LSTM definition network = rnn_cell.LSTMCell(num_hidden, embedding_size) network = rnn_cell.DropoutWrapper(network, output_keep_prob=self.__dropout) network = rnn_cell.MultiRNNCell([network] * num_layers) # loaded value from word2vec self.__embeddings_lstm = tf.Variable(tf.random_uniform( [vocabulary_size, embedding_size], -1.0, 1.0), name="embeddings_lstm") # get the word vectors learned previously embed = tf.nn.embedding_lookup(self.__embeddings_lstm, self.__data) #try: outputs, states = tf.contrib.rnn.static_rnn( network, self.__unpack_sequence(embed), dtype=tf.float32, sequence_length=self.__seq_length) #except AttributeError: #outputs, states = rnn(network, unpack_sequence(embed), dtype=tf.float32, sequence_length=seq_length) # Compute an average of all the outputs # FOR VARIABLE SEQUENCE LENGTHS # place the entire sequence into one big tensor using tf_pack. packed_op = tf.stack(outputs) # reduce sum the 0th dimension, which is the number of timesteps. summed_op = tf.reduce_sum(packed_op, reduction_indices=0) # Then, divide by the seq_length input - this is an np array of size # batch_size that stores the length of each sequence. # With any luck, this gives the output results. averaged_op = tf.div(summed_op, tf.cast(self.__seq_length, tf.float32)) # output classifier # TODO perhaps put this within a context manager softmax_weight = tf.Variable( tf.truncated_normal([num_hidden, num_classes], stddev=0.1)) softmax_bias = tf.Variable(tf.constant(0.1, shape=[num_classes])) temp = tf.matmul(averaged_op, softmax_weight) + softmax_bias prediction = tf.nn.softmax(temp) predict_output = tf.argmax(prediction, 1) tf.summary.histogram("prediction", prediction) self.__prin = tf.Print(prediction, [prediction], message="pred is ") # standard cross entropy loss self.__loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=temp, labels=self.__labels)) tf.summary.scalar("loss-xentropy", self.__loss) self.__optimizer = tf.train.AdamOptimizer(learn_rate).minimize( self.__loss, global_step=self.__global_step) # examine performance with tf.variable_scope("accuracy"): correct_predictions = tf.equal(tf.argmax(prediction, 1), tf.argmax(self.__labels, 1)) self.__accuracy = tf.reduce_mean(tf.cast( correct_predictions, tf.float32), name="accuracy") tf.summary.scalar("accuracy", self.__accuracy) self.__merged = tf.summary.merge_all() self.__train_writer = tf.summary.FileWriter( os.getcwd() + '/train', self.__graph) self.__test_writer = tf.summary.FileWriter(os.getcwd() + '/test')
def __init__(self, rnn_size, rnn_layer, batch_size, input_embedding_size, dim_image, dim_hidden, max_words_q, vocabulary_size, max_words_d, vocabulary_size_d, drop_out_rate, num_answers, variation, offline_text): self.rnn_size = rnn_size self.rnn_layer = rnn_layer self.batch_size = batch_size self.input_embedding_size = input_embedding_size self.dim_image = dim_image self.dim_hidden = dim_hidden self.max_words_q = max_words_q self.vocabulary_size = vocabulary_size self.max_words_d = max_words_d self.vocabulary_size_d = vocabulary_size_d self.drop_out_rate = drop_out_rate self.num_answers = num_answers self.variation = variation self.offline_text = offline_text # question-embedding self.embed_ques_W = tf.Variable(tf.random.uniform( [self.vocabulary_size, self.input_embedding_size], -0.08, 0.08), name='embed_ques_W') # encoder: RNN body self.lstm_1 = rnn_cell.LSTMCell(rnn_size, input_embedding_size, state_is_tuple=False) self.lstm_dropout_1 = rnn_cell.DropoutWrapper(self.lstm_1, output_keep_prob=1 - self.drop_out_rate) self.lstm_2 = rnn_cell.LSTMCell(rnn_size, rnn_size, state_is_tuple=False) self.lstm_dropout_2 = rnn_cell.DropoutWrapper(self.lstm_2, output_keep_prob=1 - self.drop_out_rate) self.stacked_lstm = rnn_cell.MultiRNNCell( [self.lstm_dropout_1, self.lstm_dropout_2]) # state-embedding self.embed_state_W = tf.Variable(tf.random.uniform( [2 * rnn_size * rnn_layer, self.dim_hidden], -0.08, 0.08), name='embed_state_W') self.embed_state_b = tf.Variable(tf.random.uniform([self.dim_hidden], -0.08, 0.08), name='embed_state_b') if self.variation in ['isq', 'sq']: if self.offline_text == "False": # print("\n\n\noffline_text false\n\n\n") # description-embedding self.embed_desc_W = tf.Variable(tf.random.uniform( [self.vocabulary_size_d, self.input_embedding_size], -0.08, 0.08), name='embed_desc_W') # encoder: RNN body self.lstm_1_d = rnn_cell.LSTMCell(rnn_size, input_embedding_size, state_is_tuple=False) self.lstm_dropout_1_d = rnn_cell.DropoutWrapper( self.lstm_1_d, output_keep_prob=1 - self.drop_out_rate) self.lstm_2_d = rnn_cell.LSTMCell(rnn_size, rnn_size, state_is_tuple=False) self.lstm_dropout_2_d = rnn_cell.DropoutWrapper( self.lstm_2_d, output_keep_prob=1 - self.drop_out_rate) self.stacked_lstm_d = rnn_cell.MultiRNNCell( [self.lstm_dropout_1_d, self.lstm_dropout_2_d]) # description state-embedding self.embed_state_desc_W = tf.Variable( tf.random.uniform( [2 * rnn_size * rnn_layer, self.dim_hidden], -0.08, 0.08), name='embed_state_desc_W') elif self.offline_text == "True": # print("\n\n\noffline_text true\n\n\n") self.embed_state_desc_W = tf.Variable( tf.random.uniform([self.dim_hidden, self.dim_hidden], -0.08, 0.08), name='embed_state_desc_W') self.embed_state_desc_b = tf.Variable(tf.random.uniform( [self.dim_hidden], -0.08, 0.08), name='embed_state_desc_b') if self.variation in ['isq', 'iq']: # image-embedding 1 self.embed_image_W = tf.Variable(tf.random.uniform( [dim_image, self.dim_hidden], -0.08, 0.08), name='embed_image_W') self.embed_image_b = tf.Variable(tf.random.uniform([dim_hidden], -0.08, 0.08), name='embed_image_b') # my code # image-embedding 2 self.embed_image2_W = tf.Variable(tf.random.uniform( [dim_image, self.dim_hidden], -0.08, 0.08), name='embed_image2_W') self.embed_image2_b = tf.Variable(tf.random.uniform([dim_hidden], -0.08, 0.08), name='embed_image2_b') # image-embedding 3 self.embed_image3_W = tf.Variable(tf.random.uniform( [dim_image, self.dim_hidden], -0.08, 0.08), name='embed_image3_W') self.embed_image3_b = tf.Variable(tf.random.uniform([dim_hidden], -0.08, 0.08), name='embed_image3_b') # image-embedding 4 self.embed_image4_W = tf.Variable(tf.random.uniform( [dim_image, self.dim_hidden], -0.08, 0.08), name='embed_image4_W') self.embed_image4_b = tf.Variable(tf.random.uniform([dim_hidden], -0.08, 0.08), name='embed_image4_b') # image-embedding 5 self.embed_image5_W = tf.Variable(tf.random.uniform( [dim_image, self.dim_hidden], -0.08, 0.08), name='embed_image5_W') self.embed_image5_b = tf.Variable(tf.random.uniform([dim_hidden], -0.08, 0.08), name='embed_image5_b') # options-embedding self.embed_options_W = tf.Variable(tf.random.uniform( [self.num_answers, options_embedding_size], -0.1, 0.1), name='embed_options_W') # print("\n\nself.embed_options_W: {}\n\n".format(self.embed_options_W)) # self.lstm_o = rnn_cell.LSTMCell(64, state_is_tuple=False, reuse=tf.AUTO_REUSE) # self.lstm_1_o = rnn_cell.LSTMCell(rnn_size, input_embedding_size, state_is_tuple=False) # self.lstm_dropout_1_o = rnn_cell.DropoutWrapper(self.lstm_1_o, output_keep_prob = 1 - self.drop_out_rate) # self.lstm_2_o = rnn_cell.LSTMCell(rnn_size, rnn_size, state_is_tuple=False) # self.lstm_dropout_2_o = rnn_cell.DropoutWrapper(self.lstm_2_o, output_keep_prob = 1 - self.drop_out_rate) # self.stacked_lstm_o = rnn_cell.MultiRNNCell([self.lstm_dropout_1_o, self.lstm_dropout_2_o]) # options state-embedding # self.embed_options_state_W = tf.Variable(tf.random.uniform([2*rnn_size*rnn_layer, self.dim_hidden], -0.08,0.08),name='embed_options_state_W') # self.embed_options_b = tf.Variable(tf.random.uniform([self.dim_hidden], -0.08, 0.08), name='embed_options_b') # end my code # score-embedding for 5-way self.embed_score_W = tf.Variable(tf.random.uniform( [dim_hidden, options_embedding_size], -0.08, 0.08), name='embed_score_W') self.embed_h_b = tf.Variable(tf.random.uniform( [options_embedding_size], -0.08, 0.08), name='embed_h_b') self.embed_score_b = tf.Variable(tf.random.uniform([num_output], -0.08, 0.08), name='embed_score_b')
def get_policy(self, role_id, state, last_cards, lstm_state): # policy network, different for three agents import pdb pdb.set_trace() batch_size = tf.shape(role_id)[0] gathered_outputs = [] indices = [] # train landlord only for idx in range(1, 4): with tf.variable_scope('policy_network_%d' % idx): lstm = rnn.LSTMCell(1024, state_is_tuple=False) id_idx = tf.where(tf.equal(role_id, idx)) indices.append(id_idx) state_id = tf.gather_nd(state, id_idx) last_cards_id = tf.gather_nd(last_cards, id_idx) lstm_state_id = tf.gather_nd(lstm_state, id_idx) with slim.arg_scope([slim.fully_connected, slim.conv2d], weights_regularizer=slim.l2_regularizer( POLICY_WEIGHT_DECAY)): with tf.variable_scope('branch_main'): policy_blocks = [[128, 3, 'identity'], [128, 3, 'identity'], [128, 3, 'downsampling'], [128, 3, 'identity'], [128, 3, 'identity'], [256, 3, 'downsampling'], [256, 3, 'identity'], [256, 3, 'identity']] flattened_1 = policy_conv_block( state_id[:, :60], 32, POLICY_INPUT_DIM // 3, policy_blocks, 'branch_main1') flattened_2 = policy_conv_block( state_id[:, 60:120], 32, POLICY_INPUT_DIM // 3, policy_blocks, 'branch_main2') flattened_3 = policy_conv_block( state_id[:, 120:], 32, POLICY_INPUT_DIM // 3, policy_blocks, 'branch_main3') flattened = tf.concat( [flattened_1, flattened_2, flattened_3], axis=1) fc, new_lstm_state = lstm(flattened, lstm_state_id) active_fc = slim.fully_connected(fc, 1024) active_logits = slim.fully_connected(active_fc, len(action_space), activation_fn=None, scope='final_fc') with tf.variable_scope('branch_passive'): flattened_last = policy_conv_block( last_cards_id, 32, POLICY_LAST_INPUT_DIM, [[128, 3, 'identity'], [128, 3, 'identity'], [128, 3, 'downsampling'], [128, 3, 'identity'], [128, 3, 'identity'], [256, 3, 'downsampling'], [256, 3, 'identity'], [256, 3, 'identity']], 'last_cards') passive_attention = slim.fully_connected( inputs=flattened_last, num_outputs=1024, activation_fn=tf.nn.sigmoid) passive_fc = passive_attention * active_fc passive_logits = slim.fully_connected(passive_fc, len(action_space), activation_fn=None, reuse=True, scope='final_fc') gathered_output = [active_logits, passive_logits, new_lstm_state] if idx not in ROLE_IDS_TO_TRAIN: for k in range(len(gathered_output)): gathered_output[k] = tf.stop_gradient(gathered_output[k]) gathered_outputs.append(gathered_output) # 3: B * ? outputs = [] for i in range(3): scatter_shape = tf.cast(tf.stack( [batch_size, gathered_outputs[0][i].shape[1]]), dtype=tf.int64) # scatter_shape = tf.Print(scatter_shape, [tf.shape(scatter_shape)]) outputs.append( tf.add_n([ tf.scatter_nd(indices[k], gathered_outputs[k][i], scatter_shape) for k in range(3) ])) return outputs