コード例 #1
0
ファイル: textgenrnn.py プロジェクト: Chegashi/bio_generator
    def train_on_texts(self,
                       texts,
                       context_labels=None,
                       batch_size=128,
                       num_epochs=50,
                       verbose=1,
                       new_model=False,
                       gen_epochs=1,
                       train_size=1.0,
                       max_gen_length=300,
                       validation=True,
                       dropout=0.0,
                       via_new_model=False,
                       save_epochs=0,
                       multi_gpu=False,
                       **kwargs):

        if new_model and not via_new_model:
            self.train_new_model(texts,
                                 context_labels=context_labels,
                                 num_epochs=num_epochs,
                                 gen_epochs=gen_epochs,
                                 train_size=train_size,
                                 batch_size=batch_size,
                                 dropout=dropout,
                                 validation=validation,
                                 save_epochs=save_epochs,
                                 multi_gpu=multi_gpu,
                                 **kwargs)
            return

        if context_labels:
            context_labels = LabelBinarizer().fit_transform(context_labels)

        if 'prop_keep' in kwargs:
            train_size = prop_keep

        if self.config['word_level']:
            # If training word level, must add spaces around each
            # punctuation. https://stackoverflow.com/a/3645946/9314418
            punct = '!"#$%&()*+,-./:;<=>?@[\]^_`{|}~\\n\\t\'‘’“”’–—…'
            for i in range(len(texts)):
                texts[i] = re.sub('([{}])'.format(punct), r' \1 ', texts[i])
                texts[i] = re.sub(' {2,}', ' ', texts[i])
            texts = [text_to_word_sequence(text, filters='') for text in texts]

        # calculate all combinations of text indices + token indices
        indices_list = [
            np.meshgrid(np.array(i), np.arange(len(text) + 1))
            for i, text in enumerate(texts)
        ]
        # indices_list = np.block(indices_list) # this hangs when indices_list is large enough
        # FIX BEGIN ------
        indices_list_o = np.block(indices_list[0])
        for i in range(len(indices_list) - 1):
            tmp = np.block(indices_list[i + 1])
            indices_list_o = np.concatenate([indices_list_o, tmp])
        indices_list = indices_list_o
        # FIX END ------

        # If a single text, there will be 2 extra indices, so remove them
        # Also remove first sequences which use padding
        if self.config['single_text']:
            indices_list = indices_list[self.config['max_length']:-2, :]

        indices_mask = np.random.rand(indices_list.shape[0]) < train_size

        if multi_gpu:
            num_gpus = len(config.get_visible_devices('GPU'))
            batch_size = batch_size * num_gpus

        gen_val = None
        val_steps = None
        if train_size < 1.0 and validation:
            indices_list_val = indices_list[~indices_mask, :]
            gen_val = generate_sequences_from_texts(texts, indices_list_val,
                                                    self, context_labels,
                                                    batch_size)
            val_steps = max(
                int(np.floor(indices_list_val.shape[0] / batch_size)), 1)

        indices_list = indices_list[indices_mask, :]

        num_tokens = indices_list.shape[0]
        assert num_tokens >= batch_size, "Fewer tokens than batch_size."

        level = 'word' if self.config['word_level'] else 'character'
        print("Training on {:,} {} sequences.".format(num_tokens, level))

        steps_per_epoch = max(int(np.floor(num_tokens / batch_size)), 1)

        gen = generate_sequences_from_texts(texts, indices_list, self,
                                            context_labels, batch_size)

        base_lr = 4e-3

        # scheduler function must be defined inline.
        def lr_linear_decay(epoch):
            return (base_lr * (1 - (epoch / num_epochs)))

        '''
        FIXME
        This part is a bit messy as we need to initialize the model within
        strategy.scope() when using multi-GPU. Can probably be cleaned up a bit.
        '''

        if context_labels is not None:
            if new_model:
                weights_path = None
            else:
                weights_path = "{}_weights.hdf5".format(self.config['name'])
                self.save(weights_path)

            if multi_gpu:
                from tensorflow import distribute as distribute
                strategy = distribute.MirroredStrategy()
                with strategy.scope():
                    parallel_model = textgenrnn_model(
                        self.num_classes,
                        dropout=dropout,
                        cfg=self.config,
                        context_size=context_labels.shape[1],
                        weights_path=weights_path)
                    parallel_model.compile(loss='categorical_crossentropy',
                                           optimizer=Adam(lr=4e-3))
                model_t = parallel_model
                print("Training on {} GPUs.".format(num_gpus))
            else:
                model_t = self.model
        else:
            if multi_gpu:
                from tensorflow import distribute as distribute
                if new_model:
                    weights_path = None
                else:
                    weights_path = "{}_weights.hdf5".format(
                        self.config['name'])

                strategy = distribute.MirroredStrategy()
                with strategy.scope():
                    # Do not locate model/merge on CPU since sample sizes are small.
                    parallel_model = textgenrnn_model(
                        self.num_classes,
                        cfg=self.config,
                        weights_path=weights_path)
                    parallel_model.compile(loss='categorical_crossentropy',
                                           optimizer=Adam(lr=4e-3))

                model_t = parallel_model
                print("Training on {} GPUs.".format(num_gpus))
            else:
                model_t = self.model

        model_t.fit(gen,
                    steps_per_epoch=steps_per_epoch,
                    epochs=num_epochs,
                    callbacks=[
                        LearningRateScheduler(lr_linear_decay),
                        generate_after_epoch(self, gen_epochs, max_gen_length),
                        save_model_weights(self, num_epochs, save_epochs)
                    ],
                    verbose=verbose,
                    max_queue_size=10,
                    validation_data=gen_val,
                    validation_steps=val_steps)

        # Keep the text-only version of the model if using context labels
        if context_labels is not None:
            self.model = Model(inputs=self.model.input[0],
                               outputs=self.model.output[1])
