Beispiel #1
0
def init_bert_model_with_teacher(
    student: BertModel,
    teacher: BertModel,
    layers_to_transfer: List[int] = None,
) -> BertModel:
    """Initialize student model with teacher layers.

    Args:
        student (BertModel): Student model.
        teacher (BertModel): Teacher model.
        layers_to_transfer (List[int], optional): Defines which layers will be transfered.
            If None then will transfer last layers. Defaults to None.

    Returns:
        BertModel: [description]
    """
    teacher_hidden_size = teacher.config.hidden_size
    student_hidden_size = student.config.hidden_size
    if teacher_hidden_size != student_hidden_size:
        raise Exception("Teacher and student hidden size should be the same")
    teacher_layers_num = teacher.config.num_hidden_layers
    student_layers_num = student.config.num_hidden_layers

    if layers_to_transfer is None:
        layers_to_transfer = list(
            range(teacher_layers_num - student_layers_num, teacher_layers_num))

    prefix_teacher = list(teacher.state_dict().keys())[0].split(".")[0]
    prefix_student = list(student.state_dict().keys())[0].split(".")[0]
    student_sd = _extract_layers(
        teacher_model=teacher,
        layers=layers_to_transfer,
    )
    student.load_state_dict(student_sd)
    return student
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
    # Instantiate model
    logger.info(f"Loading model based on config from {config_path}...")
    config = BertConfig.from_json_file(config_path)
    model = BertModel(config)

    # Load weights from checkpoint
    logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
    load_tf2_weights_in_bert(model, tf_checkpoint_path, config)

    # Save pytorch-model
    logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
    torch.save(model.state_dict(), pytorch_dump_path)
def check_compability(torch_model: BertModel, tf_model: TFBertModel):
    torch_weights = []
    for k, v in torch_model.state_dict().items():
        if k == "embeddings.position_ids":
            print("im here")
            continue
        if not k.startswith("embeddings.") and k.endswith(".weight"):
            torch_weights.append(v.t().numpy())
        else:
            torch_weights.append(v.numpy())
    torch_weights[1], torch_weights[2] = torch_weights[2], torch_weights[1]

    tf_weights = tf_model.get_weights()

    check = [(torch_weight == tf_weight).all()
             for torch_weight, tf_weight in zip(torch_weights, tf_weights)]
    return all(check)
