예제 #1
0
def unflatten(flat_dict):
    """ Unflatten flatten dict.

    Given a "flattened" dict with compound keys, e.g.
        {"a.b": 0}
    unflatten it:
        {"a": {"b": 0}}
    """
    unflat_dict = {}

    for compound_key, value in flat_dict.items():
        curr_dict = unflat_dict
        parts = compound_key.split(".")
        for key in parts[:-1]:
            curr_value = curr_dict.get(key)
            if key not in curr_dict:
                curr_dict[key] = {}
                curr_dict = curr_dict[key]
            elif isinstance(curr_value, dict):
                curr_dict = curr_value
            else:
                raise ConfigurationError("flattened dictionary is invalid")
        if not isinstance(curr_dict, dict) or parts[-1] in curr_dict:
            raise ConfigurationError("flattened dictionary is invalid")
        else:
            curr_dict[parts[-1]] = value

    return unflat_dict
예제 #2
0
def create_serialization_dir(params, serialization_dir, reset):
    """
    This function creates the serialization directory if it doesn't exist.  If it already exists
    and is non-empty, then it verifies that we're recovering from a training with an identical
    configuration.

    Parameters:
        params (Params): A parameter object specifying an AllenNLP Experiment.
        serialization_dir (str): The directory in which to save results and logs.
        reset (bool): If ``True``, we will overwrite the serialization directory if it already
            exists.
    """

    if os.path.exists(serialization_dir) and reset:
        shutil.rmtree(serialization_dir)

    if os.path.exists(serialization_dir):
        logger.info(f"Recovering from prior training at {serialization_dir}.")

        recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME)
        if not os.path.exists(recovered_config_file):
            raise ConfigurationError(
                "The serialization directory already exists but doesn't "
                f"contain a {CONFIG_NAME}. You probably gave the wrong directory."
            )

        loaded_params = Params.load(recovered_config_file)

        # Check whether any of the training configuration differs from the configuration we are
        # resuming.  If so, warn the user that training may fail.
        fail = False
        flat_params = params.as_flat_dict()
        flat_loaded = loaded_params.as_flat_dict()
        for key in flat_params.keys() - flat_loaded.keys():
            logger.error(
                f"Key '{key}' found in training configuration but not in the "
                f"serialization directory we're recovering from.")
            fail = True
        for key in flat_loaded.keys() - flat_params.keys():
            logger.error(
                f"Key '{key}' found in the serialization directory we're recovering "
                f"from but not in the training config.")
            fail = True
        for key in flat_params.keys():
            if flat_params.get(key, None) != flat_loaded.get(key, None):
                logger.error(
                    f"Value for '{key}' in training configuration does not match that "
                    f"the value in the serialization directory we're recovering from: "
                    f"{flat_params[key]} != {flat_loaded[key]}")
                fail = True

        if fail:
            raise ConfigurationError(
                "Training configuration does not match the configuration we're "
                "recovering from.")

    os.makedirs(serialization_dir, exist_ok=True)
예제 #3
0
def check_for_gpu(device_id):
    from torch import cuda
    if device_id == 'cuda':
        device_id = 0
    elif device_id == 'cpu':
        device_id = -1

    if device_id is not None and (device_id >= 0):
        num_devices_available = cuda.device_count()
        if num_devices_available == 0:
            raise ConfigurationError(
                "Experiment specified a GPU but none are available;"
                " if you want to run on CPU use the override"
                " 'trainer.cuda_device=-1' in the json config file.")
        elif device_id >= num_devices_available:
            raise ConfigurationError(
                f"Experiment specified GPU device {device_id}"
                f" but there are only {num_devices_available} devices "
                f" available.")
예제 #4
0
    def assert_empty(self, class_name):
        """Assert if Params is empty.
        Raises a ``ConfigurationError`` if ``self.params`` is not empty.  We take ``class_name`` as
        an argument so that the error message gives some idea of where an error happened, if there
        was one.  ``class_name`` should be the name of the `calling` class, the one that got extra
        parameters (if there are any).
        """
        if self.params:
            raise ConfigurationError(
                "Extra parameters passed to {}: {}".format(
                    class_name, self.params))

        return True
예제 #5
0
    def pop(self, key, default=DEFAULT):
        if default is self.DEFAULT:
            try:
                value = self.params.pop(key)
            except KeyError:
                raise ConfigurationError(
                    f"key '{key}' is required at location '{self.history}'")
        else:
            value = self.params.pop(key, default)

        if isinstance(value, str):
            value = recursively_expandvars(value, ext_vars=self.ext_vars)

        if not isinstance(value, dict):
            logger.info(self.history + key + " = " + str(value))

        return self._check_is_dict(key, value)
예제 #6
0
    def get(self, key, default=DEFAULT):
        """
        Perform the functionality associated with dict.get(key) but also checks for returned
        dicts and returns a Params object in their place with an updated history.
        """
        if default is self.DEFAULT:
            try:
                value = self.params.get(key)
            except KeyError:
                raise ConfigurationError(
                    "key \"{}\" is required at location \"{}\"".format(
                        key, self.history))
        else:
            value = self.params.get(key, default)

        if isinstance(value, str):
            value = recursively_expandvars(value, ext_vars=self.ext_vars)

        return self._check_is_dict(key, value)