コード例 #2
0
from tensorflow import train, convert_to_tensor, round, function, distribute
from keras_bert.layers import Extract

from keras_self_attention import SeqSelfAttention
from tensorflow.python.keras.layers import BatchNormalization

from torch import cuda, tensor, long, float
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, TFBertModel, BertModel
from datetime import datetime
from sklearn.metrics import roc_curve, classification_report, auc
from abc import ABC

logging.basicConfig(level=logging.ERROR)
pd.set_option('display.max_columns', None)
strategy = distribute.MirroredStrategy()

# 调用bert 相关模型参数
with open('../config/config.yml', 'r', encoding='utf-8') as f:
    bert_config = yaml.safe_load(f)
    bert_config = bert_config['bert_model']

EPOCHS = bert_config['epochs']
BATCH_SIZE = bert_config['batch_size']
SHUFFLE = bert_config['shuffle']
NUM_WORKS = bert_config['num_works']
LEARNING_RATE = bert_config['lr']
TEXT_LEN = bert_config['text_len']
MAX_TO_KEEP = bert_config['max_to_keep']

# token = BertTokenizer.from_pretrained("bert-base-chinese")
コード例 #3
0
def train_cnn_classifier(paths, params):
	"""
	Train CNN classifier

	Parameters
	-----------


	"""

	# grid search variables
	cnn_architectures = ['v6']

	for cnn_architecture in cnn_architectures:

		# read datasets from file
		datasets = get_datasets_paths(paths['dataset_folder'])
		# # type of architecture to use
		# cnn_architecture = 'v3'
		# read one dataset and extract number of classes
		num_classes = len(np.unique(np.load(datasets['Y_train'])))
		# read input shape
		input_shape = np.load(datasets['X_train']).shape
		# model checkpoint and final model save folder
		model_save_folder = os.path.join(paths['model_folder'], get_current_timestamp())
		# create folder
		create_directory(model_save_folder)
		

		"""
			DEFINE LEARNING PARAMETERS
		"""
		params.update({'ARCHITECTURE' : cnn_architecture,
					'NUM_CLASSES' : num_classes,
					'LR' : .05,
					'OPTIMIZER' : 'sgd',
					'TRAIN_SHAPE' : input_shape,
					'INPUT_SHAPE' : input_shape[1:],
					'BATCH_SIZE' : 32,
					'EPOCHS' : 100,
					'ES' : True,
					'ES_PATIENCE' : 20,
					'ES_RESTORE_WEIGHTS' : True,
					'SAVE_CHECKPOINTS' : True,
					'RESCALE' : params['rescale_factor'],
					'ROTATION_RANGE' : None,
					'WIDTH_SHIFT_RANGE' : None,
					'HEIGHT_SHIFT_RANGE' : None,
					'SHEAR_RANGE' : None,
					'ZOOM_RANGE' : None,
					'HORIZONTAL_FLIP' : False,
					'VERTICAL_FLIP' : False,
					'BRIGHTNESS_RANGE' : None,
					})

		"""
			DATAGENERATORS
		"""

		# generator for training data
		train_generator = create_image_data_generator(x = datasets['X_train'], y = datasets['Y_train'], batch_size = params['BATCH_SIZE'], rescale = params['RESCALE'],
													rotation_range = params['ROTATION_RANGE'], width_shift_range = params['WIDTH_SHIFT_RANGE'],
													height_shift_range = params['HEIGHT_SHIFT_RANGE'], shear_range = params['SHEAR_RANGE'], 
													zoom_range = params['ZOOM_RANGE'], horizontal_flip = params['HORIZONTAL_FLIP'],
													vertical_flip = params['VERTICAL_FLIP'], brightness_range = params['BRIGHTNESS_RANGE'],
													save_to_dir = None if paths['augmentation_folder'] is None else paths['augmentation_folder'])

		# generator for validation data
		val_generator = create_image_data_generator(x = datasets['X_val'], y = datasets['Y_val'], batch_size = params['BATCH_SIZE'], rescale = params['RESCALE'])	
		
		# generator for test data
		test_generator = create_image_data_generator(x = datasets['X_test'], y = datasets['Y_test'], batch_size = params['BATCH_SIZE'], rescale = params['RESCALE'])	

		"""
			CALLBACKS
		"""

		# empty list to hold callbacks
		callback_list = []

		# early stopping callback
		if params['ES']:
			callback_list.append(EarlyStopping(monitor = 'val_loss', min_delta = 0, patience = params['ES_PATIENCE'], restore_best_weights = params['ES_RESTORE_WEIGHTS'], verbose = 1, mode = 'auto'))

		# save checkpoints model
		if params['SAVE_CHECKPOINTS']:
			# create checkpoint subfolder
			create_directory(os.path.join(model_save_folder, 'checkpoints'))
			callback_list.append(ModelCheckpoint(filepath = os.path.join(model_save_folder, 'checkpoints', 'checkpoint_model.{epoch:02d}_{val_loss:.3f}_{val_accuracy:.3f}.h5'), save_weights_only = False, monitor = 'val_loss', mode = 'auto', save_best_only = True))

		"""
			TRAIN CNN MODEL
		"""

		# use multi GPUs
		mirrored_strategy = distribute.MirroredStrategy()
		
		# context manager for multi-gpu
		with mirrored_strategy.scope():

			# get cnn model architecture
			model = get_cnn_model(cnn_type = params['ARCHITECTURE'], input_shape = params['INPUT_SHAPE'], num_classes = params['NUM_CLASSES'], learning_rate = params['LR'], optimizer_name = params['OPTIMIZER'])
			
			history = model.fit(train_generator,
				epochs = params['EPOCHS'], 
				steps_per_epoch = len(train_generator),
				validation_data = val_generator,
				validation_steps = len(val_generator),
				callbacks = callback_list)

			# evaluate on test set
			history_test = model.evaluate(test_generator)

			# save the whole model
			model.save(os.path.join(model_save_folder, 'model.h5'))
			
			# save history of training
			pd.DataFrame(history.history).to_csv(os.path.join(model_save_folder, 'history_training.csv'))
			
			# save test results
			pd.DataFrame(history_test, index = ['loss', 'accuracy']).to_csv(os.path.join(model_save_folder, 'history_test.csv'))

			# save model hyperparameters
			pd.DataFrame(pd.Series(params)).to_csv(os.path.join(model_save_folder, 'params.csv'))
