示例#1
0
def add_default_entry(lexicon):
    default_entry = {ENT_VERB_SUBC: {DEFAULT_SUBCAT: {}}}
    default_subcat = default_entry[ENT_VERB_SUBC][DEFAULT_SUBCAT]

    for word_entry in lexicon.values():
        for subcat in word_entry[ENT_VERB_SUBC].values():
            all_complements = difference_list(subcat.keys(), [
                SUBCAT_OPTIONAL, SUBCAT_REQUIRED, SUBCAT_NOT,
                SUBCAT_CONSTRAINTS
            ])

            for complement_type in all_complements:
                complement_info = subcat[complement_type]
                clean_complement_type = re.sub("PP1|PP2", 'PP',
                                               complement_type)
                clean_complement_type = re.sub("-OC|-SC|-POC|-NPC|-VC", '',
                                               clean_complement_type)

                if clean_complement_type not in default_subcat:
                    default_subcat[clean_complement_type] = defaultdict(list)

                default_arg = default_subcat[clean_complement_type]

                default_arg[ARG_PREFIXES] += complement_info.get(
                    ARG_PREFIXES, [])
                default_arg[ARG_POSITIONS] = list(
                    set(default_arg[ARG_POSITIONS] +
                        complement_info.get(ARG_POSITIONS, [])))
                default_arg[ARG_ROOT_UPOSTAGS] = list(
                    set(default_arg[ARG_ROOT_UPOSTAGS] +
                        complement_info.get(ARG_ROOT_UPOSTAGS, [])))
                default_arg[ARG_ROOT_URELATIONS] = list(
                    set(default_arg[ARG_ROOT_URELATIONS] +
                        complement_info.get(ARG_ROOT_URELATIONS, [])))
                default_arg[ARG_ROOT_PATTERNS] = list(
                    set(default_arg[ARG_ROOT_PATTERNS] +
                        complement_info.get(ARG_ROOT_PATTERNS, [])))

                constraints = difference_list(
                    complement_info.get(ARG_CONSTRAINTS, []), [
                        ARG_CONSTRAINT_PLURAL,
                        ARG_CONSTRAINT_N_N_MOD_NO_OTHER_OBJ,
                        ARG_CONSTRAINT_DET_POSS_NO_OTHER_OBJ
                    ])
                default_arg[ARG_CONSTRAINTS] = list(
                    set(default_arg[ARG_CONSTRAINTS] + constraints))

    # Most common prepositional positions for each NP or PP arguments
    for complement_type, complement_info in default_subcat.items():
        #if complement_type in [COMP_SUBJ, COMP_OBJ, COMP_IND_OBJ, COMP_PP]:
        #	print(complement_type)
        #	print(np.unique(complement_info[ARG_PREFIXES], return_counts=True))

        complement_info[ARG_PREFIXES] = list(set(
            complement_info[ARG_PREFIXES]))

    default_subcat[SUBCAT_OPTIONAL] = list(default_subcat.keys())
    default_entry[ENT_ORTH] = DEFAULT_ENTRY
    lexicon[DEFAULT_ENTRY] = default_entry
    def _is_in_not(self, complement_types: list, matched_positions: list):
        """
		Checks whether the given match between arguments and positions appear in the NOT constraints
		:param complement_types: a list of complement types
		:param matched_positions: a list of the corresponding matched positions for the complements
		:return: True if this match don't appear in the NOT constraints of this subcat, and False otherwise
		"""

        if self.not_constraints == []:
            return False

        for not_constraint in self.not_constraints:
            # The NOT constraint isn't relevant it doesn't intersect with the complement types
            if difference_list(complement_types,
                               not_constraint.keys()) == complement_types:
                continue

            # Check for any violation with the current NOT constraint
            found_violation = False
            for complement_type, matched_position in zip(
                    complement_types, matched_positions):
                if complement_type in not_constraint.keys() and \
                  matched_position not in not_constraint[complement_type]:
                    found_violation = True
                    break

            if found_violation:
                return False

        return True
    def __init__(self, subcat_info: dict, subcat_type, is_verb):
        self.subcat_type = subcat_type

        self.arguments = defaultdict(LexicalArgument)
        for complement_type in difference_list(
                subcat_info.keys(),
            [SUBCAT_REQUIRED, SUBCAT_OPTIONAL, SUBCAT_NOT, SUBCAT_CONSTRAINTS
             ]):
            self.arguments[complement_type] = LexicalArgument(
                subcat_info[complement_type], complement_type, is_verb)

        self.requires = subcat_info.get(SUBCAT_REQUIRED, [])
        self.optionals = difference_list(self.arguments.keys(),
                                         subcat_info.get(SUBCAT_REQUIRED, []))
        self.not_constraints = subcat_info.get(SUBCAT_NOT, [])
        self.constraints = subcat_info.get(SUBCAT_CONSTRAINTS, [])
        self.is_verb = is_verb