def train(config, bert_config, train_path, dev_path, rel2id, id2rel,
          tokenizer):
    if os.path.exists(config.output_dir) is False:
        os.makedirs(config.output_dir, exist_ok=True)
    if os.path.exists('./data/train_file.pkl'):
        train_data = pickle.load(open("./data/train_file.pkl", mode='rb'))
    else:
        train_data = data.load_data(train_path, tokenizer, rel2id, num_rels)
        pickle.dump(train_data, open("./data/train_file.pkl", mode='wb'))
    dev_data = json.load(open(dev_path))
    for sent in dev_data:
        data.to_tuple(sent)
    data_manager = data.SPO(train_data)
    train_sampler = RandomSampler(data_manager)
    train_data_loader = DataLoader(data_manager,
                                   sampler=train_sampler,
                                   batch_size=config.batch_size,
                                   drop_last=True)
    num_train_steps = int(
        len(data_manager) / config.batch_size) * config.max_epoch

    if config.bert_pretrained_model is not None:
        logger.info('load bert weight')
        Bert_model = BertModel.from_pretrained(config.bert_pretrained_model,
                                               config=bert_config)
    else:
        logger.info('random initialize bert model')
        Bert_model = BertModel(config=bert_config).init_weights()
    Bert_model.to(device)
    submodel = sub_model(config).to(device)
    objmodel = obj_model(config).to(device)

    loss_fuc = nn.BCELoss(reduction='none')
    params = list(Bert_model.parameters()) + list(
        submodel.parameters()) + list(objmodel.parameters())
    optimizer = AdamW(params, lr=config.lr)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(data_manager))
    logger.info("  Num Epochs = %d", config.max_epoch)
    logger.info("  Total train batch size = %d", config.batch_size)
    logger.info("  Total optimization steps = %d", num_train_steps)
    logger.info("  Logging steps = %d", config.print_freq)
    logger.info("  Save steps = %d", config.save_freq)

    global_step = 0
    Bert_model.train()
    submodel.train()
    objmodel.train()

    for _ in range(config.max_epoch):
        optimizer.zero_grad()
        epoch_itorator = tqdm(train_data_loader, disable=None)
        for step, batch in enumerate(epoch_itorator):
            batch = tuple(t.to(device) for t in batch)
            input_ids, segment_ids, input_masks, sub_positions, sub_heads, sub_tails, obj_heads, obj_tails = batch

            bert_output = Bert_model(input_ids, input_masks, segment_ids)[0]
            pred_sub_heads, pred_sub_tails = submodel(
                bert_output)  # [batch_size, seq_len, 1]
            pred_obj_heads, pred_obj_tails = objmodel(bert_output,
                                                      sub_positions)

            # 计算loss
            mask = input_masks.view(-1)

            # loss1
            sub_heads = sub_heads.unsqueeze(-1)  # [batch_szie, seq_len, 1]
            sub_tails = sub_tails.unsqueeze(-1)

            loss1_head = loss_fuc(pred_sub_heads, sub_heads).view(-1)
            loss1_head = torch.sum(loss1_head * mask) / torch.sum(mask)

            loss1_tail = loss_fuc(pred_sub_tails, sub_tails).view(-1)
            loss1_tail = torch.sum(loss1_tail * mask) / torch.sum(mask)

            loss1 = loss1_head + loss1_tail

            # loss2
            loss2_head = loss_fuc(pred_obj_heads,
                                  obj_heads).view(-1, obj_heads.shape[-1])
            loss2_head = torch.sum(
                loss2_head * mask.unsqueeze(-1)) / torch.sum(mask)

            loss2_tail = loss_fuc(pred_obj_tails,
                                  obj_tails).view(-1, obj_tails.shape[-1])
            loss2_tail = torch.sum(
                loss2_tail * mask.unsqueeze(-1)) / torch.sum(mask)

            loss2 = loss2_head + loss2_tail

            # optimize
            loss = loss1 + loss2
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            if (global_step + 1) % config.print_freq == 0:
                logger.info(
                    "epoch : {} step: {} #### loss1: {}  loss2: {}".format(
                        _, global_step + 1,
                        loss1.cpu().item(),
                        loss2.cpu().item()))

            if (global_step + 1) % config.eval_freq == 0:
                logger.info("***** Running evaluating *****")
                with torch.no_grad():
                    Bert_model.eval()
                    submodel.eval()
                    objmodel.eval()
                    P, R, F1 = utils.metric(Bert_model, submodel, objmodel,
                                            dev_data, id2rel, tokenizer)
                    logger.info(f'precision:{P}\nrecall:{R}\nF1:{F1}')
                Bert_model.train()
                submodel.train()
                objmodel.train()

            if (global_step + 1) % config.save_freq == 0:
                # Save a trained model
                model_name = "pytorch_model_%d" % (global_step + 1)
                output_model_file = os.path.join(config.output_dir, model_name)
                state = {
                    'bert_state_dict': Bert_model.state_dict(),
                    'subject_state_dict': submodel.state_dict(),
                    'object_state_dict': objmodel.state_dict(),
                }
                torch.save(state, output_model_file)

    model_name = "pytorch_model_last"
    output_model_file = os.path.join(config.output_dir, model_name)
    state = {
        'bert_state_dict': Bert_model.state_dict(),
        'subject_state_dict': submodel.state_dict(),
        'object_state_dict': objmodel.state_dict(),
    }
    torch.save(state, output_model_file)
Beispiel #5
0
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str,
                                     model_name: str):
    """
    :param model:BertModel Pytorch model instance to be converted
    :param ckpt_dir: Tensorflow model directory
    :param model_name: model name
    :return:

    Currently supported HF models:
        Y BertModel
        N BertForMaskedLM
        N BertForPreTraining
        N BertForMultipleChoice
        N BertForNextSentencePrediction
        N BertForSequenceClassification
        N BertForQuestionAnswering
    """

    tensors_to_transpose = ("dense.weight", "attention.self.query",
                            "attention.self.key", "attention.self.value")

    var_map = (
        ("layer.", "layer_"),
        ("word_embeddings.weight", "word_embeddings"),
        ("position_embeddings.weight", "position_embeddings"),
        ("token_type_embeddings.weight", "token_type_embeddings"),
        (".", "/"),
        ("LayerNorm/weight", "LayerNorm/gamma"),
        ("LayerNorm/bias", "LayerNorm/beta"),
        ("weight", "kernel"),
    )

    if not os.path.isdir(ckpt_dir):
        os.makedirs(ckpt_dir)

    state_dict = model.state_dict()

    def to_tf_var_name(name: str):
        for patt, repl in iter(var_map):
            name = name.replace(patt, repl)
        return "bert/{}".format(name)

    def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
        tf_var = tf.get_variable(dtype=tf_dtype,
                                 shape=tensor.shape,
                                 name=name,
                                 initializer=tf.zeros_initializer())
        session.run(tf.variables_initializer([tf_var]))
        session.run(tf_var)
        return tf_var

    tf.reset_default_graph()
    with tf.Session() as session:
        for var_name in state_dict:
            tf_name = to_tf_var_name(var_name)
            torch_tensor = state_dict[var_name].numpy()
            if any([x in var_name for x in tensors_to_transpose]):
                torch_tensor = torch_tensor.T
            tf_var = create_tf_var(tensor=torch_tensor,
                                   name=tf_name,
                                   session=session)
            tf.keras.backend.set_value(tf_var, torch_tensor)
            tf_weight = session.run(tf_var)
            print("Successfully created {}: {}".format(
                tf_name, np.allclose(tf_weight, torch_tensor)))

        saver = tf.train.Saver(tf.trainable_variables())
        saver.save(
            session,
            os.path.join(ckpt_dir,
                         model_name.replace("-", "_") + ".ckpt"))
