Ejemplo n.º 1
0
def tasks_and_vocab_from_params(params: Params, serialization_dir: str) -> Tuple[List[Task], Dictionary]:
  """
  """
  task_list = []
  instances_for_vocab_creation = itertools.chain()
  datasets_for_vocab_creation = {}
  task_keys = [key for key in params.keys() if re.search("^task_", key)]

  for key in task_keys:
    LOGGER.info("Creating task '{}'".format(key))
    task_params = params.pop(key)
    task_description = task_params.pop("task_description")
    task_data_params = task_params.pop("data_params")

    task = Task.from_params(params=task_description)
    task_list.append(task)

    task_instances_for_vocab, task_datasets_for_vocab = task.setup_data(params=task_data_params)
    instances_for_vocab_creation = itertools.chain(instances_for_vocab_creation, task_instances_for_vocab)
    datasets_for_vocab_creation[task.name] = task_datasets_for_vocab

  # Create and save the dictionary
  for task_name, task_dataset_list in datasets_for_vocab_creation.items():
    LOGGER.info("creating dictionary for '{} from '{}'".format(task_name, ', '.join(task_dataset_list)))

  LOGGER.info('fitting dictionary from dataset')
  vocab = Dictionary.from_params(params.pop("dictionary", {}), instances_for_vocab_creation)

  # vocab save_to_files

  return task_list, vocab
Ejemplo n.º 2
0
def prepare_environment(params: Params):
  """
  Sets random seeds for reproducible experiments. This may not work as expected
  if you use this from within a python project in which you have already imported Pytorch.
  If you use the scripts/run_model.py entry point to training models with this library,
  your experiments should be reasonably reproducible. If you are using this from your own
  project, you will want to call this function before importing Pytorch. Complete determinism
  is very difficult to achieve with libraries doing optimized linear algebra due to massively
  parallel execution, which is exacerbated by using GPUs.
  Parameters
  ----------
  params: Params object or dict, required.
      A ``Params`` object or dict holding the json parameters.
  """
  seed = params.pop_int("random_seed", 13370)
  numpy_seed = params.pop_int("numpy_seed", 1337)
  torch_seed = params.pop_int("pytorch_seed", 133)

  if seed is not None:
    random.seed(seed)
  if numpy_seed is not None:
    numpy.random.seed(numpy_seed)
  if torch_seed is not None:
    torch.manual_seed(torch_seed)
    # Seed all GPUs with the same seed if available.
    if torch.cuda.is_available():
      torch.cuda.manual_seed_all(torch_seed)
Ejemplo n.º 3
0
def create_kwargs(cls: Type[T], params: Params, **extras) -> Dict[str, Any]:
    """
    Given some class, a `Params` object, and potentially other keyword arguments,
    create a dict of keyword args suitable for passing to the class's constructor.

    The function does this by finding the class's constructor, matching the constructor
    arguments to entries in the `params` object, and instantiating values for the parameters
    using the type annotation and possibly a from_params method.

    Any values that are provided in the `extras` will just be used as is.
    For instance, you might provide an existing `Vocabulary` this way.
    """
    # Get the signature of the constructor.
    signature = inspect.signature(cls.__init__)
    kwargs: Dict[str, Any] = {}

    # Iterate over all the constructor parameters and their annotations.
    for name, param in signature.parameters.items():
        # Skip "self". You're not *required* to call the first parameter "self",
        # so in theory this logic is fragile, but if you don't call the self parameter
        # "self" you kind of deserve what happens.
        if name == "self":
            continue

        # If the annotation is a compound type like typing.Dict[str, int],
        # it will have an __origin__ field indicating `typing.Dict`
        # and an __args__ field indicating `(str, int)`. We capture both.
        annotation = remove_optional(param.annotation)
        kwargs[name] = construct_arg(cls, name, annotation, param.default,
                                     params, **extras)

    params.assert_empty(cls.__name__)
    return kwargs