示例#4
0
def change_types(subcat, types_dict):
    """
	Translates the types of the complements according to the types dictionary
	:param subcat: a dictionary of the subcategorization info
	:param types_dict: a ditionary of types ({old_type: new_type})
	:return: None
	"""

    # Moving over all the needed to be changed arguments
    for complement_type, new_complement_type in types_dict.items():
        curr_specs["comp"] = complement_type

        if complement_type == new_complement_type:
            continue

        # If the argument appears in the subcat info, change its name
        if complement_type in subcat.keys():
            if new_complement_type != IGNORE_COMP:
                subcat[new_complement_type] = deepcopy(subcat[complement_type])

                # Replacing the complement type on the required list
                if complement_type in difference_list(subcat[SUBCAT_REQUIRED],
                                                      types_dict.values()):
                    subcat[SUBCAT_REQUIRED].remove(complement_type)

                    if new_complement_type not in subcat[SUBCAT_OPTIONAL]:
                        subcat[SUBCAT_REQUIRED].append(new_complement_type)

                # Replacing the complement type on the optionals list
                if complement_type in difference_list(subcat[SUBCAT_OPTIONAL],
                                                      types_dict.values()):
                    subcat[SUBCAT_OPTIONAL].remove(complement_type)

                    if new_complement_type not in subcat[SUBCAT_REQUIRED]:
                        subcat[SUBCAT_OPTIONAL].append(new_complement_type)

            del subcat[complement_type]

    subcat[SUBCAT_REQUIRED] = list(set(subcat[SUBCAT_REQUIRED]))
    subcat[SUBCAT_OPTIONAL] = list(set(subcat[SUBCAT_OPTIONAL]))
    curr_specs["comp"] = None
示例#5
0
def add_extensions(subcat, is_verb=False):
    """
	Extends the positions for certain compelements that gets a constant words after prefixes
	For example, P-WH-S that can get "whether" after a preposition (like "of")
	The extension is relevant only for prepositions
	:param subcat: a dictionary of the subcategorization info
	:param is_verb: whether or not the given subcat is for verb rearranging (otherwise- nominalization)
	:return: None
	"""

    get_options_with_extensions = lambda option, extensions: [
        option + " " + addition if option not in extensions else option
        for addition in extensions
    ]

    # Adding additional fixed info for specific complements
    for complement_type in difference_list(
            subcat.keys(),
        [SUBCAT_CONSTRAINTS, SUBCAT_REQUIRED, SUBCAT_OPTIONAL, SUBCAT_NOT]):
        curr_specs["comp"] = complement_type

        # only for complements that are represented as a list (excluding NOT for example)
        if type(subcat[complement_type]) == list:
            new_options = []

            # Update each option for the complement, by adding constants after the prepositions
            for option in subcat[complement_type]:
                if complement_type in [COMP_WH_S, COMP_P_WH_S]:
                    if is_verb:
                        extensions = WH_VERB_OPTIONS
                    else:
                        extensions = WH_NOM_OPTIONS
                elif complement_type == COMP_WHERE_WHEN_S:
                    extensions = WHERE_WHEN_OPTIONS
                elif complement_type == COMP_HOW_S:
                    extensions = HOW_OPTIONS
                elif complement_type == COMP_HOW_TO_INF:
                    extensions = HOW_TO_OPTIONS
                else:
                    extensions = [option]

                if option in extensions:
                    extensions = [option]

                new_options += get_options_with_extensions(option, extensions)

            subcat[complement_type] = new_options

    curr_specs["comp"] = None