Beispiel #6
0
def main():
    parser = argparse.ArgumentParser(
        description='Train the individual Transformer model')
    parser.add_argument('--dataset_folder', type=str, default='datasets')
    parser.add_argument('--dataset_name', type=str, default='zara1')
    parser.add_argument('--obs', type=int, default=8)
    parser.add_argument('--preds', type=int, default=12)
    parser.add_argument('--emb_size', type=int, default=1024)
    parser.add_argument('--heads', type=int, default=8)
    parser.add_argument('--layers', type=int, default=6)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--output_folder', type=str, default='Output')
    parser.add_argument('--val_size', type=int, default=50)
    parser.add_argument('--gpu_device', type=str, default="0")
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--max_epoch', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--validation_epoch_start', type=int, default=30)
    parser.add_argument('--resume_train', action='store_true')
    parser.add_argument('--delim', type=str, default='\t')
    parser.add_argument('--name', type=str, default="zara1")

    args = parser.parse_args()
    model_name = args.name

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/BERT')
    except:
        pass
    try:
        os.mkdir(f'models/BERT')
    except:
        pass

    try:
        os.mkdir(f'output/BERT/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/BERT/{args.name}')
    except:
        pass

    log = SummaryWriter('logs/BERT_%s' % model_name)

    log.add_scalar('eval/mad', 0, 0)
    log.add_scalar('eval/fad', 0, 0)

    try:
        os.mkdir(args.name)
    except:
        pass

    device = torch.device("cuda")
    if args.cpu or not torch.cuda.is_available():
        device = torch.device("cpu")

    args.verbose = True

    ## creation of the dataloaders for train and validation
    train_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                    args.dataset_name,
                                                    0,
                                                    args.obs,
                                                    args.preds,
                                                    delim=args.delim,
                                                    train=True,
                                                    verbose=args.verbose)
    val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                  args.dataset_name,
                                                  0,
                                                  args.obs,
                                                  args.preds,
                                                  delim=args.delim,
                                                  train=False,
                                                  verbose=args.verbose)
    test_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                   args.dataset_name,
                                                   0,
                                                   args.obs,
                                                   args.preds,
                                                   delim=args.delim,
                                                   train=False,
                                                   eval=True,
                                                   verbose=args.verbose)

    from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, AdamW

    config = BertConfig(vocab_size=30522,
                        hidden_size=768,
                        num_hidden_layers=12,
                        num_attention_heads=12,
                        intermediate_size=3072,
                        hidden_act='relu',
                        hidden_dropout_prob=0.1,
                        attention_probs_dropout_prob=0.1,
                        max_position_embeddings=512,
                        type_vocab_size=2,
                        initializer_range=0.02,
                        layer_norm_eps=1e-12)
    model = BertModel(config).to(device)

    from individual_TF import LinearEmbedding as NewEmbed, Generator as GeneratorTS
    a = NewEmbed(3, 768).to(device)
    model.set_input_embeddings(a)
    generator = GeneratorTS(768, 2).to(device)
    #model.set_output_embeddings(GeneratorTS(1024,2))

    tr_dl = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=0)
    val_dl = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=0)
    test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=0)

    #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01)
    #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005)
    optim = NoamOpt(
        768, 0.1, len(tr_dl),
        torch.optim.Adam(list(a.parameters()) + list(model.parameters()) +
                         list(generator.parameters()),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    #optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001)
    epoch = 0

    mean = train_dataset[:]['src'][:, :, 2:4].mean((0, 1)) * 0
    std = train_dataset[:]['src'][:, :, 2:4].std((0, 1)) * 0 + 1

    while epoch < args.max_epoch:
        epoch_loss = 0
        model.train()

        for id_b, batch in enumerate(tr_dl):

            optim.optimizer.zero_grad()
            r = 0
            rot_mat = np.array([[np.cos(r), np.sin(r)],
                                [-np.sin(r), np.cos(r)]])

            inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device)
            inp = torch.matmul(inp,
                               torch.from_numpy(rot_mat).float().to(device))
            trg_masked = torch.zeros((inp.shape[0], args.preds, 2)).to(device)
            inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device)
            trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1],
                                  1).to(device)
            inp_cat = torch.cat((inp, trg_masked), 1)
            cls_cat = torch.cat((inp_cls, trg_cls), 1)
            net_input = torch.cat((inp_cat, cls_cat), 2)

            position = torch.arange(0, net_input.shape[1]).repeat(
                inp.shape[0], 1).long().to(device)
            token = torch.zeros(
                (inp.shape[0], net_input.shape[1])).long().to(device)
            attention_mask = torch.ones(
                (inp.shape[0], net_input.shape[1])).long().to(device)

            out = model(input_ids=net_input,
                        position_ids=position,
                        token_type_ids=token,
                        attention_mask=attention_mask)

            pred = generator(out[0])

            loss = F.pairwise_distance(
                pred[:, :].contiguous().view(-1, 2),
                torch.matmul(
                    torch.cat(
                        (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]),
                        1).contiguous().view(-1, 2).to(device),
                    torch.from_numpy(rot_mat).float().to(device))).mean()
            loss.backward()
            optim.step()
            print("epoch %03i/%03i  frame %04i / %04i loss: %7.4f" %
                  (epoch, args.max_epoch, id_b, len(tr_dl), loss.item()))
            epoch_loss += loss.item()
        #sched.step()
        log.add_scalar('Loss/train', epoch_loss / len(tr_dl), epoch)
        with torch.no_grad():
            model.eval()

            gt = []
            pr = []
            val_loss = 0
            for batch in val_dl:
                inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device)
                trg_masked = torch.zeros(
                    (inp.shape[0], args.preds, 2)).to(device)
                inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device)
                trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1],
                                      1).to(device)
                inp_cat = torch.cat((inp, trg_masked), 1)
                cls_cat = torch.cat((inp_cls, trg_cls), 1)
                net_input = torch.cat((inp_cat, cls_cat), 2)

                position = torch.arange(0, net_input.shape[1]).repeat(
                    inp.shape[0], 1).long().to(device)
                token = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)
                attention_mask = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)

                out = model(input_ids=net_input,
                            position_ids=position,
                            token_type_ids=token,
                            attention_mask=attention_mask)

                pred = generator(out[0])

                loss = F.pairwise_distance(
                    pred[:, :].contiguous().view(-1, 2),
                    torch.cat(
                        (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]),
                        1).contiguous().view(-1, 2).to(device)).mean()
                val_loss += loss.item()

                gt_b = batch['trg'][:, :, 0:2]
                preds_tr_b = pred[:, args.obs:].cumsum(1).to(
                    'cpu').detach() + batch['src'][:, -1:, 0:2]
                gt.append(gt_b)
                pr.append(preds_tr_b)

            gt = np.concatenate(gt, 0)
            pr = np.concatenate(pr, 0)
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)
            log.add_scalar('validation/loss', val_loss / len(val_dl), epoch)
            log.add_scalar('validation/mad', mad, epoch)
            log.add_scalar('validation/fad', fad, epoch)

            model.eval()

            gt = []
            pr = []
            for batch in test_dl:
                inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device)
                trg_masked = torch.zeros(
                    (inp.shape[0], args.preds, 2)).to(device)
                inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device)
                trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1],
                                      1).to(device)
                inp_cat = torch.cat((inp, trg_masked), 1)
                cls_cat = torch.cat((inp_cls, trg_cls), 1)
                net_input = torch.cat((inp_cat, cls_cat), 2)

                position = torch.arange(0, net_input.shape[1]).repeat(
                    inp.shape[0], 1).long().to(device)
                token = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)
                attention_mask = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)

                out = model(input_ids=net_input,
                            position_ids=position,
                            token_type_ids=token,
                            attention_mask=attention_mask)

                pred = generator(out[0])

                gt_b = batch['trg'][:, :, 0:2]
                preds_tr_b = pred[:, args.obs:].cumsum(1).to(
                    'cpu').detach() + batch['src'][:, -1:, 0:2]
                gt.append(gt_b)
                pr.append(preds_tr_b)

            gt = np.concatenate(gt, 0)
            pr = np.concatenate(pr, 0)
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)

            torch.save(model.state_dict(),
                       "models/BERT/%s/ep_%03i.pth" % (args.name, epoch))
            torch.save(generator.state_dict(),
                       "models/BERT/%s/gen_%03i.pth" % (args.name, epoch))
            torch.save(a.state_dict(),
                       "models/BERT/%s/emb_%03i.pth" % (args.name, epoch))

            log.add_scalar('eval/mad', mad, epoch)
            log.add_scalar('eval/fad', fad, epoch)

        epoch += 1

    ab = 1
