Пример #1
0
    def _load(cls,
              config: Params,
              serialization_dir: str,
              weights_file: str = None,
              device=None) -> 'Model':
        """
        Instantiates an already-trained model, based on the experiment
        configuration and some optional overrides.
        """
        weights_file = weights_file or os.path.join(serialization_dir,
                                                    _DEFAULT_WEIGHTS)

        # Load vocabulary from file
        vocab_dir = os.path.join(serialization_dir, 'vocabulary')
        # If the config specifies a vocabulary subclass, we need to use it.
        vocab = Vocabulary.from_files(vocab_dir)

        model_params = config['model']

        # The experiment config tells us how to _train_ a model, including where to get pre-trained
        # embeddings from.  We're now _loading_ the model, so those embeddings will already be
        # stored in our weights.  We don't need any pretrained weight file anymore, and we don't
        # want the code to look for it, so we remove it from the parameters here.
        remove_pretrained_embedding_params(model_params)
        model = cls.from_params(vocab=vocab, params=model_params)
        model_state = torch.load(weights_file, map_location=device_mapping(-1))
        if not isinstance(model, torch.nn.DataParallel):
            model_state = {
                re.sub(r'^module\.', '', k): v
                for k, v in model_state.items()
            }
        model.load_state_dict(model_state)
        model.set_vocab(vocab)
        model.to(device)

        return model
Пример #2
0
def train_model(params: Params):
    """
    Trains the model specified in the given :class:`Params` object, using the data and training
    parameters also specified in that object, and saves the results.
    Parameters
    ----------
    params : ``Params``
        A parameter object specifying an AllenNLP Experiment.
    Returns
    -------
    best_model: ``Model``
        The model with the best epoch weights.
    """
    # Set up the environment.
    environment_params = params['environment']
    environment.set_seed(environment_params)
    create_serialization_dir(params)
    environment.prepare_global_logging(environment_params)
    environment.check_for_gpu(environment_params)
    if environment_params['gpu']:
        device = torch.device('cuda:{}'.format(environment_params['cuda_device']))
        environment.occupy_gpu(device)
    else:
        device = torch.device('cpu')
    params['trainer']['device'] = device

    # Load data.
    data_params = params['data']
    dataset = dataset_from_params(data_params,
                                  universal_postags=params["model"].get('universal_postags',False),
                                  generator_source_copy=data_params.get('source_copy', True),
                                  multilingual=params['model'].get('multilingual',False),
                                  extra_check=params['data'].get('extra_check',False))
    train_data = dataset['train']
    dev_data = dataset.get('dev')
    test_data = dataset.get('test')
    train_mappings = dataset.get('train_mappings',None)
    train_replacements = dataset.get('train_replacements',None)

    # Vocabulary and iterator are created here.
    vocab_params = params.get('vocab', {})
    if "fixed_vocab" in vocab_params and vocab_params["fixed_vocab"]:
        vocab = Vocabulary.from_files("data/vocabulary")
    else:
        vocab = Vocabulary.from_instances(instances=train_data, **vocab_params)

    # Initializing the model can have side effect of expanding the vocabulary
    vocab.save_to_files(os.path.join(environment_params['serialization_dir'], "vocabulary"))
    train_iterator, dev_iterater, test_iterater = iterator_from_params(vocab, data_params['iterator'])
    if train_mappings is not None and train_replacements is not None:
        with open(os.path.join(environment_params['serialization_dir'],"trns_lex_missing.json"),"w", encoding='utf-8') as outfile:
            json.dump(train_mappings[-1], outfile, indent=4, default=serialize_sets)
        with open(os.path.join(environment_params['serialization_dir'],"trns_lexicalizations.json"),"w", encoding='utf-8') as outfile:
            json.dump(train_mappings[-2], outfile, indent=4, default=serialize_sets)
        with open(os.path.join(environment_params['serialization_dir'],"trns_rep.json"), "w", encoding='utf-8') as outfile:
            json.dump(train_replacements, outfile, indent=4, default=serialize_sets)
    # Build the model.
    model_params = params['model']
    model = getattr(Models, model_params['model_type']).from_params(vocab, model_params, environment_params['gpu'], train_mappings, train_replacements)
    logger.info(model)

    # Train
    trainer_params = params['trainer']
    no_grad_regexes = trainer_params['no_grad']
    for name, parameter in model.named_parameters():
        if any(re.search(regex, name) for regex in no_grad_regexes):
            parameter.requires_grad_(False)

    frozen_parameter_names, tunable_parameter_names = \
        environment.get_frozen_and_tunable_parameter_names(model)
    logger.info("Following parameters are Frozen  (without gradient):")
    for name in frozen_parameter_names:
        logger.info(name)
    logger.info("Following parameters are Tunable (with gradient):")
    for name in tunable_parameter_names:
        logger.info(name)

    logger.info("Total nr of parameters Tunable (with gradient):")
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(pytorch_total_params)

    trainer = Trainer.from_params(model, train_data, dev_data, train_iterator, dev_iterater, trainer_params)

    serialization_dir = trainer_params['serialization_dir']
    try:
        metrics = trainer.train()
    except KeyboardInterrupt:
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
            logger.info("Training interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
            archive_model(serialization_dir)
        raise

    # Now tar up results
    archive_model(serialization_dir)

    logger.info("Loading the best epoch weights.")
    best_model_state_path = os.path.join(serialization_dir, 'best.th')
    best_model_state = torch.load(best_model_state_path)
    best_model = model
    if not isinstance(best_model, torch.nn.DataParallel):
        best_model_state = {re.sub(r'^module\.', '', k):v for k, v in best_model_state.items()}
    best_model.load_state_dict(best_model_state)

    return best_model