示例#6
0
    def _update_unused_candidates(self,
                                  token_candidates: list,
                                  predicate_token: Token,
                                  used_tokens: list,
                                  extraction: dict,
                                  specify_none=False,
                                  trim_arguments=True):
        if not specify_none:
            return

        extraction[COMP_NONE] = []
        prepositions = list(
            itertools.chain.from_iterable([
                self.entries[DEFAULT_ENTRY].subcats[DEFAULT_SUBCAT].
                arguments[arg_type].prefixes
                for arg_type in [COMP_PP, COMP_IND_OBJ, COMP_SUBJ, COMP_OBJ]
            ]))

        # Add any candidate that isn't in the used tokens to the NONE complement
        for unused_candidate in difference_list(token_candidates, used_tokens):
            unused_token = unused_candidate.get_token()

            nom_links = [URELATION_NMOD, URELATION_COMPOUND, URELATION_ACL]
            verb_links = [
                URELATION_NSUBJ, URELATION_IOBJ, URELATION_DOBJ,
                URELATION_NMOD_POSS, URELATION_NSUBJPASS, URELATION_NMOD
            ]
            relevant_links = verb_links if self.is_verb else nom_links

            if unused_token.dep_ not in relevant_links or unused_token.i == predicate_token.i:
                continue

            if not unused_token.pos_.startswith("N"):
                continue

            if unused_token.dep_ in [URELATION_NMOD, URELATION_ACL]:
                found_prep = False
                candidate_text = unused_token._.subtree_text + " "

                for prefix in prepositions:
                    if candidate_text.startswith(prefix):
                        found_prep = True

                if not found_prep:
                    continue

            unused_arg = ExtractedArgument(unused_token, COMP_NONE)
            arg_span = unused_arg.as_span(trim_argument=trim_arguments)
            extraction[COMP_NONE].append(arg_span)
    def _check_if_no_other_object(self, complement_types: list,
                                  matched_positions: list):
        """
		Checks the compatibility of a single object argument of nominalizations (not relevant to verbs)
		This function has any meaning only when there is signle object argument
		:param complement_types: a list of complement types
		:param matched_positions: a list of the matched positions for the complements
		:return: True if a single object (if any) gets an appropriate position
		"""

        if self.is_verb:
            return True

        object_args = []

        # Get all the "objects" that are arguments of the nominalization
        for object_candidate in [COMP_OBJ, COMP_IND_OBJ, COMP_SUBJ]:
            if object_candidate in complement_types:
                object_args.append(object_candidate)

        if len(object_args) != 1:
            return True

        # Only one argument was founded
        complement_index = complement_types.index(object_args[0])
        complement_type = complement_types[complement_index]
        matched_position = matched_positions[complement_index]

        # Check if there isn't any other argument that should get that position when it is the only object
        # Meaning- this argument cannot get this position when it is the only object argument
        for other_complement_type in difference_list(self.arguments.keys(),
                                                     [complement_type]):
            if matched_position == POS_DET_POSS:
                if self.arguments[other_complement_type].is_det_poss_only():
                    return False

            elif matched_position == POS_N_N_MOD:
                if self.arguments[other_complement_type].is_n_n_mod_only():
                    return False

        return True
    def predict(self,
                dependency_tree,
                candidate_start,
                candidate_end,
                predicate_idx,
                suitable_verb,
                tagset_type,
                limited_types=None):
        tokens = [token.orth_ for token in dependency_tree]

        *features, _ = self.model.encode(tokens,
                                         candidate_start,
                                         candidate_end,
                                         predicate_idx,
                                         suitable_verb,
                                         tagset_type,
                                         all_sizes=True)

        with torch.no_grad():
            output = self.model(features).view(-1)

        # Avoid impossible predictions
        if limited_types is not None:
            excluded_ids = [
                self.tagset[arg_type]
                for arg_type in difference_list(self.tagset, limited_types)
            ]
            output[excluded_ids] = -np.inf

        logits = F.log_softmax(output, dim=0)
        #predicted_type = self.reverse_dataset[logits.argmax().item()]
        #print(tokens[predicate_idx], tokens[candidate_start:candidate_end], predicted_type)

        #logits[-np.inf == logits] = 0
        #entropy = -(np.exp(logits) * logits).sum()

        #return predicted_type, entropy
        return logits
