def __eq__(self, other): if isinstance(self, other.__class__): return all([ util.tensors_equal(self.terminal_actions, other.terminal_actions), util.tensors_equal(self.checklist_target, other.checklist_target), util.tensors_equal(self.checklist_mask, other.checklist_mask), util.tensors_equal(self.checklist, other.checklist), self.terminal_indices_dict == other.terminal_indices_dict, ]) return NotImplemented
def __eq__(self, other): if isinstance(self, other.__class__): return all([ util.tensors_equal(self.hidden_state, other.hidden_state, tolerance=1e-5), util.tensors_equal(self.memory_cell, other.memory_cell, tolerance=1e-5), util.tensors_equal(self.previous_action_embedding, other.previous_action_embedding, tolerance=1e-5), util.tensors_equal(self.attended_input, other.attended_input, tolerance=1e-5), ]) return NotImplemented
def __eq__(self, other): if isinstance(self, other.__class__): # pylint: disable=protected-access return all([ self._nonterminal_stack == other._nonterminal_stack, self._lambda_stacks == other._lambda_stacks, util.tensors_equal(self._valid_actions, other._valid_actions), util.tensors_equal(self._context_actions, other._context_actions), self._is_nonterminal == other._is_nonterminal, ]) return NotImplemented
def __eq__(self, other): if isinstance(self, other.__class__): return all([ self.batch_indices == other.batch_indices, self.action_history == other.action_history, util.tensors_equal(self.score, other.score, tolerance=1e-3), util.tensors_equal(self.rnn_state, other.rnn_state, tolerance=1e-4), self.grammar_state == other.grammar_state, self.checklist_state == other.checklist_state, self.possible_actions == other.possible_actions, self.extras == other.extras, util.tensors_equal(self.debug_info, other.debug_info, tolerance=1e-6), ]) return NotImplemented
def __eq__(self, other): if isinstance(self, other.__class__): return all([ self._nonterminal_stack == other._nonterminal_stack, util.tensors_equal(self._valid_actions, other._valid_actions), self._is_nonterminal == other._is_nonterminal, self._reverse_productions == other._reverse_productions, ]) return NotImplemented
def __eq__(self, other): if isinstance(self, other.__class__): return all([ util.tensors_equal(self.hidden_state, other.hidden_state, tolerance=1e-5), util.tensors_equal(self.memory_cell, other.memory_cell, tolerance=1e-5), util.tensors_equal(self.previous_action_embedding, other.previous_action_embedding, tolerance=1e-5), util.tensors_equal(self.attended_input, other.attended_input, tolerance=1e-5), util.tensors_equal(self.encoder_outputs, other.encoder_outputs, tolerance=1e-5), util.tensors_equal(self.encoder_output_mask, other.encoder_output_mask, tolerance=1e-5), ]) return NotImplemented