def __init__(self,
                 logger=None,
                 machine_type='AMR',
                 from_sent_act_pairs=None,
                 actions_by_stack_rules=None,
                 no_whitespace_in_actions=False,
                 entity_rules=None):

        assert not no_whitespace_in_actions, \
            '--no-whitespace-in-actions deprected'

        # Dummy mode: simulate parser from pre-computed pairs of sentences
        # and actions
        self.actions_by_sentence = {
            " ".join(sent): actions
            for sent, actions in from_sent_act_pairs
        }
        self.logger = logger
        self.sent_idx = 0
        self.actions_by_stack_rules = actions_by_stack_rules
        self.no_whitespace_in_actions = no_whitespace_in_actions
        self.machine_type = machine_type
        self.entity_rules = entity_rules
        # initialize here for speed
        self.spacy_lemmatizer = get_spacy_lemmatizer()

        # counters
        self.pred_counts = Counter()
        self.rule_violation = Counter()
Beispiel #2
0
def merge_rules(sentences, actions, rule_stats, entity_rules=None):

    # generate rules to restrict action space by stack content
    actions_by_stack_rules = rule_stats['possible_predicates']
    for token, counter in rule_stats['possible_predicates'].items():
        actions_by_stack_rules[token] = Counter(counter)

    spacy_lemmatizer = get_spacy_lemmatizer()

    possible_predicates = defaultdict(lambda: Counter())
    for index, sentence_actions in tqdm(enumerate(actions), desc='merge rules'):

        tokens = sentences[index]

        # Initialize machine
        state_machine = AMRStateMachine(
            tokens,
            actions_by_stack_rules=actions_by_stack_rules,
            spacy_lemmatizer=spacy_lemmatizer,
            entity_rules=entity_rules
        )

        for action in sentence_actions:
            # NOTE: At the oracle, possible predicates are collected before
            # PRED/COPY decision (tryConfirm action) we have to take all of
            # them into account
            position, mpositions = \
                state_machine.get_top_of_stack(positions=True)
            if action.startswith('PRED'):
                node = action[5:-1]
                possible_predicates[tokens[position]].update([node])
                if mpositions:
                    mtokens = ','.join([tokens[p] for p in mpositions])
                    possible_predicates[mtokens].update([node])

            elif action == 'COPY_LEMMA':
                lemma, _ = state_machine.get_top_of_stack(lemma=True)
                node = lemma
                possible_predicates[tokens[position]].update([node])
                if mpositions:
                    mtokens = ','.join([tokens[p] for p in mpositions])
                    possible_predicates[mtokens].update([node])

            elif action == 'COPY_SENSE01':
                lemma, _ = state_machine.get_top_of_stack(lemma=True)
                node = f'{lemma}-01'
                possible_predicates[tokens[position]].update([node])
                if mpositions:
                    mtokens = ','.join([tokens[p] for p in mpositions])
                    possible_predicates[mtokens].update([node])

            # execute action
            state_machine.applyAction(action)

    out_rule_stats = rule_stats
    new_possible_predicates = merge_both_rules(possible_predicates, actions_by_stack_rules)
    out_rule_stats['possible_predicates'] = new_possible_predicates

    return out_rule_stats