Beispiel #7
0
class RenamingModelHybrid(nn.Module):
    def __init__(self, vocab, top_k, config, device):
        super(RenamingModelHybrid, self).__init__()

        self.vocab = vocab
        self.top_k = top_k
        self.source_vocab_size = len(self.vocab.source_tokens) + 1

        self.graph_encoder = GraphASTEncoder.build(
            config['encoder']['graph_encoder'])
        self.graph_emb_size = config['encoder']['graph_encoder']['gnn'][
            'hidden_size']
        self.emb_size = 256

        state_dict = torch.load(
            'saved_checkpoints/bert_2604/bert_pretrained_epoch_23_batch_140000.pth',
            map_location=device)

        keys_to_delete = [
            "cls.predictions.bias", "cls.predictions.transform.dense.weight",
            "cls.predictions.transform.dense.bias",
            "cls.predictions.transform.LayerNorm.weight",
            "cls.predictions.transform.LayerNorm.bias",
            "cls.predictions.decoder.weight", "cls.predictions.decoder.bias",
            "cls.seq_relationship.weight", "cls.seq_relationship.bias"
        ]

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict['model'].items():
            if k in keys_to_delete: continue
            name = k[5:]  # remove `bert.`
            new_state_dict[name] = v

        bert_config = BertConfig(vocab_size=self.source_vocab_size,
                                 max_position_embeddings=512,
                                 num_hidden_layers=6,
                                 hidden_size=self.emb_size,
                                 num_attention_heads=4)
        self.bert_encoder = BertModel(bert_config)
        self.bert_encoder.load_state_dict(new_state_dict)

        self.target_vocab_size = len(self.vocab.all_subtokens) + 1

        bert_config = BertConfig(vocab_size=self.target_vocab_size,
                                 max_position_embeddings=1000,
                                 num_hidden_layers=6,
                                 hidden_size=self.emb_size,
                                 num_attention_heads=4,
                                 is_decoder=True)
        self.bert_decoder = BertModel(bert_config)

        state_dict = torch.load(
            'saved_checkpoints/bert_0905/bert_decoder_epoch_19_batch_220000.pth',
            map_location=device)

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict['model'].items():
            if k in keys_to_delete: continue
            if 'crossattention' in k: continue
            name = k[5:]  # remove `bert.`
            new_state_dict[name] = v

        for key in new_state_dict:
            self.bert_decoder.state_dict()[key].copy_(new_state_dict[key])

        self.enc_graph_map = nn.Linear(self.emb_size + self.graph_emb_size,
                                       self.emb_size)
        self.fc_final = nn.Linear(self.emb_size, self.target_vocab_size)

        self.fc_final.weight.data = state_dict['model'][
            'cls.predictions.decoder.weight']

    def forward(self, src_tokens, src_mask, variable_ids, target_tokens,
                graph_input):
        encoder_attention_mask = torch.ones_like(src_tokens).float().to(
            src_tokens.device)
        encoder_attention_mask[src_tokens == PAD_ID] = 0.0

        assert torch.max(src_tokens) < self.source_vocab_size
        assert torch.min(src_tokens) >= 0
        assert torch.max(target_tokens) < self.target_vocab_size
        assert torch.min(target_tokens) >= 0

        encoder_output = self.bert_encoder(
            input_ids=src_tokens, attention_mask=encoder_attention_mask)[0]

        graph_output = self.graph_encoder(graph_input)
        variable_emb = graph_output['variable_encoding']

        graph_embedding = torch.gather(
            variable_emb, 1,
            variable_ids.unsqueeze(2).repeat(
                1, 1, variable_emb.shape[2])) * src_mask.unsqueeze(2)

        full_enc_output = self.enc_graph_map(
            torch.cat((encoder_output, graph_embedding), dim=2))

        decoder_attention_mask = torch.ones_like(target_tokens).float().to(
            target_tokens.device)
        decoder_attention_mask[target_tokens == PAD_ID] = 0.0

        decoder_output = self.bert_decoder(
            input_ids=target_tokens,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=full_enc_output,
            encoder_attention_mask=encoder_attention_mask)[0]

        predictions = self.fc_final(decoder_output)

        return predictions

    def predict(self,
                src_tokens,
                src_mask,
                variable_ids,
                graph_input,
                approx=False):
        end_token = self.vocab.all_subtokens.word2id['</s>']
        start_token = self.vocab.all_subtokens.word2id['<s>']
        batch_size = src_tokens.shape[0]

        encoder_attention_mask = torch.ones_like(src_tokens).float().to(
            src_tokens.device)
        encoder_attention_mask[src_tokens == PAD_ID] = 0.0

        assert torch.max(src_tokens) < self.source_vocab_size
        assert torch.min(src_tokens) >= 0

        encoder_output = self.bert_encoder(
            input_ids=src_tokens, attention_mask=encoder_attention_mask)[0]

        graph_output = self.graph_encoder(graph_input)
        variable_emb = graph_output['variable_encoding']

        graph_embedding = torch.gather(
            variable_emb, 1,
            variable_ids.unsqueeze(2).repeat(
                1, 1, variable_emb.shape[2])) * src_mask.unsqueeze(2)

        full_enc_output = self.enc_graph_map(
            torch.cat((encoder_output, graph_embedding), dim=2))

        source_vocab_to_target = {
            self.vocab.source_tokens.word2id[t]:
            self.vocab.all_subtokens.word2id[t]
            for t in self.vocab.source_tokens.word2id.keys()
        }
        src_target_maps = []
        confidences = []

        for i in range(batch_size):

            if src_tokens[i][0] != start_token:
                input_sequence = torch.zeros(src_tokens.shape[1] + 1).to(
                    src_tokens.device)
                input_mask = torch.zeros(src_mask.shape[1] + 1).to(
                    src_mask.device)
                input_sequence[1:] = src_tokens[i]
                input_mask[1:] = src_mask[i]
            else:
                input_sequence = src_tokens[i]
                input_mask = src_mask[i]

            num_vars = int(input_mask.sum())
            seq_len = torch.sum((input_sequence != PAD_ID).long())
            generated_seqs = torch.zeros(1, min(
                seq_len + 10 * num_vars, 1000)).long().to(src_tokens.device)

            source_marker = 0
            gen_markers = torch.LongTensor([0]).to(generated_seqs.device)
            prior_probs = torch.FloatTensor([0]).to(generated_seqs.device)

            candidate_maps = [{}]

            for _ in range(num_vars):
                # Filling up the known (non-identifier) tokens
                while source_marker < seq_len and input_mask[
                        source_marker] != 1:
                    token = input_sequence[source_marker]
                    values = source_vocab_to_target[token.item(
                    )] * torch.ones_like(gen_markers).to(generated_seqs.device)

                    generated_seqs = torch.scatter(generated_seqs, 1,
                                                   gen_markers.unsqueeze(1),
                                                   values.unsqueeze(1))

                    source_marker += 1
                    gen_markers += 1

                if source_marker >= seq_len: break

                curr_var = input_sequence[source_marker].item()

                if curr_var in candidate_maps[0]:
                    if approx is True:
                        source_marker += 1
                        continue
                    # If we've seen this variable before, just use the previous predictions and update the scores
                    # Note - it's enough to check candidate_maps[0] because if it is in the first map, it is in all of them

                    orig_markers = gen_markers.clone()

                    for j in range(len(candidate_maps)):
                        pred = candidate_maps[j][curr_var]
                        generated_seqs[j][gen_markers[j]:gen_markers[j] +
                                          len(pred)] = torch.LongTensor(
                                              pred).to(generated_seqs.device)
                        gen_markers[j] += len(pred)

                    decoder_attention_mask = torch.ones_like(
                        generated_seqs).float().to(generated_seqs.device)
                    decoder_attention_mask[generated_seqs == PAD_ID] = 0.0

                    decoder_output = self.bert_decoder(
                        input_ids=generated_seqs,
                        attention_mask=decoder_attention_mask,
                        encoder_hidden_states=full_enc_output[i].unsqueeze(0),
                        encoder_attention_mask=encoder_attention_mask[i].
                        unsqueeze(0))[0]

                    probabilities = F.log_softmax(
                        self.fc_final(decoder_output), dim=-1)

                    # Add up the scores of the token at the __next__ time step

                    scores = torch.zeros(generated_seqs.shape[0]).to(
                        generated_seqs.device)
                    active = torch.ones(generated_seqs.shape[0]).long().to(
                        generated_seqs.device)
                    temp_markers = orig_markers

                    while torch.sum(active) != 0:
                        position_probs = torch.gather(
                            probabilities, 1,
                            (temp_markers - 1).reshape(-1, 1, 1).repeat(
                                1, 1, probabilities.shape[2])).squeeze(1)
                        curr_tokens = torch.gather(generated_seqs, 1,
                                                   temp_markers.unsqueeze(1))
                        tok_probs = torch.gather(position_probs, 1,
                                                 curr_tokens).squeeze(1)

                        tok_probs *= active
                        scores += tok_probs

                        active *= (temp_markers != (gen_markers - 1)).long()
                        temp_markers += active

                    # Update the prior probabilities
                    prior_probs = prior_probs + scores

                else:
                    # You encounter a new variable which hasn't been seen before
                    # Generate <beam_width> possibilities for its name
                    generated_seqs, gen_markers, prior_probs, candidate_maps = self.beam_search(
                        generated_seqs,
                        gen_markers,
                        prior_probs,
                        candidate_maps,
                        curr_var,
                        full_enc_output[i].unsqueeze(0),
                        encoder_attention_mask[i].unsqueeze(0),
                        beam_width=5,
                        top_k=self.top_k)

                source_marker += 1

            final_ind = torch.argmax(prior_probs)
            confidence = torch.max(prior_probs).item()
            src_target_map = candidate_maps[final_ind]

            src_target_maps.append(src_target_map)
            confidences.append(confidence)

        return src_target_maps, confidences

    def beam_search(self,
                    generated_seqs,
                    gen_markers,
                    prior_probs,
                    candidate_maps,
                    curr_var,
                    full_enc_output,
                    encoder_attention_mask,
                    beam_width=5,
                    top_k=10):

        if generated_seqs.shape[0] * beam_width < top_k:
            beam_width = top_k

        active = torch.ones_like(gen_markers).to(gen_markers.device)
        beam_alpha = 0.7
        end_token = self.vocab.all_subtokens.word2id['</s>']

        candidate_maps = candidate_maps
        orig_markers = gen_markers.clone()

        for _ in range(10):  # Predict at most 10 subtokens
            decoder_attention_mask = torch.ones_like(
                generated_seqs).float().to(generated_seqs.device)
            decoder_attention_mask[generated_seqs == PAD_ID] = 0.0

            decoder_output = self.bert_decoder(
                input_ids=generated_seqs,
                attention_mask=decoder_attention_mask,
                encoder_hidden_states=full_enc_output,
                encoder_attention_mask=encoder_attention_mask)[0]
            probabilities = F.log_softmax(self.fc_final(decoder_output),
                                          dim=-1)
            # Gather the predictions at the current markers
            # (gen_marker - 1) because prediction happens one step ahead
            probabilities = torch.gather(
                probabilities, 1, (gen_markers - 1).reshape(-1, 1, 1).repeat(
                    1, 1, probabilities.shape[2])).squeeze(1)

            probs, preds = probabilities.sort(dim=-1, descending=True)

            probs *= active.unsqueeze(
                1)  # Set log prob of non-active ones to 0
            preds[
                active ==
                0] = end_token  # Set preds of non-active ones to the end token (ie, remain unchanged)

            # Repeat active ones only once. Repeat the rest beam_width no. of times.
            filter_mask = torch.ones(
                (preds.shape[0], beam_width)).long().to(preds.device)
            filter_mask *= active.unsqueeze(1)
            filter_mask[:, 0][active == 0] = 1
            filter_mask = filter_mask.reshape(-1)

            preds = preds[:, :beam_width].reshape(-1)[filter_mask == 1]
            probs = probs[:, :beam_width].reshape(-1)[filter_mask == 1]

            generated_seqs = torch.repeat_interleave(generated_seqs,
                                                     beam_width,
                                                     dim=0)[filter_mask == 1]
            orig_markers = torch.repeat_interleave(orig_markers,
                                                   beam_width,
                                                   dim=0)[filter_mask == 1]
            gen_markers = torch.repeat_interleave(gen_markers,
                                                  beam_width,
                                                  dim=0)[filter_mask == 1]
            active = torch.repeat_interleave(active, beam_width,
                                             dim=0)[filter_mask == 1]
            prior_probs = torch.repeat_interleave(prior_probs,
                                                  beam_width,
                                                  dim=0)[filter_mask == 1]

            candidate_maps = [
                item.copy() for item in candidate_maps
                for _ in range(beam_width)
            ]
            candidate_maps = [
                candidate_maps[i] for i in range(len(candidate_maps))
                if filter_mask[i] == 1
            ]

            generated_seqs.scatter_(1, gen_markers.unsqueeze(1),
                                    preds.unsqueeze(1))

            # lengths       = (gen_markers - gen_marker + 1).float()
            # penalties     = torch.pow(5 + lengths, beam_alpha) / math.pow(6, beam_alpha)
            penalties = torch.ones_like(probs).to(probs.device)

            updated_probs = probs + prior_probs

            sort_inds = (updated_probs / penalties).argsort(descending=True)
            updated_probs = updated_probs[sort_inds]

            prior_probs = updated_probs[:top_k]

            new_preds = preds[sort_inds[:top_k]]
            generated_seqs = generated_seqs[sort_inds[:top_k]]
            gen_markers = gen_markers[sort_inds[:top_k]]
            active = active[sort_inds[:top_k]]
            orig_markers = orig_markers[sort_inds[:top_k]]

            candidate_maps = [
                candidate_maps[ind.item()] for ind in sort_inds[:top_k]
            ]

            active = active * (new_preds != end_token).long()
            gen_markers += active

            if torch.sum(active) == 0: break

        # gen_markers are pointing at the end_token. Move them one ahead
        gen_markers += 1

        assert generated_seqs.shape[0] == top_k

        for i in range(top_k):
            candidate_maps[i][curr_var] = generated_seqs[i][
                orig_markers[i]:gen_markers[i]].cpu().tolist()

        return generated_seqs, gen_markers, prior_probs, candidate_maps
