def testInputTargetBatch(self): """Test the batching of the dataset.""" vocabs = input_pipeline.create_vocabs(self._filename) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] sentence_dataset = input_pipeline.sentence_dataset_dict( self._filename, vocabs, attributes_input, attributes_target, batch_size=2, bucket_size=10, repeat=1) sentence_dataset_iter = iter(sentence_dataset) batch = next(sentence_dataset_iter) inputs = batch['inputs'].numpy().tolist() self.assertSameStructure(inputs, [[2., 3., 4., 5., 6., 0., 0., 0., 0., 0.], [2., 3., 4., 5., 6., 0., 0., 0., 0., 0.]]) targets = batch['targets'].numpy().tolist() self.assertSameStructure(targets, [[2., 4., 5., 3., 6., 0., 0., 0., 0., 0.], [2., 4., 5., 3., 6., 0., 0., 0., 0., 0.]])
def test_vocab_creation(self): """Tests the creation of the vocab.""" vocabs = input_pipeline.create_vocabs(self._filename) self.assertEqual( vocabs['forms'], { '<p>': 0, '<u>': 1, '<r>': 2, 'They': 3, 'buy': 4, 'books': 5, '.': 6, 'NY': 7, })
def testInputBatch(self): """Test the batching of the dataset.""" vocabs = input_pipeline.create_vocabs(self._filename) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [] # empty target for tagging of unlabeled data. sentence_dataset = input_pipeline.sentence_dataset_dict( self._filename, vocabs, attributes_input, attributes_target, batch_size=2, bucket_size=10, repeat=1) sentence_dataset_iter = iter(sentence_dataset) batch = next(sentence_dataset_iter) inputs = batch['inputs'].numpy().tolist() self.assertSameStructure(inputs, [[2., 3., 4., 5., 6., 0., 0., 0., 0., 0.], [2., 3., 4., 5., 6., 0., 0., 0., 0., 0.]]) self.assertLen(batch, 1) # make sure target is not included.
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps num_eval_steps = FLAGS.num_eval_steps eval_freq = FLAGS.eval_frequency max_length = FLAGS.max_length random_seed = FLAGS.random_seed if not FLAGS.dev: raise app.UsageError('Please provide path to dev set.') if not FLAGS.train: raise app.UsageError('Please provide path to training set.') parameter_path = os.path.join(FLAGS.model_dir, FLAGS.experiment + '.params') if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval')) if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] train_ds = input_pipeline.sentence_dataset_dict( FLAGS.train, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=max_length) eval_ds = input_pipeline.sentence_dataset_dict( FLAGS.dev, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=max_length, repeat=1) train_iter = iter(train_ds) bs = device_batch_size * jax.device_count() rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) input_shape = (bs, max_length) transformer_kwargs = { 'vocab_size': len(vocabs['forms']), 'output_vocab_size': len(vocabs['xpos']), 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': max_length, } model = create_model(init_rng, tuple(input_shape), transformer_kwargs) optimizer = create_optimizer(model, learning_rate) del model # don't keep a copy of the initial model learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) metrics_all = [] tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) if (step + 1) % eval_freq == 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # reset metric accumulation for next evaluation cycle. metrics_all = [] eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size != batch_size: logging.info('Uneven batch size %d.', cur_pred_batch_size) eval_batch = jax.tree_map( lambda x: pad_examples(x, batch_size), eval_batch) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip( jnp.exp(eval_summary['loss']), a_max=1.0e4) logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] # TODO: save model. eval_summary['best_dev_score'] = best_dev_score logging.info('best development model score %.4f', best_dev_score) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps eval_freq = FLAGS.eval_frequency random_seed = FLAGS.random_seed if not FLAGS.dev: raise app.UsageError('Please provide path to dev set.') if not FLAGS.train: raise app.UsageError('Please provide path to training set.') if batch_size % jax.device_count() > 0: raise ValueError('Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval')) # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) config = models.TransformerConfig( vocab_size=len(vocabs['forms']), output_vocab_size=len(vocabs['xpos']), max_len=FLAGS.max_length) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] train_ds = input_pipeline.sentence_dataset_dict( FLAGS.train, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len) train_iter = iter(train_ds) eval_ds = input_pipeline.sentence_dataset_dict( FLAGS.dev, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len, repeat=1) model = models.Transformer(config) rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) # call a jitted initialization function to get the initial parameter tree @jax.jit def initialize_variables(init_rng): init_batch = jnp.ones((config.max_len, 1), jnp.float32) init_variables = model.init(init_rng, inputs=init_batch, train=False) return init_variables init_variables = initialize_variables(init_rng) optimizer_def = optim.Adam(learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=1e-1) optimizer = optimizer_def.create(init_variables['params']) optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap( functools.partial(train_step, model=model, learning_rate_fn=learning_rate_fn), axis_name='batch') def eval_step(params, batch): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] weights = jnp.where(targets > 0, 1.0, 0.0) logits = model.apply({'params': params}, inputs=inputs, train=False) return compute_metrics(logits, targets, weights) p_eval_step = jax.pmap(eval_step, axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) metrics_all = [] tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step(optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) if (step + 1) % eval_freq == 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.process_index() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] # reset metric accumulation for next evaluation cycle. eval_metrics = [] eval_iter = iter(eval_ds) for eval_batch in eval_iter: eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size != batch_size: # pad up to batch size eval_batch = jax.tree_map( lambda x: pad_examples(x, batch_size), eval_batch) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] # TODO: save model. eval_summary['best_dev_score'] = best_dev_score logging.info('best development model score %.4f', best_dev_score) if jax.process_index() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()