Ejemplo n.º 4
0
def datasets_from_params(
  params: Params
) -> Dict[str, Iterable[Instance]]:
  # Receive the configuration for the dataset reader to use
  dataset_reader_params = params.pop("dataset_reader")

  # Initialize the dataset reader
  dataset_reader = DatasetReader.from_params(dataset_reader_params)

  # We will definitively need a training data path
  training_data_path = params.pop("train_data_path")
  LOGGER.info(f"reading training data from path '{training_data_path}'")
  train_data = dataset_reader.read(training_data_path)

  datasets: Dict[str, Iterable[Instance]] = {"train": train_data}

  # Now the optional stuff: validation and test datasets
  validation_data_path = params.pop("validation_data_path", None)
  if validation_data_path is not None:
    LOGGER.info(f"reading validation data from path '{validation_data_path}'")
    validation_data = dataset_reader.read(validation_data_path)
    datasets["validation"] = validation_data
  
  test_data_path = params.pop("test_data_path", None)
  if test_data_path is not None:
    LOGGER.info(f"reading test data from path '{test_data_path}'")
    test_data = dataset_reader.read(test_data_path)
    datasets["test"] = test_data

  # Done, now return the dictionary of all datasets
  return datasets
Ejemplo n.º 5
0
    def from_params(cls, params: Params, instances: Iterable['adi.Instance'] = None):
        """
        """
        dictionary_type = params.pop("type", None)
        if dictionary_type is not None:
            return cls.by_name(dictionary_type).from_params(params=params, instances=instances)

        # Should we extend the dictionary
        extend = params.pop("extend", False)
        dictionary_path = params.pop("directory_path", None)

        if not dictionary_path and not instances:
            raise ConfigurationError("you must either provide a directory_path inside the parameters or a dataset to build a dictionary from")

        if extend and not instances:
            raise ConfigurationError("'extend' is activated, but there are no instances to pass through")
        if extend and not dictionary_path:
            raise ConfigurationError("'entend' is activated, but there is no 'directory_path' to extend from.")

        if dictionary_path and instances:
            if extend:
                LOGGER.info("loading the dictionary from files and extending it with a dataset.")
            else:
                LOGGER.info("loading the dictionary from files instead of a dataset")

        # Enough parameter evaluation, now let's finally create and initialize the data
        if dictionary_path:
            vocab = cls.from_files(dictionary_path)
            if not extend:
                return vocab

        if extend:
            vocab.extend_from_instances(params, instances=instances)
            return vocab

        # There is no dictionary path given and we should not extend, so we have to create the 
        # vocabulary from a dataset
        min_count = params.pop("min_count", None, keep_as_dict=True)
        max_vocab_size = pop_max_vocab_size(params)
        non_padded_namespaces = params.pop("non_padded_namespaces", DEFAULT_NON_PADDED_NAMESPACES)
        pretrained_files = params.pop("pretrained_files", {}, keep_as_dict=True)
        min_pretrained_embeddings = params.pop("min_pretrained_embeddings", None)
        only_include_pretrained_words = params.pop_bool("only_include_pretrained_words", False)
        tokens_to_add = params.pop("tokens_to_add", None)

        return cls.from_instances(
            instances=instances,
            min_count=min_count,
            max_vocab_size=max_vocab_size,
            non_padded_namespaces=non_padded_namespaces,
            pretrained_files=pretrained_files,
            only_include_pretrained_words=only_include_pretrained_words,
            tokens_to_add=tokens_to_add,
            min_pretrained_embeddings=min_pretrained_embeddings
        )