예제 #7
0
def datasets_from_params(params):
    """
    Load all the datasets specified by the config.
    """
    sets = {}
    for split in ['train', 'val', 'test']:
        dataset_params = params.pop(f'{split}_dataset', None)

        if dataset_params is None:
            if split == 'train':
                ConfigurationError('Must provide train_dataset params.')
            continue

        data_path = dataset_params.get('manifest_filepath', None)
        if data_path is not None:
            check_for_data_path(data_path, 'manifest_filepath')

        sets[split] = datasets.from_params(dataset_params)

    return sets
예제 #8
0
    def pop_choice(self, key, choices):
        """
        Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of
        the given choices. Note that this `pops` the key from params, modifying the dictionary,
        consistent with how parameters are processed in this codebase.

        Params:
            key:
                Key to get the value from in the param dictionary
            choices:
                A list of valid options for values corresponding to ``key``.  For example,
                if you're specifying the type of encoder to use for some part of your model, the
                choices might be the list of encoder classes we know about and can instantiate.  If
                the value we find in the param dictionary is not in ``choices``, we raise a
                ``ConfigurationError``, because the user specified an invalid value in their
                parameter file.
        """
        value = self.pop(key, self.DEFAULT)
        if value not in choices:
            key_str = self.history + key
            message = '%s not in acceptable choices for %s: %s' % (
                value, key_str, str(choices))
            raise ConfigurationError(message)
        return value
예제 #9
0
def check_for_data_path(data_path: str, dataset_name: str):
    if not os.path.exists(data_path):
        raise ConfigurationError(f"Experiment specified {dataset_name}, "
                                 f"but {data_path} doesn't exist.")
예제 #10
0
def train_model_from_args(args):

    if args.local_rank == 0 and args.prev_output_dir is not None:
        logger.info('Copying results from {} to {}...'.format(args.prev_output_dir, args.serialization_dir))

        copy_tree(args.prev_output_dir, args.serialization_dir, update=True, verbose=True)

    if not os.path.isfile(args.param_path):
        raise ConfigurationError(f'Parameters file {args.param_path} not found.')

    logger.info(f'Loading experiment from {args.param_path} with overrides `{args.overrides}`.')

    params = Params.load(args.param_path, args.overrides)

    prepare_environment(params)

    logger.info(args.local_rank)
    if args.local_rank == 0:
        create_serialization_dir(params, args.serialization_dir, args.reset)

    if args.distributed:
        logger.info(f'World size: {dist.get_world_size()} | Rank {dist.get_rank()} | ' f'Local Rank {args.local_rank}')
        dist.barrier()

    prepare_global_logging(args.serialization_dir, local_rank=args.local_rank, verbosity=args.verbosity)

    if args.local_rank == 0:
        params.save(os.path.join(args.serialization_dir, CONFIG_NAME))

    loaders = loaders_from_params(params,
                                  distributed=args.distributed,
                                  world_size=args.world_size,
                                  first_epoch=args.first_epoch)

    if os.path.exists(os.path.join(args.serialization_dir, "alphabet")):
        alphabet = Alphabet.from_file(os.path.join(args.serialization_dir, "alphabet"))
    else:
        alphabet = Alphabet.from_params(params.pop("alphabet", {}))

    alphabet.save_to_files(os.path.join(args.serialization_dir, "alphabet"))

    loss = losses.from_params(params.pop('loss'))
    model = models.from_params(alphabet=alphabet, params=params.pop('model'))

    trainer_params = params.pop("trainer")
    if args.fine_tune:
        _, archive_weight_file = models.load_archive(args.fine_tune)

        archive_weights = torch.load(archive_weight_file, map_location=lambda storage, loc: storage)['model']

        # Avoiding initializing from archive some weights
        no_ft_regex = trainer_params.pop("no_ft", ())

        finetune_weights = {}
        random_weights = []
        for name, parameter in archive_weights.items():
            if any(re.search(regex, name) for regex in no_ft_regex):
                random_weights.append(name)
                continue
            finetune_weights[name] = parameter

        logger.info(f'Loading the following weights from archive {args.fine_tune}:')
        logger.info(','.join(finetune_weights.keys()))
        logger.info(f'The following weights are at random:')
        logger.info(','.join(random_weights))

        model.load_state_dict(finetune_weights, strict=False)

    # Freezing some parameters
    freeze_params(model, trainer_params.pop('no_grad', ()))

    trainer = Trainer(args.serialization_dir,
                      trainer_params,
                      model,
                      loss,
                      alphabet,
                      local_rank=args.local_rank,
                      world_size=args.world_size,
                      sync_bn=args.sync_bn,
                      opt_level=args.opt_level,
                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,
                      loss_scale=args.loss_scale)

    try:
        trainer.run(loaders['train'], val_loader=loaders.get('val'), num_epochs=trainer_params['num_epochs'])
    except KeyboardInterrupt:
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(args.serialization_dir, models.DEFAULT_WEIGHTS)):
            logging.info("Training interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
        raise