Beispiel #3
0
def machine_generator(actions_by_stack_rules, spacy_lemmatizer=None, entity_rules=None, post_process=False):
    """Return function that itself returns initialized state machines"""

    # initialize spacy lemmatizer
    if spacy_lemmatizer is None:
        spacy_lemmatizer = get_spacy_lemmatizer()

    def get_new_state_machine(sent_tokens, machine_type=None):

        nonlocal actions_by_stack_rules
        nonlocal spacy_lemmatizer
        nonlocal entity_rules

        # automatic determination of machine if no flag provided
        if sent_tokens[0] in ['<NER>', '<AMR>', 'SRL']:
            assert machine_type is None, \
                "specify --machine-type OR pre-append <machine-type token>"
            machine_type = sent_tokens[0][1:-1]
        elif machine_type is None:
            Exception(
                "needs either --machine-type or appending <machine-type token>"
            )

        # select machine
        if machine_type == 'AMR':
            return AMRStateMachine(
                sent_tokens,
                actions_by_stack_rules=actions_by_stack_rules,
                spacy_lemmatizer=spacy_lemmatizer,
                entity_rules=entity_rules,
                # this is only needed to generate the AMR
                post_process=post_process
            )
        elif machine_type == 'dep-parsing':
            assert sent_tokens[-1] == 'ROOT'
            # sent_tokens.pop()
            # sent_tokens.append('<ROOT>')
            return DepParsingStateMachine(sent_tokens)

        elif machine_type in ['NER', 'SRL']:
            from bio_tags.machine import BIOStateMachine
            return BIOStateMachine(sent_tokens)
        else:
            raise Exception(f'Unknown machine {machine_type}')

    return get_new_state_machine