Ejemplo n.º 6
0
    def from_params(cls, model: BaseFairseqModel, task_list: List[Task],
                    serialization_dir: str, params: Params) -> 'Trainer':
        """
    """
        cuda_device = params.pop_int("cuda_device", -1)
        grad_clipping = params.pop_float("grad_clipping", 0.1)
        grad_norm = params.pop_float("grad_norm", 5.0)
        min_lr = params.pop_float("min_lr", 1e-7)
        num_epochs = params.pop_int("num_epochs", 100)
        patience = params.pop_int("patience", 5)

        optimizer_params = params.pop("optimizer", None)
        parameters_to_train = [(n, p) for n, p in model.named_parameters()
                               if p.requires_grad]
        optimizer_ = Optimizer.from_params(
            model_parameters=parameters_to_train, params=optimizer_params)

        return cls(model=model,
                   task_list=task_list,
                   serialization_dir=serialization_dir,
                   cuda_device=cuda_device,
                   grad_clipping=grad_clipping,
                   grad_norm=grad_norm,
                   min_lr=min_lr,
                   num_epochs=num_epochs,
                   patience=patience,
                   optimizer=optimizer_)
Ejemplo n.º 7
0
def main():
  """
  """
  parser = argparse.ArgumentParser()
  parser.add_argument('--serialization_dir', type=str, 
                      help='The directory where to save trained models, etc.')
  parser.add_argument('--params', type=str, 
                      help='path to the parameter file describing the tasks to train.')
  parser.add_argument('--seed', type=int, default=1,
                      help='The random seed to use for the initialization of PyTorch and numpy.')
  parser.add_argument('--recover', action='store_true',
                      help='Recover from a previous experiment?')
  args = parser.parse_args()

  # Import user defined modules
  utils.import_user_module(args)

  # If we are in polyaxon redirect 
  if IN_CLUSTER:
    args.serialization_dir = get_outputs_path()

  # Set the random seed
  np.random.seed(args.seed)
  torch.manual_seed(args.seed)

  # Read the parameter file
  params = Params.from_file(args.params)
  serialization_dir = args.serialization_dir

  # Create the serialization directory
  create_serialization_dir(serialization_dir)
  
  # Write the parameter file to the output directory
  with open(os.path.join(serialization_dir, 'config.json'), 'w') as fout:
    json.dump(deepcopy(params).as_dict(quiet=True), fout, indent=2)


  # Call the tasks_and_vocab_from_params method
  tasks, vocab = tasks_and_vocab_from_params(params=params, serialization_dir=serialization_dir)

  # Load the data iterator for all tasks

  # Create the model
  model_params = params.pop("model")
  model = BaseFairseqModel.from_params(vocab=vocab, params=model_params)

  LOGGER.info("created model")
  print("created model: {}".format(model))

  # Finally, create an instance of the required trainer
  trainer_params = params.pop("trainer")
  # TODO(naetherm): Dependent on the trainer type ...
  trainer = BaseTrainer.from_params(model=model, task_list=tasks, serialization_dir=serialization_dir, params=trainer_params)

  # Everything is set up, start the training
  train(trainer)
Ejemplo n.º 8
0
    def from_params(cls, model: BaseFairseqModel, task_list: List[Task],
                    serialization_dir: str,
                    params: Params) -> 'MultiTaskTrainer':
        """
    Static class method that constructs a multi task trainer, based on the 
    description given in ``params``.
    """
        choices = params.pop_choice("type", cls.list_available())

        return cls.by_name(choices).from_params(
            model=model,
            task_list=task_list,
            serialization_dir=serialization_dir,
            params=params)
Ejemplo n.º 9
0
def pop_max_vocab_size(params: Params) -> Union[int, Dict[str, int]]:
    """
    max_vocab_size limits the size of the vocabulary, not including the @@UNKNOWN@@ token.
    max_vocab_size is allowed to be either an int or a Dict[str, int] (or nothing).
    But it could also be a string representing an int (in the case of environment variable
    substitution). So we need some complex logic to handle it.
    """
    size = params.pop("max_vocab_size", None, keep_as_dict=True)

    if isinstance(size, dict):
        # This is the Dict[str, int] case.
        return size
    elif size is not None:
        # This is the int / str case.
        return int(size)
    else:
        return None
