def load_state(model_path): state = load_checkpoint_to_cpu(model_path, arg_overrides={}) args = state["args"] args = recursive_contractuser(args) args = recursive_expanduser(args) task = tasks.setup_task(args) # load src/tgt dicts model = task.build_model(args) model.load_state_dict(state["model"]) use_cuda = torch.cuda.is_available() and not args['common']['cpu'] if args['common']['fp16'] and use_cuda: model.half() if use_cuda: torch.cuda.empty_cache() torch.cuda.set_device(torch.cuda.device_count() - 1) model.cuda() model.eval() del state return args, task, model, use_cuda
def main(model_path, input): LOGGER.info('Load model from {}'.format(model_path)) state = load_checkpoint_to_cpu(model_path, arg_overrides={}) args = state["args"] task = tasks.setup_task(args) # load src/tgt dicts model = task.build_model(args) model.load_state_dict(state["model"]) use_cuda = torch.cuda.is_available() and not args['common']['cpu'] use_cuda = 0 if use_cuda: torch.cuda.empty_cache() torch.cuda.set_device(torch.cuda.device_count() - 1) model.cuda() model.eval() if args['common']['fp16'] and use_cuda: model.half() # TODO: source tensor should be handled in corresponding task scripts. here we only use seq2seq pipeline for instance. intput_ids = task.target_dictionary.encode_string(input, line_tokenizer=None, add_if_not_exist=False) src_input_ids = intput_ids.long().unsqueeze(dim=0) sample = { 'net_input': { 'src_tokens': src_input_ids, }, } sample = utils.move_to_cuda(sample) if use_cuda else sample generator = task.sequence_completor net_output = generator.complete(models=[model], sample=sample) # from ipdb import set_trace # set_trace() pred_prob = torch.softmax(net_output[0][0, -1, :], dim=-1) topk_prob, topk_idx = pred_prob.topk(k=10, dim=-1) # remove unk/eos/bos/pad topk_info = [(round(prob.item(), 6), idx.item()) for prob, idx in zip(topk_prob, topk_idx)][:5] topk_info = [(task.target_dictionary[idx], prob) for prob, idx in topk_info] pred_sentence = [ (input[:-1] + [topk_token], topk_prob) for topk_token, topk_prob in topk_info ] return topk_info, pred_sentence
def main(model_path, input): state = load_checkpoint_to_cpu(model_path, arg_overrides={}) args = state["args"] task = tasks.setup_task(args) # load src/tgt dicts model = task.build_model(args) model.load_state_dict(state["model"]) use_cuda = torch.cuda.is_available() and not args['common']['cpu'] if use_cuda: torch.cuda.empty_cache() torch.cuda.set_device(torch.cuda.device_count() - 1) model.cuda() if args['common']['fp16'] and use_cuda: model.half() model.eval() # TODO: source tensor should be handled in corresponding task scripts. here we only use seq2seq pipeline for instance. src_input_ids = task.src_dict.encode_line(input, line_tokenizer=None, add_if_not_exist=False) src_input_ids = torch.cat([ src_input_ids[:args['task']['max_source_positions'] - 1], torch.Tensor([task.src_dict.eos()]).long() ]) padding_size = args['task']['max_source_positions'] - len(src_input_ids) if padding_size > 0: src_input_ids = torch.cat([ src_input_ids, torch.Tensor([task.src_dict.pad()] * padding_size).long() ]) if use_cuda: src_input_ids = src_input_ids.unsqueeze(dim=0).cuda() sample = { 'net_input': { 'src_tokens': src_input_ids, 'src_lengths': torch.LongTensor([s.numel() for s in src_input_ids]), }, } sample = utils.move_to_cuda(sample) if use_cuda else sample generator = task.build_generator(args) pred_sentence_ids = generator.generate(models=[model], sample=sample) pred_sentence = task.tgt_dict.string(pred_sentence_ids[0][0]['tokens']) return pred_sentence
def main(model_path, input): LOGGER.info('Load model from {}'.format(model_path)) state = load_checkpoint_to_cpu(model_path, arg_overrides={}) args = state["args"] args = recursive_contractuser(args, old_cache_name='.ncc') args = recursive_expanduser(args) task = tasks.setup_task(args) # load src/tgt dicts model = task.build_model(args) model.load_state_dict(state["model"]) use_cuda = torch.cuda.is_available() and not args['common']['cpu'] if use_cuda: torch.cuda.empty_cache() torch.cuda.set_device(torch.cuda.device_count() - 1) model.cuda() model.eval() if args['common']['fp16'] and use_cuda: model.half() sample = task.encode_input(input) sample = utils.move_to_cuda(sample) if use_cuda else sample generator = task.sequence_completor net_output = generator.complete(models=[model], sample=sample) out = task.decode_output(net_output) return out
def load_checkpoint( self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None, reset_meters=False, ): """Load all training state from a checkpoint file.""" extra_state, self._optim_history, last_optim_state = None, [], None bexists = PathManager.isfile(filename) if bexists: state = checkpoint_utils.load_checkpoint_to_cpu(filename) # load model parameters try: self.get_model().load_state_dict( state["model"], strict=True, args=self.args ) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict( state["criterion"], strict=True ) except Exception: raise Exception( "Cannot load model parameters from checkpoint {}; " "please ensure that the architectures match.".format(filename) ) extra_state = state["extra_state"] self._optim_history = state["optimizer_history"] last_optim_state = state.get("last_optimizer_state", None) if last_optim_state is not None and not reset_optimizer: # rebuild optimizer after loading model, since params may have changed self._build_optimizer() # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] assert ( last_optim["criterion_name"] == self.get_criterion().__class__.__name__ ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." assert ( last_optim["optimizer_name"] == self.optimizer.__class__.__name__ ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.set_num_updates(last_optim["num_updates"]) if extra_state is not None: epoch = extra_state["train_iterator"]["epoch"] logger.info( "loaded checkpoint {} (epoch {} @ {} updates)".format( filename, epoch, self.get_num_updates() ) ) self.lr_step(epoch) if "metrics" in extra_state and not reset_meters: # if load pre-trained checkpoint, we need to convert metrics to cuda self.metrics2cuda(extra_state["metrics"]) metrics.load_state_dict(extra_state["metrics"]) # reset TimeMeters, since their start times don't make sense anymore for meter in metrics.get_meters("default"): if isinstance(meter, meters.TimeMeter): meter.reset() else: logger.info("no existing checkpoint found {}".format(filename)) return extra_state