Esempio n. 1
0
def set_model_config(args, tokenizer):
    sentence_config = BertConfig()
    sentence_config.vocab_size = tokenizer.get_vocab_size()
    sentence_config.num_hidden_layers = args.num_layers1
    sentence_config.hidden_size = args.hidden_size1
    sentence_config.num_attention_heads = args.attention_heads1
    sentence_config.max_position_embeddings = args.block_length

    document_config = BertConfig()
    document_config.vocab_size = tokenizer.get_vocab_size()
    document_config.num_hidden_layers = args.num_layers2
    document_config.hidden_size = args.hidden_size2
    document_config.num_attention_heads = args.attention_heads2
    document_config.num_masked_blocks = args.max_blocks
    document_config.max_position_embeddings = args.max_blocks

    return sentence_config, document_config
Esempio n. 2
0
def create_model(model_class: BertPreTrainedModel,
                 encoder_config: BertConfig,
                 tokenizer: BertTokenizer,
                 encoder_path=None,
                 entity_types: dict = None,
                 relation_types: dict = None,
                 prop_drop: float = 0.1,
                 meta_embedding_size: int = 25,
                 size_embeddings_count: int = 10,
                 ed_embeddings_count: int = 300,
                 token_dist_embeddings_count: int = 700,
                 sentence_dist_embeddings_count: int = 50,
                 mention_threshold: float = 0.5,
                 coref_threshold: float = 0.5,
                 rel_threshold: float = 0.5,
                 position_embeddings_count: int = 700,
                 cache_path=None):
    params = dict(
        config=encoder_config,
        # JEREX model parameters
        cls_token=tokenizer.convert_tokens_to_ids('[CLS]'),
        entity_types=len(entity_types),
        relation_types=len(relation_types),
        prop_drop=prop_drop,
        meta_embedding_size=meta_embedding_size,
        size_embeddings_count=size_embeddings_count,
        ed_embeddings_count=ed_embeddings_count,
        token_dist_embeddings_count=token_dist_embeddings_count,
        sentence_dist_embeddings_count=sentence_dist_embeddings_count,
        mention_threshold=mention_threshold,
        coref_threshold=coref_threshold,
        rel_threshold=rel_threshold,
        tokenizer=tokenizer,
        cache_dir=cache_path,
    )

    if encoder_path is not None:
        model = model_class.from_pretrained(encoder_path, **params)
    else:
        model = model_class(**params)

    # conditionally increase position embedding count
    if encoder_config.max_position_embeddings < position_embeddings_count:
        old = model.bert.embeddings.position_embeddings

        new = nn.Embedding(position_embeddings_count,
                           encoder_config.hidden_size)
        new.weight.data[:encoder_config.
                        max_position_embeddings, :] = old.weight.data
        model.bert.embeddings.position_embeddings = new
        model.bert.embeddings.register_buffer(
            "position_ids",
            torch.arange(position_embeddings_count).expand((1, -1)))

        encoder_config.max_position_embeddings = position_embeddings_count

    return model
Esempio n. 3
0
    def build(self):

        # to be further set
        # breakpoint()
        self.image_feature_module = build_image_encoder(
            self.config.image_feature_processor, direct_features=True
        )
        if self.config.concate_trace:
            self.trace_feature_module = build_encoder(self.config.trace_feature_encoder)

        if self.config.base_model_name == "bert-base-uncased":
            self.encoderdecoder = EncoderDecoderModel.from_encoder_decoder_pretrained(
                "bert-base-uncased", "bert-base-uncased"
            )
        elif self.config.base_model_name == "2layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.max_position_embeddings = 1090
            config_encoder.num_hidden_layers = 2
            config_decoder.num_hidden_layers = 2
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        elif self.config.base_model_name == "3layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.num_hidden_layers = 3
            config_decoder.num_hidden_layers = 3
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        if self.config.loop_contrastive:
            self.trace_caption_contrastive = TraceCaptionContrastiveModel(
                self.config.tc_contrastive_aggregate_method
            )
        if (
            hasattr(self.config, "pretrans_attention")
            and self.config.pretrans_attention
        ):

            # import ipdb; ipdb.set_trace()
            tempconf = self.encoderdecoder.config.encoder
            num_heads = tempconf.num_attention_heads
            num_layers = tempconf.num_hidden_layers
            self.attention_trans = AttentionTransform(num_layers, num_heads, 100)
        self.BOS_ID = 101
        self.vae = OpenAIDiscreteVAE()
        image_code_dim = 768
        image_fmap_size = self.vae.image_size // (2 ** self.vae.num_layers)
        self.image_seq_len = image_fmap_size ** 2
        self.image_emb = torch.nn.Embedding(self.vae.num_tokens, image_code_dim)
        self.image_pos_emb = AxialPositionalEmbedding(
            image_code_dim, axial_shape=(image_fmap_size, image_fmap_size)
        )