Ejemplo n.º 10
0
    def from_params(cls, params: Params) -> 'Task':
        """
    Create a task instance from parameters.
    """
        task_name = params.pop("task_name", "ensec")
        validation_metric_name = params.pop("validation_metric_name", None)
        validation_metric_decreases = params.pop_bool(
            "validation_metric_decreases", False)
        evaluate_on_test = params.pop_bool("evaluate_on_test", False)

        params.assert_empty(cls.__name__)

        return cls(name=task_name,
                   validation_metric_name=validation_metric_name,
                   validation_metric_decreases=validation_metric_decreases,
                   evaluate_on_test=evaluate_on_test)
Ejemplo n.º 11
0
    def extend_from_instances(
        self,
        params: Params,
        instances: Iterable['adi.Instance'] = ()
    ) -> None:
        """
        Here we extend the already existing dictionary with additional instances from 
        the given datasets (instances).
        """
        min_count_ = params.pop("min_count", None)
        max_vocab_size_ = params.pop("max_vocab_size", None)

        if isinstance(max_vocab_size_, Params):
            # This is the Dict[str, int] case.
            max_vocab_size_ = max_vocab_size_.as_dict()
        elif max_vocab_size_ is not None:
            # This is the int / str case.
            max_vocab_size_ =  int(max_vocab_size_)
        else:
            max_vocab_size_ = None

        non_padded_namespaces = params.pop("non_padded_namespaces", DEFAULT_NON_PADDED_NAMESPACES)
        pretrained_files = params.pop("pretrained_files", {})
        min_pretrained_embeddings = params.pop("min_pretrained_embeddings", None)
        only_include_pretrained_words = params.pop_bool("only_include_pretrained_words", False)
        tokens_to_add = params.pop("tokens_to_add", None)

        LOGGER.info("Fitting token dictionary from dataset")

        namespace_token_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
        for i in Tqdm.tqdm(instances):
            i.count_vocab_items(namespace_token_counts)
        self.extend(
            counter=namespace_token_counts,
            min_count=min_count_,
            max_vocab_size=max_vocab_size_,
            non_padded_namespaces=non_padded_namespaces,
            pretrained_files=pretrained_files,
            only_include_pretrained_words=only_include_pretrained_words,
            tokens_to_add=tokens_to_add,
            min_pretrained_embeddings=min_pretrained_embeddings
        )
Ejemplo n.º 12
0
    def setup_data(self, params: Params):
        """
    This method is responsible for fetching the dataset information from the given 
    parameters and setup everything related to the data.
    """
        all_datasets = datasets_from_params(params)
        datasets_for_vocab_creation = set(
            params.pop("datasets_for_vocab_creation", all_datasets))

        LOGGER.info(
            f"datasets_for_vocab_creation: {datasets_for_vocab_creation}")

        for dataset in datasets_for_vocab_creation:
            if dataset not in all_datasets:
                raise ConfigurationError(
                    f"the dataset {dataset} is not known in 'all_datasets")

        # TODO(naetherm): Implement me!
        instances_for_vocab_creation = ()

        self.instances_for_vocab_creation = instances_for_vocab_creation
        self.datasets_for_vocab_creation = datasets_for_vocab_creation

        if "train" in all_datasets.keys():
            self.train_data = all_datasets["train"]
            self.train_instances = sum(1 for e in self.train_data)
        if "validation" in all_datasets.keys():
            self.validation_data = all_datasets["validation"]
            self.validation_instances = sum(1 for e in self.validation_data)
        if "test" in all_datasets.keys():
            self.test_data = all_datasets["test"]
            self.test_instances = sum(1 for e in self.test_data)

        # Security check: If we want to evaluate on the test data we _must_ have test data!
        if self.evaluate_on_test:
            assert self.test_data is not None

        return self.instances_for_vocab_creation, self.datasets_for_vocab_creation