示例#9
0
def sanity_checks(lexicon, is_verb=False):
    curr_specs["is_verb"] = is_verb

    for word in lexicon.keys():
        word_entry = lexicon[word]
        curr_specs["word"] = word_entry[ENT_ORTH]

        if not is_verb:
            is_known(without_part(word_entry[ENT_NOM_TYPE][TYPE_OF_NOM]),
                     ["NOM_TYPE"], "NOM-TYPE")

        for subentry in lexicon[word].keys():
            if lexicon[word][subentry] is not None:
                entries_by_type[type(lexicon[word][subentry])].update(
                    [subentry])

        for subcat_type, subcat in lexicon[word][ENT_VERB_SUBC].items():
            curr_specs["subcat"] = subcat_type
            optionals = subcat[SUBCAT_OPTIONAL]
            requires = subcat[SUBCAT_REQUIRED]

            for constraint in subcat[SUBCAT_CONSTRAINTS]:
                is_known(constraint, ["SUBCAT_CONSTRAINT"],
                         "SUBCAT COMPLEMENTS & CONSTRAINTS")

            if len(set(requires)) != len(requires):
                print(requires)
                raise Exception(
                    f"Requires list isn't unique ({get_current_specs()}).")

            if len(set(optionals)) != len(optionals):
                print(optionals)
                raise Exception(
                    f"Optionals list isn't unique ({get_current_specs()}).")

            # Check that the requires and the optionals lists aren't intersecting
            if set(difference_list(optionals, requires)) != set(optionals):
                print(requires)
                print(optionals)
                raise Exception(
                    f"Requires and optionals are intersecting ({get_current_specs()})."
                )

            all_complements = difference_list(subcat.keys(), [
                SUBCAT_OPTIONAL, SUBCAT_REQUIRED, SUBCAT_NOT,
                SUBCAT_CONSTRAINTS
            ])
            complements_per_subcat[subcat_type].update(all_complements)

            if not (set(optionals + requires) >= set(all_complements)):
                print(set(optionals + requires))
                print(set(all_complements))
                raise Exception(
                    f"Some complements don't appear in required or optional ({get_current_specs()})."
                )

            # Check that all the required are specified in the subcategorization
            if difference_list(requires, subcat.keys()) != []:
                raise Exception(
                    f"There is a required argument without a specification ({get_current_specs()})."
                )

            positions_per_complement = defaultdict(list)

            for complement_type in all_complements:
                curr_specs["comp"] = complement_type

                positions_per_complement[
                    complement_type] += subcat[complement_type][
                        ARG_POSITIONS] + subcat[complement_type][ARG_PREFIXES]

                complement_info = subcat[complement_type]
                for constraint in complement_info[ARG_CONSTRAINTS]:
                    is_known(constraint, ["ARG_CONSTRAINT"], "ARG CONSTRAINTS")

                if (ARG_CONSTRAINT_DET_POSS_NO_OTHER_OBJ in complement_info[ARG_CONSTRAINTS] and POS_DET_POSS not in complement_info[ARG_POSITIONS]) or \
                   (ARG_CONSTRAINT_N_N_MOD_NO_OTHER_OBJ in complement_info[ARG_CONSTRAINTS] and POS_N_N_MOD not in complement_info[ARG_POSITIONS]):
                    noms_with_missing_positions.append(word)

                argument_properties[ARG_POSITIONS].update(
                    complement_info[ARG_POSITIONS])
                argument_properties[ARG_PREFIXES].update(
                    complement_info[ARG_PREFIXES])
                argument_properties[ARG_ILLEGAL_PREFIXES].update(
                    complement_info.get(ARG_ILLEGAL_PREFIXES, []))

            for complement_type, positions in positions_per_complement.items():
                for other_complement_type, other_positions in positions_per_complement.items(
                ):
                    pos_intersection = set(positions).intersection(
                        other_positions)

                    if complement_type != other_complement_type and len(
                            pos_intersection) != 0 and pos_intersection != {
                                POS_PREFIX
                            }:
                        collided_args = sorted(
                            list({complement_type, other_complement_type}))
                        arguments_collisions[tuple(collided_args)].append(word)

            curr_specs["comp"] = None
            more_argument_constraints = get_right_value(
                argument_constraints, subcat_type, {}, is_verb)

            for complement_type in more_argument_constraints.keys():
                curr_specs["comp"] = complement_type
                if complement_type not in subcat.keys():
                    continue

                auto_controlled = []

                # Automatic constraints
                if complement_type.endswith("-POC"):
                    auto_controlled = [COMP_PP]
                elif complement_type.endswith("-NPC"):
                    auto_controlled = [COMP_NP]
                elif complement_type.endswith("-OC"):
                    auto_controlled = [COMP_OBJ]
                elif complement_type.endswith("-SC"):
                    auto_controlled = [COMP_SUBJ]
                elif complement_type.endswith("-VC"):
                    auto_controlled = [COMP_SUBJ, COMP_OBJ]

                    if subcat_type == "NOM-P-NP-TO-INF-VC":
                        auto_controlled = [COMP_SUBJ, COMP_PP]

                # Assure that the manual constraints were added correctly
                if set(auto_controlled) != set(subcat[complement_type].get(
                        ARG_CONTROLLED, [])):
                    print(subcat[complement_type].get(ARG_CONTROLLED, []))
                    print(auto_controlled)
                    raise Exception(
                        f"Manual controlled constraints do not agree with the automatic ones ({get_current_specs()})."
                    )

                if subcat[complement_type][ARG_POSITIONS] == []:
                    print(word, subcat_type, complement_type)
                    raise Exception(
                        f"There is a complement without any position ({get_current_specs()})."
                    )

            curr_specs["comp"] = None
