Beispiel #1
0
def _build_trainer(config, model, vocab, train_data, valid_data):
    optimizer = optim.AdamW(model.parameters(), lr=config.trainer.lr)
    scheduler = None

    if config.embedder.name.endswith('bert') or config.embedder.name == 'both':
        non_bert_params = (param for name, param in model.named_parameters()
                           if not name.startswith('text_field_embedder'))

        optimizer = optim.AdamW([{
            'params':
            model.text_field_embedder.parameters(),
            'lr':
            config.trainer.bert_lr
        }, {
            'params': non_bert_params,
            'lr': config.trainer.lr
        }, {
            'params': []
        }])

        scheduler = SlantedTriangular(
            optimizer=optimizer,
            num_epochs=config.trainer.num_epochs,
            num_steps_per_epoch=len(train_data) / config.trainer.batch_size,
            cut_frac=config.trainer.cut_frac,
            gradual_unfreezing=config.trainer.gradual_unfreezing,
            discriminative_fine_tuning=config.trainer.
            discriminative_fine_tuning)

    logger.info('Trainable params:')
    for name, param in model.named_parameters():
        if param.requires_grad:
            logger.info('\t' + name)

    iterator = BucketIterator(batch_size=config.trainer.batch_size)
    iterator.index_with(vocab)

    if torch.cuda.is_available():
        cuda_device = 0
        model = model.cuda(cuda_device)
        logger.info('Using cuda')
    else:
        cuda_device = -1
        logger.info('Using cpu')

    logger.info('Example batch:')
    _log_batch(next(iterator(train_data)))

    if config.embedder.name.endswith('bert') or config.embedder.name == 'both':
        train_data = _filter_data(train_data, vocab)
        valid_data = _filter_data(valid_data, vocab)

    return Trainer(model=model,
                   optimizer=optimizer,
                   iterator=iterator,
                   train_dataset=train_data,
                   validation_dataset=valid_data,
                   validation_metric='+MeanAcc',
                   patience=config.trainer.patience,
                   num_epochs=config.trainer.num_epochs,
                   cuda_device=cuda_device,
                   grad_clipping=5.,
                   learning_rate_scheduler=scheduler,
                   serialization_dir=os.path.join(config.data.models_dir,
                                                  config.model_name),
                   should_log_parameter_statistics=False,
                   should_log_learning_rate=False,
                   num_gradient_accumulation_steps=config.trainer.
                   num_gradient_accumulation_steps)
Beispiel #2
0
def _build_trainer(config, model, vocab, train_data, valid_data):
    optimizer = optim.AdamW(model.parameters(), lr=config.trainer.lr)
    scheduler = None

    is_bert_based = any(
        model.name.endswith('bert') for model in config.embedder.models)
    is_trainable_elmo_based = any(
        model.name == 'elmo' and model.params['requires_grad']
        for model in config.embedder.models)

    if is_bert_based or is_trainable_elmo_based:
        params_list = []
        non_pretrained_params = []
        if is_bert_based:
            bert_groups = [
                'transformer_model.embeddings.',
                'transformer_model.encoder.layer.0.',
                'transformer_model.encoder.layer.1.',
                'transformer_model.encoder.layer.2.',
                'transformer_model.encoder.layer.3.',
                'transformer_model.encoder.layer.4.',
                'transformer_model.encoder.layer.5.',
                'transformer_model.encoder.layer.6.',
                'transformer_model.encoder.layer.7.',
                'transformer_model.encoder.layer.8.',
                'transformer_model.encoder.layer.9.',
                'transformer_model.encoder.layer.10.',
                'transformer_model.encoder.layer.11.',
                'transformer_model.pooler.'
            ]
            bert_group2params = {bg: [] for bg in bert_groups}
            for name, param in model.named_parameters():
                is_bert_layer = False
                for bg in bert_groups:
                    if bg in name:
                        is_bert_layer = True
                        bert_group2params[bg].append(param)
                        logger.info('Param: %s assigned to %s group', name, bg)
                        break
                if not is_bert_layer:
                    non_pretrained_params.append(param)
                    logger.info('Param: %s assigned to non_pretrained group',
                                name)
            for bg in bert_groups:
                params_list.append({
                    'params': bert_group2params[bg],
                    'lr': config.trainer.bert_lr
                })
            params_list.append({
                'params': non_pretrained_params,
                'lr': config.trainer.lr
            })
            params_list.append({'params': []})
        elif is_trainable_elmo_based:
            pretrained_params = []
            for name, param in model.named_parameters():
                if '_elmo_lstm' in name:
                    logger.info('Pretrained param: %s', name)
                    pretrained_params.append(param)
                else:
                    logger.info('Non-pretrained param: %s', name)
                    non_pretrained_params.append(param)
            params_list = [{
                'params': pretrained_params,
                'lr': config.trainer.bert_lr
            }, {
                'params': non_pretrained_params,
                'lr': config.trainer.lr
            }, {
                'params': []
            }]
        optimizer = optim.AdamW(params_list)
        scheduler = SlantedTriangular(
            optimizer=optimizer,
            num_epochs=config.trainer.num_epochs,
            num_steps_per_epoch=len(train_data) / config.trainer.batch_size,
            cut_frac=config.trainer.cut_frac,
            gradual_unfreezing=config.trainer.gradual_unfreezing,
            discriminative_fine_tuning=config.trainer.
            discriminative_fine_tuning)

    logger.info('Trainable params:')
    for name, param in model.named_parameters():
        if param.requires_grad:
            logger.info('\t' + name)

    iterator = BucketIterator(batch_size=config.trainer.batch_size)
    iterator.index_with(vocab)

    if torch.cuda.is_available():
        cuda_device = 0
        model = model.cuda(cuda_device)
        logger.info('Using cuda')
    else:
        cuda_device = -1
        logger.info('Using cpu')

    logger.info('Example batch:')
    _log_batch(next(iterator(train_data)))

    if is_bert_based:
        train_data = _filter_data(train_data, vocab)
        valid_data = _filter_data(valid_data, vocab)

    return Trainer(model=model,
                   optimizer=optimizer,
                   iterator=iterator,
                   train_dataset=train_data,
                   validation_dataset=valid_data,
                   validation_metric='+MeanAcc',
                   patience=config.trainer.patience,
                   num_epochs=config.trainer.num_epochs,
                   cuda_device=cuda_device,
                   grad_clipping=5.,
                   learning_rate_scheduler=scheduler,
                   serialization_dir=os.path.join(config.data.models_dir,
                                                  config.model_name),
                   should_log_parameter_statistics=False,
                   should_log_learning_rate=False,
                   num_gradient_accumulation_steps=config.trainer.
                   num_gradient_accumulation_steps)
