def load(self, path):
     """Load using fairseq's checkpointing."""
     if self.trainer:
         old_options = self.trainer.load_checkpoint(path, self.args.reset_optimizer)
         self._check_opts_unchanged(old_options, self.opt)
     else:
         load_model_state(path, self.model)
예제 #2
0
def prepare_cycle_kd_models(args, trainer, task):
    teacher_models = {}
    teacher_weights = []
    curr_cycle = trainer.get_cosine_cyle()
    print('Begin to load previous {} sharp ckpts'.format(args.teachers_cnt))

    met_unknown_ckpt = False
    for cycle_idx in range(curr_cycle - args.teachers_cnt, curr_cycle):
        model = task.build_model(args)
        ckpt_name = 'checkpoint_cycle_{}.pt'.format(cycle_idx)
        checkpoint_path = os.path.join(args.save_dir, ckpt_name)
        utils.load_model_state(checkpoint_path, model)
        model.cuda()
        teacher_models[ckpt_name] = model
        if ckpt_name not in trainer.prev_teacher_val_losses:
            met_unknown_ckpt = True
        else:
            teacher_weights.append(-trainer.prev_teacher_val_losses[ckpt_name])
    if met_unknown_ckpt:
        teacher_weights = [
            1.0 / args.teachers_cnt for _ in range(args.teachers_cnt)
        ]
    else:
        teacher_weights = [x - max(teacher_weights) for x in teacher_weights]
        teacher_weights_exp = [math.exp(x) for x in teacher_weights]
        teacher_weights = [
            x / sum(teacher_weights_exp) for x in teacher_weights_exp
        ]
    print(teacher_weights, 'Done')
    sys.stdout.flush()
    return teacher_models, teacher_weights
예제 #3
0
    def load_checkpoint(self, filename, load_optim=True):
        """Load all training state from a checkpoint file."""
        extra_state, optim_history, last_optim_state = \
            utils.load_model_state(filename, self.model)

        if last_optim_state is not None:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            if load_optim:
                self._optim_history = optim_history
                # only reload optimizer and lr_scheduler if they match
                last_optim = self._optim_history[-1]
                if last_optim[
                        'criterion_name'] == self.criterion.__class__.__name__:
                    self.lr_scheduler.load_state_dict(
                        last_optim['lr_scheduler_state'])
                    if last_optim[
                            'optimizer_name'] == self.optimizer.__class__.__name__:
                        self.optimizer.load_state_dict(last_optim_state)

                self._num_updates = last_optim['num_updates']

        if extra_state is not None and 'train_meters' in extra_state:
            self.meters = extra_state['train_meters']
            del extra_state['train_meters']

        return extra_state
예제 #4
0
    def load_checkpoint(self, filename):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = utils.load_model_state(
            filename, self.model, cuda_device=torch.cuda.current_device())

        if last_optim_state is not None:
            # rebuild optimizer after loading model, since params may have changed
            self.optimizer = optim.build_optimizer(self.args,
                                                   self.model.parameters())
            self.lr_scheduler = lr_scheduler.build_lr_scheduler(
                self.args, self.optimizer)

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            if last_optim[
                    'criterion_name'] == self.criterion.__class__.__name__:
                self.lr_scheduler.load_state_dict(
                    last_optim['lr_scheduler_state'])
                if last_optim[
                        'optimizer_name'] == self.optimizer.__class__.__name__:
                    self.optimizer.load_state_dict(last_optim_state)

            self._num_updates = last_optim['num_updates']

        return extra_state
