Ejemplo n.º 1
0
def get_autoencoder_config(config: configure_pretraining.PretrainingConfig,
                           bert_config: modeling.BertConfig):
    """Get model config for the autoencoder network."""
    ae_config = modeling.BertConfig.from_dict(bert_config.to_dict())
    ae_config.hidden_size = int(
        round(bert_config.hidden_size * config.autoencoder_hidden_size))
    ae_config.num_hidden_layers = int(
        round(bert_config.num_hidden_layers * config.autoencoder_layers))
    ae_config.intermediate_size = 4 * ae_config.hidden_size
    ae_config.num_attention_heads = max(1, ae_config.hidden_size // 64)
    return ae_config
Ejemplo n.º 2
0
def get_generator_config(config: configure_pretraining.PretrainingConfig,
                         bert_config: modeling.BertConfig):
    """Get model config for the generator network."""
    gen_config = modeling.BertConfig.from_dict(bert_config.to_dict())
    gen_config.hidden_size = int(
        round(bert_config.hidden_size * config.generator_hidden_size))
    gen_config.num_hidden_layers = int(
        round(bert_config.num_hidden_layers * config.generator_layers))
    gen_config.intermediate_size = 4 * gen_config.hidden_size
    gen_config.num_attention_heads = max(1, gen_config.hidden_size // 64)
    return gen_config
Ejemplo n.º 3
0
def main():

    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--input_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain .hdf5 files  for the task.")

    parser.add_argument(
        "--bert_model",
        default="bert-large-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument("--config_file",
                        default=None,
                        type=str,
                        help="The BERT model config")
    parser.add_argument("--ckpt", default="", type=str)
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--max_predictions_per_seq",
        default=80,
        type=int,
        help="The maximum total of masked tokens in input sequence")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps",
                        default=1000,
                        type=float,
                        help="Total number of training steps to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.01,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0.0,
        help=
        'Loss scaling, positive power of 2 values can improve fp16 convergence.'
    )
    parser.add_argument('--log_freq',
                        type=float,
                        default=500,
                        help='frequency of logging loss.')
    parser.add_argument('--checkpoint_activations',
                        default=False,
                        action='store_true',
                        help="Whether to use gradient checkpointing")
    parser.add_argument("--resume_from_checkpoint",
                        default=False,
                        action='store_true',
                        help="Whether to resume training from checkpoint.")
    parser.add_argument('--resume_step',
                        type=int,
                        default=-1,
                        help="Step to resume training from.")
    parser.add_argument(
        '--num_steps_per_checkpoint',
        type=int,
        default=2000,
        help="Number of update steps until a model checkpoint is saved to disk."
    )
    parser.add_argument('--dev_data_file', type=str, default="dev/dev.hdf5")
    parser.add_argument('--dev_batch_size', type=int, default=16)
    parser.add_argument("--save_total_limit", type=int, default=10)

    args = parser.parse_args()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    min_dev_loss = 1000000
    best_step = 0

    assert (torch.cuda.is_available())
    print(args.local_rank)
    if args.local_rank == -1:
        device = torch.device("cuda")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    if args.train_batch_size % args.gradient_accumulation_steps != 0:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible"
            .format(args.gradient_accumulation_steps, args.train_batch_size))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    if not args.resume_from_checkpoint and os.path.exists(
            args.output_dir) and (os.listdir(args.output_dir) and os.listdir(
                args.output_dir) != ['logfile.txt']):
        logger.warning(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
        # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))

    if not args.resume_from_checkpoint:
        os.makedirs(args.output_dir, exist_ok=True)

    # Prepare model
    if args.config_file:
        config = BertConfig.from_json_file(args.config_file)

    if args.bert_model:
        model = BertForMaskedLM.from_pretrained(args.bert_model)
    else:
        model = BertForMaskedLM(config)

    print(args.ckpt)
    if args.ckpt:
        print("load from", args.ckpt)
        ckpt = torch.load(args.ckpt, map_location='cpu')
        if model in ckpt:
            ckpt = ckpt['model']
        model.load_state_dict(ckpt, strict=False)

    pretrained_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    torch.save(model.state_dict(), pretrained_model_file)

    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if args.resume_step == -1:
            model_names = [
                f for f in os.listdir(args.output_dir) if f.endswith(".pt")
            ]
            args.resume_step = max([
                int(x.split('.pt')[0].split('_')[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step

        checkpoint = torch.load(os.path.join(args.output_dir,
                                             "ckpt_{}.pt".format(global_step)),
                                map_location="cpu")
        model.load_state_dict(checkpoint['model'], strict=False)

        print("resume step from ", args.resume_step)

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:
        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              weight_decay=0.01)

        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              keep_batchnorm_fp32=False,
                                              loss_scale="dynamic")
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O2",
                                              keep_batchnorm_fp32=False,
                                              loss_scale=args.loss_scale)

        scheduler = LinearWarmUpScheduler(optimizer,
                                          warmup=args.warmup_proportion,
                                          total_steps=args.max_steps)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=args.max_steps)

    if args.resume_from_checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])

    if args.local_rank != -1:
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    files = [
        os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
        if os.path.isfile(os.path.join(args.input_dir, f))
    ]
    files.sort()

    num_files = len(files)

    logger.info("***** Loading Dev Data *****")
    dev_data = pretraining_dataset(
        input_file=os.path.join(args.input_dir, args.dev_data_file),
        max_pred_length=args.max_predictions_per_seq)
    if args.local_rank == -1:
        dev_sampler = RandomSampler(dev_data)
        dev_dataloader = DataLoader(dev_data,
                                    sampler=dev_sampler,
                                    batch_size=args.dev_batch_size * n_gpu,
                                    num_workers=4,
                                    pin_memory=True)
    else:
        dev_sampler = DistributedSampler(dev_data)
        dev_dataloader = DataLoader(dev_data,
                                    sampler=dev_sampler,
                                    batch_size=args.dev_batch_size,
                                    num_workers=4,
                                    pin_memory=True)

    logger.info("***** Running training *****")
    logger.info("  Batch size = {}".format(args.train_batch_size))
    logger.info("  LR = {}".format(args.learning_rate))

    model.train()
    logger.info(" Training. . .")

    most_recent_ckpts_paths = []

    tr_loss = 0.0  # total added training loss
    average_loss = 0.0  # averaged loss every args.log_freq steps
    epoch = 0
    training_steps = 0
    while True:
        if not args.resume_from_checkpoint:
            random.shuffle(files)
            f_start_id = 0
        else:
            f_start_id = checkpoint['files'][0]
            files = checkpoint['files'][1:]
            args.resume_from_checkpoint = False
        for f_id in range(f_start_id, len(files)):
            data_file = files[f_id]
            logger.info("file no {} file {}".format(f_id, data_file))
            train_data = pretraining_dataset(
                input_file=data_file,
                max_pred_length=args.max_predictions_per_seq)

            if args.local_rank == -1:
                train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(
                    train_data,
                    sampler=train_sampler,
                    batch_size=args.train_batch_size * n_gpu,
                    num_workers=4,
                    pin_memory=True)
            else:
                train_sampler = DistributedSampler(train_data)
                train_dataloader = DataLoader(train_data,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size,
                                              num_workers=4,
                                              pin_memory=True)

            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="File Iteration")):
                model.train()
                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch  #\
                loss = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    attention_mask=input_mask,
                    masked_lm_labels=masked_lm_labels,
                    checkpoint_activations=args.checkpoint_activations)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                tr_loss += loss.item()
                average_loss += loss.item()

                if training_steps % args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scheduler.step()
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if training_steps == 1 * args.gradient_accumulation_steps:
                    logger.info(
                        "Global Step:{} Average Loss = {} Step Loss = {} LR {}"
                        .format(global_step, average_loss, loss.item(),
                                optimizer.param_groups[0]['lr']))

                if training_steps % (args.log_freq *
                                     args.gradient_accumulation_steps) == 0:
                    logger.info(
                        "Global Step:{} Average Loss = {} Step Loss = {} LR {}"
                        .format(global_step, average_loss / args.log_freq,
                                loss.item(), optimizer.param_groups[0]['lr']))
                    average_loss = 0

                if training_steps % (args.num_steps_per_checkpoint *
                                     args.gradient_accumulation_steps) == 0:
                    logger.info("Begin Eval")
                    model.eval()
                    with torch.no_grad():
                        dev_global_step = 0
                        dev_final_loss = 0.0
                        for dev_step, dev_batch in enumerate(
                                tqdm(dev_dataloader, desc="Evaluating")):
                            batch = [t.to(device) for t in batch]
                            dev_input_ids, dev_segment_ids, dev_input_mask, dev_masked_lm_labels, dev_next_sentence_labels = batch
                            loss = model(input_ids=dev_input_ids,
                                         token_type_ids=dev_segment_ids,
                                         attention_mask=dev_input_mask,
                                         masked_lm_labels=dev_masked_lm_labels)
                            dev_final_loss += loss
                            dev_global_step += 1
                        dev_final_loss /= dev_global_step
                        if (torch.distributed.is_initialized()):
                            dev_final_loss /= torch.distributed.get_world_size(
                            )
                            torch.distributed.all_reduce(dev_final_loss)
                        logger.info("Dev Loss: {}".format(
                            dev_final_loss.item()))
                        if dev_final_loss < min_dev_loss:
                            best_step = global_step
                            min_dev_loss = dev_final_loss
                            if (not torch.distributed.is_initialized() or
                                (torch.distributed.is_initialized()
                                 and torch.distributed.get_rank() == 0)):
                                logger.info(
                                    "** ** * Saving best dev loss model ** ** * at step {}"
                                    .format(best_step))
                                dev_model_to_save = model.module if hasattr(
                                    model, 'module') else model
                                output_save_file = os.path.join(
                                    args.output_dir, "best_ckpt.pt")
                                torch.save(
                                    {
                                        'model':
                                        dev_model_to_save.state_dict(),
                                        'optimizer': optimizer.state_dict(),
                                        'files': [f_id] + files
                                    }, output_save_file)

                    if (not torch.distributed.is_initialized()
                            or (torch.distributed.is_initialized()
                                and torch.distributed.get_rank() == 0)):
                        # Save a trained model
                        logger.info(
                            "** ** * Saving fine - tuned model ** ** * ")
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_save_file = os.path.join(
                            args.output_dir, "ckpt_{}.pt".format(global_step))

                        torch.save(
                            {
                                'model': model_to_save.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'files': [f_id] + files
                            }, output_save_file)

                        most_recent_ckpts_paths.append(output_save_file)
                        if len(most_recent_ckpts_paths
                               ) > args.save_total_limit:
                            ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
                            os.remove(ckpt_to_be_removed)

                    if global_step >= args.max_steps:
                        tr_loss = tr_loss * args.gradient_accumulation_steps / training_steps
                        if (torch.distributed.is_initialized()):
                            tr_loss /= torch.distributed.get_world_size()
                            print(tr_loss)
                            torch.distributed.all_reduce(
                                torch.tensor(tr_loss).cuda())
                        logger.info("Total Steps:{} Final Loss = {}".format(
                            training_steps, tr_loss))

                        with open(
                                os.path.join(args.output_dir,
                                             "valid_results.txt"), "w") as f:
                            f.write("Min dev loss: {}\nBest step: {}\n".format(
                                min_dev_loss, best_step))

                        return
            del train_dataloader
            del train_sampler
            del train_data

            torch.cuda.empty_cache()
        epoch += 1
