def __init__(self, cfg, word_vocab_size, pos_vocab_size, dep_vocab_size): super().__init__() self.transformer = models.Transformer(cfg) #logits_pos self.fc1 = nn.Linear(cfg.dim, cfg.dim) self.activ1 = models.gelu self.norm1 = models.LayerNorm(cfg) self.decoder1 = nn.Linear(cfg.dim, pos_vocab_size) #logits_vocab self.fc2 = nn.Linear(cfg.dim, cfg.dim) self.activ2 = models.gelu self.norm2 = models.LayerNorm(cfg) self.decoder2 = nn.Linear(cfg.dim, dep_vocab_size) #logits_word_vocab_size self.fc3 = nn.Linear(cfg.dim, cfg.dim) self.activ3 = models.gelu self.norm3 = models.LayerNorm(cfg) embed_weight = self.transformer.embed.tok_embed.weight n_vocab, n_dim = embed_weight.size() self.decoder3 = nn.Linear(n_dim, n_vocab, bias=False) self.decoder3.weight = embed_weight self.decoder3_bias = nn.Parameter(torch.zeros(n_vocab))
def __init__(self, cfg): super().__init__() self.transformer = models.Transformer(cfg) self.fc = nn.Linear(cfg.hidden, cfg.hidden) self.activ1 = nn.Tanh() self.linear = nn.Linear(cfg.hidden, cfg.hidden) self.activ2 = models.gelu self.norm = models.LayerNorm(cfg) self.classifier = nn.Linear(cfg.hidden, 2) # decoder is shared with embedding layer ## project hidden layer to embedding layer embed_weight2 = self.transformer.embed.tok_embed2.weight n_hidden, n_embedding = embed_weight2.size() self.decoder1 = nn.Linear(n_hidden, n_embedding, bias=False) self.decoder1.weight.data = embed_weight2.data.t() ## project embedding layer to vocabulary layer embed_weight1 = self.transformer.embed.tok_embed1.weight n_vocab, n_embedding = embed_weight1.size() self.decoder2 = nn.Linear(n_embedding, n_vocab, bias=False) self.decoder2.weight = embed_weight1 # self.tok_embed1 = nn.Embedding(cfg.vocab_size, cfg.embedding) # self.tok_embed2 = nn.Linear(cfg.embedding, cfg.hidden) self.decoder_bias = nn.Parameter(torch.zeros(n_vocab)) # 역할이 뭐지..?
def predict_step(inputs, params, cache, eos_id, max_decode_len, config, beam_size=4): """Predict translation with fast decoding beam search on a batch.""" batch_size = inputs.shape[0] # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place # rather than tiled. # i.e. if we denote each batch element subtensor as el[n]: # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] src_padding_mask = decode.flat_batch_beam_expand((inputs > 0)[..., None], beam_size) tgt_padding_mask = decode.flat_batch_beam_expand( jnp.ones((batch_size, 1, 1)), beam_size) encoded_inputs = decode.flat_batch_beam_expand( models.Transformer(config).apply({'param': params}, inputs, method=models.Transformer.encode), beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.Transformer(config).apply( { 'param': params, 'cache': flat_cache }, encoded_inputs, src_padding_mask, flat_ids, tgt_padding_mask=tgt_padding_mask, mutable=['cache'], method=models.Transformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search(inputs, cache, tokens_ids_to_logits, beam_size=beam_size, alpha=0.6, eos_id=eos_id, max_decode_len=max_decode_len) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. # Return the highest scoring beam sequence, drop first dummy 0 token. return beam_seqs[:, -1, 1:]
def eval_step(params, batch, config, label_smoothing=0.0): """Calculate evaluation metrics on a batch.""" inputs, targets = batch["inputs"], batch["targets"] weights = jnp.where(targets > 0, 1.0, 0.0) logits = models.Transformer(config).apply({"params": params}, inputs, targets) return compute_metrics(logits, targets, weights, label_smoothing)
def initialize_cache(inputs, max_decode_len, config): """Initialize a cache for a given input shape and max decode length.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.Transformer(config).init( jax.random.PRNGKey(0), jnp.ones(inputs.shape, config.dtype), jnp.ones(target_shape, config.dtype)) return initial_variables['cache']
def __init__(self, cfg, n_labels): super().__init__() self.transformer = models.Transformer(cfg) self.fc = nn.Linear(cfg.dim, cfg.dim) self.activ = nn.Tanh() self.drop = nn.Dropout(cfg.p_drop_hidden) self.classifier = nn.Linear(cfg.dim, n_labels)
def __init__(self, cfg): super().__init__() self.transformer = models.Transformer(cfg) #logits_sentence_clsf self.fc = nn.Linear(cfg.dim, cfg.dim) self.activ1 = nn.Tanh() self.classifier = nn.Linear(cfg.dim, 2) #logits_paragraph_clsf ''' self.fc = nn.Linear(cfg.dim, 2) self.activ1 = nn.Tanh() self.norm1 = models.LayerNorm(cfg) self.drop = nn.Dropout(cfg.p_drop_hidden) self.classifier = nn.Linear(cfg.max_len * 2, 2) ''' #logits_lm self.linear = nn.Linear(cfg.dim, cfg.dim) self.activ2 = models.gelu self.norm2 = models.LayerNorm(cfg) # decoder is shared with embedding layer embed_weight = self.transformer.embed.tok_embed.weight n_vocab, n_dim = embed_weight.size() self.decoder = nn.Linear(n_dim, n_vocab, bias=False) self.decoder.weight = embed_weight self.decoder_bias = nn.Parameter(torch.zeros(n_vocab)) #logits_same self.linear2 = nn.Linear(cfg.dim, cfg.vocab_size)
def __init__(self, cfg, n_labels): super().__init__() self.transformer = models.Transformer(cfg) self.fc = nn.Linear(cfg.hidden, cfg.hidden) self.activ = nn.ReLU() self.drop = nn.Dropout(cfg.p_drop_hidden) self.pool = nn.AdaptiveMaxPool1d(1) self.classifier = nn.Linear(cfg.hidden, n_labels)
def single_translate(english_text): """ 只输入一句话,即batch_size==1 """ conf = configuration.Config() tokenizer = tokenization.FullTokenizer(en_vocab_file=os.path.join(conf.file_config.data_path, conf.file_config.en_vocab), zh_vocab_file=os.path.join(conf.file_config.data_path, conf.file_config.zh_vocab)) conf.model_config.src_vocab_size = tokenizer.get_en_vocab_size() + 2 conf.model_config.trg_vocab_size = tokenizer.get_zh_vocab_size() + 2 model = models.Transformer(conf) model.to(device) translate_dataset = datasets.Translate(mode='single_translate', config=conf, tokenizer=tokenizer, texts=english_text) # encoder输入即为原始英文id en_ids, _ = translate_dataset[0] en_ids = torch.tensor(en_ids).unsqueeze(dim=0) # decoder的初始输出为<BOS> decoder_input = torch.tensor([tokenizer.get_zh_vocab_size()]).view(1, 1) # [1, 1] model.load_state_dict(torch.load(os.path.join(conf.train_config.model_dir, conf.train_config.model_name + '.pth'), map_location=device)) # checkpoint = torch.load(os.path.join(conf.train_config.model_dir, # conf.train_config.model_name + '_epoch_{}.tar'.format(50)), # map_location=device) # model.load_state_dict(checkpoint['model_state_dict']) model.eval() for i in range(51): if torch.cuda.is_available(): prediction_logits = model(en_ids.cuda(), decoder_input.cuda()) # [1, i+1, vocab_size] else: prediction_logits = model(en_ids, decoder_input) # 取出最后一个distribution,做argmax得到预测的新字 predictions = prediction_logits[:, -1, :] # [batch_size, vocab_size] predictions = F.softmax(predictions, dim=-1) predict_zh_ids = torch.argmax(predictions, dim=-1) # [batch_size] # 若预测出的结果是<EOS>,则结束 if predict_zh_ids.data == tokenizer.get_zh_vocab_size() + 1: break # 否则,预测出的结果与先前的结果拼接,重新循环 else: if torch.cuda.is_available(): decoder_input = torch.cat([decoder_input.cuda(), predict_zh_ids.view(1, 1)], dim=1) else: decoder_input = torch.cat([decoder_input, predict_zh_ids.view(1, 1)], dim=1) # 将生成的中文id转回中文文字 translated_text = tokenizer.convert_zh_ids_to_text(list(decoder_input.cpu().detach().numpy()[0])[1: 51]) print('原文:', english_text) print('翻译:', translated_text) return translated_text
def main(): conf = configuration.Config() tokenizer = tokenization.FullTokenizer(en_vocab_file=os.path.join(conf.file_config.data_path, conf.file_config.en_vocab), zh_vocab_file=os.path.join(conf.file_config.data_path, conf.file_config.zh_vocab)) logging.info('Using Device: {}'.format(device)) conf.model_config.src_vocab_size = tokenizer.get_en_vocab_size() + 2 conf.model_config.trg_vocab_size = tokenizer.get_zh_vocab_size() + 2 model = models.Transformer(conf) model = model.to(device) if args.train: train_dataset = datasets.Translate(mode='train', config=conf, tokenizer=tokenizer, auto_padding=conf.train_config.auto_padding, do_filter=False) logging.info("***** Running training *****") logging.info(" Num examples = %d", len(train_dataset)) logging.info(" Total training steps: {}".format(train_dataset.num_steps)) train_dataloader = DataLoader(train_dataset, batch_size=conf.train_config.train_batch_size, shuffle=True, collate_fn=collate_fn) run(config=conf, dataloader=train_dataloader, model=model, mode='train', start_epoch=0, total_steps=train_dataset.num_steps) if args.eval: eval_dataset = datasets.Translate(mode='eval', config=conf, tokenizer=tokenizer, auto_padding=conf.train_config.auto_padding, do_filter=False) logging.info("***** Running validating *****") logging.info(" Num examples = %d", len(eval_dataset)) logging.info(" Total validating steps: {}".format(eval_dataset.num_steps)) eval_dataloader = DataLoader(eval_dataset, batch_size=conf.train_config.eval_batch_size, collate_fn=collate_fn) run(config=conf, dataloader=eval_dataloader, model=model, mode='eval', start_epoch=0, total_steps=eval_dataset.num_steps)
def __init__(self, cfg): super().__init__() self.transformer = models.Transformer(cfg) self.fc = nn.Linear(cfg.hidden, cfg.hidden) self.activ1 = nn.Tanh() self.linear = nn.Linear(cfg.hidden, cfg.hidden) self.activ2 = models.gelu self.norm = models.LayerNorm(cfg) self.classifier = nn.Linear(cfg.hidden, 2) # decoder is shared with embedding layer ## project hidden layer to embedding layer self.discriminator = nn.Linear(cfg.hidden, 1, bias=False)
def __init__(self, cfg, n_labels, local_pretrained=False): super().__init__() if (local_pretrained): self.transformer = models.Transformer(cfg) else: self.transformer = BertModel.from_pretrained('bert-base-uncased') self.fc = nn.Linear(cfg.dim, cfg.dim) self.activ = nn.Tanh() self.drop = nn.Dropout(cfg.p_drop_hidden) #self.classifier = nn.Linear(cfg.dim, n_labels) self.local_pretrained = local_pretrained
def __init__(self, cfg): super().__init__() self.transformer = models.Transformer(cfg) #logits_word_vocab_size self.fc3 = nn.Linear(cfg.dim, cfg.dim) self.activ3 = models.gelu self.norm3 = models.LayerNorm(cfg) embed_weight3 = self.transformer.embed.tok_embed.weight n_vocab3, n_dim3 = embed_weight3.size() self.decoder3 = nn.Linear(n_dim3, n_vocab3, bias=False) self.decoder3.weight = embed_weight3 self.decoder3_bias = nn.Parameter(torch.zeros(n_vocab3))
def __init__(self, cfg): super().__init__() self.transformer = models.Transformer(cfg) self.fc = nn.Linear(cfg.dim, cfg.dim) self.activ1 = nn.Tanh() self.linear = nn.Linear(cfg.dim, cfg.dim) self.activ2 = models.gelu self.norm = models.LayerNorm(cfg) self.classifier = nn.Linear(cfg.dim, 2) # decoder is shared with embedding layer embed_weight = self.transformer.embed.tok_embed.weight n_vocab, n_dim = embed_weight.size() self.decoder = nn.Linear(n_dim, n_vocab, bias=False) self.decoder.weight = embed_weight self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
def loss_fn(params): """loss function used for training.""" logits = models.Transformer(config).apply( {'params': params}, inputs, targets, inputs_positions=inputs_positions, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, rngs={'dropout': dropout_rng}) loss, weight_sum = compute_weighted_cross_entropy( logits, targets, weights, label_smoothing) mean_loss = loss / weight_sum return mean_loss, logits
def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.Transformer(config).apply( { 'params': params, 'cache': flat_cache }, encoded_inputs, raw_inputs, # only needed for input padding mask flat_ids, mutable=['cache'], method=models.Transformer.decode) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache
def __init__(self, cfg, n_labels): super().__init__() self.transformer = models.Transformer(cfg) self.fc = nn.Linear(cfg.dim, cfg.dim) self.activ = nn.Tanh() self.drop = nn.Dropout(cfg.p_drop_hidden) self.classifier = nn.Linear(cfg.dim, n_labels) ''' self.n_labels = n_labels self.transformer = models.Transformer(cfg) self.fc = nn.Linear(cfg.dim, n_labels) self.norm = nn.BatchNorm1d(cfg.dim) self.activ = nn.Sigmoid() self.drop = nn.Dropout(cfg.p_drop_hidden) self.classifier = nn.Linear(cfg.max_len * n_labels, n_labels) ''' '''
def build_model(model_type, model_part, learned, seq_len, feature_num, kmer, clayer_num, filters, layer_sizes, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout=0.1): if model_type == "ConvModel": convlayers = [{"kernel_size": kmer, "filters": filters, "activation": "ReLU"} for _ in range(clayer_num)] model = models.ConvModel(model_part, seq_len, feature_num, convlayers, layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout, posembed=False) elif model_type == "SpannyConvModel": convlayers = [{"kernel_size": kmer, "filters": filters, "activation": "ReLU"} for _ in range(clayer_num)] global_kernel = {"kernel_size": kmer, "filters": filters, "activation": "ReLU"} model = models.SpannyConvModel(model_part, seq_len, feature_num, global_kernel, convlayers, layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout, posembed=False) elif model_type == "MHCflurry": locally_connected_layers = [{"kernel_size": kmer, "filters": filters, "activation": "Tanh"} for _ in range(clayer_num)] model = models.MHCflurry(model_part, seq_len, feature_num, locally_connected_layers, layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim, dropout) elif model_type == "Transformer": # d_model : model = models.Transformer(model_part, seq_len, feature_num, feature_num*2, filters, int(feature_num/filters), int(feature_num/filters), layer_sizes, learned, embedding_dim, activation, output_activation, transfer, transfer_dim) else: raise ValueError("Unsupported model type : "+model_type) if torch.cuda.is_available(): model.cuda() return model
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(workdir) train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) if config.batch_size % n_devices: raise ValueError("Batch size must be divisible by the number of devices") vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=config.dataset_name, eval_dataset_name=config.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), vocab_path=vocab_path, target_vocab_size=config.vocab_size, batch_size=config.batch_size, max_corpus_chars=config.max_corpus_chars, max_length=config.max_target_length, max_eval_length=config.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) input_shape = (config.batch_size, config.max_target_length) target_shape = (config.batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, config=eval_config, label_smoothing=config.label_smoothing), axis_name="batch") p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info("Starting training loop.") metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, config.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) # Save a checkpoint on one host after every checkpoint_freq steps. if (config.save_checkpoints and step % config.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue # Training Metrics logging.info("Gathering training metrics.") metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr steps_per_eval = config.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar("steps per second", steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info("train in step: %d, loss: %.4f", step, summary["loss"]) # Eval Metrics logging.info("Gathering evaluation metrics.") t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(config.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info("eval in step: %d, loss: %.4f", step, eval_summary["loss"]) logging.info("eval time: %.4f s step %d", time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info("Translating evaluation dataset.") t_inference_start = time.time() sources, references, predictions = [], [], [] for pred_batch in predict_ds: pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch["inputs"].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch["inputs"]) predicted = p_pred_step(pred_batch["inputs"], optimizer.target, cache, eos_id, config.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch["inputs"]) targets = tohost(pred_batch["targets"]) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info("Translation: %d predictions %d references %d sources.", len(predictions), len(references), len(sources)) logging.info("Translation time: %.4f s step %d.", time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = "" for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n" if jax.host_id() == 0: eval_summary_writer.scalar("bleu", bleu_score, step) eval_summary_writer.text("samples", exemplars, step) eval_summary_writer.flush() logging.info("Translation BLEU Score %.4f", bleu_score)
def initialize_variables(rng): return models.Transformer(eval_config).init( rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32))
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=jax.local_device_count(), config=config, reverse_translation=config.reverse_translation, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) target_shape = (config.per_device_batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 0: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, config=eval_config), axis_name="batch") p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("translate_and_bleu"): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = (step % config.checkpoint_every_steps == 0 or is_last_step) if config.save_checkpoints and save_checkpoint and jax.host_id() == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step)
if args.dataset == 'speechcoco': audio_root_path_train = os.path.join(args.data_dir, 'train2014/wav/') image_root_path_train = os.path.join(args.data_dir, 'train2014/imgs/') segment_file_train = os.path.join(args.data_dir, 'train2014/mscoco_train_word_segments.txt') bbox_file_train = os.path.join(args.data_dir, 'train2014/mscoco_train_rcnn_feature.npz') audio_root_path_test = os.path.join(args.data_dir, 'val2014/wav/') image_root_path_test = os.path.join(args.data_dir, 'val2014/imgs/') segment_file_test = os.path.join(args.data_dir, 'val2014/mscoco_val_word_segments.txt') bbox_file_test = os.path.join(args.data_dir, 'val2014/mscoco_val_rcnn_feature.npz') split_file = os.path.join(args.data_dir, 'val2014/mscoco_val_split.txt') if args.audio_model == 'davenet': audio_model = nn.DataParallel(models.Davenet(embedding_dim=1024)) elif args.audio_model == 'transformer': audio_model = nn.DataParallel(models.Transformer(embedding_dim=1024)) train_set = OnlineImageAudioCaptionDataset(audio_root_path_train, image_root_path_train, segment_file_train, bbox_file_train, configs={'return_boundary': True}) test_set = OnlineImageAudioCaptionDataset(audio_root_path_test, image_root_path_test, segment_file_test, bbox_file_test, keep_index_file=split_file, configs={'return_boundary': True}) elif args.dataset == 'mscoco': segment_file_train = os.path.join(args.data_dir, 'train2014/mscoco_train_word_phone_segments.txt') segment_file_test = os.path.join(args.data_dir, 'val2014/mscoco_val_word_phone_segments.txt') # TODO
batch_size=args.batch_size, shuffle=False, num_workers=1, pin_memory=True) # Initialize the image and audio models if args.audio_model == 'tdnn': if args.supervision_level == 'text': audio_model = models.SmallDavenetEncoder(audio_encoder_configs) else: # TODO audio_model = models.DavenetEncoder(audio_encoder_configs) elif args.audio_model == 'lstm': audio_model = models.BLSTMEncoder(audio_encoder_configs) elif args.audio_model == 'transformer': # TODO pretrained_model_file = '/ws/ifp-53_1/hasegawa/tools/espnet/egs/discophone/ifp_lwang114/dump/mscoco/eval/deltafalse/split1utt/data_encoder.pth' audio_model = models.Transformer( n_class=49, pretrained_model_file=pretrained_model_file) if args.image_model == 'res34': image_model = models.ResnetEncoder(image_encoder_configs) elif args.image_model == 'linear' or args.image_model == 'rcnn': image_model = models.LinearEncoder(image_encoder_configs) if args.supervision_level == 'audio': audio_segment_model = models.NoopSegmenter(audio_segmenter_configs) else: audio_segment_model = models.FixedTextSegmenter(audio_segmenter_configs) image_segment_model = models.NoopSegmenter(image_segmenter_configs) if args.alignment_model == 'mixture_aligner': if args.translate_direction == 'sp2im':
if not os.path.exists(args.exp_dir): os.makedirs("%s/models" % args.exp_dir) if args.precompute_acoustic_feature: audio_model = models.NoOpEncoder(embedding_dim=1000) image_model = models.LinearTrans(input_dim=2048, embedding_dim=1000) attention_model = models.DotProductAttention(in_size=1000) if not args.only_eval: train_attention(audio_model, image_model, attention_model, train_loader, val_loader, args) else: evaluation_attention(audio_model, image_model, attention_model, val_loader, args) else: audio_model = models.Transformer(embedding_dim=1024) image_model = models.LinearTrans(input_dim=2048, embedding_dim=1024) attention_model = models.DotProductAttention(in_size=1024) if not args.only_eval: train_attention(audio_model, image_model, attention_model, train_loader, val_loader, args) else: loader_for_alignment = torch.utils.data.DataLoader( dataloaders.OnlineImageAudioCaptionDataset( audio_root_path_test, image_root_path_test, segment_file_test, bbox_file_test, keep_index_file=split_file, configs={'return_boundary': True}), batch_size=args.batch_size,
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps eval_freq = FLAGS.eval_frequency random_seed = FLAGS.random_seed if not FLAGS.dev: raise app.UsageError('Please provide path to dev set.') if not FLAGS.train: raise app.UsageError('Please provide path to training set.') if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval')) # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) config = models.TransformerConfig( vocab_size=len(vocabs['forms']), output_vocab_size=len(vocabs['xpos']), max_len=FLAGS.max_length) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] train_ds = input_pipeline.sentence_dataset_dict( FLAGS.train, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len) train_iter = iter(train_ds) eval_ds = input_pipeline.sentence_dataset_dict( FLAGS.dev, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len, repeat=1) model = models.Transformer(config) rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) # call a jitted initialization function to get the initial parameter tree @jax.jit def initialize_variables(init_rng): init_batch = jnp.ones((config.max_len, 1), jnp.float32) init_variables = model.init(init_rng, inputs=init_batch, train=False) return init_variables init_variables = initialize_variables(init_rng) optimizer_def = optim.Adam(learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=1e-1) optimizer = optimizer_def.create(init_variables['params']) optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap( functools.partial(train_step, model=model, learning_rate_fn=learning_rate_fn), axis_name='batch') def eval_step(params, batch): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] weights = jnp.where(targets > 0, 1.0, 0.0) logits = model.apply({'params': params}, inputs=inputs, train=False) return compute_metrics(logits, targets, weights) p_eval_step = jax.pmap(eval_step, axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) metrics_all = [] tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) if (step + 1) % eval_freq == 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.process_index() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] # reset metric accumulation for next evaluation cycle. eval_metrics = [] eval_iter = iter(eval_ds) for eval_batch in eval_iter: eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size != batch_size: # pad up to batch size eval_batch = jax.tree_map( lambda x: pad_examples(x, batch_size), eval_batch) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] # TODO: save model. eval_summary['best_dev_score'] = best_dev_score logging.info('best development model score %.4f', best_dev_score) if jax.process_index() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
from torch.utils.data import DataLoader import models conf = configuration.Config() tokenizer = tokenization.FullTokenizer( en_vocab_file=os.path.join(conf.file_config.data_path, conf.file_config.en_vocab), zh_vocab_file=os.path.join(conf.file_config.data_path, conf.file_config.zh_vocab)) conf.model_config.src_vocab_size = tokenizer.get_en_vocab_size() + 2 conf.model_config.trg_vocab_size = tokenizer.get_zh_vocab_size() + 2 dataset = Translate(mode='demo', config=conf, tokenizer=tokenizer, do_filter=False) transformer = models.Transformer(conf) def collate_fn(batches): """ 每个batch做padding,而非所有样本做padding """ batch_en_ids = [torch.tensor(batch[0]) for batch in batches] batch_en_ids = rnn_utils.pad_sequence(batch_en_ids, batch_first=True, padding_value=0) if batches[0][1] is not None: batch_zh_ids = [torch.tensor(batch[1]) for batch in batches] batch_zh_ids = rnn_utils.pad_sequence(batch_zh_ids, batch_first=True, padding_value=0) else: batch_zh_ids = None