Пример #1
0
def load_state(model_path):
    state = load_checkpoint_to_cpu(model_path, arg_overrides={})
    args = state["args"]
    args = recursive_contractuser(args)
    args = recursive_expanduser(args)
    task = tasks.setup_task(args)  # load src/tgt dicts
    model = task.build_model(args)
    model.load_state_dict(state["model"])
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']
    if args['common']['fp16'] and use_cuda:
        model.half()
    if use_cuda:
        torch.cuda.empty_cache()
        torch.cuda.set_device(torch.cuda.device_count() - 1)
        model.cuda()
    model.eval()
    del state
    return args, task, model, use_cuda
Пример #2
0
def main(model_path, input):
    LOGGER.info('Load model from {}'.format(model_path))
    state = load_checkpoint_to_cpu(model_path, arg_overrides={})
    args = state["args"]
    task = tasks.setup_task(args)  # load src/tgt dicts
    model = task.build_model(args)
    model.load_state_dict(state["model"])
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']
    use_cuda = 0
    if use_cuda:
        torch.cuda.empty_cache()
        torch.cuda.set_device(torch.cuda.device_count() - 1)
        model.cuda()
    model.eval()
    if args['common']['fp16'] and use_cuda:
        model.half()

    # TODO: source tensor should be handled in corresponding task scripts. here we only use seq2seq pipeline for instance.

    intput_ids = task.target_dictionary.encode_string(input, line_tokenizer=None, add_if_not_exist=False)
    src_input_ids = intput_ids.long().unsqueeze(dim=0)

    sample = {
        'net_input': {
            'src_tokens': src_input_ids,
        },
    }
    sample = utils.move_to_cuda(sample) if use_cuda else sample
    generator = task.sequence_completor
    net_output = generator.complete(models=[model], sample=sample)

    # from ipdb import set_trace
    # set_trace()

    pred_prob = torch.softmax(net_output[0][0, -1, :], dim=-1)
    topk_prob, topk_idx = pred_prob.topk(k=10, dim=-1)
    # remove unk/eos/bos/pad
    topk_info = [(round(prob.item(), 6), idx.item()) for prob, idx in zip(topk_prob, topk_idx)][:5]
    topk_info = [(task.target_dictionary[idx], prob) for prob, idx in topk_info]
    pred_sentence = [
        (input[:-1] + [topk_token], topk_prob)
        for topk_token, topk_prob in topk_info
    ]
    return topk_info, pred_sentence
Пример #3
0
def main(model_path, input):
    state = load_checkpoint_to_cpu(model_path, arg_overrides={})
    args = state["args"]
    task = tasks.setup_task(args)  # load src/tgt dicts
    model = task.build_model(args)
    model.load_state_dict(state["model"])
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']
    if use_cuda:
        torch.cuda.empty_cache()
        torch.cuda.set_device(torch.cuda.device_count() - 1)
        model.cuda()
    if args['common']['fp16'] and use_cuda:
        model.half()
    model.eval()

    # TODO: source tensor should be handled in corresponding task scripts. here we only use seq2seq pipeline for instance.
    src_input_ids = task.src_dict.encode_line(input,
                                              line_tokenizer=None,
                                              add_if_not_exist=False)
    src_input_ids = torch.cat([
        src_input_ids[:args['task']['max_source_positions'] - 1],
        torch.Tensor([task.src_dict.eos()]).long()
    ])
    padding_size = args['task']['max_source_positions'] - len(src_input_ids)
    if padding_size > 0:
        src_input_ids = torch.cat([
            src_input_ids,
            torch.Tensor([task.src_dict.pad()] * padding_size).long()
        ])
    if use_cuda:
        src_input_ids = src_input_ids.unsqueeze(dim=0).cuda()
    sample = {
        'net_input': {
            'src_tokens': src_input_ids,
            'src_lengths':
            torch.LongTensor([s.numel() for s in src_input_ids]),
        },
    }
    sample = utils.move_to_cuda(sample) if use_cuda else sample
    generator = task.build_generator(args)
    pred_sentence_ids = generator.generate(models=[model], sample=sample)
    pred_sentence = task.tgt_dict.string(pred_sentence_ids[0][0]['tokens'])
    return pred_sentence
Пример #4
0
def main(model_path, input):
    LOGGER.info('Load model from {}'.format(model_path))
    state = load_checkpoint_to_cpu(model_path, arg_overrides={})
    args = state["args"]
    args = recursive_contractuser(args, old_cache_name='.ncc')
    args = recursive_expanduser(args)
    task = tasks.setup_task(args)  # load src/tgt dicts
    model = task.build_model(args)
    model.load_state_dict(state["model"])
    use_cuda = torch.cuda.is_available() and not args['common']['cpu']
    if use_cuda:
        torch.cuda.empty_cache()
        torch.cuda.set_device(torch.cuda.device_count() - 1)
        model.cuda()
    model.eval()
    if args['common']['fp16'] and use_cuda:
        model.half()

    sample = task.encode_input(input)
    sample = utils.move_to_cuda(sample) if use_cuda else sample
    generator = task.sequence_completor
    net_output = generator.complete(models=[model], sample=sample)
    out = task.decode_output(net_output)
    return out
Пример #5
0
    def load_checkpoint(
        self,
        filename,
        reset_optimizer=False,
        reset_lr_scheduler=False,
        optimizer_overrides=None,
        reset_meters=False,
    ):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = None, [], None

        bexists = PathManager.isfile(filename)
        if bexists:
            state = checkpoint_utils.load_checkpoint_to_cpu(filename)

            # load model parameters
            try:
                self.get_model().load_state_dict(
                    state["model"], strict=True, args=self.args
                )
                if utils.has_parameters(self.get_criterion()):
                    self.get_criterion().load_state_dict(
                        state["criterion"], strict=True
                    )
            except Exception:
                raise Exception(
                    "Cannot load model parameters from checkpoint {}; "
                    "please ensure that the architectures match.".format(filename)
                )

            extra_state = state["extra_state"]
            self._optim_history = state["optimizer_history"]
            last_optim_state = state.get("last_optimizer_state", None)

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert (
                last_optim["criterion_name"] == self.get_criterion().__class__.__name__
            ), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
            assert (
                last_optim["optimizer_name"] == self.optimizer.__class__.__name__
            ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

            self.set_num_updates(last_optim["num_updates"])

        if extra_state is not None:
            epoch = extra_state["train_iterator"]["epoch"]
            logger.info(
                "loaded checkpoint {} (epoch {} @ {} updates)".format(
                    filename, epoch, self.get_num_updates()
                )
            )

            self.lr_step(epoch)

            if "metrics" in extra_state and not reset_meters:
                # if load pre-trained checkpoint, we need to convert metrics to cuda
                self.metrics2cuda(extra_state["metrics"])
                metrics.load_state_dict(extra_state["metrics"])

                # reset TimeMeters, since their start times don't make sense anymore
                for meter in metrics.get_meters("default"):
                    if isinstance(meter, meters.TimeMeter):
                        meter.reset()
        else:
            logger.info("no existing checkpoint found {}".format(filename))

        return extra_state