Esempio n. 1
0
    def evaluate(self, seq_data, head_i=0, loss='poisson'):
        """ Evaluate model on SeqDataset. """
        # choose model
        if self.ensemble is None:
            model = self.models[head_i]
        else:
            model = self.ensemble

        # compile with dense metrics
        num_targets = self.model.output_shape[-1]

        if loss == 'bce':
            model.compile(optimizer=tf.keras.optimizers.SGD(),
                          loss=loss,
                          metrics=[
                              metrics.SeqAUC(curve='ROC', summarize=False),
                              metrics.SeqAUC(curve='PR', summarize=False)
                          ])
        else:
            model.compile(optimizer=tf.keras.optimizers.SGD(),
                          loss=loss,
                          metrics=[
                              metrics.PearsonR(num_targets, summarize=False),
                              metrics.R2(num_targets, summarize=False)
                          ])

        # evaluate
        return model.evaluate(seq_data.dataset)
Esempio n. 2
0
 def compile(self, seqnn_model):
   for model in seqnn_model.models:
     num_targets = model.output_shape[-1]
     model.compile(loss=self.loss_fn,
                   optimizer=self.optimizer,
                   metrics=[metrics.PearsonR(num_targets), metrics.R2(num_targets)])
   self.compiled = True
Esempio n. 3
0
 def compile(self, seqnn_model):
   for model in seqnn_model.models:
     if self.loss == 'bce':
       model_metrics = [metrics.SeqAUC(curve='ROC'), metrics.SeqAUC(curve='PR')]
     else:
       num_targets = model.output_shape[-1]
       model_metrics = [metrics.PearsonR(num_targets), metrics.R2(num_targets)]
     
     model.compile(loss=self.loss_fn,
                   optimizer=self.optimizer,
                   metrics=model_metrics)
   self.compiled = True
Esempio n. 4
0
  def fit2(self, seqnn_model):
    if not self.compiled:
      self.compile(seqnn_model)

    assert(len(seqnn_model.models) >= self.num_datasets)

    ################################################################
    # prep

    # metrics
    train_loss, train_r, train_r2 = [], [], []
    for di in range(self.num_datasets):
      num_targets = seqnn_model.models[di].output_shape[-1]
      train_loss.append(tf.keras.metrics.Mean())
      train_r.append(metrics.PearsonR(num_targets))
      train_r2.append(metrics.R2(num_targets))

    # generate decorated train steps
    """
    train_steps = []
    for di in range(self.num_datasets):
      model = seqnn_model.models[di]

      @tf.function
      def train_step(x, y):
        with tf.GradientTape() as tape:
          pred = model(x, training=tf.constant(True))
          loss = self.loss_fn(y, pred) + sum(model.losses)
        train_loss[di](loss)
        train_r[di](y, pred)
        train_r2[di](y, pred)
        gradients = tape.gradient(loss, model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

      train_steps.append(train_step)
    """
    @tf.function
    def train_step0(x, y):
      with tf.GradientTape() as tape:
        pred = seqnn_model.models[0](x, training=tf.constant(True))
        loss = self.loss_fn(y, pred) + sum(seqnn_model.models[0].losses)
      train_loss[0](loss)
      train_r[0](y, pred)
      train_r2[0](y, pred)
      gradients = tape.gradient(loss, seqnn_model.models[0].trainable_variables)
      self.optimizer.apply_gradients(zip(gradients, seqnn_model.models[0].trainable_variables))

    if self.num_datasets > 1:
      @tf.function
      def train_step1(x, y):
        with tf.GradientTape() as tape:
          pred = seqnn_model.models[1](x, training=tf.constant(True))
          loss = self.loss_fn(y, pred) + sum(seqnn_model.models[1].losses)
        train_loss[1](loss)
        train_r[1](y, pred)
        train_r2[1](y, pred)
        gradients = tape.gradient(loss, seqnn_model.models[1].trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, seqnn_model.models[1].trainable_variables))

    # improvement variables
    valid_best = [np.inf]*self.num_datasets
    unimproved = [0]*self.num_datasets

    ################################################################
    # training loop

    for ei in range(self.train_epochs_max):
      if ei >= self.train_epochs_min and np.min(unimproved) > self.patience:
        break
      else:
        # shuffle datasets
        np.random.shuffle(self.dataset_indexes)

        # get iterators
        train_data_iters = [iter(td.dataset) for td in self.train_data]

        # train
        t0 = time.time()
        for di in self.dataset_indexes:
          x, y = next(train_data_iters[di])
          # train_steps[di](x, y)
          if di == 0:
            train_step0(x, y)
          else:
            train_step1(x, y)

        print('Epoch %d - %ds' % (ei, (time.time()-t0)))
        for di in range(self.num_datasets):
          print('  Data %d' % di, end='')
          model = seqnn_model.models[di]

          # print training accuracy
          print(' - train_loss: %.4f' % train_loss[di].result().numpy(), end='')
          print(' - train_r: %.4f' %  train_r[di].result().numpy(), end='')
          print(' - train_r: %.4f' %  train_r2[di].result().numpy(), end='')

          # print validation accuracy
          valid_stats = model.evaluate(self.eval_data[di].dataset, verbose=0)
          print(' - valid_loss: %.4f' % valid_stats[0], end='')
          print(' - valid_r: %.4f' % valid_stats[1], end='')
          print(' - valid_r2: %.4f' % valid_stats[2], end='')
          early_stop_stat = valid_stats[1]

          # checkpoint
          model.save('%s/model%d_check.h5' % (self.out_dir, di))

          # check best
          if early_stop_stat > valid_best[di]:
            print(' - best!', end='')
            unimproved[di] = 0
            valid_best[di] = early_stop_stat
            model.save('%s/model%d_best.h5' % (self.out_dir, di))
          else:
            unimproved[di] += 1
          print('', flush=True)

          # reset metrics
          train_loss[di].reset_states()
          train_r[di].reset_states()
          train_r2[di].reset_states()