def main():
    args = parser.parse_args()

    if os.path.isfile(args.model + '/hparams.json'):
        with open(args.model + '/hparams.json') as f:
            bert_config_params = json.load(f)
    else:
        raise ValueError('invalid model name.')

    if not (len(args.input_file) > 0 or len(args.context) > 0):
        raise ValueError('--input_file or --context required.')
    if (not os.path.isfile(args.input_file)) and len(args.context) == 0:
        raise ValueError('invalid input file name.')
    if len(args.input_file) > 0 and os.path.isfile(args.input_file):
        with open(args.input_file) as f:
            args.context = f.read()

    vocab_size = bert_config_params['vocab_size']
    max_seq_length = bert_config_params['max_position_embeddings']
    batch_size = 1
    EOT_TOKEN = vocab_size - 4
    MASK_TOKEN = vocab_size - 3
    CLS_TOKEN = vocab_size - 2
    SEP_TOKEN = vocab_size - 1

    with open('ja-bpe.txt', encoding='utf-8') as f:
        bpe = f.read().split('\n')

    with open('emoji.json', encoding='utf-8') as f:
        emoji = json.loads(f.read())

    enc = BPEEncoder_ja(bpe, emoji)

    bert_config = BertConfig(**bert_config_params)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = args.gpu

    with tf.Session(config=config) as sess:
        input_ids = tf.placeholder(tf.int32, [None, None])
        input_mask = tf.placeholder(tf.int32, [None, None])
        segment_ids = tf.placeholder(tf.int32, [None, None])
        masked_lm_positions = tf.placeholder(tf.int32, [None, None])
        masked_lm_ids = tf.placeholder(tf.int32, [None, None])
        masked_lm_weights = tf.placeholder(tf.float32, [None, None])
        next_sentence_labels = tf.placeholder(tf.int32, [None])

        model = BertModel(config=bert_config,
                          is_training=False,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          token_type_ids=segment_ids,
                          use_one_hot_embeddings=False)

        output = model.get_sequence_output()
        (_, _, _) = get_masked_lm_output(bert_config,
                                         model.get_sequence_output(),
                                         model.get_embedding_table(),
                                         masked_lm_positions, masked_lm_ids,
                                         masked_lm_weights)
        (_, _, _) = get_next_sentence_output(bert_config,
                                             model.get_pooled_output(),
                                             next_sentence_labels)

        saver = tf.train.Saver()

        masked_lm_values = tf.placeholder(tf.float32, [None, None])

        with tf.variable_scope("loss"):
            (_, outputs) = get_masked_regression_output(
                bert_config, model.get_sequence_output(), masked_lm_positions,
                masked_lm_values, masked_lm_weights)

            saver = tf.train.Saver(var_list=tf.trainable_variables())
            ckpt = tf.train.latest_checkpoint(args.model)
            saver.restore(sess, ckpt)

            _input_ids = []
            _lm_positions = []
            tokens = [enc.encode(p.strip()) for p in sep_txt(args.context)]
            tokens = [t for t in tokens if len(t) > 0]
            for t in tokens:
                _lm_positions.append(len(_input_ids))
                _input_ids.extend([CLS_TOKEN] + t)
            _input_ids.append(EOT_TOKEN)
            _input_masks = [1] * len(_input_ids)
            _segments = [1] * len(_input_ids)
            _input_ids = _input_ids[:max_seq_length]
            _input_masks = _input_masks[:max_seq_length]
            _segments = _segments[:max_seq_length]
            while len(_segments) < max_seq_length:
                _input_ids.append(0)
                _input_masks.append(0)
                _segments.append(0)
            _lm_positions = [p for p in _lm_positions if p < max_seq_length]
            _lm_positions = _lm_positions[:max_seq_length]
            _lm_lm_weights = [1] * len(_lm_positions)
            while len(_lm_positions) < max_seq_length:
                _lm_positions.append(0)
                _lm_lm_weights.append(0)
            _lm_ids = [0] * len(_lm_positions)
            _lm_vals = [0] * len(_lm_positions)

            regress = sess.run(outputs,
                               feed_dict={
                                   input_ids: [_input_ids],
                                   input_mask: [_input_masks],
                                   segment_ids: [_segments],
                                   masked_lm_positions: [_lm_positions],
                                   masked_lm_ids: [_lm_ids],
                                   masked_lm_weights: [_lm_lm_weights],
                                   next_sentence_labels: [0],
                                   masked_lm_values: [_lm_vals]
                               })
            regress = regress.reshape((-1, ))
            if args.output_file == '':
                for tok, value in zip(tokens, regress):
                    print(f'{value}\t{enc.decode(tok)}')
            else:
                sent = []
                impt = []
                for tok, value in zip(tokens, regress):
                    sent.append(enc.decode(tok))
                    impt.append(value)
                df = pd.DataFrame({'sentence': sent, 'importance': impt})
                df.to_csv(args.output_file, index=False)