コード例 #4
0
 def train_model(self, params):
     # Compute dictionary hash
     ident = dencode(params)
     chk = sha256(ident.encode("utf-8")).hexdigest()
     params['hash'] = chk
     # Do not forget to extract it
     for key in params.keys():
         if isinstance(key, list):
             params[key] = params[key][0]
     # Check for some basic defaults
     if 'batch_size' in params.keys():
         batch_size = params['batch_size']
     else:
         batch_size = 1
     if "epochs" in params.keys():
         epochs = params['epochs']
     else:
         epochs = 1
     if "learning_schedule" in params.keys():
         # Assume there exists some file
         # myfile.py
         # Which contains the learning schedule function
         # def fun(epoch, lr)
         # Then params.json would contain
         # myfile.fun
         all_things = params["learning_schedule"]
         all_things = all_things.split(".")
         baseImport, func = ".".join(all_things[:-1]), all_things[-1]
         # Import function dynamically
         baseImport = __import__(baseImport, globals(), locals(), [func], 0)
         # Extract function from returned object
         func = getattr(baseImport, func)
         schedule = LearningRateScheduler(func)
     else:
         # This is the standard keras learning rate schedule
         def_schedule = lambda epoch, lr: lr
         schedule = LearningRateScheduler(def_schedule)
     # Just pass everything to create, the function may need it.
     model = self.cm(**params)
     if "gpu" in params.keys():
         num_gpus = params['gpu']
         if self.agpu:
             # opt = Adam(lr=lr)
             if num_gpus > 1 and tfversion < twoOh:
                 opt = model.optimizer
                 model = multi_gpu_model(model, num_gpus)
                 model.compile(opt, model.loss, metrics=['accuracy'])
             elif num_gpus > 1 and tfversion >= twoOh:
                 # strat = D.MirroredStrategy(devices=self.phys)
                 vis = list(map(lambda x: "/gpu:{}".format(x), self.vis))
                 # print("vis", vis)
                 # Currently crashing due to NCCL error
                 # strat = D.MirroredStrategy(devices=vis)
                 # Possible fix is using separate merge technique
                 strat = D.MirroredStrategy(devices=vis,
                                            cross_device_ops=tf.distribute.
                                            HierarchicalCopyAllReduce())
                 # strat = D.MirroredStrategy()
                 with strat.scope():
                     model = self.cm(**params)
                     opt = model.optimizer
                     model.compile(opt, model.loss, metrics=['accuracy'])
     if self.augmentation:
         # Create a primitive augmentation object
         datagen = ImageDataGenerator(rotation_range=5,
                                      width_shift_range=0,
                                      height_shift_range=0,
                                      shear_range=0,
                                      zoom_range=0,
                                      horizontal_flip=True,
                                      fill_mode='nearest')
         self.aprint("Training a network {}...".format(datetime.now()))
         # Fit the new model
         start_time = datetime.now()
         history = model.fit_generator(datagen.flow(self.idata,
                                                    self.odata,
                                                    batch_size=batch_size,
                                                    shuffle=shuffle),
                                       steps_per_epoch=coeff *
                                       self.idata.shape[0] // batch_size,
                                       epochs=epochs,
                                       verbose=0,
                                       use_multiprocessing=True,
                                       workers=threads)
         end_time = datetime.now()
         # Test accuracy
         loss, accuracy = model.evaluate(x=self.idatae,
                                         y=self.odatae,
                                         batch_size=batch_size,
                                         verbose=0)
     else:
         self.aprint("Training network {} :: {}...".format(
             chk, datetime.now()))
         # Fit the new model
         # history = model.fit(x=self.idata, y=self.odata, batch_size=batch_size,
         # [self.idata, self.idata, self.idata],
         if not isdir("./training_logs/"):
             try:
                 mkdir("./training_logs/")
             except:
                 pass
         csv = CSVLogger("./training_logs/{}.log".format(chk), append=False)
         start_time = datetime.now()
         history = model.fit(
             x=self.idata,
             y=self.odata,  #, self.odata[:,1,:], self.odata[:,2,:]], 
             batch_size=batch_size,
             validation_split=0.2,
             epochs=epochs,
             verbose=0,
             shuffle=True,
             callbacks=[TimeHistory(), schedule, csv])
         # epochs=epochs, verbose=1, shuffle=True)
         end_time = datetime.now()
         # Test accuracy
         # loss, accuracy = model.evaluate( x=self.idatae, y=self.odatae,
         # [self.idatae, self.idatae, self.idatae],
         loss, accuracy = model.evaluate(
             x=self.idatae,
             y=self.odatae,  # self.odatae[:,1,:], self.odatae[:,2,:]],
             batch_size=batch_size,
             verbose=0)
     # Compute timing metrics
     run_time = deltaToString(end_time - start_time)
     params['time'] = str(run_time)
     # Consider accuracy to be the average of all three layer accuracies
     # acc = list(map(lambda x: sum(x)/len(x), zip(layer1_acc, layer2_acc, layer3_acc)))
     # accuracy = layer1_acc #+ layer2_acc + layer3_acc
     # accuracy /= 3
     params['acc'] = str(accuracy)
     # Gather all output accuracies
     keys = history.history.keys()
     # taccs = [history.history[key] for key in keys if "accuracy" in key]
     # loss = [history.history[key] for key in keys if "loss" in key]
     # print(history.history.keys())
     # params['tacc'] = list(
     # map(lambda x: sum(x)/len(x), zip(*taccs)))
     # params['tacc'] = history.history['accuracy']
     # params['loss'] = history.history['loss']
     # params['vloss'] = history.history['val_loss']
     # params['vacc'] = history.history['val_accuracy']
     # list( map(lambda x: sum(x)/len(x), zip(*loss)))
     # params['acc1'] = taccs[0]
     # params['acc2'] = taccs[1]
     # params['acc3'] = taccs[2]
     if not isdir("./weights/"):
         try:
             mkdir("./weights/")
         except:
             pass
     results = str(params)
     model.save("./weights/{}".format(chk))
     self.aprint("Trained! {}".format(results))
     # model.summary()
     return results
