Exemple #1
0
def build_default_optimizer(
    model: torch.nn.Module, optimizer_kwargs: hf_parse.OptimizerKwargs
) -> Union[hf_opt.Adafactor, hf_opt.AdamW]:
    """
    This follows the function in transformer's Trainer to construct the optimizer.

    Args:
        model: model whose parameters will be updated by the optimizer
        weight_decay: weight_decay factor to apply to weights
        optimizer_kwargs: see OptimizerKwargs in _config_parser.py for expected fields
    Returns:
        optimizer configured accordingly
    """
    optimizer_grouped_parameters = group_parameters_for_optimizer(
        model, optimizer_kwargs.weight_decay)
    if optimizer_kwargs.adafactor:
        return hf_opt.Adafactor(
            optimizer_grouped_parameters,
            lr=optimizer_kwargs.learning_rate,
            scale_parameter=optimizer_kwargs.scale_parameter,
            relative_step=optimizer_kwargs.relative_step,
        )
    return hf_opt.AdamW(
        optimizer_grouped_parameters,
        lr=optimizer_kwargs.learning_rate,
        betas=(optimizer_kwargs.adam_beta1, optimizer_kwargs.adam_beta2),
        eps=optimizer_kwargs.adam_epsilon,
    )
Exemple #2
0
    def __init__(self, dataset, batch_size=32):
        """Creates a new model for sentiment analysis using BERT."""
        # The pretrained weights to use.
        pretrained_weights = 'bert-base-uncased'

        # Create trainsformer to convert text to indexed tokens.
        transformer = BertTransform(62, pretrained_weights)

        # Setup the train loader
        train_dataset = dataset('./',
                                train=True,
                                transforms=DataToTensor(),
                                vectorizer=transformer,
                                download=True)
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=False)

        # Setup the validation loader
        val_dataset = dataset('./',
                              train=False,
                              transforms=DataToTensor(),
                              vectorizer=transformer,
                              download=True)
        self.val_loader = DataLoader(val_dataset,
                                     batch_size=batch_size,
                                     shuffle=False)

        # Retrive the CUDA device if available otherwise use CPU instead
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        device_name = torch.cuda.get_device_name(
            0) if torch.cuda.is_available() else "CPU"
        print("Training on:", device_name)

        # Loads the pretrained BERT model with classifcation layer
        self.model = BertForSequenceClassification.from_pretrained(
            pretrained_weights, num_labels=2)
        self.model.to(self.device)

        # Set the learning rate
        self.lr = 1e-5

        # Set the optimizer and scheduler
        training_steps = len(train_dataset) / batch_size
        self.optimizer = optim.AdamW(self.model.parameters(),
                                     lr=self.lr,
                                     correct_bias=False)
        self.scheduler = optim.get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=0.1,
            num_training_steps=training_steps)

        # Maximum gradient norm (used for gradient clipping)
        self.max_grad_norm = 1.0
Exemple #3
0
 def __init__(self, args):
     self.opt = args
     self.model = MyBert(args).to(args.device)
     self.criterion = nn.CrossEntropyLoss()
     self._params = filter(lambda p: p.requires_grad, self.model.parameters())
     # self.optimizer = optim.Adam(self._params)
     self.optimizer = optimization.AdamW(self._params, lr=2e-5, correct_bias=False)
     steps = (args.num_example + args.bs - 1) // args.bs
     total = steps * args.n_epoch
     warmup_step = int(total * args.warmup_rate)
     self.scheduler = optimization.get_linear_schedule_with_warmup(
         self.optimizer, num_warmup_steps=warmup_step, num_training_steps=total)
     self.total_loss = 0
     self.mse = nn.MSELoss()
     self.lamda = args.lamda
     self.max_grad_norm=1.0
Exemple #4
0
def create_optimizer(
        model, lr,
        num_train_steps,
        weight_decay=0.0,
        warmup_steps=0,
        warmup_proportion=0.1,
        layerwise_lr_decay_power=0.8,
        transformer_preffix="transformer",
        n_transformer_layers=12,
        get_layer_lrs=get_layer_lrs,
        get_layer_lrs_kwargs=None,
        lr_scheduler=get_polynomial_decay_schedule_with_warmup,
        lr_scheduler_kwargs=None,
):
    """

    Args:
        model:
        lr: 3e-4 for Small, 1e-4 for Base, 5e-5 for Large
        num_train_steps:
        weight_decay: 0
        warmup_steps: 0
        warmup_proportion: 0.1
        lr_decay_power: 1.0
        layerwise_lr_decay_power: 0.8 for Base/Small, 0.9 for Large

    Returns:

    """
    if lr_scheduler_kwargs is None:
        lr_scheduler_kwargs = {}
    if get_layer_lrs_kwargs is None:
        get_layer_lrs_kwargs = {}
    if lr_scheduler_kwargs is None:
        lr_scheduler_kwargs = {}
    if layerwise_lr_decay_power > 0:
        parameters = get_layer_lrs(
            named_parameters=list(model.named_parameters()),
            transformer_preffix=transformer_preffix,
            learning_rate=lr,
            layer_decay=layerwise_lr_decay_power,
            n_layers=n_transformer_layers,
            **get_layer_lrs_kwargs
        )
    else:
        parameters = model.parameters()

    optimizer = optimization.AdamW(
        parameters,
        lr=lr,
        weight_decay=weight_decay,
        betas=(0.9, 0.999),
        eps=1e-6,
        correct_bias=False,
    )

    warmup_steps = max(num_train_steps * warmup_proportion, warmup_steps)
    scheduler = lr_scheduler(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=num_train_steps,
        **lr_scheduler_kwargs
    )
    return optimizer, scheduler
Exemple #5
0
def BertAdamW4CRF(named_params,
                  lr=1e-5,
                  weight_decay=0.01,
                  bert_lr=1e-5,
                  bert_weight_decay=None,
                  crf_lr=1e-3,
                  crf_weight_decay=None,
                  **kwargs):
    if bert_lr is None:
        bert_lr = lr
    if bert_weight_decay is None:
        bert_weight_decay = weight_decay
    if crf_weight_decay is None:
        crf_weight_decay = weight_decay

    param_optimizer = list(named_params)
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    params = {
        'in_pretrained': {
            'decay': [],
            'no_decay': []
        },
        'out_pretrained': {
            'crf': [],
            'no_crf': []
        }
    }
    for n, p in param_optimizer:
        is_in_pretrained = 'in_pretrained' if 'pretrained' in n else 'out_pretrained'
        is_no_decay = 'no_decay' if any(nd in n
                                        for nd in no_decay) else 'decay'
        is_crf = 'crf' if 'crf' in n else 'no_crf'

        if is_in_pretrained == 'in_pretrained':
            params[is_in_pretrained][is_no_decay].append(p)
        else:
            params[is_in_pretrained][is_crf].append(p)

    optimizer_grouped_parameters = [
        {
            'params': params['in_pretrained']['decay'],
            'weight_decay': bert_weight_decay,
            'lr': bert_lr
        },
        {
            'params': params['in_pretrained']['no_decay'],
            'weight_decay': 0.0,
            'lr': bert_lr
        },
        {
            'params': params['out_pretrained']['crf'],
            'weight_decay': crf_weight_decay,
            'lr': crf_lr
        },
        {
            'params': params['out_pretrained']['no_crf'],
            'weight_decay': weight_decay,
            'lr': lr
        },
    ]

    from transformers import optimization
    return optimization.AdamW(optimizer_grouped_parameters,
                              lr=lr,
                              weight_decay=weight_decay,
                              **kwargs)