Ejemplo n.º 5
0
from model.modeling import BertConfig, BertModel
from run_finetune import get_masked_lm_output,get_next_sentence_output

from encode_bpe import BPEEncoder_ja

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='RoBERTa-ja_small')
parser.add_argument('--context', type=str, required=True)
parser.add_argument('--split_tag', type=str, default='')
parser.add_argument('--gpu', default='0', help='visible gpu number.')
parser.add_argument('--output_max', default=False, action='store_true')
args = parser.parse_args()

with open(args.model+'/hparams.json') as f:
    bert_config_params = json.load(f)
bert_config = BertConfig(**bert_config_params)
vocab_size = bert_config_params['vocab_size']
max_seq_length = bert_config_params['max_position_embeddings']
EOT_TOKEN = vocab_size - 4
MASK_TOKEN = vocab_size - 3
CLS_TOKEN = vocab_size - 2
SEP_TOKEN = vocab_size - 1

config = tf.ConfigProto()
config.gpu_options.visible_device_list = args.gpu

with tf.Session(config=config,graph=tf.Graph()) as sess:
    input_ids = tf.placeholder(tf.int32, [None, None])
    input_mask = tf.placeholder(tf.int32, [None, None])
    segment_ids = tf.placeholder(tf.int32, [None, None])
    masked_lm_positions = tf.placeholder(tf.int32, [None, None])
