def __init__(self, model, data, config): self.model_train = model['train'] self.model_test = model['eval'] self.data = data self.config = config self.loss_fn = ClassificationLoss(config) self.optimizer = Adam(learning_rate=self.config.learning_rate) self.metric_train = ClassificationMetrics(config.train_writer, name='train') self.metric_train_eval = ClassificationMetrics(config.train_writer, name='train_eval') self.metric_test = ClassificationMetrics(config.test_writer, name='test') self.global_step = tf.Variable(0, trainable=False, dtype=tf.int64)
class ClassificationTrainer: def __init__(self, model, data, config): self.model_train = model['train'] self.model_test = model['eval'] self.data = data self.config = config self.loss_fn = ClassificationLoss(config) self.optimizer = Adam(learning_rate=self.config.learning_rate) self.metric_train = ClassificationMetrics(config.train_writer, name='train') self.metric_train_eval = ClassificationMetrics(config.train_writer, name='train_eval') self.metric_test = ClassificationMetrics(config.test_writer, name='test') self.global_step = tf.Variable(0, trainable=False, dtype=tf.int64) @tf.function def compute_grads(self, samples, targets): with tf.GradientTape() as tape: predictions = self.model_train(samples, training=True) ''' generate the targets and apply the corresponding loss function ''' loss = self.loss_fn(targets, predictions) gradients = tape.gradient(loss, self.model_train.trainable_weights) gradients, grad_norm = tf.clip_by_global_norm(gradients, self.config.clip_grad_norm) with self.config.train_writer.as_default(): tf.summary.scalar("grad_norm", grad_norm, self.global_step) self.global_step.assign_add(1) return gradients, predictions @tf.function def apply_grads(self, gradients): self.optimizer.apply_gradients(zip(gradients, self.model_train.trainable_weights)) def sync_eval_model(self): model_weights = self.model_train.get_weights() ma_weights = self.model_test.get_weights() alpha = self.config.moving_average_coefficient self.model_test.set_weights([ma * alpha + w * (1 - alpha) for ma, w in zip(ma_weights, model_weights)]) @tf.function def train_step(self, samples, targets): gradients, predictions = self.compute_grads(samples, targets) self.apply_grads(gradients) return predictions @tf.function def eval_step(self, samples): predictions = self.model_test(samples, training=False) return predictions def train_epoch(self, epoch): self.metric_train.reset_states() self.model_train.reset_states() for samples, targets in self.data['train']: predictions = self.train_step(samples, targets) self.metric_train.update_state(targets, predictions) self.sync_eval_model() self.metric_train.print(epoch) self.metric_train.log_metrics(epoch) def evaluate_train(self, epoch): self.metric_train_eval.reset_states() self.model_test.reset_states() for samples, targets in self.data['train_eval']: predictions = self.eval_step(samples) self.metric_train_eval.update_state(targets, predictions) self.metric_train_eval.print(epoch) self.metric_train_eval.log_metrics(epoch) def evaluate_test(self, epoch): self.metric_test.reset_states() self.model_test.reset_states() for samples, targets in self.data['test']: predictions = self.eval_step(samples) self.metric_test.update_state(targets, predictions) self.metric_test.print(epoch) self.metric_test.log_metrics(epoch) def train(self): for epoch in range(self.config.num_epochs): self.train_epoch(epoch) if epoch % self.config.eval_freq == 0: self.evaluate_train(epoch) self.evaluate_test(epoch)