Esempio n. 1
0
 def run(self):
     self.opt.log()
     model_file = self.opt['path']
     if not model_file:
         raise RuntimeError('--model-file argument is required')
     if not os.path.isfile(model_file):
         raise RuntimeError(f"'{model_file}' does not exist")
     logging.info(f"Loading {model_file}")
     with PathManager.open(model_file, 'rb') as f:
         states = torch.load(
             f, map_location=lambda cpu, _: cpu, pickle_module=parlai.utils.pickle
         )
     logging.info(f"Backing up {model_file} to {model_file}.unvacuumed")
     os.rename(model_file, model_file + ".unvacuumed")
     for key in [
         'optimizer',
         'optimizer_type',
         'lr_scheduler',
         'lr_scheduler_type',
         'warmup_scheduler',
         'number_training_updates',
     ]:
         if key in states:
             logging.info(f"Deleting key {key}")
             del states[key]
     keys = ", ".join(states.keys())
     logging.info(f"Remaining keys: {keys}")
     logging.info(f"Saving to {model_file}")
     atomic_save(states, model_file)
Esempio n. 2
0
 def _setup_cands(self):
     """
     Override for different call to model.
     """
     self.fixed_cands = None
     self.fixed_cands_enc = None
     if self.fcp is not None:
         with PathManager.open(self.fcp) as f:
             self.fixed_cands = [c.replace("\n", "") for c in f.readlines()]
         cands_enc_file = "{}.cands_enc".format(self.fcp)
         print("loading saved cand encodings")
         if PathManager.exists(cands_enc_file):
             with PathManager.open(cands_enc_file, 'rb') as f:
                 self.fixed_cands_enc = torch.load(
                     f, map_location=lambda cpu, _: cpu)
         else:
             print("Extracting cand encodings")
             self.model.eval()
             pbar = tqdm.tqdm(
                 total=len(self.fixed_cands),
                 unit="cand",
                 unit_scale=True,
                 desc="Extracting candidate encodings",
             )
             fixed_cands_enc = []
             for _, batch in enumerate([
                     self.fixed_cands[i:i + 50]
                     for i in range(0,
                                    len(self.fixed_cands) - 50, 50)
             ]):
                 embedding = self.model.forward_text_encoder(batch).detach()
                 fixed_cands_enc.append(embedding)
                 pbar.update(50)
             self.fixed_cands_enc = torch.cat(fixed_cands_enc, 0)
             torch_utils.atomic_save(self.fixed_cands_enc, cands_enc_file)
Esempio n. 3
0
 def _setup_cands(self):
     self.fixed_cands = None
     self.fixed_cands_enc = None
     if self.fcp is not None:
         with PathManager.open(self.fcp) as f:
             self.fixed_cands = [c.replace('\n', '') for c in f.readlines()]
         cands_enc_file = '{}.cands_enc'.format(self.fcp)
         print('loading saved cand encodings')
         if PathManager.exists(cands_enc_file):
             with PathManager.open(cands_enc_file, 'rb') as f:
                 self.fixed_cands_enc = torch.load(
                     f, map_location=lambda cpu, _: cpu
                 )
         else:
             print('Extracting cand encodings')
             self.model.eval()
             pbar = tqdm.tqdm(
                 total=len(self.fixed_cands),
                 unit='cand',
                 unit_scale=True,
                 desc='Extracting candidate encodings',
             )
             fixed_cands_enc = []
             for _, batch in enumerate(
                 [
                     self.fixed_cands[i : i + 50]
                     for i in range(0, len(self.fixed_cands) - 50, 50)
                 ]
             ):
                 embedding = self.model(None, None, batch)[1].detach()
                 fixed_cands_enc.append(embedding)
                 pbar.update(50)
             self.fixed_cands_enc = torch.cat(fixed_cands_enc, 0)
             torch_utils.atomic_save(self.fixed_cands_enc, cands_enc_file)
Esempio n. 4
0
def remove_projection_matrices(model_file: str):
    """
    Remove all projection matrices used for distillation from the model and re-save it.
    """

    print(f'Creating a backup copy of the original model at {model_file}._orig.')
    PathManager.copy(model_file, f'{model_file}._orig')

    print(f"Loading {model_file}.")
    with PathManager.open(model_file, 'rb') as f:
        states = torch.load(f, map_location=lambda cpu, _: cpu, pickle_module=pickle)

    print('Deleting projection matrices.')
    orig_num_keys = len(states['model'])
    states['model'] = {
        key: val
        for key, val in states['model'].items()
        if key.split('.')[0]
        not in ['encoder_proj_layer', 'embedding_proj_layers', 'hidden_proj_layers']
    }
    new_num_keys = len(states['model'])
    print(f'{orig_num_keys-new_num_keys:d} model keys removed.')

    print(f"Saving to {model_file}.")
    atomic_save(states, model_file)
Esempio n. 5
0
 def save(self, path=None):
     """
     Save dictionary tokenizer if available.
     """
     path = self.opt.get('model_file', None) if path is None else path
     if path:
         self.dictionary.save(path + '.dict')
         data = {}
         data['opt'] = self.opt
         torch_utils.atomic_save(data, path)
         with PathManager.open(path + '.opt', 'w') as handle:
             json.dump(self.opt, handle)
Esempio n. 6
0
 def save(self, filename):
     params = {
         'state_dict': {
             'network': self.network.state_dict()
         },
         'feature_dict': self.feature_dict,
         'config': self.opt,
     }
     try:
         torch_utils.atomic_save(params, filename)
     except BaseException:
         logger.warning('[ WARN: Saving failed... continuing anyway. ]')
Esempio n. 7
0
 def save(self, path=None):
     """
     Save model parameters if model_file is set.
     """
     path = self.opt.get('model_file', None) if path is None else path
     if path and hasattr(self, 'model'):
         data = {}
         data['model'] = self.model.state_dict()
         data['optimizer'] = self.optimizer.state_dict()
         data['opt'] = self.opt
         torch_utils.atomic_save(data, path)
         with PathManager.open(path + '.opt', 'w') as handle:
             json.dump(self.opt, handle)
Esempio n. 8
0
    def save(self, path=None):
        """
        Save the model.

        :param path:
            path for saving model
        """
        path = self.opt.get('model_file', None) if path is None else path
        self.dict.save(path + '.dict', sort=False)
        print('Saving best model')
        states = {}
        states['model'] = self.model.state_dict()
        torch_utils.atomic_save(states, path)

        with PathManager.open(path + '.opt', 'w') as handle:
            json.dump(self.opt, handle)
            handle.write('\n')
Esempio n. 9
0
    def extract(self, image, path=None):
        # check whether initialize CNN network.
        # extract the image feature
        if 'faster_r_cnn' not in self.image_mode:
            transform = self.transform(image).unsqueeze(0)
            if self.use_cuda:
                transform = transform.cuda()
            with torch.no_grad():
                feature = self.netCNN(transform)
        else:
            feature = self.netCNN.get_detectron_features([image])[0]
        # save the feature
        if path is not None:
            import parlai.utils.torch as torch_utils

            torch_utils.atomic_save(feature.cpu(), path)
        return feature