Ejemplo n.º 6
0
def main():
    global EOT_TOKEN, MASK_TOKEN, CLS_TOKEN, SEP_TOKEN, enc
    args = parser.parse_args()

    if os.path.isfile(args.model + '/hparams.json'):
        with open(args.model + '/hparams.json') as f:
            bert_config_params = json.load(f)
    else:
        raise ValueError('invalid model name.')

    vocab_size = bert_config_params['vocab_size']
    max_seq_length = bert_config_params['max_position_embeddings']
    batch_size = args.batch_size
    save_every = args.save_every
    num_epochs = args.num_epochs
    EOT_TOKEN = vocab_size - 4
    MASK_TOKEN = vocab_size - 3
    CLS_TOKEN = vocab_size - 2
    SEP_TOKEN = vocab_size - 1

    with open('ja-bpe.txt', encoding='utf-8') as f:
        bpe = f.read().split('\n')

    with open('emoji.json', encoding='utf-8') as f:
        emoji = json.loads(f.read())

    enc = BPEEncoder_ja(bpe, emoji)

    fl = [f'{args.input_dir}/{f}' for f in os.listdir(args.input_dir)]
    with Pool(args.num_encode_process) as pool:
        imap = pool.imap(encode_one, fl)
        input_contexts = list(tqdm(imap, total=len(fl)))
    input_indexs = np.random.permutation(len(input_contexts))

    if args.do_eval:
        eval_num = int(args.eval_rate * len(input_indexs))
        eval_input_indexs = input_indexs[:eval_num]
        input_indexs = input_indexs[eval_num:]

    bert_config = BertConfig(**bert_config_params)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = args.gpu

    with tf.Session(config=config) as sess:
        input_ids = tf.placeholder(tf.int32, [None, None])
        input_mask = tf.placeholder(tf.int32, [None, None])
        segment_ids = tf.placeholder(tf.int32, [None, None])
        masked_lm_positions = tf.placeholder(tf.int32, [None, None])
        masked_lm_ids = tf.placeholder(tf.int32, [None, None])
        masked_lm_weights = tf.placeholder(tf.float32, [None, None])
        next_sentence_labels = tf.placeholder(tf.int32, [None])

        model = BertModel(config=bert_config,
                          is_training=True,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          token_type_ids=segment_ids,
                          use_one_hot_embeddings=False)

        output = model.get_sequence_output()
        (_, _, _) = get_masked_lm_output(bert_config,
                                         model.get_sequence_output(),
                                         model.get_embedding_table(),
                                         masked_lm_positions, masked_lm_ids,
                                         masked_lm_weights)
        (_, _, _) = get_next_sentence_output(bert_config,
                                             model.get_pooled_output(),
                                             next_sentence_labels)

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(args.model)
        saver.restore(sess, ckpt)
        train_vars = tf.trainable_variables()
        restored_weights = {}
        for i in range(len(train_vars)):
            restored_weights[train_vars[i].name] = sess.run(train_vars[i])

        labels = tf.placeholder(tf.float32, [
            None,
        ])

        output_layer = model.get_pooled_output()

        if int(tf.__version__[0]) > 1:
            hidden_size = output_layer.shape[-1]
        else:
            hidden_size = output_layer.shape[-1].value

        masked_lm_values = tf.placeholder(tf.float32, [None, None])

        with tf.variable_scope("loss"):
            (loss, _) = get_masked_regression_output(
                bert_config, model.get_sequence_output(), masked_lm_positions,
                masked_lm_values, masked_lm_weights)

            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            train_vars = tf.trainable_variables()
            opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summaries = tf.summary.scalar('loss', loss)
            summary_log = tf.summary.FileWriter(
                os.path.join(CHECKPOINT_DIR, args.run_name))

            counter = 1
            counter_path = os.path.join(CHECKPOINT_DIR, args.run_name,
                                        'counter')
            if os.path.exists(counter_path):
                # Load the step number if we're resuming a run
                # Add 1 so we don't immediately try to save again
                with open(counter_path, 'r') as fp:
                    counter = int(fp.read()) + 1

            hparams_path = os.path.join(CHECKPOINT_DIR, args.run_name,
                                        'hparams.json')
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            with open(hparams_path, 'w') as fp:
                fp.write(json.dumps(bert_config_params))

            sess.run(tf.global_variables_initializer())  # init output_weights
            restored = 0
            for k, v in restored_weights.items():
                for i in range(len(train_vars)):
                    if train_vars[i].name == k:
                        assign_op = train_vars[i].assign(v)
                        sess.run(assign_op)
                        restored += 1
            assert restored == len(restored_weights), 'fail to restore model.'
            saver = tf.train.Saver(var_list=tf.trainable_variables())

            def save():
                maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
                print(
                    'Saving',
                    os.path.join(CHECKPOINT_DIR, args.run_name,
                                 'model-{}').format(counter))
                saver.save(sess,
                           os.path.join(CHECKPOINT_DIR, args.run_name,
                                        'model'),
                           global_step=counter)
                with open(counter_path, 'w') as fp:
                    fp.write(str(counter) + '\n')

            avg_loss = (0.0, 0.0)
            start_time = time.time()

            def sample_feature(i, eval=False):
                indexs = eval_input_indexs if eval else input_indexs
                last = min((i + 1) * batch_size, len(indexs))
                _input_ids = []
                _input_masks = []
                _segments = []
                _lm_positions = []
                _lm_vals = []
                _lm_lm_weights = []
                _lm_ids = []
                for j in range(i * batch_size, last, 1):
                    (lm_tokens, lm_positions,
                     lm_imprtances) = input_contexts[indexs[j]]
                    ids = copy(lm_tokens)[:max_seq_length]
                    seg = [1] * len(ids)
                    while len(ids) < max_seq_length:
                        ids.append(0)
                        seg.append(0)
                    _input_ids.append(ids)
                    _input_masks.append(seg)
                    _segments.append(seg)
                    pos = copy(lm_positions)[:max_seq_length]
                    val = copy(lm_imprtances)[:max_seq_length]
                    wei = [1] * len(pos)
                    while len(ids) < max_seq_length:
                        pos.append(0)
                        val.append(0)
                        wei.append(0)
                    _lm_positions.append(pos)
                    _lm_ids.append([0] * max_seq_length)
                    _lm_lm_weights.append(wei)
                    _lm_vals.append(val)

                return {
                    input_ids: _input_ids,
                    input_mask: _input_masks,
                    segment_ids: _segments,
                    masked_lm_positions: _lm_positions,
                    masked_lm_ids: _lm_ids,
                    masked_lm_weights: _lm_lm_weights,
                    next_sentence_labels: [0] * len(_input_ids),
                    masked_lm_values: _lm_vals
                }

            try:
                for ep in range(num_epochs):
                    if ep % args.save_every == 0:
                        save()

                    prog = tqdm(range(0, len(input_indexs) // batch_size, 1))
                    for i in prog:
                        (_, v_loss, v_summary) = sess.run(
                            (opt_apply, loss, summaries),
                            feed_dict=sample_feature(i))

                        summary_log.add_summary(v_summary, counter)

                        avg_loss = (avg_loss[0] * 0.99 + v_loss,
                                    avg_loss[1] * 0.99 + 1.0)

                        prog.set_description(
                            '[{ep} | {time:2.0f}] loss={loss:.4f} avg={avg:.4f}'
                            .format(ep=ep,
                                    time=time.time() - start_time,
                                    loss=v_loss,
                                    avg=avg_loss[0] / avg_loss[1]))

                        counter += 1

                    if args.do_eval:
                        eval_losses = []
                        for i in tqdm(
                                range(0,
                                      len(eval_input_indexs) // batch_size,
                                      1)):
                            eval_losses.append(
                                sess.run(loss,
                                         feed_dict=sample_feature(i, True)))
                        print("eval loss:", np.mean(eval_losses))

            except KeyboardInterrupt:
                print('interrupted')
                save()

            save()
Ejemplo n.º 7
0
def main():
    args = parser.parse_args()

    if os.path.isfile(args.model + '/hparams.json'):
        with open(args.model + '/hparams.json') as f:
            bert_config_params = json.load(f)
    else:
        raise ValueError('invalid model name.')

    vocab_size = bert_config_params['vocab_size']
    max_seq_length = bert_config_params['max_position_embeddings']
    batch_size = args.batch_size
    save_every = args.save_every
    num_epochs = args.num_epochs
    EOT_TOKEN = vocab_size - 4
    MASK_TOKEN = vocab_size - 3
    CLS_TOKEN = vocab_size - 2
    SEP_TOKEN = vocab_size - 1

    with open('ja-bpe.txt', encoding='utf-8') as f:
        bpe = f.read().split('\n')

    with open('emoji.json', encoding='utf-8') as f:
        emoji = json.loads(f.read())

    enc = BPEEncoder_ja(bpe, emoji)

    keys = [
        f for f in os.listdir(args.input_dir)
        if os.path.isdir(args.input_dir + '/' + f)
    ]
    keys = sorted(keys)
    num_labels = len(keys)
    input_contexts = []
    input_keys = []
    idmapping_dict = {}
    for i, f in enumerate(keys):
        n = 0
        for t in os.listdir(f'{args.input_dir}/{f}'):
            if os.path.isfile(f'{args.input_dir}/{f}/{t}'):
                with open(f'{args.input_dir}/{f}/{t}', encoding='utf-8') as fn:
                    if args.train_by_line:
                        for p in fn.readlines():
                            tokens = enc.encode(p.strip())[:max_seq_length - 2]
                            tokens = [CLS_TOKEN] + tokens + [SEP_TOKEN]
                            if len(tokens) < max_seq_length:
                                tokens.extend([0] *
                                              (max_seq_length - len(tokens)))
                            input_contexts.append(tokens)
                            input_keys.append(i)
                            n += 1
                    else:
                        p = fn.read()
                        tokens = enc.encode(p.strip())[:max_seq_length - 3]
                        tokens = [CLS_TOKEN] + tokens + [EOT_TOKEN, SEP_TOKEN]
                        if len(tokens) < max_seq_length:
                            tokens.extend([0] * (max_seq_length - len(tokens)))
                        input_contexts.append(tokens)
                        input_keys.append(i)
                        n += 1
        print(f'{args.input_dir}/{f} mapped for id_{i}, read {n} contexts.')
        idmapping_dict[f] = i
    input_indexs = np.random.permutation(len(input_contexts))

    bert_config = BertConfig(**bert_config_params)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = args.gpu

    with tf.Session(config=config) as sess:
        input_ids = tf.placeholder(tf.int32, [None, None])
        input_mask = tf.placeholder(tf.int32, [None, None])
        segment_ids = tf.placeholder(tf.int32, [None, None])
        masked_lm_positions = tf.placeholder(tf.int32, [None, None])
        masked_lm_ids = tf.placeholder(tf.int32, [None, None])
        masked_lm_weights = tf.placeholder(tf.float32, [None, None])
        next_sentence_labels = tf.placeholder(tf.int32, [None])

        model = BertModel(config=bert_config,
                          is_training=True,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          token_type_ids=segment_ids,
                          use_one_hot_embeddings=False)

        output = model.get_sequence_output()
        (_, _, _) = get_masked_lm_output(bert_config,
                                         model.get_sequence_output(),
                                         model.get_embedding_table(),
                                         masked_lm_positions, masked_lm_ids,
                                         masked_lm_weights)
        (_, _, _) = get_next_sentence_output(bert_config,
                                             model.get_pooled_output(),
                                             next_sentence_labels)

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(args.model)
        saver.restore(sess, ckpt)
        train_vars = tf.trainable_variables()
        restored_weights = {}
        for i in range(len(train_vars)):
            restored_weights[train_vars[i].name] = sess.run(train_vars[i])

        labels = tf.placeholder(tf.int32, [
            None,
        ])

        output_layer = model.get_pooled_output()

        if int(tf.__version__[0]) > 1:
            hidden_size = output_layer.shape[-1]
        else:
            hidden_size = output_layer.shape[-1].value

        output_weights = tf.get_variable(
            "output_weights", [num_labels, hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))

        output_bias = tf.get_variable("output_bias", [num_labels],
                                      initializer=tf.zeros_initializer())

        with tf.variable_scope("loss"):
            output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            probabilities = tf.nn.softmax(logits, axis=-1)
            log_probs = tf.nn.log_softmax(logits, axis=-1)

            one_hot_labels = tf.one_hot(labels,
                                        depth=num_labels,
                                        dtype=tf.float32)

            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            loss = tf.reduce_mean(per_example_loss)

            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
            train_vars = tf.trainable_variables()
            opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summaries = tf.summary.scalar('loss', loss)
            summary_log = tf.summary.FileWriter(
                os.path.join(CHECKPOINT_DIR, args.run_name))

            counter = 1
            counter_path = os.path.join(CHECKPOINT_DIR, args.run_name,
                                        'counter')
            if os.path.exists(counter_path):
                # Load the step number if we're resuming a run
                # Add 1 so we don't immediately try to save again
                with open(counter_path, 'r') as fp:
                    counter = int(fp.read()) + 1

            hparams_path = os.path.join(CHECKPOINT_DIR, args.run_name,
                                        'hparams.json')
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            with open(hparams_path, 'w') as fp:
                fp.write(json.dumps(bert_config_params))
            idmaps_path = os.path.join(CHECKPOINT_DIR, args.run_name,
                                       'idmaps.json')
            with open(idmaps_path, 'w') as fp:
                fp.write(json.dumps(idmapping_dict))

            sess.run(tf.global_variables_initializer())  # init output_weights
            restored = 0
            for k, v in restored_weights.items():
                for i in range(len(train_vars)):
                    if train_vars[i].name == k:
                        assign_op = train_vars[i].assign(v)
                        sess.run(assign_op)
                        restored += 1
            assert restored == len(restored_weights), 'fail to restore model.'
            saver = tf.train.Saver(var_list=tf.trainable_variables())

            def save():
                maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
                print(
                    'Saving',
                    os.path.join(CHECKPOINT_DIR, args.run_name,
                                 'model-{}').format(counter))
                saver.save(sess,
                           os.path.join(CHECKPOINT_DIR, args.run_name,
                                        'model'),
                           global_step=counter)
                with open(counter_path, 'w') as fp:
                    fp.write(str(counter) + '\n')

            avg_loss = (0.0, 0.0)
            start_time = time.time()

            def sample_feature(i):
                last = min((i + 1) * batch_size, len(input_indexs))
                _input_ids = [
                    input_contexts[idx]
                    for idx in input_indexs[i * batch_size:last]
                ]
                _input_masks = [[1] * len(input_contexts[idx]) + [0] *
                                (max_seq_length - len(input_contexts[idx]))
                                for idx in input_indexs[i * batch_size:last]]
                _segments = [[1] * len(input_contexts[idx]) + [0] *
                             (max_seq_length - len(input_contexts[idx]))
                             for idx in input_indexs[i * batch_size:last]]
                _labels = [
                    input_keys[idx]
                    for idx in input_indexs[i * batch_size:last]
                ]
                return {
                    input_ids:
                    _input_ids,
                    input_mask:
                    _input_masks,
                    segment_ids:
                    _segments,
                    masked_lm_positions:
                    np.zeros((len(_input_ids), 0), dtype=np.int32),
                    masked_lm_ids:
                    np.zeros((len(_input_ids), 0), dtype=np.int32),
                    masked_lm_weights:
                    np.ones((len(_input_ids), 0), dtype=np.float32),
                    next_sentence_labels:
                    np.zeros((len(_input_ids), ), dtype=np.int32),
                    labels:
                    _labels
                }

            try:
                for ep in range(num_epochs):
                    if ep % args.save_every == 0:
                        save()

                    prog = tqdm.tqdm(
                        range(0,
                              len(input_contexts) // batch_size, 1))
                    for i in prog:
                        (_, v_loss, v_summary) = sess.run(
                            (opt_apply, loss, summaries),
                            feed_dict=sample_feature(i))

                        summary_log.add_summary(v_summary, counter)

                        avg_loss = (avg_loss[0] * 0.99 + v_loss,
                                    avg_loss[1] * 0.99 + 1.0)

                        prog.set_description(
                            '[{ep} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                            .format(ep=ep,
                                    time=time.time() - start_time,
                                    loss=v_loss,
                                    avg=avg_loss[0] / avg_loss[1]))

                        counter += 1
            except KeyboardInterrupt:
                print('interrupted')
                save()

            save()
def main():
    args = parser.parse_args()

    if os.path.isfile(args.model + '/hparams.json'):
        with open(args.model + '/hparams.json') as f:
            bert_config_params = json.load(f)
    else:
        raise ValueError('invalid model name.')
    if os.path.isfile(args.model + '/idmaps.json'):
        with open(args.model + '/idmaps.json') as f:
            idmapping_dict = json.load(f)
    else:
        raise ValueError('invalid model name.')

    vocab_size = bert_config_params['vocab_size']
    max_seq_length = bert_config_params['max_position_embeddings']
    batch_size = args.batch_size
    EOT_TOKEN = vocab_size - 4
    MASK_TOKEN = vocab_size - 3
    CLS_TOKEN = vocab_size - 2
    SEP_TOKEN = vocab_size - 1

    with open('ja-bpe.txt', encoding='utf-8') as f:
        bpe = f.read().split('\n')

    with open('emoji.json', encoding='utf-8') as f:
        emoji = json.loads(f.read())

    enc = BPEEncoder_ja(bpe, emoji)

    num_labels = len(idmapping_dict)
    input_contexts = []
    input_keys = []
    input_names = []
    for f, i in idmapping_dict.items():
        n = 0
        for t in os.listdir(f'{args.input_dir}/{f}'):
            if os.path.isfile(f'{args.input_dir}/{f}/{t}'):
                with open(f'{args.input_dir}/{f}/{t}', encoding='utf-8') as fn:
                    if args.train_by_line:
                        for ln, p in enumerate(fn.readlines()):
                            tokens = enc.encode(p.strip())[:max_seq_length - 3]
                            tokens = [CLS_TOKEN
                                      ] + tokens + [EOT_TOKEN, SEP_TOKEN]
                            if len(tokens) < max_seq_length:
                                tokens.extend([0] *
                                              (max_seq_length - len(tokens)))
                            input_contexts.append(tokens)
                            input_keys.append(i)
                            input_names.append(f'{f}/{t}#{ln}')
                            n += 1
                    else:
                        p = fn.read()
                        tokens = enc.encode(p.strip())[:max_seq_length - 2]
                        tokens = [CLS_TOKEN] + tokens + [SEP_TOKEN]
                        if len(tokens) < max_seq_length:
                            tokens.extend([0] * (max_seq_length - len(tokens)))
                        input_contexts.append(tokens)
                        input_keys.append(i)
                        input_names.append(f'{f}/{t}')
                        n += 1
        print(f'{args.input_dir}/{f} mapped for id_{i}, read {n} contexts.')
    input_indexs = np.arange(len(input_contexts))

    bert_config = BertConfig(**bert_config_params)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = args.gpu

    with tf.Session(config=config) as sess:
        input_ids = tf.placeholder(tf.int32, [None, None])
        input_mask = tf.placeholder(tf.int32, [None, None])
        segment_ids = tf.placeholder(tf.int32, [None, None])
        masked_lm_positions = tf.placeholder(tf.int32, [None, None])
        masked_lm_ids = tf.placeholder(tf.int32, [None, None])
        masked_lm_weights = tf.placeholder(tf.float32, [None, None])
        next_sentence_labels = tf.placeholder(tf.int32, [None])

        model = BertModel(config=bert_config,
                          is_training=False,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          token_type_ids=segment_ids,
                          use_one_hot_embeddings=False)

        output = model.get_sequence_output()
        (_, _, _) = get_masked_lm_output(bert_config,
                                         model.get_sequence_output(),
                                         model.get_embedding_table(),
                                         masked_lm_positions, masked_lm_ids,
                                         masked_lm_weights)
        (_, _, _) = get_next_sentence_output(bert_config,
                                             model.get_pooled_output(),
                                             next_sentence_labels)

        saver = tf.train.Saver()

        labels = tf.placeholder(tf.int32, [
            batch_size,
        ])

        output_layer = model.get_pooled_output()

        if int(tf.__version__[0]) > 1:
            hidden_size = output_layer.shape[-1]
        else:
            hidden_size = output_layer.shape[-1].value

        output_weights = tf.get_variable(
            "output_weights", [num_labels, hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))

        output_bias = tf.get_variable("output_bias", [num_labels],
                                      initializer=tf.zeros_initializer())

        logits = tf.matmul(output_layer, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        probabilities = tf.nn.softmax(logits, axis=-1)

        saver = tf.train.Saver(var_list=tf.trainable_variables())
        ckpt = tf.train.latest_checkpoint(args.model)
        saver.restore(sess, ckpt)

        def sample_feature(i):
            last = min((i + 1) * batch_size, len(input_indexs))
            _input_ids = [
                input_contexts[idx]
                for idx in input_indexs[i * batch_size:last]
            ]
            _input_masks = [[1] * len(input_contexts[idx]) + [0] *
                            (max_seq_length - len(input_contexts[idx]))
                            for idx in input_indexs[i * batch_size:last]]
            _segments = [[1] * len(input_contexts[idx]) + [0] *
                         (max_seq_length - len(input_contexts[idx]))
                         for idx in input_indexs[i * batch_size:last]]
            _labels = [
                input_keys[idx] for idx in input_indexs[i * batch_size:last]
            ]
            return {
                input_ids: _input_ids,
                input_mask: _input_masks,
                segment_ids: _segments,
                masked_lm_positions: np.zeros((len(_input_ids), 0),
                                              dtype=np.int32),
                masked_lm_ids: np.zeros((len(_input_ids), 0), dtype=np.int32),
                masked_lm_weights: np.ones((len(_input_ids), 0),
                                           dtype=np.float32),
                next_sentence_labels: np.zeros((len(_input_ids), ),
                                               dtype=np.int32),
                labels: _labels
            }

        preds = []
        prog = tqdm.tqdm(range(0, len(input_contexts) // batch_size, 1))
        for i in prog:
            prob = sess.run(probabilities, feed_dict=sample_feature(i))
            for p in prob:
                pred = np.argmax(p)
                preds.append(pred)

        pd.DataFrame({
            'id': input_names,
            'y_true': input_keys,
            'y_pred': preds
        }).to_csv(args.output_file, index=False)

        r = np.zeros((num_labels, num_labels), dtype=int)
        for t, p in zip(input_keys, preds):
            r[t, p] += 1
        fig = plt.figure(figsize=(12, 6), dpi=72)
        ax = plt.matshow(r, interpolation='nearest', aspect=.5, cmap='cool')
        for (i, j), z in np.ndenumerate(r):
            if z >= 1000:
                plt.text(j - .33,
                         i,
                         '{:0.1f}K'.format(z / 1000),
                         ha='left',
                         va='center',
                         size=9,
                         color='black')
            else:
                plt.text(j - .33,
                         i,
                         f'{z}',
                         ha='left',
                         va='center',
                         size=9,
                         color='black')
        pfile = args.output_file
        if args.output_file.lower().endswith('.csv'):
            pfile = args.output_file[:-4]
        plt.savefig(pfile + '_map.png')
Ejemplo n.º 9
0
def main():
    args = parser.parse_args()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = args.gpu
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF

    vocab_size = 20573 + 3 # [MASK] [CLS] [SEP]
    EOT_TOKEN = vocab_size - 4
    MASK_TOKEN = vocab_size - 3
    CLS_TOKEN = vocab_size - 2
    SEP_TOKEN = vocab_size - 1
    max_predictions_per_seq = args.max_predictions_per_seq
    batch_size = args.batch_size

    with tf.Session(config=config) as sess:
        input_ids = tf.placeholder(tf.int32, [batch_size, None])
        input_mask = tf.placeholder(tf.int32, [batch_size, None])
        segment_ids = tf.placeholder(tf.int32, [batch_size, None])
        masked_lm_positions = tf.placeholder(tf.int32, [batch_size, None])
        masked_lm_ids = tf.placeholder(tf.int32, [batch_size, None])
        masked_lm_weights = tf.placeholder(tf.float32, [batch_size, None])
        next_sentence_labels = tf.placeholder(tf.int32, [None])

        if os.path.isfile(args.base_model+'/hparams.json'):
            with open(args.base_model+'/hparams.json') as f:
                bert_config_params = json.loads(f.read())
        else:
            raise ValueError('invalid model name.')

        max_seq_length = bert_config_params['max_position_embeddings']

        bert_config = BertConfig(**bert_config_params)
        model = BertModel(
            config=bert_config,
            is_training=True,
            input_ids=input_ids,
            input_mask=input_mask,
            use_one_hot_embeddings=False)

        (masked_lm_loss,_,_) = get_masked_lm_output(
             bert_config, model.get_sequence_output(), model.get_embedding_table(),
             masked_lm_positions, masked_lm_ids, masked_lm_weights)
        (next_sentence_loss,_,_) = get_next_sentence_output(
             bert_config, model.get_pooled_output(), next_sentence_labels)

        loss = masked_lm_loss + next_sentence_loss

        train_vars = tf.trainable_variables()

        global_step = tf.Variable(0, trainable=False)
        if args.warmup_steps > 0:
            learning_rate = tf.compat.v1.train.polynomial_decay(
                    learning_rate=1e-10,
                    end_learning_rate=args.learning_rate,
                    global_step=global_step,
                    decay_steps=args.warmup_steps
                )
        else:
            learning_rate = args.learning_rate

        if args.optim=='adam':
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           beta1=0.9,
                                           beta2=0.98,
                                           epsilon=1e-7)
        elif args.optim=='adagrad':
            opt = tf.train.AdagradOptimizer(learning_rate=learning_rate)
        elif args.optim=='sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
        else:
            raise ValueError('invalid optimizer name.')

        train_vars = tf.trainable_variables()
        opt_grads = tf.gradients(loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)

        summaries = tf.summary.scalar('loss', loss)
        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(
            var_list=train_vars,
            max_to_keep=5,
            keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        ckpt = tf.train.latest_checkpoint(args.base_model)
        saver.restore(sess, ckpt)
        print('Loading checkpoint', ckpt)

        print('Loading dataset...')
        global_chunks = np.load(args.dataset)
        global_chunk_index = copy(global_chunks.files)
        global_chunk_step = 0
        global_epochs = 0
        np.random.shuffle(global_chunk_index)

        def get_epoch():
            return global_epochs + (1 - len(global_chunk_index) / len(global_chunks.files))

        def pop_feature():
            nonlocal global_chunks,global_chunk_index,global_chunk_step, global_epochs
            # FULL-SENTENCES
            token = [np.uint16(CLS_TOKEN)]
            chunk = global_chunks[global_chunk_index[-1]][global_chunk_step:].astype(np.uint16)
            if len(chunk) >= max_seq_length-1:
                token.extend(chunk[:max_seq_length-1].tolist())
                global_chunk_step += max_seq_length-1
            else:
                if len(chunk) > 0:
                    token.extend(chunk.tolist())
                    token.append(np.uint16(EOT_TOKEN))
                    global_chunk_step += len(chunk)+1
                while len(token) < max_seq_length:
                    global_chunk_index.pop()
                    global_chunk_step = 0
                    if len(global_chunk_index) == 0:
                        global_chunk_index = copy(global_chunks.files)
                        np.random.shuffle(global_chunk_index)
                        global_epochs += 1
                    cur = len(token)
                    chunk = global_chunks[global_chunk_index[-1]].astype(np.uint16)
                    token.extend(chunk[:max_seq_length-cur].tolist())
                    global_chunk_step += max_seq_length-cur
                    if len(token) < max_seq_length:
                        token.append(np.uint16(EOT_TOKEN))
            return token

        print('Training...')

        def sample_feature():
            nonlocal global_chunks,global_chunk_index,global_chunk_step
            # Use dynamic mask
            p_input_ids = []
            p_input_mask = []
            p_segment_ids = []
            p_masked_lm_positions = []
            p_masked_lm_ids = []
            p_masked_lm_weights = []
            p_next_sentence_labels = [0] * batch_size

            for b in range(batch_size): # FULL-SENTENCES
                sampled_token = pop_feature()
                # Make Sequence
                ids = copy(sampled_token)
                masks = [1]*len(ids)
                segments = [1]*len(ids)
                # Make Masks
                mask_indexs = []
                for i in np.random.permutation(max_seq_length):
                    if ids[i] < EOT_TOKEN:
                        mask_indexs.append(i)
                    if len(mask_indexs) >= max_predictions_per_seq:
                        break

                lm_positions = []
                lm_ids = []
                lm_weights = []
                for i in sorted(mask_indexs):
                    masked_token = None
                    # 80% of the time, replace with [MASK]
                    if np.random.random() < 0.8:
                        masked_token = MASK_TOKEN # [MASK]
                    else:
                        # 10% of the time, keep original
                        if np.random.random() < 0.5:
                            masked_token = ids[i]
                        # 10% of the time, replace with random word
                        else:
                            masked_token = np.random.randint(EOT_TOKEN-1)

                    lm_positions.append(i)
                    lm_ids.append(ids[i])
                    lm_weights.append(1.0)
                    # apply mask
                    ids[i] = masked_token
                while len(lm_positions) < max_predictions_per_seq:
                    lm_positions.append(0)
                    lm_ids.append(0)
                    lm_weights.append(0.0)

                p_input_ids.append(ids)
                p_input_mask.append(masks)
                p_segment_ids.append(segments)
                p_masked_lm_positions.append(lm_positions)
                p_masked_lm_ids.append(lm_ids)
                p_masked_lm_weights.append(lm_weights)

            return {input_ids:p_input_ids,
                    input_mask:p_input_mask,
                    segment_ids:p_segment_ids,
                    masked_lm_positions:p_masked_lm_positions,
                    masked_lm_ids:p_masked_lm_ids,
                    masked_lm_weights:p_masked_lm_weights,
                    next_sentence_labels:p_next_sentence_labels}

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        hparams_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'hparams.json')
        maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
        with open(hparams_path, 'w') as fp:
            fp.write(json.dumps(bert_config_params))

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()

                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summaries),
                    feed_dict=sample_feature())

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                counter = counter+1
                if args.warmup_steps > 0:
                    global_step = global_step+1
        except KeyboardInterrupt:
            print('interrupted')
            save()