예제 #5
0
    def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = \
            utils.load_model_state(filename, self.get_model())
        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]
            assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
                'criterion does not match; please reset the optimizer (--reset-optimizer)'
            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
                'optimizer does not match; please reset the optimizer (--reset-optimizer)'

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

            self._num_updates = last_optim['num_updates']

        if extra_state is not None and 'train_meters' in extra_state:
            self.meters.update(extra_state['train_meters'])
            del extra_state['train_meters']

            # reset TimeMeters, since their start times don't make sense anymore
            for meter in self.meters.values():
                if isinstance(meter, TimeMeter):
                    meter.reset()

        return extra_state
    def load_checkpoint(self, filename, load_optim=True):
        """Load all training state from a checkpoint file."""
        extra_state, optim_history, last_optim_state = \
            utils.load_model_state(filename, self.get_model())

        if last_optim_state is not None:
            # rebuild optimizer after loading model, since params may have changed
            #self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
            self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)

            if load_optim:
                self._optim_history = optim_history
                # only reload optimizer and lr_scheduler if they match
                last_optim = self._optim_history[-1]
                if last_optim['criterion_name'] == self.criterion.__class__.__name__:
                    self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
                    if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
                        self.optimizer.load_state_dict(last_optim_state)

                self._num_updates = last_optim['num_updates']

        if self.args.amp and extra_state is not None and 'amp_state_dict' in extra_state:
            self.optimizer.optimizer._lazy_init_maybe_master_weights()
            self.optimizer.optimizer._amp_stash.lazy_init_called = True
            self.optimizer.optimizer.load_state_dict(last_optim_state)
            for param, saved_param in zip(amp.master_params(self.optimizer.optimizer), extra_state['amp_master_params']):
                param.data.copy_(saved_param.data)
 
            amp.load_state_dict(extra_state['amp_state_dict'])

        return extra_state
예제 #7
0
파일: trainer.py 프로젝트: fyabc/fairseq
    def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
        """Load all training state from a checkpoint file."""
        extra_state, self._optim_history, last_optim_state = \
            utils.load_model_state(filename, self.model)

        if last_optim_state is not None and not reset_optimizer:
            # rebuild optimizer after loading model, since params may have changed
            self._build_optimizer()

            # only reload optimizer and lr_scheduler if they match
            last_optim = self._optim_history[-1]

            assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
                'criterion does not match; please reset the optimizer (--reset-optimizer)'

            assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
                'optimizer does not match; please reset the optimizer (--reset-optimizer)'

            if not reset_lr_scheduler:
                self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])

            self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

            self._num_updates = last_optim['num_updates']

        if extra_state is not None and 'train_meters' in extra_state:
            self.meters.update(extra_state['train_meters'])
            del extra_state['train_meters']

            # reset TimeMeters, since their start times don't make sense anymore
            for meter in self.meters.values():
                if isinstance(meter, TimeMeter):
                    meter.reset()

        return extra_state
예제 #8
0
파일: train.py 프로젝트: ezyang/translate
def load_existing_checkpoint(checkpoint_path, trainer, restore_state=True):
    extra_state = None
    loaded = False
    if restore_state:
        extra_state = trainer.load_checkpoint(checkpoint_path)
        if extra_state is None:
            loaded = False
            print(
                f"Failed to load checkpoint and state from {checkpoint_path}.")
        else:
            loaded = True
            print(
                f"| loaded checkpoint {checkpoint_path} (epoch {extra_state['epoch']})\n"
                f"| extra_state {extra_state}")
            # batch_offset being None denotes this was a checkpoint saved at
            # the end of an epoch (after the last batch).
            if extra_state["batch_offset"] is None:
                trainer.lr_step(extra_state["epoch"])
                extra_state["epoch"] += 1
                extra_state["batch_offset"] = 0

            # check availability for checkpoint backward compatiblity
            if "start_time" not in extra_state:
                extra_state["start_time"] = time.time()

            if "last_bleu_eval" not in extra_state:
                extra_state["last_bleu_eval"] = 0

    else:
        # TODO(weiho): use trainer.load_checkpoint(load_optim=False) after
        # that's been synced to open-source fairseq.
        dummy_state, _, _ = utils.load_model_state(
            checkpoint_path,
            trainer.model,
            cuda_device=torch.cuda.current_device())
        trainer.optimizer = optim.build_optimizer(trainer.args,
                                                  trainer.model.parameters())
        trainer.lr_scheduler = optim.lr_scheduler.build_lr_scheduler(
            trainer.args, trainer.optimizer)
        trainer._optim_history = []

        if dummy_state is None:
            loaded = False
            print(f"Failed to load checkpoint weights from {checkpoint_path}.")
        else:
            loaded = True
            print(f"Loaded checkpoint weights from {checkpoint_path}.")

    if extra_state is None:
        extra_state = {
            "epoch": 1,
            "batch_offset": 0,
            "val_loss": None,
            "start_time": time.time(),
            "last_bleu_eval": 0,
        }

    return loaded, extra_state