Beispiel #3
0
def _build_trainer(config, model, vocab, train_data, valid_data):
    optimizer = optim.AdamW(model.parameters(), lr=config.trainer.lr)
    scheduler = None

    is_bert_based = any(
        model.name.endswith('bert') for model in config.embedder.models)
    is_trainable_elmo_based = any(
        model.name == 'elmo' and model.params['requires_grad']
        for model in config.embedder.models)

    if is_bert_based or is_trainable_elmo_based:

        def _is_pretrained_param(name):
            return 'transformer_model' in name or '_elmo_lstm' in name

        pretrained_params, non_pretrained_params = [], []
        for name, param in model.named_parameters():
            if _is_pretrained_param(name):
                logger.info('Pretrained param: %s', name)
                pretrained_params.append(param)
            else:
                logger.info('Non-pretrained param: %s', name)
                non_pretrained_params.append(param)

        optimizer = optim.AdamW([{
            'params': pretrained_params,
            'lr': config.trainer.bert_lr
        }, {
            'params': non_pretrained_params,
            'lr': config.trainer.lr
        }, {
            'params': []
        }])

        scheduler = SlantedTriangular(
            optimizer=optimizer,
            num_epochs=config.trainer.num_epochs,
            num_steps_per_epoch=len(train_data) / config.trainer.batch_size,
            cut_frac=config.trainer.cut_frac,
            gradual_unfreezing=config.trainer.gradual_unfreezing,
            discriminative_fine_tuning=config.trainer.
            discriminative_fine_tuning)

    logger.info('Trainable params:')
    for name, param in model.named_parameters():
        if param.requires_grad:
            logger.info('\t' + name)

    iterator = BucketIterator(batch_size=config.trainer.batch_size)
    iterator.index_with(vocab)

    if torch.cuda.is_available():
        cuda_device = 0
        model = model.cuda(cuda_device)
        logger.info('Using cuda')
    else:
        cuda_device = -1
        logger.info('Using cpu')

    logger.info('Example batch:')
    _log_batch(next(iterator(train_data)))

    if is_bert_based:
        train_data = _filter_data(train_data, vocab)
        valid_data = _filter_data(valid_data, vocab)

    return Trainer(model=model,
                   optimizer=optimizer,
                   iterator=iterator,
                   train_dataset=train_data,
                   validation_dataset=valid_data,
                   validation_metric='+MeanAcc',
                   patience=config.trainer.patience,
                   num_epochs=config.trainer.num_epochs,
                   cuda_device=cuda_device,
                   grad_clipping=5.,
                   learning_rate_scheduler=scheduler,
                   serialization_dir=os.path.join(config.data.models_dir,
                                                  config.model_name),
                   should_log_parameter_statistics=False,
                   should_log_learning_rate=False,
                   num_gradient_accumulation_steps=config.trainer.
                   num_gradient_accumulation_steps)
