def unflatten(flat_dict): """ Unflatten flatten dict. Given a "flattened" dict with compound keys, e.g. {"a.b": 0} unflatten it: {"a": {"b": 0}} """ unflat_dict = {} for compound_key, value in flat_dict.items(): curr_dict = unflat_dict parts = compound_key.split(".") for key in parts[:-1]: curr_value = curr_dict.get(key) if key not in curr_dict: curr_dict[key] = {} curr_dict = curr_dict[key] elif isinstance(curr_value, dict): curr_dict = curr_value else: raise ConfigurationError("flattened dictionary is invalid") if not isinstance(curr_dict, dict) or parts[-1] in curr_dict: raise ConfigurationError("flattened dictionary is invalid") else: curr_dict[parts[-1]] = value return unflat_dict
def create_serialization_dir(params, serialization_dir, reset): """ This function creates the serialization directory if it doesn't exist. If it already exists and is non-empty, then it verifies that we're recovering from a training with an identical configuration. Parameters: params (Params): A parameter object specifying an AllenNLP Experiment. serialization_dir (str): The directory in which to save results and logs. reset (bool): If ``True``, we will overwrite the serialization directory if it already exists. """ if os.path.exists(serialization_dir) and reset: shutil.rmtree(serialization_dir) if os.path.exists(serialization_dir): logger.info(f"Recovering from prior training at {serialization_dir}.") recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME) if not os.path.exists(recovered_config_file): raise ConfigurationError( "The serialization directory already exists but doesn't " f"contain a {CONFIG_NAME}. You probably gave the wrong directory." ) loaded_params = Params.load(recovered_config_file) # Check whether any of the training configuration differs from the configuration we are # resuming. If so, warn the user that training may fail. fail = False flat_params = params.as_flat_dict() flat_loaded = loaded_params.as_flat_dict() for key in flat_params.keys() - flat_loaded.keys(): logger.error( f"Key '{key}' found in training configuration but not in the " f"serialization directory we're recovering from.") fail = True for key in flat_loaded.keys() - flat_params.keys(): logger.error( f"Key '{key}' found in the serialization directory we're recovering " f"from but not in the training config.") fail = True for key in flat_params.keys(): if flat_params.get(key, None) != flat_loaded.get(key, None): logger.error( f"Value for '{key}' in training configuration does not match that " f"the value in the serialization directory we're recovering from: " f"{flat_params[key]} != {flat_loaded[key]}") fail = True if fail: raise ConfigurationError( "Training configuration does not match the configuration we're " "recovering from.") os.makedirs(serialization_dir, exist_ok=True)
def check_for_gpu(device_id): from torch import cuda if device_id == 'cuda': device_id = 0 elif device_id == 'cpu': device_id = -1 if device_id is not None and (device_id >= 0): num_devices_available = cuda.device_count() if num_devices_available == 0: raise ConfigurationError( "Experiment specified a GPU but none are available;" " if you want to run on CPU use the override" " 'trainer.cuda_device=-1' in the json config file.") elif device_id >= num_devices_available: raise ConfigurationError( f"Experiment specified GPU device {device_id}" f" but there are only {num_devices_available} devices " f" available.")
def assert_empty(self, class_name): """Assert if Params is empty. Raises a ``ConfigurationError`` if ``self.params`` is not empty. We take ``class_name`` as an argument so that the error message gives some idea of where an error happened, if there was one. ``class_name`` should be the name of the `calling` class, the one that got extra parameters (if there are any). """ if self.params: raise ConfigurationError( "Extra parameters passed to {}: {}".format( class_name, self.params)) return True
def pop(self, key, default=DEFAULT): if default is self.DEFAULT: try: value = self.params.pop(key) except KeyError: raise ConfigurationError( f"key '{key}' is required at location '{self.history}'") else: value = self.params.pop(key, default) if isinstance(value, str): value = recursively_expandvars(value, ext_vars=self.ext_vars) if not isinstance(value, dict): logger.info(self.history + key + " = " + str(value)) return self._check_is_dict(key, value)
def get(self, key, default=DEFAULT): """ Perform the functionality associated with dict.get(key) but also checks for returned dicts and returns a Params object in their place with an updated history. """ if default is self.DEFAULT: try: value = self.params.get(key) except KeyError: raise ConfigurationError( "key \"{}\" is required at location \"{}\"".format( key, self.history)) else: value = self.params.get(key, default) if isinstance(value, str): value = recursively_expandvars(value, ext_vars=self.ext_vars) return self._check_is_dict(key, value)
def datasets_from_params(params): """ Load all the datasets specified by the config. """ sets = {} for split in ['train', 'val', 'test']: dataset_params = params.pop(f'{split}_dataset', None) if dataset_params is None: if split == 'train': ConfigurationError('Must provide train_dataset params.') continue data_path = dataset_params.get('manifest_filepath', None) if data_path is not None: check_for_data_path(data_path, 'manifest_filepath') sets[split] = datasets.from_params(dataset_params) return sets
def pop_choice(self, key, choices): """ Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of the given choices. Note that this `pops` the key from params, modifying the dictionary, consistent with how parameters are processed in this codebase. Params: key: Key to get the value from in the param dictionary choices: A list of valid options for values corresponding to ``key``. For example, if you're specifying the type of encoder to use for some part of your model, the choices might be the list of encoder classes we know about and can instantiate. If the value we find in the param dictionary is not in ``choices``, we raise a ``ConfigurationError``, because the user specified an invalid value in their parameter file. """ value = self.pop(key, self.DEFAULT) if value not in choices: key_str = self.history + key message = '%s not in acceptable choices for %s: %s' % ( value, key_str, str(choices)) raise ConfigurationError(message) return value
def check_for_data_path(data_path: str, dataset_name: str): if not os.path.exists(data_path): raise ConfigurationError(f"Experiment specified {dataset_name}, " f"but {data_path} doesn't exist.")
def train_model_from_args(args): if args.local_rank == 0 and args.prev_output_dir is not None: logger.info('Copying results from {} to {}...'.format(args.prev_output_dir, args.serialization_dir)) copy_tree(args.prev_output_dir, args.serialization_dir, update=True, verbose=True) if not os.path.isfile(args.param_path): raise ConfigurationError(f'Parameters file {args.param_path} not found.') logger.info(f'Loading experiment from {args.param_path} with overrides `{args.overrides}`.') params = Params.load(args.param_path, args.overrides) prepare_environment(params) logger.info(args.local_rank) if args.local_rank == 0: create_serialization_dir(params, args.serialization_dir, args.reset) if args.distributed: logger.info(f'World size: {dist.get_world_size()} | Rank {dist.get_rank()} | ' f'Local Rank {args.local_rank}') dist.barrier() prepare_global_logging(args.serialization_dir, local_rank=args.local_rank, verbosity=args.verbosity) if args.local_rank == 0: params.save(os.path.join(args.serialization_dir, CONFIG_NAME)) loaders = loaders_from_params(params, distributed=args.distributed, world_size=args.world_size, first_epoch=args.first_epoch) if os.path.exists(os.path.join(args.serialization_dir, "alphabet")): alphabet = Alphabet.from_file(os.path.join(args.serialization_dir, "alphabet")) else: alphabet = Alphabet.from_params(params.pop("alphabet", {})) alphabet.save_to_files(os.path.join(args.serialization_dir, "alphabet")) loss = losses.from_params(params.pop('loss')) model = models.from_params(alphabet=alphabet, params=params.pop('model')) trainer_params = params.pop("trainer") if args.fine_tune: _, archive_weight_file = models.load_archive(args.fine_tune) archive_weights = torch.load(archive_weight_file, map_location=lambda storage, loc: storage)['model'] # Avoiding initializing from archive some weights no_ft_regex = trainer_params.pop("no_ft", ()) finetune_weights = {} random_weights = [] for name, parameter in archive_weights.items(): if any(re.search(regex, name) for regex in no_ft_regex): random_weights.append(name) continue finetune_weights[name] = parameter logger.info(f'Loading the following weights from archive {args.fine_tune}:') logger.info(','.join(finetune_weights.keys())) logger.info(f'The following weights are at random:') logger.info(','.join(random_weights)) model.load_state_dict(finetune_weights, strict=False) # Freezing some parameters freeze_params(model, trainer_params.pop('no_grad', ())) trainer = Trainer(args.serialization_dir, trainer_params, model, loss, alphabet, local_rank=args.local_rank, world_size=args.world_size, sync_bn=args.sync_bn, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale) try: trainer.run(loaders['train'], val_loader=loaders.get('val'), num_epochs=trainer_params['num_epochs']) except KeyboardInterrupt: # if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(args.serialization_dir, models.DEFAULT_WEIGHTS)): logging.info("Training interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") raise