예제 #9
0
    def _async_load_checkpoint(self, rank, device_id, filename):
        extra_state, self._optim_history, last_optim_state = utils.load_model_state(
            filename, self.model, cuda_device=device_id)

        if last_optim_state is not None:
            # rebuild optimizer after loading model, since params may have changed
            self.optimizer = self._build_optimizer()
            self.lr_scheduler = self._build_lr_scheduler()

            # only load optimizer and lr_scheduler if they match the checkpoint
            last_optim = self._optim_history[-1]
            if last_optim['criterion_name'] == self.criterion.__class__.__name__:
                self.optimizer.load_state_dict(last_optim_state)
                self.lr_scheduler.best = last_optim['best_loss']

        # override learning rate, momentum, etc. with latest values
        for group in self.optimizer.param_groups:
            group.update(self._override_optim_state)

        return extra_state
예제 #10
0
def setup_training(args):
    """Parse args, load dataset, and load model trainer."""
    if not torch.cuda.is_available():
        raise NotImplementedError("Training on CPU is not supported")
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Load dataset
    splits = [args.train_subset, args.valid_subset]

    validate_and_set_default_args(args)

    train_corpus = pytorch_translate_data.ParallelCorpusConfig(
        source=pytorch_translate_data.CorpusConfig(
            dialect=args.source_lang, data_file=args.train_source_binary_path),
        target=pytorch_translate_data.CorpusConfig(
            dialect=args.target_lang, data_file=args.train_target_binary_path),
        weights_file=args.train_weights_path if hasattr(
            args, "train_weights_path") else None,
    )

    eval_corpus = pytorch_translate_data.ParallelCorpusConfig(
        source=pytorch_translate_data.CorpusConfig(
            dialect=args.source_lang, data_file=args.eval_source_binary_path),
        target=pytorch_translate_data.CorpusConfig(
            dialect=args.target_lang, data_file=args.eval_target_binary_path),
        weights_file=None,
    )

    if args.log_verbose:
        print("Starting to load binarized data files.", flush=True)
    use_char_source = args.arch == "char_source"
    dataset = pytorch_translate_data.load_binarized_dataset(
        train_corpus=train_corpus,
        eval_corpus=eval_corpus,
        train_split=args.train_subset,
        eval_split=args.valid_subset,
        args=args,
        use_char_source=use_char_source,
    )
    if args.log_verbose:
        print("Finished loading dataset", flush=True)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    print(f"| [{dataset.src}] dictionary: {len(dataset.src_dict)} types")
    print(f"| [{dataset.dst}] dictionary: {len(dataset.dst_dict)} types")

    for split in splits:
        print(f"| {split} {len(dataset.splits[split])} examples")

    # Build model and criterion
    model = models.build_model(args, dataset.src_dict, dataset.dst_dict)
    print("building criterion")
    criterion = criterions.build_criterion(args, dataset.src_dict,
                                           dataset.dst_dict)
    print(f"| model {args.arch}, criterion {criterion.__class__.__name__}")
    print(f"| num. model params: \
        {sum(p.numel() for p in model.parameters())}")

    # Load pretrained model weights if applicable
    if args.pretrained_weights_file:
        utils.load_model_state(args.pretrained_weights_file,
                               model,
                               cuda_device=torch.cuda.current_device())

    # Build trainer
    trainer = Trainer(args, model, criterion)
    print(f"| training on {args.distributed_world_size} GPUs")
    print(
        f"| max tokens per GPU = {args.max_tokens} and \
        max sentences per GPU = {args.max_sentences}",
        flush=True,
    )

    os.makedirs(args.save_dir, exist_ok=True)
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    if not os.path.isfile(checkpoint_path) and args.multi_model_restore_files:
        print(
            f"| Restoring individual models from {args.multi_model_restore_files}"
        )
        extra_state = multi_model.import_individual_models(
            args.multi_model_restore_files, trainer)
    else:
        extra_state = load_existing_checkpoint(checkpoint_path, trainer)
    return extra_state, trainer, dataset
 def load_teacher_checkpoint(self, filename):
     utils.load_model_state(filename, self.teacher_model)
     return None