Ejemplo n.º 13
0
    def from_params(cls: Type[T], params: Params, **extras) -> T:
        """
        This is the automatic implementation of `from_params`. Any class that subclasses `FromParams`
        (or `Registrable`, which itself subclasses `FromParams`) gets this implementation for free.
        If you want your class to be instantiated from params in the "obvious" way -- pop off parameters
        and hand them to your constructor with the same names -- this provides that functionality.

        If you need more complex logic in your from `from_params` method, you'll have to implement
        your own method that overrides this one.
        """
        # pylint: disable=protected-access
        from fairseq.common.registrable import Registrable  # import here to avoid circular imports

        logger.info(
            f"instantiating class {cls} from params {getattr(params, 'params', params)} "
            f"and extras {set(extras.keys())}")

        if params is None:
            return None

        if isinstance(params, str):
            params = Params({"type": params})

        registered_subclasses = Registrable._registry.get(cls)

        if registered_subclasses is not None:
            # We know ``cls`` inherits from Registrable, so we'll use a cast to make mypy happy.
            # We have to use a disable to make pylint happy.
            # pylint: disable=no-member
            as_registrable = cast(Type[Registrable], cls)
            default_to_first_choice = as_registrable.default_implementation is not None
            choice = params.pop_choice(
                "type",
                choices=as_registrable.list_available(),
                default_to_first_choice=default_to_first_choice)
            subclass = registered_subclasses[choice]

            if hasattr(subclass, 'from_params'):
                # We want to call subclass.from_params
                extras = create_extras(subclass, extras)
                return subclass.from_params(params=params, **extras)
            else:
                # In some rare cases, we get a registered subclass that does _not_ have a
                # from_params method (this happens with Activations, for instance, where we
                # register pytorch modules directly).  This is a bit of a hack to make those work,
                # instead of adding a `from_params` method for them somehow.  We just trust that
                # you've done the right thing in passing your parameters, and nothing else needs to
                # be recursively constructed.
                extras = create_extras(subclass, extras)
                constructor_args = {**params, **extras}
                return subclass(**constructor_args)
        else:
            # This is not a base class, so convert our params and extras into a dict of kwargs.

            if cls.__init__ == object.__init__:
                # This class does not have an explicit constructor, so don't give it any kwargs.
                # Without this logic, create_kwargs will look at object.__init__ and see that
                # it takes *args and **kwargs and look for those.
                kwargs: Dict[str, Any] = {}
            else:
                # This class has a constructor, so create kwargs for it.
                kwargs = create_kwargs(cls, params, **extras)

            return cls(**kwargs)  # type: ignore
