コード例 #1
0
def prepare_model(args, device):
    config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir)

    # config.num_hidden_layers = 12
    if args.force_num_hidden_layers:
        logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers)
        config.num_hidden_layers = args.force_num_hidden_layers

    model = BertForPreTraining(config)
    if args.init_state_dict is not None:
        model.load_state_dict(args.init_state_dict)
    model_desc = bert_model_description(config)

    lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion)

    loss_scaler = amp.DynamicLossScaler() if args.fp16 else None

    options = orttrainer.ORTTrainerOptions({'batch': {
                                                'gradient_accumulation_steps': args.gradient_accumulation_steps},
                                            'device': {'id': str(device)},
                                            'mixed_precision': {
                                                'enabled': args.fp16,
                                                'loss_scaler': loss_scaler},
                                            'graph_transformer': {
                                                'attn_dropout_recompute': args.attn_dropout_recompute,
                                                'gelu_recompute': args.gelu_recompute,
                                                'transformer_layer_recompute': args.transformer_layer_recompute,
                                            },
                                            'debug': {'deterministic_compute': True, },
                                            'utils': {
                                                'grad_norm_clip': True},
                                            'distributed': {
                                                'world_rank': max(0, args.local_rank),
                                                'world_size': args.world_size,
                                                'local_rank': max(0, args.local_rank),
                                                'allreduce_post_accumulation': args.allreduce_post_accumulation,
                                                'deepspeed_zero_optimization': {'stage': args.deepspeed_zero_stage},
                                                'enable_adasum': False},
                                            'lr_scheduler': lr_scheduler
                                            })

    param_optimizer = list(model.named_parameters())
    no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
    params = [{
        'params': [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)],
        "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}, {
        'params': [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)],
        "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}]

    optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True)
    model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options)

    return model
コード例 #2
0
def convert_ckpt_compatible(ckpt_path, config_path):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    keys = list(ckpt.keys())
    for key in keys:
        if 'LayerNorm' in key:
            if 'gamma' in key:
                ckpt[key.replace('gamma', 'weight')] = ckpt.pop(key)
            else:
                ckpt[key.replace('beta', 'bias')] = ckpt.pop(key)

    model_config = BertConfig.from_json_file(config_path)
    model = BertForPreTraining(model_config)
    model.load_state_dict(ckpt)
    new_ckpt = model.bert.state_dict()

    return new_ckpt
def convert_multibert_checkpoint_to_pytorch(tf_checkpoint_path, config_path, save_path):
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")

    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    config = BertConfig.from_pretrained(config_path)
    model = BertForPreTraining(config)

    layer_nums = []
    for full_name, shape in init_vars:
        array = tf.train.load_variable(tf_path, full_name)
        names.append(full_name)
        split_names = full_name.split("/")
        for name in split_names:
            if name.startswith("layer_"):
                layer_nums.append(int(name.split("_")[-1]))

        arrays.append(array)
    logger.info(f"Read a total of {len(arrays):,} layers")

    name_to_array = dict(zip(names, arrays))

    # Check that number of layers match
    assert config.num_hidden_layers == len(list(set(layer_nums)))

    state_dict = model.state_dict()

    # Need to do this explicitly as it is a buffer
    position_ids = state_dict["bert.embeddings.position_ids"]
    new_state_dict = {"bert.embeddings.position_ids": position_ids}

    # Encoder Layers
    for weight_name in names:
        pt_weight_name = weight_name.replace("kernel", "weight").replace("gamma", "weight").replace("beta", "bias")
        name_split = pt_weight_name.split("/")
        for name_idx, name in enumerate(name_split):
            if name.startswith("layer_"):
                name_split[name_idx] = name.replace("_", ".")

        if name_split[-1].endswith("embeddings"):
            name_split.append("weight")

        if name_split[0] == "cls":
            if name_split[-1] == "output_bias":
                name_split[-1] = "bias"
            if name_split[-1] == "output_weights":
                name_split[-1] = "weight"

        if name_split[-1] == "weight" and name_split[-2] == "dense":
            name_to_array[weight_name] = name_to_array[weight_name].T

        pt_weight_name = ".".join(name_split)

        new_state_dict[pt_weight_name] = torch.from_numpy(name_to_array[weight_name])

    new_state_dict["cls.predictions.decoder.weight"] = new_state_dict["bert.embeddings.word_embeddings.weight"].clone()
    new_state_dict["cls.predictions.decoder.bias"] = new_state_dict["cls.predictions.bias"].clone().T
    # Load State Dict
    model.load_state_dict(new_state_dict)

    # Save PreTrained
    logger.info(f"Saving pretrained model to {save_path}")
    model.save_pretrained(save_path)

    return model
