def set_model_config(args, tokenizer): sentence_config = BertConfig() sentence_config.vocab_size = tokenizer.get_vocab_size() sentence_config.num_hidden_layers = args.num_layers1 sentence_config.hidden_size = args.hidden_size1 sentence_config.num_attention_heads = args.attention_heads1 sentence_config.max_position_embeddings = args.block_length document_config = BertConfig() document_config.vocab_size = tokenizer.get_vocab_size() document_config.num_hidden_layers = args.num_layers2 document_config.hidden_size = args.hidden_size2 document_config.num_attention_heads = args.attention_heads2 document_config.num_masked_blocks = args.max_blocks document_config.max_position_embeddings = args.max_blocks return sentence_config, document_config
def create_model(model_class: BertPreTrainedModel, encoder_config: BertConfig, tokenizer: BertTokenizer, encoder_path=None, entity_types: dict = None, relation_types: dict = None, prop_drop: float = 0.1, meta_embedding_size: int = 25, size_embeddings_count: int = 10, ed_embeddings_count: int = 300, token_dist_embeddings_count: int = 700, sentence_dist_embeddings_count: int = 50, mention_threshold: float = 0.5, coref_threshold: float = 0.5, rel_threshold: float = 0.5, position_embeddings_count: int = 700, cache_path=None): params = dict( config=encoder_config, # JEREX model parameters cls_token=tokenizer.convert_tokens_to_ids('[CLS]'), entity_types=len(entity_types), relation_types=len(relation_types), prop_drop=prop_drop, meta_embedding_size=meta_embedding_size, size_embeddings_count=size_embeddings_count, ed_embeddings_count=ed_embeddings_count, token_dist_embeddings_count=token_dist_embeddings_count, sentence_dist_embeddings_count=sentence_dist_embeddings_count, mention_threshold=mention_threshold, coref_threshold=coref_threshold, rel_threshold=rel_threshold, tokenizer=tokenizer, cache_dir=cache_path, ) if encoder_path is not None: model = model_class.from_pretrained(encoder_path, **params) else: model = model_class(**params) # conditionally increase position embedding count if encoder_config.max_position_embeddings < position_embeddings_count: old = model.bert.embeddings.position_embeddings new = nn.Embedding(position_embeddings_count, encoder_config.hidden_size) new.weight.data[:encoder_config. max_position_embeddings, :] = old.weight.data model.bert.embeddings.position_embeddings = new model.bert.embeddings.register_buffer( "position_ids", torch.arange(position_embeddings_count).expand((1, -1))) encoder_config.max_position_embeddings = position_embeddings_count return model
def build(self): # to be further set # breakpoint() self.image_feature_module = build_image_encoder( self.config.image_feature_processor, direct_features=True ) if self.config.concate_trace: self.trace_feature_module = build_encoder(self.config.trace_feature_encoder) if self.config.base_model_name == "bert-base-uncased": self.encoderdecoder = EncoderDecoderModel.from_encoder_decoder_pretrained( "bert-base-uncased", "bert-base-uncased" ) elif self.config.base_model_name == "2layer-base": config_encoder = BertConfig() config_decoder = BertConfig() config_encoder.max_position_embeddings = 1090 config_encoder.num_hidden_layers = 2 config_decoder.num_hidden_layers = 2 self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs( config_encoder, config_decoder ) self.encoderdecoder = EncoderDecoderModel(config=self.codec_config) elif self.config.base_model_name == "3layer-base": config_encoder = BertConfig() config_decoder = BertConfig() config_encoder.num_hidden_layers = 3 config_decoder.num_hidden_layers = 3 self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs( config_encoder, config_decoder ) self.encoderdecoder = EncoderDecoderModel(config=self.codec_config) if self.config.loop_contrastive: self.trace_caption_contrastive = TraceCaptionContrastiveModel( self.config.tc_contrastive_aggregate_method ) if ( hasattr(self.config, "pretrans_attention") and self.config.pretrans_attention ): # import ipdb; ipdb.set_trace() tempconf = self.encoderdecoder.config.encoder num_heads = tempconf.num_attention_heads num_layers = tempconf.num_hidden_layers self.attention_trans = AttentionTransform(num_layers, num_heads, 100) self.BOS_ID = 101 self.vae = OpenAIDiscreteVAE() image_code_dim = 768 image_fmap_size = self.vae.image_size // (2 ** self.vae.num_layers) self.image_seq_len = image_fmap_size ** 2 self.image_emb = torch.nn.Embedding(self.vae.num_tokens, image_code_dim) self.image_pos_emb = AxialPositionalEmbedding( image_code_dim, axial_shape=(image_fmap_size, image_fmap_size) )
pretrain = True sentence_block_length = 32 max_sentence_blocks = 48 hidden_size = 256 batch_size = 4 shuffle = True drop_last = True sentence_block_vector = torch.normal(mean=0.0, std=1.0, size=[hidden_size]) sentence_config = BertConfig() sentence_config.vocab_size = tokenizer.get_vocab_size() sentence_config.num_hidden_layers = 6 sentence_config.hidden_size = 256 sentence_config.num_attention_heads = 4 sentence_config.max_position_embeddings = sentence_block_length # sentence_block_length document_config = BertConfig() document_config.vocab_size = tokenizer.get_vocab_size() document_config.num_hidden_layers = 3 document_config.hidden_size = 256 document_config.num_attention_heads = 4 document_config.max_position_embeddings = max_sentence_blocks # sentence_block_length dataset = Dataset(file_path, tokenizer, sentence_block_length, max_sentence_blocks, mask=True) dataloader = DataLoader(dataset, batch_size=batch_size,
def main(): parser = ArgumentParser() parser.add_argument('--config_file', type=str, required=True) args = parser.parse_args() # settings config_path = Path(args.config_file) config = Config.load(config_path) warnings.filterwarnings('ignore') set_seed(config.seed) start_time = time.time() with timer('load data'): DATA_DIR = './input/riiid-test-answer-prediction/' usecols = [ 'row_id', 'timestamp', 'user_id', 'content_id', 'content_type_id', 'answered_correctly', 'prior_question_elapsed_time', ] dtype = { 'row_id': 'int64', 'timestamp': 'int64', 'user_id': 'int32', 'content_id': 'int16', 'content_type_id': 'int8', 'answered_correctly': 'int8', 'prior_question_elapsed_time': 'float32' } train_df = pd.read_csv(DATA_DIR + 'train.csv', usecols=usecols, dtype=dtype) question_df = pd.read_csv(DATA_DIR + 'questions.csv', usecols=['question_id', 'part']) train_df = train_df[train_df['content_type_id'] == 0].reset_index( drop=True) question_df['part'] += 1 # 0: padding id, 1: start id train_df['content_id'] += 2 # 0: padding id, 1: start id question_df['question_id'] += 2 train_df = train_df.merge(question_df, how='left', left_on='content_id', right_on='question_id') with timer('validation split'): train_idx, valid_idx, epoch_valid_idx = virtual_time_split( train_df, valid_size=config.valid_size, epoch_valid_size=config.epoch_valid_size) valid_y = train_df.iloc[valid_idx]['answered_correctly'].values epoch_valid_y = train_df.iloc[epoch_valid_idx][ 'answered_correctly'].values print('-' * 20) print(f'train size: {len(train_idx)}') print(f'valid size: {len(valid_idx)}') with timer('prepare data loader'): train_user_seqs = get_user_sequences(train_df.iloc[train_idx]) valid_user_seqs = get_user_sequences(train_df.iloc[valid_idx]) train_dataset = TrainDataset(train_user_seqs, window_size=config.window_size, stride_size=config.stride_size) valid_dataset = ValidDataset(train_df, train_user_seqs, valid_user_seqs, valid_idx, window_size=config.window_size) train_loader = DataLoader(train_dataset, **config.train_loader_params) valid_loader = DataLoader(valid_dataset, **config.valid_loader_params) # valid loader for epoch validation epoch_valid_user_seqs = get_user_sequences( train_df.iloc[epoch_valid_idx]) epoch_valid_dataset = ValidDataset(train_df, train_user_seqs, epoch_valid_user_seqs, epoch_valid_idx, window_size=config.window_size) epoch_valid_loader = DataLoader(epoch_valid_dataset, **config.valid_loader_params) with timer('train'): if config.model == 'akt': content_encoder_config = BertConfig( **config.content_encoder_config) knowledge_encoder_config = BertConfig( **config.knowledge_encoder_config) decoder_config = BertConfig(**config.decoder_config) content_encoder_config.max_position_embeddings = config.window_size + 1 knowledge_encoder_config.max_position_embeddings = config.window_size decoder_config.max_position_embeddings = config.window_size + 1 model = AktEncoderDecoderModel(content_encoder_config, knowledge_encoder_config, decoder_config) elif config.model == 'saint': encoder_config = BertConfig(**config.encoder_config) decoder_config = BertConfig(**config.decoder_config) encoder_config.max_position_embeddings = config.window_size decoder_config.max_position_embeddings = config.window_size model = SaintEncoderDecoderModel(encoder_config, decoder_config) else: raise ValueError(f'Unknown model: {config.model}') model.to(config.device) model.zero_grad() optimizer = optim.Adam(model.parameters(), **config.optimizer_params) scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps) loss_ema = None for epoch in range(config.n_epochs): epoch_start_time = time.time() model.train() progress = tqdm(train_loader, desc=f'epoch {epoch + 1}', leave=False) for i, (x_batch, w_batch, y_batch) in enumerate(progress): y_pred = model(**x_batch.to(config.device).to_dict()) loss = nn.BCEWithLogitsLoss(weight=w_batch.to(config.device))( y_pred, y_batch.to(config.device)) loss.backward() if (config.gradient_accumulation_steps is None or (i + 1) % config.gradient_accumulation_steps == 0): optimizer.step() optimizer.zero_grad() scheduler.step() loss_ema = loss_ema * 0.9 + loss.item( ) * 0.1 if loss_ema is not None else loss.item() progress.set_postfix(loss=loss_ema) valid_preds = predict(model, epoch_valid_loader, device=config.device) valid_score = roc_auc_score(epoch_valid_y, valid_preds) elapsed_time = time.time() - epoch_start_time print( f'Epoch {epoch + 1}/{config.n_epochs} \t valid score: {valid_score:.5f} \t time: {elapsed_time / 60:.1f} min' ) with timer('predict'): valid_preds = predict(model, valid_loader, device=config.device) valid_score = roc_auc_score(valid_y, valid_preds) print(f'valid score: {valid_score:.5f}') output_dir = Path(f'./output/{config_path.stem}/') output_dir.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), output_dir / 'model.pt') torch.save(optimizer.state_dict(), output_dir / 'optimizer.pt') elapsed_time = time.time() - start_time print(f'all processes done in {elapsed_time / 60:.1f} min.')
def get_model(enable_model_name, is_pretraining, pretrained_path): # tile(37), menzen(2), reach_state(2), n_reach(3), # reach_ippatsu(2), dans(21), rates(19), oya(4), # scores(13), n_honba(3), n_round(12), sanma_or_yonma(2), # han_or_ton(2), aka_ari(2), kui_ari(2), special_token(4) # vocab_size = 37 + 2 + 2 + 3 + 2 + 21 + 19 + 4 + 13 + 3 + 12 + 2 + 2 + 2 + 2 + 4 + 2 + 4 + 6 + 8 # 130 + shanten_diff(2) + who(4) + sum_discards(6) + shanten(8) vocab_size = 37 + 2 + 2 + 3 + 2 + 21 + 19 + 4 + 13 + 3 + 12 + 2 + 2 + 2 + 2 + 4 + 4 + 6 + 8 # 130 + who(4) + sum_discards(6) + shanten(8) # hidden_size = 1024 # num_attention_heads = 16 hidden_size = 768 num_attention_heads = 12 max_position_embeddings = 239 # base + who(1) + sum_discards(1) + shanten(1) # intermediate_size = 64 # intermediate_size = 3072 # max_position_embeddings = 239 # base + pad(1) + who(1) + pad(1) + sum_discards(1) + pad(1) + shanten(1) # max_position_embeddings = 281 # 260 + pad(1) + shanten_diff(14) + pad(1) + who(1) + pad(1) + sum_discards(1) + pad(1) + shanten(1) if is_pretraining: config = BertConfig() config.vocab_size = vocab_size config.hidden_size = hidden_size config.num_attention_heads = num_attention_heads config.max_position_embeddings = max_position_embeddings config.num_hidden_layers = 12 return MahjongPretrainingModel(config) model = None if enable_model_name == 'discard': discard_config = BertConfig() discard_config.vocab_size = vocab_size discard_config.hidden_size = hidden_size discard_config.num_attention_heads = num_attention_heads discard_config.max_position_embeddings = max_position_embeddings discard_config.num_hidden_layers = 12 # discard_config.intermediate_size = intermediate_size # discard_config.num_hidden_layers = 24 # discard_config.num_hidden_layers = 12 model = MahjongDiscardModel(discard_config) elif enable_model_name == 'reach': reach_config = BertConfig() reach_config.vocab_size = vocab_size reach_config.hidden_size = hidden_size reach_config.num_attention_heads = num_attention_heads reach_config.max_position_embeddings = max_position_embeddings reach_config.num_hidden_layers = 24 model = MahjongReachChowPongKongModel(reach_config) elif enable_model_name == 'chow': chow_config = BertConfig() chow_config.vocab_size = vocab_size chow_config.hidden_size = hidden_size chow_config.num_attention_heads = num_attention_heads chow_config.max_position_embeddings = max_position_embeddings chow_config.num_hidden_layers = 24 model = MahjongReachChowPongKongModel(chow_config) elif enable_model_name == 'pong': pong_config = BertConfig() pong_config.vocab_size = vocab_size pong_config.hidden_size = hidden_size pong_config.num_attention_heads = num_attention_heads pong_config.max_position_embeddings = max_position_embeddings pong_config.num_hidden_layers = 24 model = MahjongReachChowPongKongModel(pong_config) elif enable_model_name == 'kong': kong_config = BertConfig() kong_config.vocab_size = vocab_size kong_config.hidden_size = hidden_size kong_config.num_attention_heads = num_attention_heads kong_config.max_position_embeddings = max_position_embeddings kong_config.num_hidden_layers = 24 model = MahjongReachChowPongKongModel(kong_config) if pretrained_path != '': checkpoint = torch.load(pretrained_path, map_location=catalyst.utils.get_device()) # print(checkpoint['model_state_dict'].keys()) model.load_state_dict(checkpoint['model_state_dict'], strict=False) return model
output_base_dir.mkdir(exist_ok=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") logger.info("device: {}".format(device)) df_train = pd.read_csv(TRAIN_PATH) df_train.columns = ["doc_id", "sents"] sents_list = df_train["sents"].values.tolist() logger.info("len(sents_list): {}".format(len(sents_list))) config = BertConfig() config.num_hidden_layers = 3 config.num_attention_heads = 12 config.hidden_size = 768 config.intermediate_size = 3072 config.max_position_embeddings = 512 config.vocab_size = 32000 logger.info("USE_NSP: {}".format(USE_NSP)) if USE_NSP: model = BertForPreTraining(config) else: model = BertForPreTrainingWithoutNSP(config) model.to(device) logger.info(config) logger.info(model) optimizer = AdamW(model.parameters(), lr=2e-5) model.train() train_losses = []