Esempio n. 4
0
pretrain = True
sentence_block_length = 32
max_sentence_blocks = 48
hidden_size = 256
batch_size = 4
shuffle = True
drop_last = True

sentence_block_vector = torch.normal(mean=0.0, std=1.0, size=[hidden_size])

sentence_config = BertConfig()
sentence_config.vocab_size = tokenizer.get_vocab_size()
sentence_config.num_hidden_layers = 6
sentence_config.hidden_size = 256
sentence_config.num_attention_heads = 4
sentence_config.max_position_embeddings = sentence_block_length  # sentence_block_length

document_config = BertConfig()
document_config.vocab_size = tokenizer.get_vocab_size()
document_config.num_hidden_layers = 3
document_config.hidden_size = 256
document_config.num_attention_heads = 4
document_config.max_position_embeddings = max_sentence_blocks  # sentence_block_length

dataset = Dataset(file_path,
                  tokenizer,
                  sentence_block_length,
                  max_sentence_blocks,
                  mask=True)
dataloader = DataLoader(dataset,
                        batch_size=batch_size,
Esempio n. 5
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--config_file', type=str, required=True)
    args = parser.parse_args()

    # settings
    config_path = Path(args.config_file)
    config = Config.load(config_path)

    warnings.filterwarnings('ignore')
    set_seed(config.seed)
    start_time = time.time()

    with timer('load data'):
        DATA_DIR = './input/riiid-test-answer-prediction/'
        usecols = [
            'row_id',
            'timestamp',
            'user_id',
            'content_id',
            'content_type_id',
            'answered_correctly',
            'prior_question_elapsed_time',
        ]
        dtype = {
            'row_id': 'int64',
            'timestamp': 'int64',
            'user_id': 'int32',
            'content_id': 'int16',
            'content_type_id': 'int8',
            'answered_correctly': 'int8',
            'prior_question_elapsed_time': 'float32'
        }

        train_df = pd.read_csv(DATA_DIR + 'train.csv',
                               usecols=usecols,
                               dtype=dtype)
        question_df = pd.read_csv(DATA_DIR + 'questions.csv',
                                  usecols=['question_id', 'part'])

    train_df = train_df[train_df['content_type_id'] == 0].reset_index(
        drop=True)

    question_df['part'] += 1  # 0: padding id, 1: start id
    train_df['content_id'] += 2  # 0: padding id, 1: start id
    question_df['question_id'] += 2
    train_df = train_df.merge(question_df,
                              how='left',
                              left_on='content_id',
                              right_on='question_id')

    with timer('validation split'):
        train_idx, valid_idx, epoch_valid_idx = virtual_time_split(
            train_df,
            valid_size=config.valid_size,
            epoch_valid_size=config.epoch_valid_size)
        valid_y = train_df.iloc[valid_idx]['answered_correctly'].values
        epoch_valid_y = train_df.iloc[epoch_valid_idx][
            'answered_correctly'].values

    print('-' * 20)
    print(f'train size: {len(train_idx)}')
    print(f'valid size: {len(valid_idx)}')

    with timer('prepare data loader'):
        train_user_seqs = get_user_sequences(train_df.iloc[train_idx])
        valid_user_seqs = get_user_sequences(train_df.iloc[valid_idx])

        train_dataset = TrainDataset(train_user_seqs,
                                     window_size=config.window_size,
                                     stride_size=config.stride_size)
        valid_dataset = ValidDataset(train_df,
                                     train_user_seqs,
                                     valid_user_seqs,
                                     valid_idx,
                                     window_size=config.window_size)

        train_loader = DataLoader(train_dataset, **config.train_loader_params)
        valid_loader = DataLoader(valid_dataset, **config.valid_loader_params)

        # valid loader for epoch validation
        epoch_valid_user_seqs = get_user_sequences(
            train_df.iloc[epoch_valid_idx])
        epoch_valid_dataset = ValidDataset(train_df,
                                           train_user_seqs,
                                           epoch_valid_user_seqs,
                                           epoch_valid_idx,
                                           window_size=config.window_size)
        epoch_valid_loader = DataLoader(epoch_valid_dataset,
                                        **config.valid_loader_params)

    with timer('train'):
        if config.model == 'akt':
            content_encoder_config = BertConfig(
                **config.content_encoder_config)
            knowledge_encoder_config = BertConfig(
                **config.knowledge_encoder_config)
            decoder_config = BertConfig(**config.decoder_config)

            content_encoder_config.max_position_embeddings = config.window_size + 1
            knowledge_encoder_config.max_position_embeddings = config.window_size
            decoder_config.max_position_embeddings = config.window_size + 1

            model = AktEncoderDecoderModel(content_encoder_config,
                                           knowledge_encoder_config,
                                           decoder_config)

        elif config.model == 'saint':
            encoder_config = BertConfig(**config.encoder_config)
            decoder_config = BertConfig(**config.decoder_config)

            encoder_config.max_position_embeddings = config.window_size
            decoder_config.max_position_embeddings = config.window_size

            model = SaintEncoderDecoderModel(encoder_config, decoder_config)

        else:
            raise ValueError(f'Unknown model: {config.model}')

        model.to(config.device)
        model.zero_grad()

        optimizer = optim.Adam(model.parameters(), **config.optimizer_params)
        scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps)
        loss_ema = None

        for epoch in range(config.n_epochs):
            epoch_start_time = time.time()
            model.train()

            progress = tqdm(train_loader,
                            desc=f'epoch {epoch + 1}',
                            leave=False)
            for i, (x_batch, w_batch, y_batch) in enumerate(progress):
                y_pred = model(**x_batch.to(config.device).to_dict())
                loss = nn.BCEWithLogitsLoss(weight=w_batch.to(config.device))(
                    y_pred, y_batch.to(config.device))
                loss.backward()

                if (config.gradient_accumulation_steps is None
                        or (i + 1) % config.gradient_accumulation_steps == 0):
                    optimizer.step()
                    optimizer.zero_grad()
                    scheduler.step()

                loss_ema = loss_ema * 0.9 + loss.item(
                ) * 0.1 if loss_ema is not None else loss.item()
                progress.set_postfix(loss=loss_ema)

            valid_preds = predict(model,
                                  epoch_valid_loader,
                                  device=config.device)
            valid_score = roc_auc_score(epoch_valid_y, valid_preds)

            elapsed_time = time.time() - epoch_start_time
            print(
                f'Epoch {epoch + 1}/{config.n_epochs} \t valid score: {valid_score:.5f} \t time: {elapsed_time / 60:.1f} min'
            )

    with timer('predict'):
        valid_preds = predict(model, valid_loader, device=config.device)
        valid_score = roc_auc_score(valid_y, valid_preds)

    print(f'valid score: {valid_score:.5f}')

    output_dir = Path(f'./output/{config_path.stem}/')
    output_dir.mkdir(parents=True, exist_ok=True)

    torch.save(model.state_dict(), output_dir / 'model.pt')
    torch.save(optimizer.state_dict(), output_dir / 'optimizer.pt')

    elapsed_time = time.time() - start_time
    print(f'all processes done in {elapsed_time / 60:.1f} min.')