Ejemplo n.º 14
0
def construct_arg(
        cls: Type[T],  # pylint: disable=inconsistent-return-statements,too-many-return-statements
        param_name: str,
        annotation: Type,
        default: Any,
        params: Params,
        **extras) -> Any:
    """
    Does the work of actually constructing an individual argument for :func:`create_kwargs`.

    Here we're in the inner loop of iterating over the parameters to a particular constructor,
    trying to construct just one of them.  The information we get for that parameter is its name,
    its type annotation, and its default value; we also get the full set of ``Params`` for
    constructing the object (which we may mutate), and any ``extras`` that the constructor might
    need.

    We take the type annotation and default value here separately, instead of using an
    ``inspect.Parameter`` object directly, so that we can handle ``Union`` types using recursion on
    this method, trying the different annotation types in the union in turn.
    """
    from allennlp.models.archival import load_archive  # import here to avoid circular imports

    # We used `param_name` as the method argument to avoid conflicts with 'name' being a key in
    # `extras`, which isn't _that_ unlikely.  Now that we are inside the method, we can switch back
    # to using `name`.
    name = param_name
    origin = getattr(annotation, '__origin__', None)
    args = getattr(annotation, '__args__', [])

    # The parameter is optional if its default value is not the "no default" sentinel.
    optional = default != _NO_DEFAULT

    # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary.
    # We check the provided `extras` for these and just use them if they exist.
    if name in extras:
        return extras[name]
    # Next case is when argument should be loaded from pretrained archive.
    elif name in params and isinstance(
            params.get(name), Params) and "_pretrained" in params.get(name):
        load_module_params = params.pop(name).pop("_pretrained")
        archive_file = load_module_params.pop("archive_file")
        module_path = load_module_params.pop("module_path")
        freeze = load_module_params.pop("freeze", True)
        archive = load_archive(archive_file)
        result = archive.extract_module(module_path, freeze)  # pylint: disable=no-member
        if not isinstance(result, annotation):
            raise ConfigurationError(
                f"The module from model at {archive_file} at path {module_path} "
                f"was expected of type {annotation} but is of type {type(result)}"
            )
        return result
    # The next case is when the parameter type is itself constructible from_params.
    elif hasattr(annotation, 'from_params'):
        if name in params:
            # Our params have an entry for this, so we use that.
            subparams = params.pop(name)

            subextras = create_extras(annotation, extras)

            # In some cases we allow a string instead of a param dict, so
            # we need to handle that case separately.
            if isinstance(subparams, str):
                return annotation.by_name(subparams)()
            else:
                return annotation.from_params(params=subparams, **subextras)
        elif not optional:
            # Not optional and not supplied, that's an error!
            raise ConfigurationError(f"expected key {name} for {cls.__name__}")
        else:
            return default

    # If the parameter type is a Python primitive, just pop it off
    # using the correct casting pop_xyz operation.
    elif annotation == str:
        return params.pop(name, default) if optional else params.pop(name)
    elif annotation == int:
        return params.pop_int(name,
                              default) if optional else params.pop_int(name)
    elif annotation == bool:
        return params.pop_bool(name,
                               default) if optional else params.pop_bool(name)
    elif annotation == float:
        return params.pop_float(
            name, default) if optional else params.pop_float(name)

    # This is special logic for handling types like Dict[str, TokenIndexer],
    # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],
    # which it creates by instantiating each value from_params and returning the resulting structure.
    elif origin in (Dict, dict) and len(args) == 2 and hasattr(
            args[-1], 'from_params'):
        value_cls = annotation.__args__[-1]

        value_dict = {}

        for key, value_params in params.pop(name, Params({})).items():
            subextras = create_extras(value_cls, extras)
            value_dict[key] = value_cls.from_params(params=value_params,
                                                    **subextras)

        return value_dict

    elif origin in (List, list) and len(args) == 1 and hasattr(
            args[0], 'from_params'):
        value_cls = annotation.__args__[0]

        value_list = []

        for value_params in params.pop(name, Params({})):
            subextras = create_extras(value_cls, extras)
            value_list.append(
                value_cls.from_params(params=value_params, **subextras))

        return value_list

    elif origin in (Tuple, tuple) and all(
            hasattr(arg, 'from_params') for arg in args):
        value_list = []

        for value_cls, value_params in zip(annotation.__args__,
                                           params.pop(name, Params({}))):
            subextras = create_extras(value_cls, extras)
            value_list.append(
                value_cls.from_params(params=value_params, **subextras))

        return tuple(value_list)

    elif origin in (Set, set) and len(args) == 1 and hasattr(
            args[0], 'from_params'):
        value_cls = annotation.__args__[0]

        value_set = set()

        for value_params in params.pop(name, Params({})):
            subextras = create_extras(value_cls, extras)
            value_set.add(
                value_cls.from_params(params=value_params, **subextras))

        return value_set

    elif origin == Union:
        # Storing this so we can recover it later if we need to.
        param_value = params.get(name, Params({}))
        if isinstance(param_value, Params):
            param_value = param_value.duplicate()

        # We'll try each of the given types in the union sequentially, returning the first one that
        # succeeds.
        for arg in args:
            try:
                return construct_arg(cls, name, arg, default, params, **extras)
            except (ValueError, TypeError, ConfigurationError, AttributeError):
                # Our attempt to construct the argument may have popped `params[name]`, so we
                # restore it here.
                params[name] = param_value
                if isinstance(param_value, Params):
                    param_value = param_value.duplicate()
                continue

        # If none of them succeeded, we crash.
        raise ConfigurationError(
            f"Failed to construct argument {name} with type {annotation}")
    else:
        # Pass it on as is and hope for the best.   ¯\_(ツ)_/¯
        if optional:
            return params.pop(name, default)
        else:
            return params.pop(name)
