def get_autoencoder_config(config: configure_pretraining.PretrainingConfig, bert_config: modeling.BertConfig): """Get model config for the autoencoder network.""" ae_config = modeling.BertConfig.from_dict(bert_config.to_dict()) ae_config.hidden_size = int( round(bert_config.hidden_size * config.autoencoder_hidden_size)) ae_config.num_hidden_layers = int( round(bert_config.num_hidden_layers * config.autoencoder_layers)) ae_config.intermediate_size = 4 * ae_config.hidden_size ae_config.num_attention_heads = max(1, ae_config.hidden_size // 64) return ae_config
def get_generator_config(config: configure_pretraining.PretrainingConfig, bert_config: modeling.BertConfig): """Get model config for the generator network.""" gen_config = modeling.BertConfig.from_dict(bert_config.to_dict()) gen_config.hidden_size = int( round(bert_config.hidden_size * config.generator_hidden_size)) gen_config.num_hidden_layers = int( round(bert_config.num_hidden_layers * config.generator_layers)) gen_config.intermediate_size = 4 * gen_config.hidden_size gen_config.num_attention_heads = max(1, gen_config.hidden_size // 64) return gen_config
def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument( "--input_dir", default=None, type=str, required=True, help="The input data dir. Should contain .hdf5 files for the task.") parser.add_argument( "--bert_model", default="bert-large-uncased", type=str, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument( "--output_dir", default=None, type=str, required=True, help="The output directory where the model checkpoints will be written." ) ## Other parameters parser.add_argument("--config_file", default=None, type=str, help="The BERT model config") parser.add_argument("--ckpt", default="", type=str) parser.add_argument( "--max_seq_length", default=512, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument( "--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--max_steps", default=1000, type=float, help="Total number of training steps to perform.") parser.add_argument( "--warmup_proportion", default=0.01, type=float, help= "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumualte before performing a backward/update pass." ) parser.add_argument( '--fp16', default=False, action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument( '--loss_scale', type=float, default=0.0, help= 'Loss scaling, positive power of 2 values can improve fp16 convergence.' ) parser.add_argument('--log_freq', type=float, default=500, help='frequency of logging loss.') parser.add_argument('--checkpoint_activations', default=False, action='store_true', help="Whether to use gradient checkpointing") parser.add_argument("--resume_from_checkpoint", default=False, action='store_true', help="Whether to resume training from checkpoint.") parser.add_argument('--resume_step', type=int, default=-1, help="Step to resume training from.") parser.add_argument( '--num_steps_per_checkpoint', type=int, default=2000, help="Number of update steps until a model checkpoint is saved to disk." ) parser.add_argument('--dev_data_file', type=str, default="dev/dev.hdf5") parser.add_argument('--dev_batch_size', type=int, default=16) parser.add_argument("--save_total_limit", type=int, default=10) args = parser.parse_args() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) min_dev_loss = 1000000 best_step = 0 assert (torch.cuda.is_available()) print(args.local_rank) if args.local_rank == -1: device = torch.device("cuda") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) if args.train_batch_size % args.gradient_accumulation_steps != 0: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible" .format(args.gradient_accumulation_steps, args.train_batch_size)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps if not args.resume_from_checkpoint and os.path.exists( args.output_dir) and (os.listdir(args.output_dir) and os.listdir( args.output_dir) != ['logfile.txt']): logger.warning( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not args.resume_from_checkpoint: os.makedirs(args.output_dir, exist_ok=True) # Prepare model if args.config_file: config = BertConfig.from_json_file(args.config_file) if args.bert_model: model = BertForMaskedLM.from_pretrained(args.bert_model) else: model = BertForMaskedLM(config) print(args.ckpt) if args.ckpt: print("load from", args.ckpt) ckpt = torch.load(args.ckpt, map_location='cpu') if model in ckpt: ckpt = ckpt['model'] model.load_state_dict(ckpt, strict=False) pretrained_model_file = os.path.join(args.output_dir, "pytorch_model.bin") torch.save(model.state_dict(), pretrained_model_file) if not args.resume_from_checkpoint: global_step = 0 else: if args.resume_step == -1: model_names = [ f for f in os.listdir(args.output_dir) if f.endswith(".pt") ] args.resume_step = max([ int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names ]) global_step = args.resume_step checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu") model.load_state_dict(checkpoint['model'], strict=False) print("resume step from ", args.resume_step) model.to(device) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, weight_decay=0.01) if args.loss_scale == 0: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False, loss_scale="dynamic") else: model, optimizer = amp.initialize(model, optimizer, opt_level="O2", keep_batchnorm_fp32=False, loss_scale=args.loss_scale) scheduler = LinearWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=args.max_steps) if args.resume_from_checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) if args.local_rank != -1: model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f)) ] files.sort() num_files = len(files) logger.info("***** Loading Dev Data *****") dev_data = pretraining_dataset( input_file=os.path.join(args.input_dir, args.dev_data_file), max_pred_length=args.max_predictions_per_seq) if args.local_rank == -1: dev_sampler = RandomSampler(dev_data) dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.dev_batch_size * n_gpu, num_workers=4, pin_memory=True) else: dev_sampler = DistributedSampler(dev_data) dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.dev_batch_size, num_workers=4, pin_memory=True) logger.info("***** Running training *****") logger.info(" Batch size = {}".format(args.train_batch_size)) logger.info(" LR = {}".format(args.learning_rate)) model.train() logger.info(" Training. . .") most_recent_ckpts_paths = [] tr_loss = 0.0 # total added training loss average_loss = 0.0 # averaged loss every args.log_freq steps epoch = 0 training_steps = 0 while True: if not args.resume_from_checkpoint: random.shuffle(files) f_start_id = 0 else: f_start_id = checkpoint['files'][0] files = checkpoint['files'][1:] args.resume_from_checkpoint = False for f_id in range(f_start_id, len(files)): data_file = files[f_id] logger.info("file no {} file {}".format(f_id, data_file)) train_data = pretraining_dataset( input_file=data_file, max_pred_length=args.max_predictions_per_seq) if args.local_rank == -1: train_sampler = RandomSampler(train_data) train_dataloader = DataLoader( train_data, sampler=train_sampler, batch_size=args.train_batch_size * n_gpu, num_workers=4, pin_memory=True) else: train_sampler = DistributedSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=4, pin_memory=True) for step, batch in enumerate( tqdm(train_dataloader, desc="File Iteration")): model.train() training_steps += 1 batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch #\ loss = model( input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, masked_lm_labels=masked_lm_labels, checkpoint_activations=args.checkpoint_activations) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() tr_loss += loss.item() average_loss += loss.item() if training_steps % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scheduler.step() optimizer.step() optimizer.zero_grad() global_step += 1 if training_steps == 1 * args.gradient_accumulation_steps: logger.info( "Global Step:{} Average Loss = {} Step Loss = {} LR {}" .format(global_step, average_loss, loss.item(), optimizer.param_groups[0]['lr'])) if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0: logger.info( "Global Step:{} Average Loss = {} Step Loss = {} LR {}" .format(global_step, average_loss / args.log_freq, loss.item(), optimizer.param_groups[0]['lr'])) average_loss = 0 if training_steps % (args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0: logger.info("Begin Eval") model.eval() with torch.no_grad(): dev_global_step = 0 dev_final_loss = 0.0 for dev_step, dev_batch in enumerate( tqdm(dev_dataloader, desc="Evaluating")): batch = [t.to(device) for t in batch] dev_input_ids, dev_segment_ids, dev_input_mask, dev_masked_lm_labels, dev_next_sentence_labels = batch loss = model(input_ids=dev_input_ids, token_type_ids=dev_segment_ids, attention_mask=dev_input_mask, masked_lm_labels=dev_masked_lm_labels) dev_final_loss += loss dev_global_step += 1 dev_final_loss /= dev_global_step if (torch.distributed.is_initialized()): dev_final_loss /= torch.distributed.get_world_size( ) torch.distributed.all_reduce(dev_final_loss) logger.info("Dev Loss: {}".format( dev_final_loss.item())) if dev_final_loss < min_dev_loss: best_step = global_step min_dev_loss = dev_final_loss if (not torch.distributed.is_initialized() or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)): logger.info( "** ** * Saving best dev loss model ** ** * at step {}" .format(best_step)) dev_model_to_save = model.module if hasattr( model, 'module') else model output_save_file = os.path.join( args.output_dir, "best_ckpt.pt") torch.save( { 'model': dev_model_to_save.state_dict(), 'optimizer': optimizer.state_dict(), 'files': [f_id] + files }, output_save_file) if (not torch.distributed.is_initialized() or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)): # Save a trained model logger.info( "** ** * Saving fine - tuned model ** ** * ") model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)) torch.save( { 'model': model_to_save.state_dict(), 'optimizer': optimizer.state_dict(), 'files': [f_id] + files }, output_save_file) most_recent_ckpts_paths.append(output_save_file) if len(most_recent_ckpts_paths ) > args.save_total_limit: ckpt_to_be_removed = most_recent_ckpts_paths.pop(0) os.remove(ckpt_to_be_removed) if global_step >= args.max_steps: tr_loss = tr_loss * args.gradient_accumulation_steps / training_steps if (torch.distributed.is_initialized()): tr_loss /= torch.distributed.get_world_size() print(tr_loss) torch.distributed.all_reduce( torch.tensor(tr_loss).cuda()) logger.info("Total Steps:{} Final Loss = {}".format( training_steps, tr_loss)) with open( os.path.join(args.output_dir, "valid_results.txt"), "w") as f: f.write("Min dev loss: {}\nBest step: {}\n".format( min_dev_loss, best_step)) return del train_dataloader del train_sampler del train_data torch.cuda.empty_cache() epoch += 1
def main(): args = parser.parse_args() if os.path.isfile(args.model + '/hparams.json'): with open(args.model + '/hparams.json') as f: bert_config_params = json.load(f) else: raise ValueError('invalid model name.') if not (len(args.input_file) > 0 or len(args.context) > 0): raise ValueError('--input_file or --context required.') if (not os.path.isfile(args.input_file)) and len(args.context) == 0: raise ValueError('invalid input file name.') if len(args.input_file) > 0 and os.path.isfile(args.input_file): with open(args.input_file) as f: args.context = f.read() vocab_size = bert_config_params['vocab_size'] max_seq_length = bert_config_params['max_position_embeddings'] batch_size = 1 EOT_TOKEN = vocab_size - 4 MASK_TOKEN = vocab_size - 3 CLS_TOKEN = vocab_size - 2 SEP_TOKEN = vocab_size - 1 with open('ja-bpe.txt', encoding='utf-8') as f: bpe = f.read().split('\n') with open('emoji.json', encoding='utf-8') as f: emoji = json.loads(f.read()) enc = BPEEncoder_ja(bpe, emoji) bert_config = BertConfig(**bert_config_params) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = args.gpu with tf.Session(config=config) as sess: input_ids = tf.placeholder(tf.int32, [None, None]) input_mask = tf.placeholder(tf.int32, [None, None]) segment_ids = tf.placeholder(tf.int32, [None, None]) masked_lm_positions = tf.placeholder(tf.int32, [None, None]) masked_lm_ids = tf.placeholder(tf.int32, [None, None]) masked_lm_weights = tf.placeholder(tf.float32, [None, None]) next_sentence_labels = tf.placeholder(tf.int32, [None]) model = BertModel(config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=False) output = model.get_sequence_output() (_, _, _) = get_masked_lm_output(bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (_, _, _) = get_next_sentence_output(bert_config, model.get_pooled_output(), next_sentence_labels) saver = tf.train.Saver() masked_lm_values = tf.placeholder(tf.float32, [None, None]) with tf.variable_scope("loss"): (_, outputs) = get_masked_regression_output( bert_config, model.get_sequence_output(), masked_lm_positions, masked_lm_values, masked_lm_weights) saver = tf.train.Saver(var_list=tf.trainable_variables()) ckpt = tf.train.latest_checkpoint(args.model) saver.restore(sess, ckpt) _input_ids = [] _lm_positions = [] tokens = [enc.encode(p.strip()) for p in sep_txt(args.context)] tokens = [t for t in tokens if len(t) > 0] for t in tokens: _lm_positions.append(len(_input_ids)) _input_ids.extend([CLS_TOKEN] + t) _input_ids.append(EOT_TOKEN) _input_masks = [1] * len(_input_ids) _segments = [1] * len(_input_ids) _input_ids = _input_ids[:max_seq_length] _input_masks = _input_masks[:max_seq_length] _segments = _segments[:max_seq_length] while len(_segments) < max_seq_length: _input_ids.append(0) _input_masks.append(0) _segments.append(0) _lm_positions = [p for p in _lm_positions if p < max_seq_length] _lm_positions = _lm_positions[:max_seq_length] _lm_lm_weights = [1] * len(_lm_positions) while len(_lm_positions) < max_seq_length: _lm_positions.append(0) _lm_lm_weights.append(0) _lm_ids = [0] * len(_lm_positions) _lm_vals = [0] * len(_lm_positions) regress = sess.run(outputs, feed_dict={ input_ids: [_input_ids], input_mask: [_input_masks], segment_ids: [_segments], masked_lm_positions: [_lm_positions], masked_lm_ids: [_lm_ids], masked_lm_weights: [_lm_lm_weights], next_sentence_labels: [0], masked_lm_values: [_lm_vals] }) regress = regress.reshape((-1, )) if args.output_file == '': for tok, value in zip(tokens, regress): print(f'{value}\t{enc.decode(tok)}') else: sent = [] impt = [] for tok, value in zip(tokens, regress): sent.append(enc.decode(tok)) impt.append(value) df = pd.DataFrame({'sentence': sent, 'importance': impt}) df.to_csv(args.output_file, index=False)
from model.modeling import BertConfig, BertModel from run_finetune import get_masked_lm_output,get_next_sentence_output from encode_bpe import BPEEncoder_ja parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='RoBERTa-ja_small') parser.add_argument('--context', type=str, required=True) parser.add_argument('--split_tag', type=str, default='') parser.add_argument('--gpu', default='0', help='visible gpu number.') parser.add_argument('--output_max', default=False, action='store_true') args = parser.parse_args() with open(args.model+'/hparams.json') as f: bert_config_params = json.load(f) bert_config = BertConfig(**bert_config_params) vocab_size = bert_config_params['vocab_size'] max_seq_length = bert_config_params['max_position_embeddings'] EOT_TOKEN = vocab_size - 4 MASK_TOKEN = vocab_size - 3 CLS_TOKEN = vocab_size - 2 SEP_TOKEN = vocab_size - 1 config = tf.ConfigProto() config.gpu_options.visible_device_list = args.gpu with tf.Session(config=config,graph=tf.Graph()) as sess: input_ids = tf.placeholder(tf.int32, [None, None]) input_mask = tf.placeholder(tf.int32, [None, None]) segment_ids = tf.placeholder(tf.int32, [None, None]) masked_lm_positions = tf.placeholder(tf.int32, [None, None])
def main(): global EOT_TOKEN, MASK_TOKEN, CLS_TOKEN, SEP_TOKEN, enc args = parser.parse_args() if os.path.isfile(args.model + '/hparams.json'): with open(args.model + '/hparams.json') as f: bert_config_params = json.load(f) else: raise ValueError('invalid model name.') vocab_size = bert_config_params['vocab_size'] max_seq_length = bert_config_params['max_position_embeddings'] batch_size = args.batch_size save_every = args.save_every num_epochs = args.num_epochs EOT_TOKEN = vocab_size - 4 MASK_TOKEN = vocab_size - 3 CLS_TOKEN = vocab_size - 2 SEP_TOKEN = vocab_size - 1 with open('ja-bpe.txt', encoding='utf-8') as f: bpe = f.read().split('\n') with open('emoji.json', encoding='utf-8') as f: emoji = json.loads(f.read()) enc = BPEEncoder_ja(bpe, emoji) fl = [f'{args.input_dir}/{f}' for f in os.listdir(args.input_dir)] with Pool(args.num_encode_process) as pool: imap = pool.imap(encode_one, fl) input_contexts = list(tqdm(imap, total=len(fl))) input_indexs = np.random.permutation(len(input_contexts)) if args.do_eval: eval_num = int(args.eval_rate * len(input_indexs)) eval_input_indexs = input_indexs[:eval_num] input_indexs = input_indexs[eval_num:] bert_config = BertConfig(**bert_config_params) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = args.gpu with tf.Session(config=config) as sess: input_ids = tf.placeholder(tf.int32, [None, None]) input_mask = tf.placeholder(tf.int32, [None, None]) segment_ids = tf.placeholder(tf.int32, [None, None]) masked_lm_positions = tf.placeholder(tf.int32, [None, None]) masked_lm_ids = tf.placeholder(tf.int32, [None, None]) masked_lm_weights = tf.placeholder(tf.float32, [None, None]) next_sentence_labels = tf.placeholder(tf.int32, [None]) model = BertModel(config=bert_config, is_training=True, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=False) output = model.get_sequence_output() (_, _, _) = get_masked_lm_output(bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (_, _, _) = get_next_sentence_output(bert_config, model.get_pooled_output(), next_sentence_labels) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(args.model) saver.restore(sess, ckpt) train_vars = tf.trainable_variables() restored_weights = {} for i in range(len(train_vars)): restored_weights[train_vars[i].name] = sess.run(train_vars[i]) labels = tf.placeholder(tf.float32, [ None, ]) output_layer = model.get_pooled_output() if int(tf.__version__[0]) > 1: hidden_size = output_layer.shape[-1] else: hidden_size = output_layer.shape[-1].value masked_lm_values = tf.placeholder(tf.float32, [None, None]) with tf.variable_scope("loss"): (loss, _) = get_masked_regression_output( bert_config, model.get_sequence_output(), masked_lm_positions, masked_lm_values, masked_lm_weights) opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) train_vars = tf.trainable_variables() opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summaries = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 hparams_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'hparams.json') maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) with open(hparams_path, 'w') as fp: fp.write(json.dumps(bert_config_params)) sess.run(tf.global_variables_initializer()) # init output_weights restored = 0 for k, v in restored_weights.items(): for i in range(len(train_vars)): if train_vars[i].name == k: assign_op = train_vars[i].assign(v) sess.run(assign_op) restored += 1 assert restored == len(restored_weights), 'fail to restore model.' saver = tf.train.Saver(var_list=tf.trainable_variables()) def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') avg_loss = (0.0, 0.0) start_time = time.time() def sample_feature(i, eval=False): indexs = eval_input_indexs if eval else input_indexs last = min((i + 1) * batch_size, len(indexs)) _input_ids = [] _input_masks = [] _segments = [] _lm_positions = [] _lm_vals = [] _lm_lm_weights = [] _lm_ids = [] for j in range(i * batch_size, last, 1): (lm_tokens, lm_positions, lm_imprtances) = input_contexts[indexs[j]] ids = copy(lm_tokens)[:max_seq_length] seg = [1] * len(ids) while len(ids) < max_seq_length: ids.append(0) seg.append(0) _input_ids.append(ids) _input_masks.append(seg) _segments.append(seg) pos = copy(lm_positions)[:max_seq_length] val = copy(lm_imprtances)[:max_seq_length] wei = [1] * len(pos) while len(ids) < max_seq_length: pos.append(0) val.append(0) wei.append(0) _lm_positions.append(pos) _lm_ids.append([0] * max_seq_length) _lm_lm_weights.append(wei) _lm_vals.append(val) return { input_ids: _input_ids, input_mask: _input_masks, segment_ids: _segments, masked_lm_positions: _lm_positions, masked_lm_ids: _lm_ids, masked_lm_weights: _lm_lm_weights, next_sentence_labels: [0] * len(_input_ids), masked_lm_values: _lm_vals } try: for ep in range(num_epochs): if ep % args.save_every == 0: save() prog = tqdm(range(0, len(input_indexs) // batch_size, 1)) for i in prog: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict=sample_feature(i)) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) prog.set_description( '[{ep} | {time:2.0f}] loss={loss:.4f} avg={avg:.4f}' .format(ep=ep, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 if args.do_eval: eval_losses = [] for i in tqdm( range(0, len(eval_input_indexs) // batch_size, 1)): eval_losses.append( sess.run(loss, feed_dict=sample_feature(i, True))) print("eval loss:", np.mean(eval_losses)) except KeyboardInterrupt: print('interrupted') save() save()
def main(): args = parser.parse_args() if os.path.isfile(args.model + '/hparams.json'): with open(args.model + '/hparams.json') as f: bert_config_params = json.load(f) else: raise ValueError('invalid model name.') vocab_size = bert_config_params['vocab_size'] max_seq_length = bert_config_params['max_position_embeddings'] batch_size = args.batch_size save_every = args.save_every num_epochs = args.num_epochs EOT_TOKEN = vocab_size - 4 MASK_TOKEN = vocab_size - 3 CLS_TOKEN = vocab_size - 2 SEP_TOKEN = vocab_size - 1 with open('ja-bpe.txt', encoding='utf-8') as f: bpe = f.read().split('\n') with open('emoji.json', encoding='utf-8') as f: emoji = json.loads(f.read()) enc = BPEEncoder_ja(bpe, emoji) keys = [ f for f in os.listdir(args.input_dir) if os.path.isdir(args.input_dir + '/' + f) ] keys = sorted(keys) num_labels = len(keys) input_contexts = [] input_keys = [] idmapping_dict = {} for i, f in enumerate(keys): n = 0 for t in os.listdir(f'{args.input_dir}/{f}'): if os.path.isfile(f'{args.input_dir}/{f}/{t}'): with open(f'{args.input_dir}/{f}/{t}', encoding='utf-8') as fn: if args.train_by_line: for p in fn.readlines(): tokens = enc.encode(p.strip())[:max_seq_length - 2] tokens = [CLS_TOKEN] + tokens + [SEP_TOKEN] if len(tokens) < max_seq_length: tokens.extend([0] * (max_seq_length - len(tokens))) input_contexts.append(tokens) input_keys.append(i) n += 1 else: p = fn.read() tokens = enc.encode(p.strip())[:max_seq_length - 3] tokens = [CLS_TOKEN] + tokens + [EOT_TOKEN, SEP_TOKEN] if len(tokens) < max_seq_length: tokens.extend([0] * (max_seq_length - len(tokens))) input_contexts.append(tokens) input_keys.append(i) n += 1 print(f'{args.input_dir}/{f} mapped for id_{i}, read {n} contexts.') idmapping_dict[f] = i input_indexs = np.random.permutation(len(input_contexts)) bert_config = BertConfig(**bert_config_params) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = args.gpu with tf.Session(config=config) as sess: input_ids = tf.placeholder(tf.int32, [None, None]) input_mask = tf.placeholder(tf.int32, [None, None]) segment_ids = tf.placeholder(tf.int32, [None, None]) masked_lm_positions = tf.placeholder(tf.int32, [None, None]) masked_lm_ids = tf.placeholder(tf.int32, [None, None]) masked_lm_weights = tf.placeholder(tf.float32, [None, None]) next_sentence_labels = tf.placeholder(tf.int32, [None]) model = BertModel(config=bert_config, is_training=True, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=False) output = model.get_sequence_output() (_, _, _) = get_masked_lm_output(bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (_, _, _) = get_next_sentence_output(bert_config, model.get_pooled_output(), next_sentence_labels) saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(args.model) saver.restore(sess, ckpt) train_vars = tf.trainable_variables() restored_weights = {} for i in range(len(train_vars)): restored_weights[train_vars[i].name] = sess.run(train_vars[i]) labels = tf.placeholder(tf.int32, [ None, ]) output_layer = model.get_pooled_output() if int(tf.__version__[0]) > 1: hidden_size = output_layer.shape[-1] else: hidden_size = output_layer.shape[-1].value output_weights = tf.get_variable( "output_weights", [num_labels, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) with tf.variable_scope("loss"): output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probabilities = tf.nn.softmax(logits, axis=-1) log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) train_vars = tf.trainable_variables() opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summaries = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 hparams_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'hparams.json') maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) with open(hparams_path, 'w') as fp: fp.write(json.dumps(bert_config_params)) idmaps_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'idmaps.json') with open(idmaps_path, 'w') as fp: fp.write(json.dumps(idmapping_dict)) sess.run(tf.global_variables_initializer()) # init output_weights restored = 0 for k, v in restored_weights.items(): for i in range(len(train_vars)): if train_vars[i].name == k: assign_op = train_vars[i].assign(v) sess.run(assign_op) restored += 1 assert restored == len(restored_weights), 'fail to restore model.' saver = tf.train.Saver(var_list=tf.trainable_variables()) def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save(sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') avg_loss = (0.0, 0.0) start_time = time.time() def sample_feature(i): last = min((i + 1) * batch_size, len(input_indexs)) _input_ids = [ input_contexts[idx] for idx in input_indexs[i * batch_size:last] ] _input_masks = [[1] * len(input_contexts[idx]) + [0] * (max_seq_length - len(input_contexts[idx])) for idx in input_indexs[i * batch_size:last]] _segments = [[1] * len(input_contexts[idx]) + [0] * (max_seq_length - len(input_contexts[idx])) for idx in input_indexs[i * batch_size:last]] _labels = [ input_keys[idx] for idx in input_indexs[i * batch_size:last] ] return { input_ids: _input_ids, input_mask: _input_masks, segment_ids: _segments, masked_lm_positions: np.zeros((len(_input_ids), 0), dtype=np.int32), masked_lm_ids: np.zeros((len(_input_ids), 0), dtype=np.int32), masked_lm_weights: np.ones((len(_input_ids), 0), dtype=np.float32), next_sentence_labels: np.zeros((len(_input_ids), ), dtype=np.int32), labels: _labels } try: for ep in range(num_epochs): if ep % args.save_every == 0: save() prog = tqdm.tqdm( range(0, len(input_contexts) // batch_size, 1)) for i in prog: (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict=sample_feature(i)) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) prog.set_description( '[{ep} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format(ep=ep, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter += 1 except KeyboardInterrupt: print('interrupted') save() save()
def main(): args = parser.parse_args() if os.path.isfile(args.model + '/hparams.json'): with open(args.model + '/hparams.json') as f: bert_config_params = json.load(f) else: raise ValueError('invalid model name.') if os.path.isfile(args.model + '/idmaps.json'): with open(args.model + '/idmaps.json') as f: idmapping_dict = json.load(f) else: raise ValueError('invalid model name.') vocab_size = bert_config_params['vocab_size'] max_seq_length = bert_config_params['max_position_embeddings'] batch_size = args.batch_size EOT_TOKEN = vocab_size - 4 MASK_TOKEN = vocab_size - 3 CLS_TOKEN = vocab_size - 2 SEP_TOKEN = vocab_size - 1 with open('ja-bpe.txt', encoding='utf-8') as f: bpe = f.read().split('\n') with open('emoji.json', encoding='utf-8') as f: emoji = json.loads(f.read()) enc = BPEEncoder_ja(bpe, emoji) num_labels = len(idmapping_dict) input_contexts = [] input_keys = [] input_names = [] for f, i in idmapping_dict.items(): n = 0 for t in os.listdir(f'{args.input_dir}/{f}'): if os.path.isfile(f'{args.input_dir}/{f}/{t}'): with open(f'{args.input_dir}/{f}/{t}', encoding='utf-8') as fn: if args.train_by_line: for ln, p in enumerate(fn.readlines()): tokens = enc.encode(p.strip())[:max_seq_length - 3] tokens = [CLS_TOKEN ] + tokens + [EOT_TOKEN, SEP_TOKEN] if len(tokens) < max_seq_length: tokens.extend([0] * (max_seq_length - len(tokens))) input_contexts.append(tokens) input_keys.append(i) input_names.append(f'{f}/{t}#{ln}') n += 1 else: p = fn.read() tokens = enc.encode(p.strip())[:max_seq_length - 2] tokens = [CLS_TOKEN] + tokens + [SEP_TOKEN] if len(tokens) < max_seq_length: tokens.extend([0] * (max_seq_length - len(tokens))) input_contexts.append(tokens) input_keys.append(i) input_names.append(f'{f}/{t}') n += 1 print(f'{args.input_dir}/{f} mapped for id_{i}, read {n} contexts.') input_indexs = np.arange(len(input_contexts)) bert_config = BertConfig(**bert_config_params) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = args.gpu with tf.Session(config=config) as sess: input_ids = tf.placeholder(tf.int32, [None, None]) input_mask = tf.placeholder(tf.int32, [None, None]) segment_ids = tf.placeholder(tf.int32, [None, None]) masked_lm_positions = tf.placeholder(tf.int32, [None, None]) masked_lm_ids = tf.placeholder(tf.int32, [None, None]) masked_lm_weights = tf.placeholder(tf.float32, [None, None]) next_sentence_labels = tf.placeholder(tf.int32, [None]) model = BertModel(config=bert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=False) output = model.get_sequence_output() (_, _, _) = get_masked_lm_output(bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (_, _, _) = get_next_sentence_output(bert_config, model.get_pooled_output(), next_sentence_labels) saver = tf.train.Saver() labels = tf.placeholder(tf.int32, [ batch_size, ]) output_layer = model.get_pooled_output() if int(tf.__version__[0]) > 1: hidden_size = output_layer.shape[-1] else: hidden_size = output_layer.shape[-1].value output_weights = tf.get_variable( "output_weights", [num_labels, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) probabilities = tf.nn.softmax(logits, axis=-1) saver = tf.train.Saver(var_list=tf.trainable_variables()) ckpt = tf.train.latest_checkpoint(args.model) saver.restore(sess, ckpt) def sample_feature(i): last = min((i + 1) * batch_size, len(input_indexs)) _input_ids = [ input_contexts[idx] for idx in input_indexs[i * batch_size:last] ] _input_masks = [[1] * len(input_contexts[idx]) + [0] * (max_seq_length - len(input_contexts[idx])) for idx in input_indexs[i * batch_size:last]] _segments = [[1] * len(input_contexts[idx]) + [0] * (max_seq_length - len(input_contexts[idx])) for idx in input_indexs[i * batch_size:last]] _labels = [ input_keys[idx] for idx in input_indexs[i * batch_size:last] ] return { input_ids: _input_ids, input_mask: _input_masks, segment_ids: _segments, masked_lm_positions: np.zeros((len(_input_ids), 0), dtype=np.int32), masked_lm_ids: np.zeros((len(_input_ids), 0), dtype=np.int32), masked_lm_weights: np.ones((len(_input_ids), 0), dtype=np.float32), next_sentence_labels: np.zeros((len(_input_ids), ), dtype=np.int32), labels: _labels } preds = [] prog = tqdm.tqdm(range(0, len(input_contexts) // batch_size, 1)) for i in prog: prob = sess.run(probabilities, feed_dict=sample_feature(i)) for p in prob: pred = np.argmax(p) preds.append(pred) pd.DataFrame({ 'id': input_names, 'y_true': input_keys, 'y_pred': preds }).to_csv(args.output_file, index=False) r = np.zeros((num_labels, num_labels), dtype=int) for t, p in zip(input_keys, preds): r[t, p] += 1 fig = plt.figure(figsize=(12, 6), dpi=72) ax = plt.matshow(r, interpolation='nearest', aspect=.5, cmap='cool') for (i, j), z in np.ndenumerate(r): if z >= 1000: plt.text(j - .33, i, '{:0.1f}K'.format(z / 1000), ha='left', va='center', size=9, color='black') else: plt.text(j - .33, i, f'{z}', ha='left', va='center', size=9, color='black') pfile = args.output_file if args.output_file.lower().endswith('.csv'): pfile = args.output_file[:-4] plt.savefig(pfile + '_map.png')
def main(): args = parser.parse_args() config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = args.gpu config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF vocab_size = 20573 + 3 # [MASK] [CLS] [SEP] EOT_TOKEN = vocab_size - 4 MASK_TOKEN = vocab_size - 3 CLS_TOKEN = vocab_size - 2 SEP_TOKEN = vocab_size - 1 max_predictions_per_seq = args.max_predictions_per_seq batch_size = args.batch_size with tf.Session(config=config) as sess: input_ids = tf.placeholder(tf.int32, [batch_size, None]) input_mask = tf.placeholder(tf.int32, [batch_size, None]) segment_ids = tf.placeholder(tf.int32, [batch_size, None]) masked_lm_positions = tf.placeholder(tf.int32, [batch_size, None]) masked_lm_ids = tf.placeholder(tf.int32, [batch_size, None]) masked_lm_weights = tf.placeholder(tf.float32, [batch_size, None]) next_sentence_labels = tf.placeholder(tf.int32, [None]) if os.path.isfile(args.base_model+'/hparams.json'): with open(args.base_model+'/hparams.json') as f: bert_config_params = json.loads(f.read()) else: raise ValueError('invalid model name.') max_seq_length = bert_config_params['max_position_embeddings'] bert_config = BertConfig(**bert_config_params) model = BertModel( config=bert_config, is_training=True, input_ids=input_ids, input_mask=input_mask, use_one_hot_embeddings=False) (masked_lm_loss,_,_) = get_masked_lm_output( bert_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights) (next_sentence_loss,_,_) = get_next_sentence_output( bert_config, model.get_pooled_output(), next_sentence_labels) loss = masked_lm_loss + next_sentence_loss train_vars = tf.trainable_variables() global_step = tf.Variable(0, trainable=False) if args.warmup_steps > 0: learning_rate = tf.compat.v1.train.polynomial_decay( learning_rate=1e-10, end_learning_rate=args.learning_rate, global_step=global_step, decay_steps=args.warmup_steps ) else: learning_rate = args.learning_rate if args.optim=='adam': opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.9, beta2=0.98, epsilon=1e-7) elif args.optim=='adagrad': opt = tf.train.AdagradOptimizer(learning_rate=learning_rate) elif args.optim=='sgd': opt = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) else: raise ValueError('invalid optimizer name.') train_vars = tf.trainable_variables() opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summaries = tf.summary.scalar('loss', loss) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) saver = tf.train.Saver( var_list=train_vars, max_to_keep=5, keep_checkpoint_every_n_hours=2) sess.run(tf.global_variables_initializer()) ckpt = tf.train.latest_checkpoint(args.base_model) saver.restore(sess, ckpt) print('Loading checkpoint', ckpt) print('Loading dataset...') global_chunks = np.load(args.dataset) global_chunk_index = copy(global_chunks.files) global_chunk_step = 0 global_epochs = 0 np.random.shuffle(global_chunk_index) def get_epoch(): return global_epochs + (1 - len(global_chunk_index) / len(global_chunks.files)) def pop_feature(): nonlocal global_chunks,global_chunk_index,global_chunk_step, global_epochs # FULL-SENTENCES token = [np.uint16(CLS_TOKEN)] chunk = global_chunks[global_chunk_index[-1]][global_chunk_step:].astype(np.uint16) if len(chunk) >= max_seq_length-1: token.extend(chunk[:max_seq_length-1].tolist()) global_chunk_step += max_seq_length-1 else: if len(chunk) > 0: token.extend(chunk.tolist()) token.append(np.uint16(EOT_TOKEN)) global_chunk_step += len(chunk)+1 while len(token) < max_seq_length: global_chunk_index.pop() global_chunk_step = 0 if len(global_chunk_index) == 0: global_chunk_index = copy(global_chunks.files) np.random.shuffle(global_chunk_index) global_epochs += 1 cur = len(token) chunk = global_chunks[global_chunk_index[-1]].astype(np.uint16) token.extend(chunk[:max_seq_length-cur].tolist()) global_chunk_step += max_seq_length-cur if len(token) < max_seq_length: token.append(np.uint16(EOT_TOKEN)) return token print('Training...') def sample_feature(): nonlocal global_chunks,global_chunk_index,global_chunk_step # Use dynamic mask p_input_ids = [] p_input_mask = [] p_segment_ids = [] p_masked_lm_positions = [] p_masked_lm_ids = [] p_masked_lm_weights = [] p_next_sentence_labels = [0] * batch_size for b in range(batch_size): # FULL-SENTENCES sampled_token = pop_feature() # Make Sequence ids = copy(sampled_token) masks = [1]*len(ids) segments = [1]*len(ids) # Make Masks mask_indexs = [] for i in np.random.permutation(max_seq_length): if ids[i] < EOT_TOKEN: mask_indexs.append(i) if len(mask_indexs) >= max_predictions_per_seq: break lm_positions = [] lm_ids = [] lm_weights = [] for i in sorted(mask_indexs): masked_token = None # 80% of the time, replace with [MASK] if np.random.random() < 0.8: masked_token = MASK_TOKEN # [MASK] else: # 10% of the time, keep original if np.random.random() < 0.5: masked_token = ids[i] # 10% of the time, replace with random word else: masked_token = np.random.randint(EOT_TOKEN-1) lm_positions.append(i) lm_ids.append(ids[i]) lm_weights.append(1.0) # apply mask ids[i] = masked_token while len(lm_positions) < max_predictions_per_seq: lm_positions.append(0) lm_ids.append(0) lm_weights.append(0.0) p_input_ids.append(ids) p_input_mask.append(masks) p_segment_ids.append(segments) p_masked_lm_positions.append(lm_positions) p_masked_lm_ids.append(lm_ids) p_masked_lm_weights.append(lm_weights) return {input_ids:p_input_ids, input_mask:p_input_mask, segment_ids:p_segment_ids, masked_lm_positions:p_masked_lm_positions, masked_lm_ids:p_masked_lm_ids, masked_lm_weights:p_masked_lm_weights, next_sentence_labels:p_next_sentence_labels} counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 hparams_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'hparams.json') maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) with open(hparams_path, 'w') as fp: fp.write(json.dumps(bert_config_params)) def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) saver.save( sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') avg_loss = (0.0, 0.0) start_time = time.time() try: while True: if counter % args.save_every == 0: save() (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict=sample_feature()) summary_log.add_summary(v_summary, counter) avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) print( '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' .format( counter=counter, time=time.time() - start_time, loss=v_loss, avg=avg_loss[0] / avg_loss[1])) counter = counter+1 if args.warmup_steps > 0: global_step = global_step+1 except KeyboardInterrupt: print('interrupted') save()