Exemplo n.º 1
0
def set_model_config(args, tokenizer):
    sentence_config = BertConfig()
    sentence_config.vocab_size = tokenizer.get_vocab_size()
    sentence_config.num_hidden_layers = args.num_layers1
    sentence_config.hidden_size = args.hidden_size1
    sentence_config.num_attention_heads = args.attention_heads1
    sentence_config.max_position_embeddings = args.block_length

    document_config = BertConfig()
    document_config.vocab_size = tokenizer.get_vocab_size()
    document_config.num_hidden_layers = args.num_layers2
    document_config.hidden_size = args.hidden_size2
    document_config.num_attention_heads = args.attention_heads2
    document_config.num_masked_blocks = args.max_blocks
    document_config.max_position_embeddings = args.max_blocks

    return sentence_config, document_config
Exemplo n.º 2
0
pretrain = True
sentence_block_length = 32
max_sentence_blocks = 48
hidden_size = 256
batch_size = 4
shuffle = True
drop_last = True

sentence_block_vector = torch.normal(mean=0.0, std=1.0, size=[hidden_size])

sentence_config = BertConfig()
sentence_config.vocab_size = tokenizer.get_vocab_size()
sentence_config.num_hidden_layers = 6
sentence_config.hidden_size = 256
sentence_config.num_attention_heads = 4
sentence_config.max_position_embeddings = sentence_block_length  # sentence_block_length

document_config = BertConfig()
document_config.vocab_size = tokenizer.get_vocab_size()
document_config.num_hidden_layers = 3
document_config.hidden_size = 256
document_config.num_attention_heads = 4
document_config.max_position_embeddings = max_sentence_blocks  # sentence_block_length

dataset = Dataset(file_path,
                  tokenizer,
                  sentence_block_length,
                  max_sentence_blocks,
                  mask=True)
dataloader = DataLoader(dataset,
def get_model(enable_model_name, is_pretraining, pretrained_path):
    # tile(37), menzen(2), reach_state(2), n_reach(3),
    # reach_ippatsu(2), dans(21), rates(19), oya(4),
    # scores(13), n_honba(3), n_round(12), sanma_or_yonma(2),
    # han_or_ton(2), aka_ari(2), kui_ari(2), special_token(4)
    # vocab_size = 37 + 2 + 2 + 3 + 2 + 21 + 19 + 4 + 13 + 3 + 12 + 2 + 2 + 2 + 2 + 4 + 2 + 4 + 6 + 8 # 130 + shanten_diff(2) + who(4) + sum_discards(6) + shanten(8)
    vocab_size = 37 + 2 + 2 + 3 + 2 + 21 + 19 + 4 + 13 + 3 + 12 + 2 + 2 + 2 + 2 + 4 + 4 + 6 + 8  # 130 + who(4) + sum_discards(6) + shanten(8)
    # hidden_size = 1024
    # num_attention_heads = 16
    hidden_size = 768
    num_attention_heads = 12
    max_position_embeddings = 239  # base + who(1) + sum_discards(1) + shanten(1)
    # intermediate_size = 64
    # intermediate_size = 3072
    # max_position_embeddings = 239 # base + pad(1) + who(1) + pad(1) + sum_discards(1) + pad(1) + shanten(1)
    # max_position_embeddings = 281 # 260 + pad(1) + shanten_diff(14) + pad(1) + who(1) + pad(1) + sum_discards(1) + pad(1) + shanten(1)

    if is_pretraining:
        config = BertConfig()
        config.vocab_size = vocab_size
        config.hidden_size = hidden_size
        config.num_attention_heads = num_attention_heads
        config.max_position_embeddings = max_position_embeddings
        config.num_hidden_layers = 12
        return MahjongPretrainingModel(config)

    model = None
    if enable_model_name == 'discard':
        discard_config = BertConfig()
        discard_config.vocab_size = vocab_size
        discard_config.hidden_size = hidden_size
        discard_config.num_attention_heads = num_attention_heads
        discard_config.max_position_embeddings = max_position_embeddings
        discard_config.num_hidden_layers = 12
        # discard_config.intermediate_size = intermediate_size
        # discard_config.num_hidden_layers = 24
        # discard_config.num_hidden_layers = 12
        model = MahjongDiscardModel(discard_config)
    elif enable_model_name == 'reach':
        reach_config = BertConfig()
        reach_config.vocab_size = vocab_size
        reach_config.hidden_size = hidden_size
        reach_config.num_attention_heads = num_attention_heads
        reach_config.max_position_embeddings = max_position_embeddings
        reach_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(reach_config)
    elif enable_model_name == 'chow':
        chow_config = BertConfig()
        chow_config.vocab_size = vocab_size
        chow_config.hidden_size = hidden_size
        chow_config.num_attention_heads = num_attention_heads
        chow_config.max_position_embeddings = max_position_embeddings
        chow_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(chow_config)
    elif enable_model_name == 'pong':
        pong_config = BertConfig()
        pong_config.vocab_size = vocab_size
        pong_config.hidden_size = hidden_size
        pong_config.num_attention_heads = num_attention_heads
        pong_config.max_position_embeddings = max_position_embeddings
        pong_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(pong_config)
    elif enable_model_name == 'kong':
        kong_config = BertConfig()
        kong_config.vocab_size = vocab_size
        kong_config.hidden_size = hidden_size
        kong_config.num_attention_heads = num_attention_heads
        kong_config.max_position_embeddings = max_position_embeddings
        kong_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(kong_config)

    if pretrained_path != '':
        checkpoint = torch.load(pretrained_path,
                                map_location=catalyst.utils.get_device())
        # print(checkpoint['model_state_dict'].keys())
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    return model
Exemplo n.º 4
0
    output_base_dir = PRJ_ROOT / "output" / datetime.now().strftime(
        "train%Y%m%d%H%M%S")
    output_base_dir.mkdir(exist_ok=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info("device: {}".format(device))

    df_train = pd.read_csv(TRAIN_PATH)
    df_train.columns = ["doc_id", "sents"]
    sents_list = df_train["sents"].values.tolist()
    logger.info("len(sents_list): {}".format(len(sents_list)))

    config = BertConfig()
    config.num_hidden_layers = 3
    config.num_attention_heads = 12
    config.hidden_size = 768
    config.intermediate_size = 3072
    config.max_position_embeddings = 512
    config.vocab_size = 32000

    logger.info("USE_NSP: {}".format(USE_NSP))
    if USE_NSP:
        model = BertForPreTraining(config)
    else:
        model = BertForPreTrainingWithoutNSP(config)
    model.to(device)

    logger.info(config)
    logger.info(model)