def get_model(enable_model_name, is_pretraining, pretrained_path):
    # tile(37), menzen(2), reach_state(2), n_reach(3),
    # reach_ippatsu(2), dans(21), rates(19), oya(4),
    # scores(13), n_honba(3), n_round(12), sanma_or_yonma(2),
    # han_or_ton(2), aka_ari(2), kui_ari(2), special_token(4)
    # vocab_size = 37 + 2 + 2 + 3 + 2 + 21 + 19 + 4 + 13 + 3 + 12 + 2 + 2 + 2 + 2 + 4 + 2 + 4 + 6 + 8 # 130 + shanten_diff(2) + who(4) + sum_discards(6) + shanten(8)
    vocab_size = 37 + 2 + 2 + 3 + 2 + 21 + 19 + 4 + 13 + 3 + 12 + 2 + 2 + 2 + 2 + 4 + 4 + 6 + 8  # 130 + who(4) + sum_discards(6) + shanten(8)
    # hidden_size = 1024
    # num_attention_heads = 16
    hidden_size = 768
    num_attention_heads = 12
    max_position_embeddings = 239  # base + who(1) + sum_discards(1) + shanten(1)
    # intermediate_size = 64
    # intermediate_size = 3072
    # max_position_embeddings = 239 # base + pad(1) + who(1) + pad(1) + sum_discards(1) + pad(1) + shanten(1)
    # max_position_embeddings = 281 # 260 + pad(1) + shanten_diff(14) + pad(1) + who(1) + pad(1) + sum_discards(1) + pad(1) + shanten(1)

    if is_pretraining:
        config = BertConfig()
        config.vocab_size = vocab_size
        config.hidden_size = hidden_size
        config.num_attention_heads = num_attention_heads
        config.max_position_embeddings = max_position_embeddings
        config.num_hidden_layers = 12
        return MahjongPretrainingModel(config)

    model = None
    if enable_model_name == 'discard':
        discard_config = BertConfig()
        discard_config.vocab_size = vocab_size
        discard_config.hidden_size = hidden_size
        discard_config.num_attention_heads = num_attention_heads
        discard_config.max_position_embeddings = max_position_embeddings
        discard_config.num_hidden_layers = 12
        # discard_config.intermediate_size = intermediate_size
        # discard_config.num_hidden_layers = 24
        # discard_config.num_hidden_layers = 12
        model = MahjongDiscardModel(discard_config)
    elif enable_model_name == 'reach':
        reach_config = BertConfig()
        reach_config.vocab_size = vocab_size
        reach_config.hidden_size = hidden_size
        reach_config.num_attention_heads = num_attention_heads
        reach_config.max_position_embeddings = max_position_embeddings
        reach_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(reach_config)
    elif enable_model_name == 'chow':
        chow_config = BertConfig()
        chow_config.vocab_size = vocab_size
        chow_config.hidden_size = hidden_size
        chow_config.num_attention_heads = num_attention_heads
        chow_config.max_position_embeddings = max_position_embeddings
        chow_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(chow_config)
    elif enable_model_name == 'pong':
        pong_config = BertConfig()
        pong_config.vocab_size = vocab_size
        pong_config.hidden_size = hidden_size
        pong_config.num_attention_heads = num_attention_heads
        pong_config.max_position_embeddings = max_position_embeddings
        pong_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(pong_config)
    elif enable_model_name == 'kong':
        kong_config = BertConfig()
        kong_config.vocab_size = vocab_size
        kong_config.hidden_size = hidden_size
        kong_config.num_attention_heads = num_attention_heads
        kong_config.max_position_embeddings = max_position_embeddings
        kong_config.num_hidden_layers = 24
        model = MahjongReachChowPongKongModel(kong_config)

    if pretrained_path != '':
        checkpoint = torch.load(pretrained_path,
                                map_location=catalyst.utils.get_device())
        # print(checkpoint['model_state_dict'].keys())
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    return model
Esempio n. 7
0
    output_base_dir.mkdir(exist_ok=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info("device: {}".format(device))

    df_train = pd.read_csv(TRAIN_PATH)
    df_train.columns = ["doc_id", "sents"]
    sents_list = df_train["sents"].values.tolist()
    logger.info("len(sents_list): {}".format(len(sents_list)))

    config = BertConfig()
    config.num_hidden_layers = 3
    config.num_attention_heads = 12
    config.hidden_size = 768
    config.intermediate_size = 3072
    config.max_position_embeddings = 512
    config.vocab_size = 32000

    logger.info("USE_NSP: {}".format(USE_NSP))
    if USE_NSP:
        model = BertForPreTraining(config)
    else:
        model = BertForPreTrainingWithoutNSP(config)
    model.to(device)

    logger.info(config)
    logger.info(model)

    optimizer = AdamW(model.parameters(), lr=2e-5)
    model.train()
    train_losses = []