Esempio n. 5
0
  def fit_tape(self, model):
    if not self.compiled:
      self.compile(model)

    # metrics
    num_targets = model.output_shape[-1]
    train_loss = tf.keras.metrics.Mean()
    train_r = metrics.PearsonR(num_targets)
    train_r2 = metrics.R2(num_targets)
    
    @tf.function
    def train_step(x, y):
      with tf.GradientTape() as tape:
        pred = model(x, training=tf.constant(True))
        loss = self.loss_fn(y, pred) + sum(model.losses)
      train_loss(loss)
      train_r(y, pred)
      gradients = tape.gradient(loss, model.trainable_variables)
      self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # improvement variables
    valid_best = np.inf
    unimproved = 0

    # training loop
    for ei in range(self.train_epochs_max):
      if ei >= self.train_epochs_min and unimproved > self.patience:
        break
      else:
        # train
        t0 = time.time()
        train_iter = iter(self.train_data.dataset)
        for si in range(self.train_epoch_batches):
          x, y = next(train_iter)
          train_step(x, y)

        # print training accuracy
        train_loss_epoch = train_loss.result().numpy()
        train_r_epoch = train_r.result().numpy()
        print('Epoch %d - %ds - train_loss: %.4f - train_r: %.4f' % (ei, (time.time()-t0), train_loss_epoch, train_r_epoch), end='')

        # checkpoint
        model.save('%s/model_check.h5'%self.out_dir)

        # print validation accuracy
        valid_loss, valid_pr, valid_r2 = model.evaluate(self.eval_data.dataset, verbose=0)
        print(' - valid_loss: %.4f - valid_r: %.4f - valid_r2: %.4f' % (valid_loss, valid_pr, valid_r2), end='')

        # check best
        if valid_pr > valid_best:
          print(' - best!', end='')
          unimproved = 0
          valid_best = valid_pr
          model.save('%s/model_best.h5'%self.out_dir)
        else:
          unimproved += 1
        print('', flush=True)

        # reset metrics
        train_loss.reset_states()
        train_r.reset_states()