def main():

    # Argument handling
    args = argument_parser()

    # read rules
    train_rule_stats = read_rule_stats(args.in_rule_stats)
    assert 'action_vocabulary' in train_rule_stats
    assert 'possible_predicates' in train_rule_stats
    action_list = list(sorted(train_rule_stats['action_vocabulary'].keys()))

    # get all actions indexec by action root
    action_by_basic = defaultdict(list)
    for action in train_rule_stats['action_vocabulary'].keys():
        key = action.split('(')[0]
        action_by_basic[key].append(action)

    # open file for reading if provided
    write_out_states = h5_writer(args.out_word_states)
    write_out_valid_actions = h5_writer(args.out_valid_actions)

    # Read content
    # TODO: Point to LDC data
    sentences = readlines(args.in_sentences)
    actions = readlines(args.in_actions)
    assert len(sentences) == len(actions)

    # initialize spacy lemmatizer out of the sentence loop for speed
    spacy_lemmatizer = get_spacy_lemmatizer()

    sent_idx = -1
    stats = {
        'missing_pred_count': Counter(),
        'missing_action_count': Counter(),
        'fan_out_count': Counter()
    }
    for sent_tokens, sent_actions in tqdm(
        zip(sentences, actions),
        desc='extracting oracle masks',
        total=len(actions)
    ):

        # keep count of sentence index
        sent_idx += 1
        if args.offset and sent_idx < args.offset:
            continue

        # Initialize state machine
        amr_state_machine = AMRStateMachine(
            sent_tokens,
            spacy_lemmatizer=spacy_lemmatizer
        )

        # process each action
        word_states_sent = []
        valid_actions_sent = []
        for raw_action in sent_actions:

            # Store states BEFORE ACTION
            # state of each word (buffer B, stack S, reduced X)
            word_states = get_word_states(amr_state_machine, sent_tokens)

            # Get actions valid for this state
            valid_actions = get_valid_actions(
                action_list,
                amr_state_machine,
                train_rule_stats,
                action_by_basic,
                raw_action,
                stats
            )

            # update info
            word_states_sent.append(word_states)
            valid_actions_sent.append(valid_actions)

            # Update machine
            amr_state_machine.applyAction(raw_action)

        # Write states for this sentence
        write_out_states(word_states_sent)
        write_out_valid_actions(valid_actions_sent)

    # inform usre about missing predicates
    for miss in ['missing_pred_count', 'missing_action_count']:
        num_missing = len(stats[miss])
        if num_missing:
            alert_str = f'{num_missing} {miss} rule_stats'
            print(yellow_font(alert_str))

    # inform user about fan-out stats
    mode_fan_out = stats['fan_out_count'].most_common(1)[0][0]
    max_fan_out = max(stats['fan_out_count'].keys())
    alert_str = f'num_actions mode: {mode_fan_out} max: {max_fan_out}'
    print(alert_str)

    # Close file
    write_out_states()
    write_out_valid_actions()
    def runOracle(self,
                  gold_amrs,
                  propbank_args=None,
                  out_oracle=None,
                  out_amr=None,
                  out_sentences=None,
                  out_actions=None,
                  out_rule_stats=None,
                  add_unaligned=0,
                  no_whitespace_in_actions=False,
                  multitask_words=None,
                  copy_lemma_action=False,
                  addnode_count_cutoff=None):

        print_log("oracle", "Parsing data")
        # deep copy of gold AMRs
        self.gold_amrs = [gold_amr.copy() for gold_amr in gold_amrs]

        # print about inconsistencies in annotations
        alert_inconsistencies(self.gold_amrs)

        # open all files (if paths provided) and get writers to them
        oracle_write = writer(out_oracle)
        amr_write = writer(out_amr)
        sentence_write = writer(out_sentences, add_return=True)
        actions_write = writer(out_actions, add_return=True)

        # This will store overall stats
        self.stats = {
            'possible_predicates': Counter(),
            'action_vocabulary': Counter(),
            'addnode_counts': Counter()
        }

        # unaligned tokens
        included_unaligned = [
            '-',
            'and',
            'multi-sentence',
            'person',
            'cause-01',
            'you',
            'more',
            'imperative',
            '1',
            'thing',
        ]

        # initialize spacy lemmatizer out of the sentence loop for speed
        spacy_lemmatizer = None
        if copy_lemma_action:
            spacy_lemmatizer = get_spacy_lemmatizer()

        # Loop over golf AMRs
        for sent_idx, gold_amr in tqdm(enumerate(self.gold_amrs),
                                       desc=f'computing oracle',
                                       total=len(self.gold_amrs)):

            if self.verbose:
                print("New Sentence " + str(sent_idx) + "\n\n\n")

            # TODO: Describe what is this pre-processing
            gold_amr = preprocess_amr(gold_amr, add_unaligned,
                                      included_unaligned)

            # Initialize state machine
            tr = AMRStateMachine(gold_amr.tokens,
                                 verbose=self.verbose,
                                 add_unaligned=add_unaligned,
                                 spacy_lemmatizer=spacy_lemmatizer,
                                 entity_rules=self.entity_rules)
            self.transitions.append(tr)
            self.amrs.append(tr.amr)

            # Loop over potential actions
            while tr.buffer or tr.stack:

                if self.tryMerge(tr, tr.amr, gold_amr):
                    action = 'MERGE'

                elif self.tryEntity(tr, tr.amr, gold_amr):
                    action = f'ADDNODE({self.entity_type})'

                elif self.tryDependent(tr, tr.amr, gold_amr):
                    edge = self.new_edge[1:] \
                        if self.new_edge.startswith(':') else self.new_edge
                    action = f'DEPENDENT({self.new_node},{edge})'
                    self.dep_id = None

                elif self.tryConfirm(tr, tr.amr, gold_amr):
                    # if --copy-lemma-action check if lemma or first sense
                    # equal node name. Use corresponding action
                    if copy_lemma_action:
                        lemma, _ = tr.get_top_of_stack(lemma=True)
                        if copy_lemma_action and lemma == self.new_node:
                            action = 'COPY_LEMMA'
                        elif copy_lemma_action and f'{lemma}-01' == self.new_node:
                            action = 'COPY_SENSE01'
                        else:
                            action = f'PRED({self.new_node})'
                    else:
                        action = f'PRED({self.new_node})'

                elif self.tryIntroduce(tr, tr.amr, gold_amr):
                    action = 'INTRODUCE'

                elif self.tryLA(tr, tr.amr, gold_amr):
                    if self.new_edge == 'root':
                        action = f'LA({self.new_edge})'
                    else:
                        action = f'LA({self.new_edge[1:]})'

                elif self.tryRA(tr, tr.amr, gold_amr):
                    if self.new_edge == 'root':
                        action = f'RA({self.new_edge})'
                    else:
                        action = f'RA({self.new_edge[1:]})'

                elif self.tryReduce(tr, tr.amr, gold_amr):
                    action = 'REDUCE'

                elif self.trySWAP(tr, tr.amr, gold_amr):
                    action = 'UNSHIFT'

                elif tr.buffer:
                    action = 'SHIFT'

                else:
                    tr.stack = []
                    tr.buffer = []
                    break

                # Store stats
                # get token(s) at the top of the stack
                token, merged_tokens = tr.get_top_of_stack()
                action_label = action.split('(')[0]

                # check action has not invalid chars and normalize
                # TODO: --no-whitespace-in-actions being deprecated
                if no_whitespace_in_actions and action_label == 'PRED':
                    assert '_' not in action, \
                        "--no-whitespace-in-actions prohibits use of _ in actions"
                    if ' ' in action_label:
                        action = action.replace(' ', '_')

                # Add prediction ot top of the buffer
                if action == 'SHIFT' and multitask_words is not None:
                    action = label_shift(tr, multitask_words)

                # APPLY ACTION
                tr.applyAction(action)

            # Close machine
            tr.CLOSE(training=True,
                     gold_amr=gold_amr,
                     use_addnonde_rules=use_addnode_rules)

            # update files
            if out_oracle:
                # to avoid printing
                oracle_write(str(tr))
            # JAMR format AMR
            amr_write(tr.amr.toJAMRString())
            # Tokens and actions
            # extra tag to be reduced at start
            tokens = tr.amr.tokens
            actions = tr.actions

            # Update action count
            self.stats['action_vocabulary'].update(actions)
            del gold_amr.nodes[-1]
            addnode_actions = [a for a in actions if a.startswith('ADDNODE')]
            self.stats['addnode_counts'].update(addnode_actions)

            # separator
            if no_whitespace_in_actions:
                sep = " "
            else:
                sep = "\t"
            tokens = sep.join(tokens)
            actions = sep.join(actions)
            # Write
            sentence_write(tokens)
            actions_write(actions)

        print_log("oracle", "Done")

        # close files if open
        oracle_write()
        amr_write()
        sentence_write()
        actions_write()

        self.labelsO2idx = {'<pad>': 0}
        self.labelsA2idx = {'<pad>': 0}
        self.pred2idx = {'<pad>': 0}
        self.action2idx = {'<pad>': 0}

        for tr in self.transitions:
            for a in tr.actions:
                a = AMRStateMachine.readAction(a)[0]
                self.action2idx.setdefault(a, len(self.action2idx))
            for p in tr.predicates:
                self.pred2idx.setdefault(p, len(self.pred2idx))
            for l in tr.labels:
                self.labelsO2idx.setdefault(l, len(self.labelsO2idx))
            for l in tr.labelsA:
                self.labelsA2idx.setdefault(l, len(self.labelsA2idx))

        self.stats["action2idx"] = self.action2idx
        self.stats["pred2idx"] = self.pred2idx
        self.stats["labelsO2idx"] = self.labelsO2idx
        self.stats["labelsA2idx"] = self.labelsA2idx

        # Compute the word dictionary

        self.char2idx = {'<unk>': 0}
        self.word2idx = {'<unk>': 0, '<eof>': 1, '<ROOT>': 2, '<unaligned>': 3}
        self.node2idx = {}
        self.word_counter = Counter()

        for amr in self.gold_amrs:
            for tok in amr.tokens:
                self.word_counter[tok] += 1
                self.word2idx.setdefault(tok, len(self.word2idx))
                for ch in tok:
                    self.char2idx.setdefault(ch, len(self.char2idx))
            for n in amr.nodes:
                self.node2idx.setdefault(amr.nodes[n], len(self.node2idx))

        self.stats["char2idx"] = self.char2idx
        self.stats["word2idx"] = self.word2idx
        self.stats["node2idx"] = self.node2idx
        self.stats["word_counter"] = self.word_counter

        self.stats['possible_predicates'] = self.possiblePredicates

        if addnode_count_cutoff:
            self.stats['addnode_blacklist'] = [
                a for a, c in self.stats['addnode_counts'].items()
                if c <= addnode_count_cutoff
            ]
            num_addnode_blackl = len(self.stats['addnode_blacklist'])
            num_addnode = len(self.stats['addnode_counts'])
            print(f'{num_addnode_blackl}/{num_addnode} blacklisted ADDNODES')
            del self.stats['addnode_counts']

        # State machine stats for this senetnce
        if out_rule_stats:
            with open(out_rule_stats, 'w') as fid:
                fid.write(json.dumps(self.stats))