Ejemplo n.º 15
0
    def from_params(cls, model_parameters: List, params: Params):

        if isinstance(params, str):
            optimizer = params
            params = Params({})
        else:
            optimizer = params.pop_choice("type", Optimizer.list_available())

        # make the parameter groups if need
        groups = params.pop("parameter_groups", None)
        if groups:
            # The input to the optimizer is list of dict.
            # Each dict contains a "parameter group" and groups specific options,
            # e.g., {'params': [list of parameters], 'lr': 1e-3, ...}
            # Any config option not specified in the additional options (e.g.
            # for the default group) is inherited from the top level config.
            # see: https://pytorch.org/docs/0.3.0/optim.html?#per-parameter-options
            #
            # groups contains something like:
            #"parameter_groups": [
            #       [["regex1", "regex2"], {"lr": 1e-3}],
            #       [["regex3"], {"lr": 1e-4}]
            #]
            #(note that the allennlp config files require double quotes ", and will
            # fail (sometimes silently) with single quotes ').

            # This is typed as as Any since the dict values other then
            # the params key are passed to the Optimizer constructor and
            # can be any type it accepts.
            # In addition to any parameters that match group specific regex,
            # we also need a group for the remaining "default" group.
            # Those will be included in the last entry of parameter_groups.
            parameter_groups: Any = [{
                'params': []
            } for _ in range(len(groups) + 1)]
            # add the group specific kwargs
            for k in range(len(groups)):  # pylint: disable=consider-using-enumerate
                parameter_groups[k].update(groups[k][1].as_dict())

            regex_use_counts: Dict[str, int] = {}
            parameter_group_names: List[set] = [
                set() for _ in range(len(groups) + 1)
            ]
            for name, param in model_parameters:
                # Determine the group for this parameter.
                group_index = None
                for k, group_regexes in enumerate(groups):
                    for regex in group_regexes[0]:
                        if regex not in regex_use_counts:
                            regex_use_counts[regex] = 0
                        if re.search(regex, name):
                            if group_index is not None and group_index != k:
                                raise ValueError(
                                    "{} was specified in two separate parameter groups"
                                    .format(name))
                            group_index = k
                            regex_use_counts[regex] += 1

                if group_index is not None:
                    parameter_groups[group_index]['params'].append(param)
                    parameter_group_names[group_index].add(name)
                else:
                    # the default group
                    parameter_groups[-1]['params'].append(param)
                    parameter_group_names[-1].add(name)

            # log the parameter groups
            LOGGER.info("Done constructing parameter groups.")
            for k in range(len(groups) + 1):
                group_options = {
                    key: val
                    for key, val in parameter_groups[k].items()
                    if key != 'params'
                }
                LOGGER.info("Group %s: %s, %s", k,
                            list(parameter_group_names[k]), group_options)
            # check for unused regex
            for regex, count in regex_use_counts.items():
                if count == 0:
                    LOGGER.warning(
                        "When constructing parameter groups, "
                        " %s not match any parameter name", regex)

        else:
            parameter_groups = [param for name, param in model_parameters]

        # Log the number of parameters to optimize
        num_parameters = 0
        for parameter_group in parameter_groups:
            if isinstance(parameter_group, dict):
                num_parameters += sum(
                    parameter.numel()
                    for parameter in parameter_group["params"])
            else:
                num_parameters += parameter_group.numel()
        LOGGER.info("Number of trainable parameters: %s", num_parameters)

        # By default we cast things that e.g. look like floats to floats before handing them
        # to the Optimizer constructor, but if you want to disable that behavior you could add a
        #       "infer_type_and_cast": false
        # key to your "trainer.optimizer" config.
        infer_type_and_cast = params.pop_bool("infer_type_and_cast", True)
        params_as_dict = params.as_dict(
            infer_type_and_cast=infer_type_and_cast)
        subclass = Optimizer.by_name(optimizer)

        # If the optimizer subclass has a from_params, use it.
        if hasattr(subclass, 'from_params'):
            return subclass.from_params(parameter_groups, params=params)
        else:
            return subclass(parameter_groups, **params_as_dict)  # type: ignore