コード例 #5
0
def main(_):
    config = ALBERT_CONFIG[FLAGS.model_type]
    _init(config['max_seq_length'], config['max_mask_length'])

    # Make strategy
    assert FLAGS.strategy, 'Strategy can not be empty'
    if FLAGS.strategy == 'mirror':
        strategy = distribute.MirroredStrategy()
    elif FLAGS.strategy == 'tpu':
        cluster_resolver = init_tpu(FLAGS.tpu_addr)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    else:
        raise ValueError(
            'The distribution strategy type is not supported: %s' %
            FLAGS.strategy)

    # Prepare training dataset
    file_list = tf.data.Dataset.list_files(FLAGS.train_files)
    train_dataset = tf.data.TFRecordDataset(filenames=file_list)
    train_dataset = train_dataset.shuffle(
        buffer_size=FLAGS.shuffle_buffer_size)
    train_dataset = train_dataset.map(
        dump_example,
        num_parallel_calls=tf.data.experimental.AUTOTUNE).cache()
    train_dataset = train_dataset.shuffle(100)
    train_dataset = train_dataset.repeat()  # loop
    train_dataset = train_dataset.batch(FLAGS.train_batch_size)
    train_dataset = train_dataset.prefetch(FLAGS.train_batch_size)

    # Prepare evaluation dataset
    file_list = tf.data.Dataset.list_files(FLAGS.eval_files)
    eval_dataset = tf.data.TFRecordDataset(filenames=file_list)
    eval_dataset = eval_dataset.map(
        dump_example,
        num_parallel_calls=tf.data.experimental.AUTOTUNE).cache()
    eval_dataset = eval_dataset.batch(FLAGS.eval_batch_size)
    eval_dataset = eval_dataset.prefetch(FLAGS.eval_batch_size)

    def make_iter_dataset(dataset):
        iter_dataset = iter(
            strategy.experimental_distribute_dataset(train_dataset))
        return iter_dataset

    # Training
    with strategy.scope():
        # Build Albert model
        logging.info('Build albert: config: %s', config)
        classifier_model, albert_model, optimizer = model_fn(config)

        if FLAGS.init_checkpoint:
            logging.info('Restore albert_model from initial checkpoint: %s',
                         FLAGS.init_checkpoint)
            checkpoint = tf.train.Checkpoint(albert_model=albert_model)
            checkpoint.restore(FLAGS.init_checkpoint)

        # Make metric functions
        train_loss_metric = keras.metrics.Mean('training_loss',
                                               dtype=tf.float32)
        eval_metrics = [
            keras.metrics.SparseCategoricalAccuracy('test_accuracy',
                                                    dtype=tf.float32),
        ]
        train_metrics = [
            keras.metrics.SparseCategoricalAccuracy('test_accuracy',
                                                    dtype=tf.float32),
        ]

        # Make summary writers
        summary_dir = os.path.join(FLAGS.model_dir, 'summaries')
        eval_summary_writer = tf.summary.create_file_writer(
            os.path.join(summary_dir, 'eval'))
        train_summary_writer = tf.summary.create_file_writer(
            os.path.join(summary_dir, 'train'))

        def step_fn(batch):
            input_ids = batch['input_ids']
            input_mask = batch['input_mask']
            segment_ids = batch['segment_ids']
            mask_positions = batch['mask_positions']
            mask_label_ids = batch['mask_label_ids']
            mask_weights = batch['mask_weights']
            next_id = batch['next_id']
            attention_mask = create_attention_mask(input_mask)

            train_inputs = {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'segment_ids': segment_ids,
            }

            with tf.GradientTape() as tape:
                train_outputs = classifier_model(train_inputs)
                lm_logits = gather_indexes(train_outputs[0], mask_positions)
                loss_lm = sparse_categorical_crossentropy(
                    mask_label_ids,
                    lm_logits,
                    weights=mask_weights,
                    from_logits=True,
                )
                next_logits = train_outputs[1]
                loss_next = sparse_categorical_crossentropy(
                    next_id,
                    next_logits,
                    from_logits=True,
                )
                loss = loss_lm + loss_next

            grads = tape.gradient(loss, classifier_model.trainable_weights)
            optimizer.apply_gradients(
                list(zip(grads, classifier_model.trainable_weights)))

            # metric
            train_loss_metric.update_state(loss)
            for metric in train_metrics:
                # mask_label_ids = [batch_size * max_mask_length, 1]
                mask_label_ids = tf.reshape(mask_label_ids, [-1, 1])

                metric.update_state(mask_label_ids, lm_logits)
                metric.update_state(next_id, next_logits)

            return loss

        @tf.function
        def train_steps(iterator, steps):
            for step in tf.range(steps):
                strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

        @tf.function
        def test_step(iterator):
            def test_step_fn(batch):
                input_ids = batch['input_ids']
                input_mask = batch['input_mask']
                segment_ids = batch['segment_ids']
                mask_positions = batch['mask_positions']
                mask_label_ids = batch['mask_label_ids']
                next_id = batch['next_id']
                attention_mask = create_attention_mask(input_mask)

                eval_inputs = {
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                    'segment_ids': segment_ids,
                }
                eval_outputs = classifier_model(eval_inputs)
                lm_logits = gather_indexes(eval_outputs[0], mask_positions)
                next_logits = eval_outputs[1]

                for metric in eval_metrics:
                    # mask_label_ids = [batch_size * max_mask_length, 1]
                    mask_label_ids = tf.reshape(mask_label_ids, [-1, 1])

                    metric.update_state(mask_label_ids, lm_logits)
                    metric.update_state(next_id, next_logits)

            strategy.experimental_run_v2(test_step_fn, args=(next(iterator), ))

        def _run_evaluation(current_step, iterator):
            for _ in range(FLAGS.eval_steps):
                test_step(iterator)

            log = f'eval step: {current_step}, '
            with eval_summary_writer.as_default():
                for metric in eval_metrics:
                    metric_value = _float_metric_value(metric)
                    tf.summary.scalar(metric.name,
                                      metric_value,
                                      step=current_step)
                    log += f'metric: {metric.name} = {metric_value}, '
                eval_summary_writer.flush()
            logging.info(log)

        # Restore classifier_model
        checkpoint = tf.train.Checkpoint(model=classifier_model,
                                         optimizer=optimizer)
        latest_checkpoint_file = tf.train.latest_checkpoint(FLAGS.model_dir)
        if latest_checkpoint_file:
            logging.info(
                'Restore classifier_model from the latest checkpoint file: %s',
                latest_checkpoint_file)
            checkpoint.restore(latest_checkpoint_file)

        train_iter_dataset = make_iter_dataset(train_dataset)
        total_training_steps = FLAGS.epochs * FLAGS.steps_per_epoch
        current_step = optimizer.iterations.numpy()
        while current_step < total_training_steps:
            steps = steps_to_run(current_step, FLAGS.steps_per_epoch,
                                 FLAGS.steps_per_loop)
            # Converts steps to a Tensor to avoid tf.function retracing.
            train_steps(train_iter_dataset,
                        tf.convert_to_tensor(steps, dtype=tf.int32))
            current_step += steps
            logging.info('Step: %s', current_step)

            log = f'training step: {current_step}, '
            with train_summary_writer.as_default():
                for metric in train_metrics + [train_loss_metric]:
                    metric_value = _float_metric_value(metric)
                    tf.summary.scalar(metric.name,
                                      metric_value,
                                      step=current_step)
                    log += f'metric: {metric.name} = {metric_value}, '
                train_summary_writer.flush()

            if current_step % FLAGS.steps_per_epoch:
                checkpoint_path = os.path.join(
                    FLAGS.model_dir, f'model_step_{current_step}.ckpt')
                checkpoint.save(checkpoint_path)
                logging.info('Save model to {checkpoint_path}')

                eval_iter_dataset = make_iter_dataset(eval_dataset)
                _run_evaluation(current_step, eval_iter_dataset)