示例#10
0
def use_nom_type(subcat, nom_type_info, is_verb=False):
    """
	Uses the type of the nominalization and specifies it for the given subcat (in the relevant complement as NOM)
	Sometimes the type of the nom specifies the position of the nom for the verb, and this info should also be included for the verb only
	:param subcat: a dictionary of the subcategorization info
	:param nom_type_info: the type of the nominalization (as a dictionary)
	:param is_verb: whether or not the given subcat is for verb rearranging (otherwise- nominalization)
	:return: None
	"""

    # Get the type of complements that appropriate to the given type of nominalization
    type_of_nom = without_part(nom_type_info[TYPE_OF_NOM])
    complement_types = nom_types_to_args_dict.get(type_of_nom, [])

    changed = False

    # Search for the first appropriate complement that the subcat includes
    for complement_type in complement_types:
        curr_specs["comp"] = complement_type

        # For verbs, a relevant complement also gets the position of the nom as a new possible position
        if is_verb:
            if complement_type in subcat[SUBCAT_REQUIRED] + subcat[
                    SUBCAT_OPTIONAL] and SUBCAT_CONSTRAINT_ALTERNATES not in subcat:
                if nom_type_info[TYPE_PP] != []:
                    subcat[type_of_nom] = list(
                        set(
                            subcat.get(complement_type, []) +
                            nom_type_info[TYPE_PP]))

                    if type_of_nom != complement_type:
                        subcat.pop(complement_type, None)

                        if complement_type in subcat[SUBCAT_REQUIRED]:
                            subcat[SUBCAT_REQUIRED] = list(
                                set(subcat[SUBCAT_REQUIRED] + [type_of_nom]))
                        else:
                            subcat[SUBCAT_OPTIONAL] = list(
                                set(subcat[SUBCAT_OPTIONAL] + [type_of_nom]))

                changed = True

        # For noms, the only position of the complement is NOM
        # The complement should appear in the required or optional lists
        elif complement_type in list(
                subcat.keys()) + subcat[SUBCAT_REQUIRED] + subcat[
                    SUBCAT_OPTIONAL]:  # or complement_type == COMP_INSTRUMENT
            changed = True

            # Instead of the founded relevant complement, we will write the type of nom as a new complement
            subcat.pop(complement_type, None)
            subcat[type_of_nom] = [POS_NOM]
            subcat[SUBCAT_REQUIRED] = list(
                set(subcat[SUBCAT_REQUIRED] + [type_of_nom])
            )  # NOM must be required for the nominalization
            subcat[SUBCAT_OPTIONAL] = difference_list(subcat[SUBCAT_OPTIONAL],
                                                      [type_of_nom])

        if changed:
            # Remove the old complement type from both required and optional lists
            # The founded type can be different than the searched on only for IND-OBJ
            if complement_type != type_of_nom:
                subcat[SUBCAT_REQUIRED] = list(
                    set(
                        difference_list(subcat[SUBCAT_REQUIRED],
                                        [complement_type])))
                subcat[SUBCAT_OPTIONAL] = list(
                    set(
                        difference_list(subcat[SUBCAT_OPTIONAL],
                                        [complement_type])))

            # If we replaced PP1 with IND-OBJ, then PP2 should actually mean the complement PP
            if complement_type == COMP_PP1:
                subcat[COMP_PP] = difference_list(
                    subcat.pop(COMP_PP2),
                    [POS_NOM])  # PP2 cannot also be the NOM
                if COMP_PP2 in subcat[SUBCAT_REQUIRED]:
                    subcat[SUBCAT_REQUIRED] = difference_list(
                        subcat[SUBCAT_REQUIRED], [COMP_PP2]) + [COMP_PP]
                else:
                    subcat[SUBCAT_OPTIONAL] = difference_list(
                        subcat[SUBCAT_OPTIONAL], [COMP_PP2]) + [COMP_PP]

            break

    curr_specs["comp"] = None
