def test_take_action_gives_correct_next_states_with_lambda_productions( self): # state.take_action() doesn't read or change these objects, it just passes them through, so # we'll use some sentinels to be sure of that. valid_actions = object() context_actions = object() state = GrammarStatelet(['t', '<s,d>'], {}, valid_actions, context_actions, is_nonterminal) next_state = state.take_action('<s,d> -> [lambda x, d]') expected_next_state = GrammarStatelet(['t', 'd'], {('s', 'x'): ['d']}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('d -> [<s,r>, d]') expected_next_state = GrammarStatelet(['t', 'd', '<s,r>'], {('s', 'x'): ['d', '<s,r>']}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('<s,r> -> [lambda y, r]') expected_next_state = GrammarStatelet(['t', 'd', 'r'], { ('s', 'x'): ['d', 'r'], ('s', 'y'): ['r'] }, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('r -> identity') expected_next_state = GrammarStatelet(['t', 'd'], {('s', 'x'): ['d']}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__ state = expected_next_state next_state = state.take_action('d -> x') expected_next_state = GrammarStatelet(['t'], {}, valid_actions, context_actions, is_nonterminal) assert next_state.__dict__ == expected_next_state.__dict__
def test_take_action_crashes_with_mismatched_types(self): with pytest.raises(AssertionError): state = GrammarStatelet(['s'], {}, is_nonterminal) state.take_action('t -> identity')
def test_is_finished_just_uses_nonterminal_stack(self): state = GrammarStatelet(['s'], {}, is_nonterminal) assert not state.is_finished() state = GrammarStatelet([], {}, is_nonterminal) assert state.is_finished()
def _create_grammar_state( self, possible_actions: List[ProductionRule]) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. """ device = util.get_device_of(self._action_embedder.weight) # TODO(Mark): This type is pure \(- . ^)/ translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} actions_grouped_by_nonterminal: Dict[str, List[Tuple[ ProductionRule, int]]] = defaultdict(list) for i, action in enumerate(possible_actions): if action.rule == "": continue if action.is_global_rule: actions_grouped_by_nonterminal[action.nonterminal].append( (action, i)) else: raise ValueError( "The sql parser doesn't support non-global actions yet.") for key, production_rule_arrays in actions_grouped_by_nonterminal.items( ): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. global_actions = [] for production_rule_array, action_index in production_rule_arrays: global_actions.append( (production_rule_array.rule_id, action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0).long() if device >= 0: global_action_tensor = global_action_tensor.to(device) global_input_embeddings = self._action_embedder( global_action_tensor) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal, reverse_productions=True)
def setUp(self): super().setUp() self.decoder_step = BasicTransitionFunction( encoder_output_dim=2, action_embedding_dim=2, input_attention=Attention.by_name("dot_product")(), add_action_bias=False, ) batch_indices = [0, 1, 0] action_history = [[1], [3, 4], []] score = [torch.FloatTensor([x]) for x in [0.1, 1.1, 2.2]] hidden_state = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) memory_cell = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) previous_action_embedding = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) attended_question = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) # This maps non-terminals to valid actions, where the valid actions are grouped by _type_. # We have "global" actions, which are from the global grammar, and "linked" actions, which # are instance-specific and are generated based on question attention. Each action type # has a tuple which is (input representation, output representation, action ids). valid_actions = { "e": { "global": ( torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]), torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]), [0, 1, 2], ), "linked": ( torch.FloatTensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), torch.FloatTensor([[3, 3], [4, 4]]), [3, 4], ), }, "d": { "global": (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1, -1]]), [0]), "linked": ( torch.FloatTensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6], [-0.7, -0.8, -0.9]]), torch.FloatTensor([[5, 5], [6, 6], [7, 7]]), [1, 2, 3], ), }, } grammar_state = [ GrammarStatelet([nonterminal], valid_actions, is_nonterminal) for _, nonterminal in zip(batch_indices, ["e", "d", "e"]) ] self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]], [[10, 11], [12, 13], [14, 15]]]) self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]]) self.possible_actions = [ [ ("e -> f", False, None), ("e -> g", True, None), ("e -> h", True, None), ("e -> i", True, None), ("e -> j", True, None), ], [ ("d -> q", True, None), ("d -> g", True, None), ("d -> h", True, None), ("d -> i", True, None), ], ] rnn_state = [] for i in range(len(batch_indices)): rnn_state.append( RnnStatelet( hidden_state[i], memory_cell[i], previous_action_embedding[i], attended_question[i], self.encoder_outputs, self.encoder_output_mask, )) self.state = GrammarBasedState( batch_indices=batch_indices, action_history=action_history, score=score, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=self.possible_actions, )
def _create_grammar_state( self, world: SpiderWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, linked_actions_linking_scores: torch.Tensor, entity_types: torch.Tensor, entity_graph_encoding: torch.Tensor) -> GrammarStatelet: action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index valid_actions = world.valid_actions entity_map = {} entities = world.entities_names for entity_index, entity in enumerate(entities): entity_map[entity] = entity_index translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [ action_map[action_string] for action_string in action_strings ] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append( (production_rule_array[2], action_index)) else: linked_actions.append( (production_rule_array[0], action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat( global_action_tensors, dim=0).to(global_action_tensors[0].device).long() global_input_embeddings = self._action_embedder( global_action_tensor) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids)) if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [ rule.split(' -> ')[1].strip('[]\"') for rule in linked_rules ] entity_ids = [entity_map[entity] for entity in entities] entity_linking_scores = linking_scores[entity_ids] if linked_actions_linking_scores is not None: entity_action_linking_scores = linked_actions_linking_scores[ entity_ids] if not self._decoder_use_graph_entities: entity_type_tensor = entity_types[entity_ids] entity_type_embeddings = ( self._entity_type_decoder_embedding( entity_type_tensor).to( entity_types.device).float()) else: entity_type_embeddings = entity_graph_encoding.index_select( dim=0, index=torch.tensor( entity_ids, device=entity_graph_encoding.device)) if self._self_attend: translated_valid_actions[key]['linked'] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids), entity_action_linking_scores) else: translated_valid_actions[key]['linked'] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) return GrammarStatelet(['statement'], translated_valid_actions, self.is_nonterminal)
def _create_grammar_state(self, world: WikiTablesWorld, possible_actions: List[ProductionRuleArray], linking_scores: torch.Tensor, entity_types: torch.Tensor) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRuleArrays``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``WikiTablesWorld`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRuleArray]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index entity_map = {} for entity_index, entity in enumerate(world.table_graph.entities): entity_map[entity] = entity_index valid_actions = world.get_valid_actions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [ action_map[action_string] for action_string in action_strings ] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append( (production_rule_array[2], action_index)) else: linked_actions.append( (production_rule_array[0], action_index)) # Then we get the embedded representations of the global actions. global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder( global_action_tensor) if self._add_action_bias: global_action_biases = self._action_biases( global_action_tensor) global_input_embeddings = torch.cat( [global_input_embeddings, global_action_biases], dim=-1) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids)) # Then the representations of the linked actions. if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [rule.split(' -> ')[1] for rule in linked_rules] entity_ids = [entity_map[entity] for entity in entities] # (num_linked_actions, num_question_tokens) entity_linking_scores = linking_scores[entity_ids] # (num_linked_actions,) entity_type_tensor = entity_types[entity_ids] # (num_linked_actions, entity_type_embedding_dim) entity_type_embeddings = self._entity_type_decoder_embedding( entity_type_tensor) translated_valid_actions[key]['linked'] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) # Lastly, we need to also create embedded representations of context-specific actions. In # this case, those are only variable productions, like "r -> x". Note that our language # only permits one lambda at a time, so we don't need to worry about how nested lambdas # might impact this. context_actions = {} for action_id, action in enumerate(possible_actions): if action[0].endswith(" -> x"): input_embedding = self._action_embedder(action[2]) if self._add_action_bias: input_bias = self._action_biases(action[2]) input_embedding = torch.cat([input_embedding, input_bias], dim=-1) output_embedding = self._output_action_embedder(action[2]) context_actions[action[0]] = (input_embedding, output_embedding, action_id) return GrammarStatelet([START_SYMBOL], {}, translated_valid_actions, context_actions, type_declaration.is_nonterminal)
def _create_grammar_state( self, world: AtisWorld, possible_actions: List[ProductionRule], linking_scores: torch.Tensor, entity_types: torch.Tensor, ) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRules``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``AtisWorld`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRule]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_utterance_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index valid_actions = world.valid_actions entity_map = {} entities = world.entities for entity_index, entity in enumerate(entities): entity_map[entity] = entity_index translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [ action_map[action_string] for action_string in action_strings ] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append( (production_rule_array[2], action_index)) else: linked_actions.append( (production_rule_array[0], action_index)) if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = (torch.cat( global_action_tensors, dim=0).to(entity_types.device).long()) global_input_embeddings = self._action_embedder( global_action_tensor) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]["global"] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids), ) if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = list(linked_rules) entity_ids = [entity_map[entity] for entity in entities] entity_linking_scores = linking_scores[entity_ids] entity_type_tensor = entity_types[entity_ids] entity_type_embeddings = ( self._entity_type_decoder_embedding(entity_type_tensor).to( entity_types.device).float()) translated_valid_actions[key]["linked"] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids), ) return GrammarStatelet(["statement"], translated_valid_actions, self.is_nonterminal)
def _create_grammar_state(self, world: WikiTablesLanguage, possible_actions: List[ProductionRuleArray], linking_scores: torch.Tensor, entity_types: torch.Tensor) -> GrammarStatelet: """ This method creates the GrammarStatelet object that's used for decoding. Part of creating that is creating the `valid_actions` dictionary, which contains embedded representations of all of the valid actions. So, we create that here as well. The way we represent the valid expansions is a little complicated: we use a dictionary of `action types`, where the key is the action type (like "global", "linked", or whatever your model is expecting), and the value is a tuple representing all actions of that type. The tuple is (input tensor, output tensor, action id). The input tensor has the representation that is used when `selecting` actions, for all actions of this type. The output tensor has the representation that is used when feeding the action to the next step of the decoder (this could just be the same as the input tensor). The action ids are a list of indices into the main action list for each batch instance. The inputs to this method are for a `single instance in the batch`; none of the tensors we create here are batched. We grab the global action ids from the input ``ProductionRuleArrays``, and we use those to embed the valid actions for every non-terminal type. We use the input ``linking_scores`` for non-global actions. Parameters ---------- world : ``WikiTablesLanguage`` From the input to ``forward`` for a single batch instance. possible_actions : ``List[ProductionRuleArray]`` From the input to ``forward`` for a single batch instance. linking_scores : ``torch.Tensor`` Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch dimension). entity_types : ``torch.Tensor`` Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension). """ # TODO(mattg): Move the "valid_actions" construction to another method. action_map = {} for action_index, action in enumerate(possible_actions): action_string = action[0] action_map[action_string] = action_index entity_map = {} for entity_index, entity in enumerate(world.table_graph.entities): entity_map[entity] = entity_index valid_actions = world.get_nonterminal_productions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [ action_map[action_string] for action_string in action_strings ] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] global_actions = [] linked_actions = [] for production_rule_array, action_index in production_rule_arrays: if production_rule_array[1]: global_actions.append( (production_rule_array[2], action_index)) else: linked_actions.append( (production_rule_array[0], action_index)) # Then we get the embedded representations of the global actions if any. if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder( global_action_tensor) if self._add_action_bias: global_action_biases = self._action_biases( global_action_tensor) global_input_embeddings = torch.cat( [global_input_embeddings, global_action_biases], dim=-1) global_output_embeddings = self._output_action_embedder( global_action_tensor) translated_valid_actions[key]['global'] = ( global_input_embeddings, global_output_embeddings, list(global_action_ids)) # Then the representations of the linked actions. if linked_actions: linked_rules, linked_action_ids = zip(*linked_actions) entities = [rule.split(' -> ')[1] for rule in linked_rules] entity_ids = [entity_map[entity] for entity in entities] # (num_linked_actions, num_question_tokens) entity_linking_scores = linking_scores[entity_ids] # (num_linked_actions,) entity_type_tensor = entity_types[entity_ids] # (num_linked_actions, entity_type_embedding_dim) entity_type_embeddings = self._entity_type_decoder_embedding( entity_type_tensor) translated_valid_actions[key]['linked'] = ( entity_linking_scores, entity_type_embeddings, list(linked_action_ids)) return GrammarStatelet([START_SYMBOL], translated_valid_actions, world.is_nonterminal)
def _create_grammar_statelet(self, language: DropLanguage, possible_actions: List[ProductionRule]) -> Tuple[GrammarStatelet, Dict[str, int], List[str]]: # linked_rule2idx: Dict = None, # action2ques_linkingscore: torch.FloatTensor = None, # quesspan_action_emb: torch.FloatTensor = None) -> GrammarStatelet: """ Make grammar state for a particular instance in the batch using the global and instance-specific actions. For each instance-specific action we have a linking_score vector (size:ques_tokens), and an action embedding Parameters: ------------ world: `SampleHotpotWorld` The world for this instance possible_actions: All possible actions, global and instance-specific linked_rule2idx: Dict from linked_action to idx used for the next two members action2ques_linkingscore: Linking score matrix of size (instance-specific_actions, num_ques_tokens) The indexing is based on the linked_rule2idx dict. The num_ques_tokens is to a padded length The num_ques_tokens is to a padded length, because of which not using a dictionary but a tensor. quesspan_action_emb: Similarly, a (instance-specific_actions, action_embedding_dim) matrix. The indexing is based on the linked_rule2idx dict. """ # ProductionRule: (rule, is_global_rule, rule_id, nonterminal) action2actionidx = {} actionidx2actionstr: List[str] = [] for action_index, action in enumerate(possible_actions): action_string = action[0] action2actionidx[action_string] = action_index actionidx2actionstr.append(action_string) valid_actions = language.get_nonterminal_productions() translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. We'll first split those productions by global vs. # linked action. action_indices = [action2actionidx[action_string] for action_string in action_strings] production_rule_arrays = [(possible_actions[index], index) for index in action_indices] # For global_actions: (rule_vocab_id_tensor, action_index) global_actions = [] for production_rule_array, action_index in production_rule_arrays: # production_rule_array: ProductionRule if production_rule_array[1]: global_actions.append((production_rule_array[2], action_index)) else: raise NotImplementedError # First: Get the embedded representations of the global actions if global_actions: global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) # TODO(nitish): Figure out if need action_bias and separate input/output action embeddings # if self._add_action_bias: # global_action_biases = self._action_biases(global_action_tensor) # global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1) global_output_embeddings = self._action_embedder(global_action_tensor) translated_valid_actions[key]['global'] = (global_output_embeddings, global_output_embeddings, list(global_action_ids)) return (GrammarStatelet([START_SYMBOL], translated_valid_actions, language.is_nonterminal), action2actionidx, actionidx2actionstr)
def setUp(self): super().setUp() self.decoder_step = BasicTransitionFunction( encoder_output_dim=2, action_embedding_dim=2, input_attention=Attention.by_name('dot_product')(), num_start_types=3, add_action_bias=False) batch_indices = [0, 1, 0] action_history = [[1], [3, 4], []] score = [torch.FloatTensor([x]) for x in [.1, 1.1, 2.2]] hidden_state = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) memory_cell = torch.FloatTensor([[i, i] for i in range(len(batch_indices))]) previous_action_embedding = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) attended_question = torch.FloatTensor( [[i, i] for i in range(len(batch_indices))]) # This maps non-terminals to valid actions, where the valid actions are grouped by _type_. # We have "global" actions, which are from the global grammar, and "linked" actions, which # are instance-specific and are generated based on question attention. Each action type # has a tuple which is (input representation, output representation, action ids). valid_actions = { 'e': { 'global': (torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]), torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]), [0, 1, 2]), 'linked': (torch.FloatTensor([[.1, .2, .3], [.4, .5, .6]]), torch.FloatTensor([[3, 3], [4, 4]]), [3, 4]) }, 'd': { 'global': (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1, -1]]), [0]), 'linked': (torch.FloatTensor([[-.1, -.2, -.3], [-.4, -.5, -.6], [-.7, -.8, -.9]]), torch.FloatTensor([[5, 5], [6, 6], [7, 7]]), [1, 2, 3]) } } grammar_state = [ GrammarStatelet([nonterminal], {}, valid_actions, {}, is_nonterminal) for _, nonterminal in zip(batch_indices, ['e', 'd', 'e']) ] self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]], [[10, 11], [12, 13], [14, 15]]]) self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]]) self.possible_actions = [[ ('e -> f', False, None), ('e -> g', True, None), ('e -> h', True, None), ('e -> i', True, None), ('e -> j', True, None) ], [ ('d -> q', True, None), ('d -> g', True, None), ('d -> h', True, None), ('d -> i', True, None) ]] rnn_state = [] for i in range(len(batch_indices)): rnn_state.append( RnnStatelet(hidden_state[i], memory_cell[i], previous_action_embedding[i], attended_question[i], self.encoder_outputs, self.encoder_output_mask)) self.state = GrammarBasedState(batch_indices=batch_indices, action_history=action_history, score=score, rnn_state=rnn_state, grammar_state=grammar_state, possible_actions=self.possible_actions)