Exemple #1
0
    def __init__(self, opt, shared=None):
        init_model, is_finetune = self._get_init_model(opt, shared)
        super().__init__(opt, shared)

        self.beam_dot_log = opt.get('beam_dot_log', False)
        self.beam_size = opt.get('beam_size', 1)
        self.beam_min_n_best = opt.get('beam_min_n_best', 3)
        self.beam_min_length = opt.get('beam_min_length', 3)
        self.beam_block_ngram = opt.get('beam_block_ngram', 0)

        if shared:
            # set up shared properties
            states = shared.get('states', {})
        else:
            # Note: we cannot change the type of metrics ahead of time, so you
            # should correctly initialize to floats or ints here
            self.metrics['nll_loss'] = 0.0
            self.metrics['loss'] = 0.0
            self.metrics['correct_tokens'] = 0
            self.metrics['total_skipped_batches'] = 0

            # this is not a shared instance of this class, so do full init
            if self.beam_dot_log:
                self.beam_dot_dir = tempfile.mkdtemp(
                    prefix='{}-beamdot-beamsize-{}-'.format(
                        os.path.basename(opt.get('model_file')),
                        self.beam_size))
                print('[ Saving dot beam logs in {} ]'.format(
                    self.beam_dot_dir))

            self.criterion = self.build_criterion()
            self.model = self.build_model()
            if self.model is None or self.criterion is None:
                raise AttributeError(
                    'build_model() and build_criterion() need to return the model or criterion'
                )
            if self.use_cuda:
                self.model.cuda()
                self.criterion.cuda()

            check_synced_parameters(self.model)
            print("Total parameters: {}".format(self._total_parameters()))
            print("Trainable parameters:  {}".format(
                self._trainable_parameters()))

            if self.fp16:
                self.model = self.model.half()

            if init_model is not None:
                # load model parameters if available
                print('[ Loading existing model params from {} ]'
                      ''.format(init_model))
                states = self.load(init_model)
            else:
                states = {}

        if (
                # only build an optimizer if we're training
                'train' in opt.get('datatype', '') and
                # and this is the main model, or on every fork if doing hogwild
            (shared is None or self.opt.get('numthreads', 1) > 1)):
            # do this regardless of share state, but don't
            self.init_optim(
                [p for p in self.model.parameters() if p.requires_grad],
                optim_states=states.get('optimizer'),
                saved_optim_type=states.get('optimizer_type'),
            )
            self.build_lr_scheduler(states, hard_reset=is_finetune)

        if shared is None and is_distributed():
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.opt['gpu']],
                broadcast_buffers=False)

        self.reset()
    def __init__(self, opt, shared=None):
        init_model, is_finetune = self._get_init_model(opt, shared)
        super().__init__(opt, shared)

        self.beam_size = opt.get('beam_size', 1)
        self.beam_min_n_best = opt.get('beam_min_n_best', 3)
        self.beam_min_length = opt.get('beam_min_length', 3)

        if opt.get('beam_block_ngram'):
            # check for old opts where we might have used beam blocking.
            # this was a super rare option, so I don't expect this to be used.
            raise RuntimeError('Beam ngram blocking is no longer supported.')

        if shared:
            # set up shared properties
            states = shared.get('states', {})
        else:
            # Note: we cannot change the type of metrics ahead of time, so you
            # should correctly initialize to floats or ints here
            self.metrics['nll_loss'] = 0.0
            self.metrics['loss'] = 0.0
            self.metrics['correct_tokens'] = 0
            self.metrics['total_skipped_batches'] = 0

            # this is not a shared instance of this class, so do full init
            self.criterion = self.build_criterion()
            # ensure all distributed copies will always be in sync
            self.model = self.build_model()

            if self.model is None or self.criterion is None:
                raise AttributeError(
                    'build_model() and build_criterion() need to return the model or criterion'
                )
            if self.use_cuda:
                self.model.cuda()
                self.criterion.cuda()

            check_synced_parameters(self.model)
            print("Total parameters: {}".format(self._total_parameters()))
            print("Trainable parameters:  {}".format(self._trainable_parameters()))

            if self.fp16:
                self.model = self.model.half()

            if init_model is not None:
                # load model parameters if available
                print('[ Loading existing model params from {} ]' ''.format(init_model))
                states = self.load(init_model)
            else:
                states = {}

        if (
            # only build an optimizer if we're training
            'train' in opt.get('datatype', '')
            # and this is the main model, or on every fork if doing hogwild
            and (shared is None or self.opt.get('numthreads', 1) > 1)
        ):
            # do this regardless of share state, but don't
            self.init_optim(
                [p for p in self.model.parameters() if p.requires_grad],
                optim_states=states.get('optimizer'),
                saved_optim_type=states.get('optimizer_type'),
            )
            self.build_lr_scheduler(states, hard_reset=is_finetune)

        if shared is None and is_distributed():
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model, device_ids=[self.opt['gpu']], broadcast_buffers=False
            )

        self.reset()