Beispiel #6
0
    def __init__(self,
                 oracle_stats,
                 embedding_dim,
                 action_embedding_dim,
                 char_embedding_dim,
                 hidden_dim,
                 char_hidden_dim,
                 rnn_layers,
                 dropout_ratio,
                 pretrained_dim=1024,
                 amrs=None,
                 experiment=None,
                 use_gpu=False,
                 use_chars=False,
                 use_bert=False,
                 use_attention=False,
                 use_function_words=False,
                 use_function_words_rels=False,
                 parse_unaligned=False,
                 weight_inputs=False,
                 attend_inputs=False):
        super(AMRModel, self).__init__()
        self.embedding_dim = embedding_dim
        self.char_embedding_dim = char_embedding_dim
        self.char_hidde_dim = char_hidden_dim
        self.action_embedding_dim = action_embedding_dim
        self.hidden_dim = hidden_dim
        self.exp = experiment
        self.pretrained_dim = pretrained_dim
        self.rnn_layers = rnn_layers
        self.use_bert = use_bert
        self.use_chars = use_chars
        self.use_attention = use_attention
        self.use_function_words_all = use_function_words
        self.use_function_words_rels = use_function_words_rels
        self.use_function_words = use_function_words or use_function_words_rels
        self.parse_unaligned = parse_unaligned
        self.weight_inputs = weight_inputs
        self.attend_inputs = attend_inputs
        # self.tokenizer = self.spacy_tokenizer()

        self.warm_up = False

        self.possible_predicates = oracle_stats["possible_predicates"]

        # Load spacy lemmatizer if needed
        self.lemmatizer = get_spacy_lemmatizer()

        self.state_dim = 3 * hidden_dim + (hidden_dim if use_attention else 0) \
            + (hidden_dim if self.use_function_words_all else 0)

        self.state_size = self.state_dim // hidden_dim

        if self.weight_inputs or self.attend_inputs:
            self.state_dim = hidden_dim

        self.use_gpu = use_gpu

        # Vocab and indices

        self.char2idx = oracle_stats['char2idx']
        self.word2idx = oracle_stats['word2idx']
        self.node2idx = oracle_stats['node2idx']
        word_counter = oracle_stats['word_counter']

        self.amrs = amrs

        self.singletons = {
            self.word2idx[w]
            for w in word_counter if word_counter[w] == 1
        }
        self.singletons.discard('<unk>')
        self.singletons.discard('<eof>')
        self.singletons.discard('<ROOT>')
        self.singletons.discard('<unaligned>')

        self.labelsO2idx = oracle_stats["labelsO2idx"]
        self.labelsA2idx = oracle_stats["labelsA2idx"]
        self.pred2idx = oracle_stats["pred2idx"]
        self.action2idx = oracle_stats["action2idx"]

        self.vocab_size = len(self.word2idx)
        self.action_size = len(self.action2idx)

        self.labelA_size = len(self.labelsA2idx)
        self.labelO_size = len(self.labelsO2idx)
        self.pred_size = len(self.pred2idx)

        self.idx2labelO = {v: k for k, v in self.labelsO2idx.items()}
        self.idx2labelA = {v: k for k, v in self.labelsA2idx.items()}
        self.idx2node = {v: k for k, v in self.node2idx.items()}
        self.idx2pred = {v: k for k, v in self.pred2idx.items()}
        self.idx2action = {v: k for k, v in self.action2idx.items()}
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        self.idx2char = {v: k for k, v in self.char2idx.items()}
        # self.ner_map = ner_map
        self.labelsA = []
        for k, v in self.labelsA2idx.items():
            self.labelsA.append(v)
        self.labelsO = []
        for k, v in self.labelsO2idx.items():
            self.labelsO.append(v)
        self.preds = []
        for k, v in self.pred2idx.items():
            self.preds.append(v)
        utils.print_log('parser',
                        f'Number of characters: {len(self.char2idx)}')
        utils.print_log('parser', f'Number of words: {len(self.word2idx)}')
        utils.print_log('parser', f'Number of nodes: {len(self.node2idx)}')
        utils.print_log('parser', f'Number of actions: {len(self.action2idx)}')
        for action in self.action2idx:
            print('\t', action)
        utils.print_log('parser', f'Number of labels: {len(self.labelsO2idx)}')
        utils.print_log('parser',
                        f'Number of labelsA: {len(self.labelsA2idx)}')
        utils.print_log('parser',
                        f'Number of predicates: {len(self.pred2idx)}')

        # Parameters
        self.word_embeds = nn.Embedding(self.vocab_size, embedding_dim)
        self.action_embeds = nn.Embedding(self.action_size,
                                          action_embedding_dim)
        self.labelA_embeds = nn.Embedding(self.labelA_size,
                                          action_embedding_dim)
        self.labelO_embeds = nn.Embedding(self.labelO_size,
                                          action_embedding_dim)
        self.pred_embeds = nn.Embedding(self.pred_size, action_embedding_dim)
        self.pred_unk_embed = nn.Parameter(torch.randn(
            1, self.action_embedding_dim),
                                           requires_grad=True)
        self.empty_emb = nn.Parameter(torch.randn(1, hidden_dim),
                                      requires_grad=True)

        # Stack-LSTMs
        self.buffer_lstm = nn.LSTMCell(self.embedding_dim, hidden_dim)
        self.stack_lstm = nn.LSTMCell(self.embedding_dim, hidden_dim)
        self.action_lstm = nn.LSTMCell(action_embedding_dim, hidden_dim)
        self.lstm_initial_1 = utils.xavier_init(self.use_gpu, 1,
                                                self.hidden_dim)
        self.lstm_initial_2 = utils.xavier_init(self.use_gpu, 1,
                                                self.hidden_dim)
        self.lstm_initial = (self.lstm_initial_1, self.lstm_initial_2)

        if self.use_chars:
            self.char_embeds = nn.Embedding(len(self.char2idx),
                                            char_embedding_dim)
            self.unaligned_char_embed = nn.Parameter(torch.randn(
                1, 2 * char_hidden_dim),
                                                     requires_grad=True)
            self.root_char_embed = nn.Parameter(torch.randn(
                1, 2 * char_hidden_dim),
                                                requires_grad=True)
            self.pad_char_embed = nn.Parameter(
                torch.zeros(1, 2 * char_hidden_dim))
            self.char_lstm_forward = nn.LSTM(char_embedding_dim,
                                             char_hidden_dim,
                                             num_layers=rnn_layers,
                                             dropout=dropout_ratio)
            self.char_lstm_backward = nn.LSTM(char_embedding_dim,
                                              char_hidden_dim,
                                              num_layers=rnn_layers,
                                              dropout=dropout_ratio)

            self.tok_2_embed = nn.Linear(
                self.embedding_dim + 2 * char_hidden_dim, self.embedding_dim)

        if self.use_bert:
            # bert embeddings to LSTM input
            self.pretrained_2_embed = nn.Linear(
                self.embedding_dim + self.pretrained_dim, self.embedding_dim)

        if use_attention:
            self.forward_lstm = nn.LSTM(self.embedding_dim,
                                        hidden_dim,
                                        num_layers=rnn_layers,
                                        dropout=dropout_ratio)
            self.backward_lstm = nn.LSTM(self.embedding_dim,
                                         hidden_dim,
                                         num_layers=rnn_layers,
                                         dropout=dropout_ratio)

            self.attention_weights = nn.Parameter(torch.randn(
                2 * hidden_dim, 2 * hidden_dim),
                                                  requires_grad=True)

            self.attention_ff1_1 = nn.Linear(2 * hidden_dim, hidden_dim)

        self.dropout_emb = nn.Dropout(p=dropout_ratio)
        self.dropout = nn.Dropout(p=dropout_ratio)

        self.action_softmax1 = nn.Linear(self.state_dim, hidden_dim)
        self.labelA_softmax1 = nn.Linear(self.state_dim, hidden_dim)
        self.pred_softmax1 = nn.Linear(self.state_dim, hidden_dim)
        if not self.use_function_words_rels:
            self.label_softmax1 = nn.Linear(self.state_dim, hidden_dim)
        else:
            self.label_softmax1 = nn.Linear(self.state_dim + hidden_dim,
                                            hidden_dim)

        self.action_softmax2 = nn.Linear(hidden_dim, len(self.action2idx) + 2)
        self.labelA_softmax2 = nn.Linear(hidden_dim, len(self.labelsA2idx) + 2)
        self.label_softmax2 = nn.Linear(hidden_dim, len(self.labelsO2idx) + 2)
        self.pred_softmax2 = nn.Linear(hidden_dim, len(self.pred2idx) + 2)

        # composition functions
        self.arc_composition_head = nn.Linear(
            2 * self.embedding_dim + self.action_embedding_dim,
            self.embedding_dim)
        self.merge_composition = nn.Linear(2 * self.embedding_dim,
                                           self.embedding_dim)
        self.dep_composition = nn.Linear(
            self.embedding_dim + self.action_embedding_dim, self.embedding_dim)
        self.addnode_composition = nn.Linear(
            self.embedding_dim + self.action_embedding_dim, self.embedding_dim)
        self.pred_composition = nn.Linear(
            self.embedding_dim + self.action_embedding_dim, self.embedding_dim)

        # experiments
        if self.use_function_words:
            self.functionword_lstm = nn.LSTMCell(self.embedding_dim,
                                                 hidden_dim)

        if self.parse_unaligned:
            self.pred_softmax1_unaligned = nn.Linear(self.state_dim,
                                                     hidden_dim)
            self.pred_softmax2_unaligned = nn.Linear(hidden_dim,
                                                     len(self.pred2idx) + 2)

        if self.weight_inputs:
            self.action_attention = nn.Parameter(torch.zeros(self.state_size),
                                                 requires_grad=True)
            self.label_attention = nn.Parameter(torch.zeros(self.state_size),
                                                requires_grad=True)
            self.labelA_attention = nn.Parameter(torch.zeros(self.state_size),
                                                 requires_grad=True)
            self.pred_attention = nn.Parameter(torch.zeros(self.state_size),
                                               requires_grad=True)
            if self.parse_unaligned:
                self.pred_attention_unaligned = nn.Parameter(
                    torch.zeros(self.state_size), requires_grad=True)
        elif self.attend_inputs:
            self.action_attention = torch.nn.Linear(self.state_size * 2,
                                                    self.state_size)
            self.label_attention = torch.nn.Linear(self.state_size * 2,
                                                   self.state_size)
            self.labelA_attention = torch.nn.Linear(self.state_size * 2,
                                                    self.state_size)
            self.pred_attention = torch.nn.Linear(self.state_size * 2,
                                                  self.state_size)
            if self.parse_unaligned:
                self.pred_attention_unaligned = torch.nn.Linear(
                    self.state_size * 2, self.state_size)
            self.prevent_overfitting = torch.nn.Linear(hidden_dim,
                                                       self.state_size * 2)

        # stats and accuracy
        self.action_acc = utils.Accuracy()
        self.label_acc = utils.Accuracy()
        self.labelA_acc = utils.Accuracy()
        self.pred_acc = utils.Accuracy()

        self.action_confusion_matrix = utils.ConfusionMatrix(
            self.action2idx.keys())
        self.label_confusion_matrix = utils.ConfusionMatrix(
            self.labelsO2idx.keys())

        self.action_loss = 0
        self.label_loss = 0
        self.labelA_loss = 0
        self.pred_loss = 0

        self.epoch_loss = 0

        self.rand_init()
        if self.use_gpu:
            for m in self.modules():
                m.cuda()