Beispiel #4
0
    def train(self):
        if self.config.adjust_point:
            ram_set_flag("adjust_point")
        # ram_write('dist_reg', self.config.dist_reg)
        read_hyper_ = partial(read_hyper, self.config.task_id,
                              self.config.arch)
        num_epochs = int(read_hyper_("num_epochs"))
        batch_size = int(read_hyper_("batch_size"))
        logger.info(f"num_epochs: {num_epochs}, batch_size: {batch_size}")

        if self.config.model_name == 'tmp':
            p = pathlib.Path('saved/models/tmp')
            if p.exists():
                shutil.rmtree(p)

        # Maybe we will do some data augmentation here.
        if self.config.aug_data != '':
            log(f'Augment data from {self.config.aug_data}')
            aug_data = auto_create(
                f"{self.config.task_id}.{self.config.arch}.aug",
                lambda: self.reader.read(self.config.aug_data),
                cache=True)
            self.train_data.instances.extend(aug_data.instances)

        # Set up the adversarial training policy
        if self.config.arch == 'bert':
            model_vocab = embed_util.get_bert_vocab()
        else:
            model_vocab = self.vocab
        # yapf: disable
        adv_field = 'sent2' if is_sentence_pair(self.config.task_id) and self.config.arch != 'bert' else 'sent'
        policy_args = {
            "adv_iteration": self.config.adv_iter,
            "replace_num": self.config.adv_replace_num,
            "searcher": WordIndexSearcher(
                CachedWordSearcher(
                    "external_data/ibp-nbrs.json" if not self.config.big_nbrs else "external_data/euc-top8.json",
                    model_vocab.get_token_to_index_vocabulary("tokens"),
                    second_order=False
                ),
                word2idx=model_vocab.get_token_index,
                idx2word=model_vocab.get_token_from_index,
            ),
            'adv_field': adv_field
        }
        # yapf: enable
        if self.config.adv_policy == 'hot':
            if is_sentence_pair(
                    self.config.task_id) and self.config.arch != 'bert':
                policy_args['forward_order'] = 1
            adv_policy = adv_utils.HotFlipPolicy(**policy_args)
        elif self.config.adv_policy == 'rdm':
            adv_policy = adv_utils.RandomNeighbourPolicy(**policy_args)
        elif self.config.adv_policy == 'diy':
            adv_policy = adv_utils.DoItYourselfPolicy(self.config.adv_iter,
                                                      adv_field,
                                                      self.config.adv_step)
        else:
            adv_policy = adv_utils.NoPolicy

        # A collate_fn will do some transformation an instance before
        # fed into a model. If we want to train a model with some transformations
        # such as cropping/DAE, we can modify code here. e.g.,
        # collate_fn = partial(transform_collate, self.vocab, self.reader, Crop(0.3))
        collate_fn = allennlp_collate
        train_data_sampler = BucketBatchSampler(
            data_source=self.train_data,
            batch_size=batch_size,
        )
        # Set callbacks

        if self.config.task_id == 'SNLI' and self.config.arch != 'bert':
            epoch_callbacks = []
            if self.config.model_pretrain != "":
                epoch_callbacks = [WarmupCallback(2)]
                if self.config.model_pretrain == 'auto':
                    self.config.model_pretrain = {
                        "biboe": "SNLI-fix-biboe-sum",
                        "datt": "SNLI-fix-datt"
                    }[self.config.arch]
                logger.warning(
                    f"Try loading weights from pretrained model {self.config.model_pretrain}"
                )
                pretrain_ckpter = CheckpointerX(
                    f"saved/models/{self.config.model_pretrain}")
                self.model.load_state_dict(pretrain_ckpter.best_model_state())
        else:
            epoch_callbacks = []
        # epoch_callbacks = []
        batch_callbacks = []

        opt = self.model.get_optimizer()
        if self.config.arch == 'bert':
            scl = SlantedTriangular(opt, num_epochs,
                                    len(self.train_data) // batch_size)
        else:
            scl = None

        trainer = AdvTrainer(
            model=self.model,
            optimizer=opt,
            learning_rate_scheduler=scl,
            validation_metric='+accuracy',
            adv_policy=adv_policy,
            data_loader=DataLoader(
                self.train_data,
                batch_sampler=train_data_sampler,
                collate_fn=collate_fn,
            ),
            validation_data_loader=DataLoader(
                self.dev_data,
                batch_size=batch_size,
            ),
            num_epochs=num_epochs,
            patience=None,
            grad_clipping=1.,
            cuda_device=0,
            epoch_callbacks=epoch_callbacks,
            batch_callbacks=batch_callbacks,
            serialization_dir=f'saved/models/{self.config.model_name}',
            num_serialized_models_to_keep=20)
        trainer.train()