model_info = pytorch_kobert
model_path = download(model_info['url'],
                      model_info['fname'],
                      model_info['chksum'],
                      cachedir=cachedir)
# download vocab
vocab_info = tokenizer
vocab_path = download(vocab_info['url'],
                      vocab_info['fname'],
                      vocab_info['chksum'],
                      cachedir=cachedir)
#################################################################################################
print('BERT 모델 선언')

bertmodel = BertModel(config=BertConfig.from_dict(bert_config))
bertmodel.state_dict(torch.load(model_path))

print("GPU 디바이스 세팅")
device = torch.device(ctx)
bertmodel.to(device)
bertmodel.train()
vocab = nlp.vocab.BERTVocab.from_sentencepiece(vocab_path,
                                               padding_token='[PAD]')

#################################################################################################
# 파라미터 세팅
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

max_len = 64
batch_size = 64
Beispiel #9
0
def _extract_layers(
    teacher_model: BertModel,
    layers: List[int],
    prefix_teacher="bert",
    prefix_student="bert",
    encoder_name="encoder",
):
    state_dict = teacher_model.state_dict()
    compressed_sd = {}

    # extract embeddings
    for w in ["word_embeddings", "position_embeddings"]:
        compressed_sd[f"{prefix_student}.embeddings.{w}.weight"] = state_dict[
            f"{prefix_teacher}.embeddings.{w}.weight"]
    for w in ["weight", "bias"]:
        compressed_sd[
            f"{prefix_student}.embeddings.LayerNorm.{w}"] = state_dict[
                f"{prefix_teacher}.embeddings.LayerNorm.{w}"]
    # extract encoder

    for std_idx, teacher_idx in enumerate(layers):
        for w in ["weight", "bias"]:
            compressed_sd[
                f"{prefix_student}.encoder.layer.{std_idx}.attention.q_lin.{w}"  # noqa: E501
            ] = state_dict[
                f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.self.query.{w}"  # noqa: E501
            ]
            compressed_sd[
                f"{prefix_student}.encoder.layer.{std_idx}.attention.k_lin.{w}"  # noqa: E501
            ] = state_dict[
                f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.self.key.{w}"  # noqa: E501
            ]
            compressed_sd[
                f"{prefix_student}.encoder.layer.{std_idx}.attention.v_lin.{w}"  # noqa: E501
            ] = state_dict[
                f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.self.value.{w}"  # noqa: E501
            ]

            compressed_sd[
                f"{prefix_student}.encoder.layer.{std_idx}.attention.out_lin.{w}"  # noqa: E501
            ] = state_dict[
                f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"  # noqa: E501
            ]
            compressed_sd[
                f"{prefix_student}.encoder.layer.{std_idx}.sa_layer_norm.{w}"  # noqa: E501
            ] = state_dict[
                f"{prefix_teacher}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"  # noqa: E501
            ]

            compressed_sd[
                f"{prefix_student}.encoder.layer.{std_idx}.ffn.lin1.{w}"  # noqa: E501
            ] = state_dict[
                f"{prefix_teacher}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"  # noqa: E501
            ]
            compressed_sd[
                f"{prefix_student}.encoder.layer.{std_idx}.ffn.lin2.{w}"  # noqa: E501
            ] = state_dict[
                f"{prefix_teacher}.encoder.layer.{teacher_idx}.output.dense.{w}"  # noqa: E501
            ]
            compressed_sd[
                f"{prefix_student}.encoder.layer.{std_idx}.output_layer_norm.{w}"  # noqa: E501
            ] = state_dict[
                f"{prefix_teacher}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"  # noqa: E501
            ]

    # extract vocab
    compressed_sd["cls.predictions.decoder.weight"] = state_dict[
        "cls.predictions.decoder.weight"]
    compressed_sd["cls.predictions.bias"] = state_dict["cls.predictions.bias"]

    for w in ["weight", "bias"]:
        compressed_sd[f"vocab_transform.{w}"] = state_dict[
            f"cls.predictions.transform.dense.{w}"]
        compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[
            f"cls.predictions.transform.LayerNorm.{w}"]

    return compressed_sd