コード例 #4
0
def train(args):

    if not os.path.exists(args.save_dir): os.mkdir(args.save_dir)

    if args.gpu != '-1' and torch.cuda.is_available():
        device = torch.device('cuda')
        torch.cuda.set_rng_state(torch.cuda.get_rng_state())
        torch.backends.cudnn.deterministic = True
    else:
        device = torch.device('cpu')

    config = {
        'train': {
            'unchanged_variable_weight': 0.1,
            'buffer_size': 5000
        },
        'encoder': {
            'type': 'SequentialEncoder'
        },
        'data': {
            'vocab_file': 'data/vocab.bpe10000/vocab'
        }
    }

    train_set = Dataset('data/preprocessed_data/train-shard-*.tar')
    dev_set = Dataset('data/preprocessed_data/dev.tar')

    vocab = Vocab.load('data/vocab.bpe10000/vocab')

    if args.decoder:
        vocab_size = len(vocab.all_subtokens) + 1
    else:
        vocab_size = len(vocab.source_tokens) + 1

    max_iters = args.max_iters
    lr = args.lr
    warm_up = args.warm_up

    batch_size = 4096
    effective_batch_size = args.batch_size

    max_embeds = 1000 if args.decoder else 512

    bert_config = BertConfig(vocab_size=vocab_size,
                             max_position_embeddings=max_embeds,
                             num_hidden_layers=6,
                             hidden_size=256,
                             num_attention_heads=4)
    model = BertForPreTraining(bert_config)

    if args.restore:
        state_dict = torch.load(os.path.join(args.save_dir, args.res_name))
        model.load_state_dict(state_dict['model'])
        batch_count = state_dict['step']
        epoch = state_dict['epoch']

    model.train()
    model.to(device)

    if len(args.gpu) > 1 and device == torch.device('cuda'):
        model = nn.DataParallel(model)

    def lr_func(step):
        if step > warm_up:
            return (max_iters - step) / (max_iters - warm_up)
        else:
            return (step / warm_up)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 eps=1e-6,
                                 weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_func,
                                                  last_epoch=-1)
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none')

    if args.restore:
        optimizer.load_state_dict(state_dict['optim'])
        scheduler.load_state_dict(state_dict['scheduler'])

    batch_count = 0
    epoch = 0
    cum_loss = 0.0

    while True:
        # load training dataset, which is a collection of ASTs and maps of gold-standard renamings
        train_set_iter = train_set.batch_iterator(
            batch_size=batch_size,
            return_examples=False,
            config=config,
            progress=True,
            train=True,
            max_seq_len=512,
            num_readers=args.num_readers,
            num_batchers=args.num_batchers)
        epoch += 1
        print("Epoch {}".format(epoch))

        loss = 0
        num_seq = 0

        optimizer.zero_grad()

        for batch in train_set_iter:
            if args.decoder:
                input_ids = batch.tensor_dict['prediction_target'][
                    'src_with_true_var_names']
            else:
                input_ids = batch.tensor_dict['src_code_tokens']

            attention_mask = torch.ones_like(input_ids)
            attention_mask[input_ids == 0] = 0.0

            assert torch.max(input_ids) < vocab_size
            assert torch.min(input_ids) >= 0

            if input_ids.shape[0] > max_embeds:
                print(
                    "Warning - length {} is greater than max length {}. Skipping."
                    .format(input_ids.shape[0], max_embeds))
                continue

            input_ids, labels = mask_tokens(inputs=input_ids,
                                            mask_token_id=vocab_size - 1,
                                            vocab_size=vocab_size,
                                            mlm_probability=0.15)

            input_ids[attention_mask == 0] = 0
            labels[attention_mask == 0] = -100

            if torch.cuda.is_available():
                input_ids = input_ids.cuda()
                labels = labels.cuda()
                attention_mask = attention_mask.cuda()

            outputs = model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            masked_lm_labels=labels)

            unreduced_loss = loss_fn(
                outputs[0].view(-1, bert_config.vocab_size),
                labels.view(-1)).reshape(labels.shape) / (
                    torch.sum(labels != -100, axis=1).unsqueeze(1) + 1e-7)
            loss += unreduced_loss.sum()
            num_seq += input_ids.shape[0]

            if num_seq > effective_batch_size:
                batch_count += 1
                loss /= num_seq
                cum_loss += loss.item()

                if batch_count % 20 == 0:
                    print("{} batches, Loss : {:.4}, LR : {:.6}".format(
                        batch_count, cum_loss / 20,
                        scheduler.get_lr()[0]))
                    cum_loss = 0.0

                if batch_count % 10000 == 0:
                    fname1 = os.path.join(
                        args.save_dir, 'bert_{}_step_{}.pth'.format(
                            ('decoder' if args.decoder else 'encoder'),
                            batch_count))
                    fname2 = os.path.join(
                        args.save_dir, 'bert_{}.pth'.format(
                            ('decoder' if args.decoder else 'encoder'),
                            batch_count))

                    state = {
                        'epoch': epoch,
                        'step': batch_count,
                        'model': model.module.state_dict(),
                        'optim': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()
                    }

                    torch.save(state, fname1)
                    torch.save(state, fname2)

                    print("Saved file to path {}".format(fname1))
                    print("Saved file to path {}".format(fname2))

                loss.backward()
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

                loss = 0
                num_seq = 0

            if batch_count == max_iters:
                print(f'[Learner] Reached max iters', file=sys.stderr)
                exit()

        print("Max_len = {}".format(max_len))
        break