示例#11
0
def rearrange_requires_and_optionals(subcat,
                                     subcat_type,
                                     default_requires,
                                     other_subcat_types,
                                     is_verb=False):
    """
	Rearranges the requires and optionals for the given subcat entry
	:param subcat: a dictionary of the subcategorization info
	:param subcat_type: the type of the subcategorization (for determing whethet the object is required)
	:param default_requires: list of arguments that and constraints that are required for the given subcat
	:param other_subcat_types: list of the other subcat types in the current lexicon entry
	:param is_verb: whether or not the given subcat is for verb rearranging (otherwise- nominalization)
	:return: None
	"""

    requires = []
    optionals = []

    if is_verb and COMP_PART in subcat:
        requires.append(COMP_PART)

    # Updating the list of required complements
    for complement_type in subcat.get(SUBCAT_REQUIRED, {}).keys():
        curr_specs["comp"] = complement_type

        # The required arguments are the ones with no constraints in the required subentry
        if list(subcat[SUBCAT_REQUIRED][complement_type].keys()) == []:
            requires.append(complement_type)
        else:
            # Otherwise, the arguments are required under one of the constraints: DET-POSS-ONLY or N-N-MOD-ONLY
            # Specify those constraints for the subcategorization (Relevant for the NOMLEX-plus only)

            if "DET-POSS-ONLY" in list(
                    subcat[SUBCAT_REQUIRED][complement_type].keys()):
                subcat[ARG_CONSTRAINT_DET_POSS_NO_OTHER_OBJ] += [
                    complement_type
                ]

            if "N-N-MOD-ONLY" in list(
                    subcat[SUBCAT_REQUIRED][complement_type].keys()):
                subcat[ARG_CONSTRAINT_N_N_MOD_NO_OTHER_OBJ] += [
                    complement_type
                ]

    # Adding complements with possible optional positon to the optional list
    tmp_subcat = deepcopy(subcat)
    for complement_type in difference_list(tmp_subcat.keys(), [
            ARG_CONSTRAINT_DET_POSS_NO_OTHER_OBJ,
            ARG_CONSTRAINT_N_N_MOD_NO_OTHER_OBJ
    ]):
        curr_specs["comp"] = complement_type

        if OPT_POS in subcat[complement_type]:
            optionals += [complement_type]
            subcat[complement_type].remove(OPT_POS)

            # Assumption- OPTIONAL-POSITION value (meaning NONE\*NONE*) is preferable to the information in the requires list
            if complement_type in requires:
                requires.remove(complement_type)

        # Delete the complement if it has no possible options
        if subcat[complement_type] == []:
            del subcat[complement_type]

    curr_specs["comp"] = None

    # All the non-optional constraints in the default requires list are also required
    requires += difference_list(default_requires, optionals)

    # OBJECT is optional for NOM-NP-X subcats, only if NOM-X isn't compatible with the current entry, otherwise it is required
    if without_part(subcat_type).startswith("NOM-NP"):
        # Object is required in the next cases
        obj_is_required = False
        if subcat_type == "NOM-NP":
            if not {"NOM-INTRANS", "NOM-INTRANS-RECIP"
                    }.isdisjoint(other_subcat_types):
                obj_is_required = True
        elif subcat_type == "NOM-NP-AS-NP-SC":
            if "NOM-AS-NP" in other_subcat_types:
                obj_is_required = True
        else:
            subcat_without_np = subcat_type.replace("NOM-PART-NP", "NOM-PART")
            subcat_without_np = subcat_without_np.replace("NOM-NP-", "NOM-")
            if subcat_without_np in other_subcat_types:
                obj_is_required = True

        if obj_is_required:
            requires.append(COMP_OBJ)
            optionals = difference_list(optionals, [COMP_OBJ])
        elif COMP_OBJ not in requires:
            optionals.append(COMP_OBJ)

    # SUBJECT is optional by default
    if COMP_SUBJ not in requires:
        optionals.append(COMP_SUBJ)

    subcat[SUBCAT_REQUIRED] = list(set(requires))
    subcat[SUBCAT_OPTIONAL] = list(set(optionals))
    def determine_args_type(self,
                            candidates_args,
                            predicate: ExtractedArgument,
                            verb,
                            default_subcat=False):
        # Determines the most appropriate type of each candidate, using a model

        uncertain_types = list(self.tagset.keys()) + (
            [COMP_PP1, COMP_PP2] if COMP_PP in self.tagset else [])
        uncertain_candidates = {}
        predicate_token = predicate.get_token()
        determined_dict = {}
        none_spans = []

        # Each candidate should take one appropriate type, determined by the model
        for candidate_span, role_types in candidates_args.items():
            role_types = set(role_types)

            if predicate.get_token().i == candidate_span[
                    0].i or role_types.isdisjoint(uncertain_types):
                determined_dict[candidate_span] = role_types
                continue

            if candidate_span.lemma_ in [
                    "i", "he", "she", "it", "they", "we", "-PRON-"
            ]:
                determined_dict[candidate_span] = role_types
                continue

            role_types.add(COMP_NONE)
            logits = self.get_types_distribution(candidate_span, role_types,
                                                 predicate_token, verb,
                                                 default_subcat)

            if logits.argmax().item() == self.tagset[COMP_NONE]:
                none_spans.append(candidate_span)
            else:
                uncertain_candidates[candidate_span] = logits

        if len(uncertain_candidates) == 0:
            return determined_dict

        #print(dict(candidates_args))

        # if uncertain_candidates == {}:
        # 	return {}

        u = list(uncertain_candidates.keys())
        u += [None] * (len(uncertain_types) - 2
                       )  #(len(self.tagset) - 1 - len(uncertain_candidates))

        certain_types = [
        ]  #[list(types)[0] for types in determined_dict.values() if len(types) == 1]
        role_types = difference_list(uncertain_types,
                                     [COMP_NONE] + certain_types)

        #if len(predicate_types) == 1:
        #	role_types = difference_list(role_types, predicate_types)

        types_combinations = list(permutations(u, len(role_types)))
        empty_comb = tuple([None] * len(role_types))
        if empty_comb not in types_combinations:
            types_combinations.append(empty_comb)

        #print(predicate.get_token(), types_combinations)

        args_sum_logits = []

        for comb in types_combinations:
            # sum_logits = 0
            # for i, arg in enumerate(comb):
            # 	if arg:
            # 		print(i, role_types[i], uncertain_candidates[arg][self.tagset[role_types[i]]])
            # 		sum_logits += uncertain_candidates[arg][self.tagset[role_types[i]]].item()

            sum_logits = sum([
                uncertain_candidates[arg][self.tagset[role_types[i]]]
                for i, arg in enumerate(comb) if arg
            ])
            sum_logits += sum([
                uncertain_candidates[arg][self.tagset[COMP_NONE]].item()
                for arg in set(u).difference(comb) if arg
            ])
            args_sum_logits.append(sum_logits)

        #print(predicate.get_token(), args_sum_logits)
        max_idx = int(np.argmax(args_sum_logits))
        best = types_combinations[max_idx]

        determined_dict.update(
            {arg: [role_types[i]]
             for i, arg in enumerate(best) if arg})

        for arg in difference_list(candidates_args.keys(),
                                   determined_dict.keys()):
            determined_dict[arg] = difference_list(candidates_args[arg],
                                                   uncertain_types)

        #if predicate_span:
        #	determined_dict[predicate_span] = predicate_types

        #assert all([set(determined_dict[s]).isdisjoint(uncertain_types) for s in none_spans])

        #print(predicate.get_token(), len(types_combinations), determined_dict)
        return determined_dict