Esempio n. 6
0
    def fit_tape(self, seqnn_model):
        if not self.compiled:
            self.compile(seqnn_model)
        model = seqnn_model.model
        print(len(model.trainable_variables))
        self.test_data = list(self.eval_data[0].dataset.as_numpy_iterator())[0]
        num_infractions = np.zeros(164)
        # metrics
        num_targets = model.output_shape[-1]
        train_loss = tf.keras.metrics.Mean()
        train_r = metrics.PearsonR(num_targets)
        train_r2 = metrics.R2(num_targets)
        self.better_optimizer = PCGrad(self.optimizer)
        jacobian_variable = tf.constant([1, 2, 3, 4, 5])
        shutdownCount = 0

        #@tf.function
        #def calculate_multitask_gradients(task_loss):
        #return tf.vectorized_map(lambda x: tf.gradients(x, model.trainable_variables), task_loss)

        def better_train_step(x, y):
            pred = model(x, training=tf.constant(True))
            loss = self.loss_fn(y, pred) + sum(model.losses)
            flattened_pred = tf.unstack(pred, num=164, axis=2)
            task_pred = []
            for elt in flattened_pred:
                task_pred.append(elt)
            flattened_actual = tf.unstack(y, num=164, axis=2)
            task_actual = []
            for elt in flattened_actual:
                task_actual.append(elt)
            task_loss = []
            for i in range(164):
                task_loss.append(
                    self.loss_fn(task_actual[i], task_pred[i]) +
                    (sum(model.losses) / 164))
            task_loss = tf.stack(task_loss)
            #grads_task = tf.vectorized_map(lambda x: tf.gradients(x, model.trainable_variables), task_loss)
            #grads_task = calculate_multitask_gradients(task_loss)
            #self.optimizer.apply_gradients(zip(grads_task[3], model.trainable_variables))

        #@tf.function
        #def train_step(x, y):
        #with tf.GradientTape(persistent=True) as tape:
        #pred = model(x, training=tf.constant(True))
        #loss = self.loss_fn(y, pred) + sum(model.losses)
        #flattened_pred = tf.unstack(pred, num=164, axis=2)
        #task_pred = []
        #for elt in flattened_pred:
        #task_pred.append(elt)
        #flattened_actual = tf.unstack(y, num=164, axis=2)
        #task_actual = []
        #for elt in flattened_actual:
        #task_actual.append(elt)
        #task_loss = []
        #for i in range(164):
        #task_loss.append(self.loss_fn(task_actual[i], task_pred[i]) + (sum(model.losses) / 164))
        #alternate_task_loss = tf.vectorized_map(lambda x: tf.square(x), y - pred)
        #alternate_task_loss = (y - pred)*(y - pred)
        #print(task_loss)
        #task_loss = tf.stack(task_loss)
        #grads_task = tf.vectorized_map(lambda x: tape.gradient(x, model.trainable_variables), task_loss)
        #quantity = task_loss[0] # + task_loss[1]
        #task_loss = tf.stack(task_loss)
        #gradients = tape.gradient(quantity, model.trainable_variables) #this works for some reason
        #print(tf.shape(gradients[0]))
        #self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        #print(alternate_task_loss)
        #jacobian = tape.jacobian(task_loss, model.trainable_variables, experimental_use_pfor=True)
        #return jacobian
        #i = tf.constant(1, dtype=tf.int32)
        #while tf.less(i, 29):

        #task_0_grads = []
        #for grad in jacobian:
        #task_0_grads.append(tf.unstack(grad)[0])
        #grads_task = tf.unstack(more_gradients, num=164, axis=1)
        #print(tf.shape(more_gradients[0]))
        #task_loss = tf.stack(task_loss)
        #tape.watch(task_loss)
        #grads_task = tf.vectorized_map(lambda x: tape.gradient(x, model.trainable_variables), alternate_task_loss)
        #print(grads_task)
        #self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        #return more_gradients
        #grads_and_vars = self.better_optimizer.compute_gradients(task_loss, model.trainable_variables)
        #self.better_optimizer.apply_gradients(grads_and_vars)

        # improvement variables
        valid_best = -np.inf
        unimproved = 0
        vectorSet = []
        myArr = np.ones(164, dtype=np.float32)
        #myArr[0] = 0
        print(myArr)

        @tf.function
        def train_step_0(x, y, weights):
            with tf.GradientTape(persistent=True) as tape:
                pred = model(x, training=tf.constant(True))
                flattened_pred = tf.unstack(pred, num=164, axis=2)
                task_pred = []
                for elt in flattened_pred:
                    task_pred.append(elt)
                flattened_actual = tf.unstack(y, num=164, axis=2)
                task_actual = []
                for elt in flattened_actual:
                    task_actual.append(elt)
                task_loss = []
                unweighted_task_loss = []
                for i in range(164):
                    task_loss.append(
                        (self.loss_fn(task_actual[i], task_pred[i]) +
                         (sum(model.losses) / 164)) *
                        (1 / (2 * weights[i] * weights[i])))
                    unweighted_task_loss.append(
                        self.loss_fn(task_actual[i], task_pred[i]) +
                        (sum(model.losses) / 164))
                task_loss = tf.stack(task_loss)
                #print(task_loss)
                #print(weights)
                #weighted_sum = tf.tensordot(task_loss, weights)
                weighted_sum = tf.math.reduce_sum(task_loss)
                #weight_gradients = []
                #for i in range(164):
                #weight_gradients.append(unweighted_task_loss[i] * ( -1 / (weights[i] * weights[i] * weights[i])) + tf.math.log(weights[i]))
            gradients = tape.gradient(weighted_sum, model.trainable_variables)
            #print(gradients)
            #print(None in gradients)
            self.optimizer.apply_gradients(
                zip(gradients, model.trainable_variables))
            return unweighted_task_loss

        # training loop
        for ei in range(self.train_epochs_max):
            if ei >= self.train_epochs_min and unimproved > self.patience:
                break
            else:
                # train
                print("epoch started")
                t0 = time.time()
                train_iter = iter(self.train_data[0].dataset)
                print("num iterations")
                print(self.train_epoch_batches[0])
                for si in range(self.train_epoch_batches[0]):
                    x, y = next(train_iter)
                    #x = x.numpy()
                    #y = y.numpy()
                    #print(type(x))
                    #print(type(y))
                    #myArr = np.ones(164, dtype=np.float32)
                    #weights = tf.convert_to_tensor(myArr)
                    gradient_weights = train_step_0(x, y, myArr)
                    #model.train_on_batch(x, y)
                    #print("train step completed")
                    #print(si)
                    if si % 250 == 0:
                        print("250 train steps completed")
                        print("weights")
                        print(myArr)
                        test_x = self.test_data[0]
                        test_y = self.test_data[1]
                        predictions = model.predict(test_x)
                        squared_differences = tf.square(predictions - test_y)
                        sumVector = squared_differences[0]
                        for i in range(1, len(squared_differences)):
                            sumVector = sumVector + squared_differences[i]
                        sumVector = sumVector / len(squared_differences)
                        print(
                            "End of epoch validation MSE for each of the 164 cell types: "
                        )
                        vectorString = ""
                        for entry in sumVector:
                            vectorString = vectorString + str(entry) + " "
                        print(vectorString)
                        #print(sumVector)
                        print(si // 250)
                        outputFile = open("164_cell_validation_error.txt", "a")
                        outputFile.write(vectorString + "\n")
                        outputFile.close()
                        outputFile = open("164_cell_weights.txt", "a")
                        outputFile.write(str(myArr) + "\n")
                        outputFile.close()
                        vectorSet.append(sumVector)
                        parameter_value = 200
                        if len(vectorSet) > parameter_value:
                            vectorIterationIndex = len(vectorSet) - 1
                            checkIndex = vectorIterationIndex - parameter_value
                            for i in range(164):
                                num1 = vectorSet[vectorIterationIndex].numpy(
                                )[0][i]
                                #num1 = int(vectorSet[vectorIterationIndex][i])
                                num2 = vectorSet[checkIndex].numpy()[0][i]
                                if num1 > num2:
                                    num_infractions[i] += 1
                            for i in range(164):
                                num = int(num_infractions[i])
                                if num > 5:
                                    if myArr[i] != 0:
                                        shutdownCount += 1
                                        if shutdownCount == 82:
                                            outputFile = open(
                                                "task_being_analyzed.txt", "a")
                                            outputFile.write(str(i) + "\n")
                                            myArr = np.zeros(164,
                                                             dtype=np.float32)
                                            myArr[i] = 1
                                        if shutdownCount < 82:
                                            myArr[i] = 0
                                        outputFile = open(
                                            "early_stopping_stats.txt", "a")
                                        outputFile.write(
                                            str(i) + ": " +
                                            str(len(vectorSet)) + "\n")
                                        outputFile.close()

                # print training accuracy
                outputFile = open("epoch_stats.txt", "a")
                print("training for epoch completed")
                train_loss_epoch = train_loss.result().numpy()
                train_r_epoch = train_r.result().numpy()
                print('Epoch %d - %ds - train_loss: %.4f - train_r: %.4f' %
                      (ei,
                       (time.time() - t0), train_loss_epoch, train_r_epoch),
                      end='')
                outputFile.write("training for epoch completed\n")
                outputString = 'Epoch %d - %ds - train_loss: %.4f - train_r: %.4f' % (
                    ei, (time.time() - t0), train_loss_epoch, train_r_epoch)
                outputFile.write(outputString + "\n")

                # checkpoint
                seqnn_model.save('%s/model_check.h5' % self.out_dir)

                # print validation accuracy
                valid_loss, valid_pr, valid_r2 = model.evaluate(
                    self.eval_data[0].dataset, verbose=0)
                print(' - valid_loss: %.4f - valid_r: %.4f - valid_r2: %.4f' %
                      (valid_loss, valid_pr, valid_r2),
                      end='')
                outputString = ' - valid_loss: %.4f - valid_r: %.4f - valid_r2: %.4f' % (
                    valid_loss, valid_pr, valid_r2)
                outputFile.write(outputString + "\n")
                outputFile.close()

                # check best
                if valid_pr > valid_best:
                    print(' - best!', end='')
                    unimproved = 0
                    valid_best = valid_pr
                    seqnn_model.save('%s/model_best.h5' % self.out_dir)
                else:
                    unimproved += 1
                print('', flush=True)

                # reset metrics
                train_loss.reset_states()
                train_r.reset_states()
Esempio n. 7
0
  def fit_tape(self, seqnn_model):
    if not self.compiled:
      self.compile(seqnn_model)
    model = seqnn_model.model
    
    # metrics
    num_targets = model.output_shape[-1]
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_r = metrics.PearsonR(num_targets, name='train_r')
    train_r2 = metrics.R2(num_targets, name='train_r2')
    valid_loss = tf.keras.metrics.Mean(name='valid_loss')
    valid_r = metrics.PearsonR(num_targets, name='valid_r')
    valid_r2 = metrics.R2(num_targets, name='valid_r2')
    print("strategy", self.strategy)
    if self.strategy is None:
      @tf.function
      def train_step(x, y):
        with tf.GradientTape() as tape:
          pred = model(x, training=True)
          loss = self.loss_fn(y, pred) + sum(model.losses)
        train_loss(loss)
        train_r(y, pred)
        train_r2(y, pred)
        gradients = tape.gradient(loss, model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

      @tf.function
      def eval_step(x, y):
        pred = model(x, training=False)
        print("pred_valid")
        print(pred)
        loss = self.loss_fn(y, pred) + sum(model.losses)
        valid_loss(loss)
        valid_r(y, pred)
        valid_r2(y, pred)

    else:
      def train_step(x, y):
        with tf.GradientTape() as tape:
          pred = model(x, training=True)
          loss_batch_len = self.loss_fn(y, pred)
          loss_batch = tf.reduce_mean(loss_batch_len, axis=-1)
          loss = tf.reduce_sum(loss_batch) / self.batch_size
          loss += sum(model.losses) / self.num_gpu
        train_r(y, pred)
        train_r2(y, pred)
        gradients = tape.gradient(loss, model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        return loss

      @tf.function
      def train_step_distr(xd, yd):
        replica_losses = self.strategy.run(train_step, args=(xd, yd))
        loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM,
                                    replica_losses, axis=None)
        train_loss(loss)


      def eval_step(x, y):
        pred = model(x, training=False)
        loss = self.loss_fn(y, pred) + sum(model.losses)
        valid_loss(loss)
        valid_r(y, pred)
        valid_r2(y, pred)

      @tf.function
      def eval_step_distr(xd, yd):
        return self.strategy.run(eval_step, args=(xd, yd))


    # improvement variables
    valid_best = -np.inf
    unimproved = 0

    # training loop
    model_stat_file = open(self.out_dir + "/model_stat.txt", "w")
    model_stat_file.write("epoch"+"\t"+ "train_loss_epoch"+"\t"+"train_r_epoch" + "\t"+"train_r2_epoch" + "\t"+
                "valid_loss_epoch"+"\t"+"valid_r_epoch" + "\t"+"valid_r2_epoch"+"\n")
    for ei in range(self.train_epochs_max):
      if ei >= self.train_epochs_min and unimproved > self.patience:
        break
      else:
        # train
        t0 = time.time()
        train_iter = iter(self.train_data[0].dataset)
        for si in range(self.train_epoch_batches[0]):
          x, y = next(train_iter)
          if self.strategy is not None:
            train_step_distr(x, y)
          else:
            train_step(x, y)

        # evaluate
        # eval_iter = iter(self.eval_data[0].dataset)
        # for si in range(self.eval_epoch_batches[0]):
        #   x, y = next(eval_iter)
        for x, y in self.eval_data[0].dataset:
          if self.strategy is not None:
            eval_step_distr(x, y)
          else:
            eval_step(x, y)

        # print training accuracy
        train_loss_epoch = train_loss.result().numpy()
        train_r_epoch = train_r.result().numpy()
        train_r2_epoch = train_r2.result().numpy()
        model_stat_file.write(str(ei)+"\t"+str(train_loss_epoch)+"\t"+str(train_r_epoch)+"\t"+
                              str(train_r2_epoch)+"\t")
        print('Epoch %d - %ds - train_loss: %.4f - train_r: %.4f - train_r2: %.4f' % \
          (ei, (time.time()-t0), train_loss_epoch, train_r_epoch, train_r2_epoch), end='')

        # print validation accuracy
        # valid_loss, valid_pr, valid_r2 = model.evaluate(self.eval_data[0].dataset, verbose=0)
        valid_loss_epoch = valid_loss.result().numpy()
        valid_r_epoch = valid_r.result().numpy()
        valid_r2_epoch = valid_r2.result().numpy()
        model_stat_file.write(str(valid_loss_epoch) + "\t" + str(valid_r_epoch) + "\t" +
                              str(valid_r2_epoch) + "\n")
        print(' - valid_loss: %.4f - valid_r: %.4f - valid_r2: %.4f' % \
          (valid_loss_epoch, valid_r_epoch, valid_r2_epoch), end='')

        # checkpoint
        seqnn_model.save('%s/model_check.h5'%self.out_dir)
        if ei%40 == 0:
          draw_pred_in_train(model=seqnn_model, out_dir=self.out_dir, epoch=ei,
                             loss=valid_loss_epoch, pearson_r=valid_r_epoch)
          if ei%400 == 0:
            seqnn_model.save(self.out_dir+"/model_check_epoch"+str(ei)+".h5")

        # check best
        if valid_r_epoch > valid_best:
          print(' - best!', end='')
          unimproved = 0
          valid_best = valid_r_epoch
          seqnn_model.save('%s/model_best.h5'%self.out_dir)
        else:
          unimproved += 1
        print('', flush=True)

        # reset metrics
        train_loss.reset_states()
        train_r.reset_states()
        train_r2.reset_states()
        valid_loss.reset_states()
        valid_r.reset_states()
        valid_r2.reset_states()
    model_stat_file.close()
Esempio n. 8
0
    def fit_tape(self, seqnn_model):
        if not self.compiled:
            self.compile(seqnn_model)
        model = seqnn_model.model

        # metrics
        num_targets = model.output_shape[-1]
        train_loss = tf.keras.metrics.Mean(name='train_loss')
        train_r = metrics.PearsonR(num_targets, name='train_r')
        train_r2 = metrics.R2(num_targets, name='train_r2')
        valid_loss = tf.keras.metrics.Mean(name='valid_loss')
        valid_r = metrics.PearsonR(num_targets, name='valid_r')
        valid_r2 = metrics.R2(num_targets, name='valid_r2')

        if self.strategy is None:

            @tf.function
            def train_step(x, y):
                with tf.GradientTape() as tape:
                    pred = model(x, training=True)
                    loss = self.loss_fn(y, pred) + sum(model.losses)
                train_loss(loss)
                train_r(y, pred)
                train_r2(y, pred)
                gradients = tape.gradient(loss, model.trainable_variables)
                self.optimizer.apply_gradients(
                    zip(gradients, model.trainable_variables))

            @tf.function
            def eval_step(x, y):
                pred = model(x, training=False)
                loss = self.loss_fn(y, pred) + sum(model.losses)
                valid_loss(loss)
                valid_r(y, pred)
                valid_r2(y, pred)

        else:

            def train_step(x, y):
                with tf.GradientTape() as tape:
                    pred = model(x, training=True)
                    loss_batch_len = self.loss_fn(y, pred)
                    loss_batch = tf.reduce_mean(loss_batch_len, axis=-1)
                    loss = tf.reduce_sum(loss_batch) / self.batch_size
                    loss += sum(model.losses) / self.num_gpu
                train_r(y, pred)
                train_r2(y, pred)
                gradients = tape.gradient(loss, model.trainable_variables)
                self.optimizer.apply_gradients(
                    zip(gradients, model.trainable_variables))
                return loss

            @tf.function
            def train_step_distr(xd, yd):
                replica_losses = self.strategy.run(train_step, args=(xd, yd))
                loss = self.strategy.reduce(tf.distribute.ReduceOp.SUM,
                                            replica_losses,
                                            axis=None)
                train_loss(loss)

            def eval_step(x, y):
                pred = model(x, training=False)
                loss = self.loss_fn(y, pred) + sum(model.losses)
                valid_loss(loss)
                valid_r(y, pred)
                valid_r2(y, pred)

            @tf.function
            def eval_step_distr(xd, yd):
                return self.strategy.run(eval_step, args=(xd, yd))

        # checkpoint manager
        ckpt = tf.train.Checkpoint(model=seqnn_model.model,
                                   optimizer=self.optimizer)
        manager = tf.train.CheckpointManager(ckpt, self.out_dir, max_to_keep=1)
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            ckpt_end = 5 + manager.latest_checkpoint.find('ckpt-')
            epoch_start = int(manager.latest_checkpoint[ckpt_end:])
            print('Checkpoint restored at epoch %d, optimizer iteration %d.' % \
              (epoch_start, self.optimizer.iterations))

        else:
            print('No checkpoints found.')
            epoch_start = 0

        # improvement variables
        valid_best = -np.inf
        unimproved = 0

        # training loop
        for ei in range(epoch_start, self.train_epochs_max):
            if ei >= self.train_epochs_min and unimproved > self.patience:
                break
            else:
                # train
                t0 = time.time()
                train_iter = iter(self.train_data[0].dataset)
                for si in range(self.train_epoch_batches[0]):
                    x, y = safe_next(train_iter)
                    if self.strategy is not None:
                        train_step_distr(x, y)
                    else:
                        train_step(x, y)

                # evaluate
                for x, y in self.eval_data[0].dataset:
                    if self.strategy is not None:
                        eval_step_distr(x, y)
                    else:
                        eval_step(x, y)

                # print training accuracy
                train_loss_epoch = train_loss.result().numpy()
                train_r_epoch = train_r.result().numpy()
                train_r2_epoch = train_r2.result().numpy()
                print('Epoch %d - %ds - train_loss: %.4f - train_r: %.4f - train_r2: %.4f' % \
                  (ei, (time.time()-t0), train_loss_epoch, train_r_epoch, train_r2_epoch), end='')

                # print validation accuracy
                valid_loss_epoch = valid_loss.result().numpy()
                valid_r_epoch = valid_r.result().numpy()
                valid_r2_epoch = valid_r2.result().numpy()
                print(' - valid_loss: %.4f - valid_r: %.4f - valid_r2: %.4f' % \
                  (valid_loss_epoch, valid_r_epoch, valid_r2_epoch), end='')

                # checkpoint
                manager.save()
                seqnn_model.save('%s/model_check.h5' % self.out_dir)

                # check best
                if valid_r_epoch > valid_best:
                    print(' - best!', end='')
                    unimproved = 0
                    valid_best = valid_r_epoch
                    seqnn_model.save('%s/model_best.h5' % self.out_dir)
                else:
                    unimproved += 1
                print('', flush=True)

                # reset metrics
                train_loss.reset_states()
                train_r.reset_states()
                train_r2.reset_states()
                valid_loss.reset_states()
                valid_r.reset_states()
                valid_r2.reset_states()