Пример #1
0
    def run_epoch(self, train_set, dev_set, epoch, lr_for_epoch):
        """Performs one epoch on whole trainning set
        Args:
            train_set: tuple of words, tags for training
            dev_set: tuple of words, tags for validation
            epoch : current epoch number
        """

        batch_size = self.config.batch_size
        nbatches = (len(train_set[0]) + batch_size - 1) // batch_size
        prog = Progbar(target=nbatches)

        x_train = train_set[0]
        y_train = train_set[1]
        for i, (sentences_batch, labels_batch) in enumerate(
                get_minibatch((x_train, y_train), self.config.batch_size)):
            feed_batch = self.create_feed_dict(
                sentences_batch,
                labels_batch,
                keep_probability=self.config.dropout_rate,
                lr=lr_for_epoch)
            #_, train_loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_batch)
            _, train_loss = self.sess.run([self.train_op, self.loss],
                                          feed_dict=feed_batch)

            prog.update(i + 1, [("train loss", train_loss)])

        metrics = self.evaluate(dev_set)

        msg = " - ".join(
            ["{} {:04.2f}".format(k, v) for k, v in metrics.items()])
        self.logger.info(msg)

        return metrics["f1"]
Пример #2
0
def learn(session, dataset, main_dqn, target_dqn, batch_size, gamma):
    """
    Args:
        session: A tensorflow sesson object
        replay_memory: A ReplayMemory object
        main_dqn: A DQN object
        target_dqn: A DQN object
        batch_size: Integer, Batch size
        gamma: Float, discount factor for the Bellman equation
    Returns:
        loss: The loss of the minibatch, for tensorboard
    Draws a minibatch from the replay memory, calculates the
    target Q-value that the prediction Q-value is regressed to.
    Then a parameter update is performed on the main DQN.
    """
    # Draw a minibatch from the replay memory
    # max_val = np.max(dataset.reward_per_mini_batch)
    states, actions, _, _, _, weights = utils.get_minibatch(dataset)
    # weights = np.copy(weights)/max_val
    loss, _ = session.run(
        [main_dqn.behavior_cloning_loss, main_dqn.update],
        feed_dict={
            main_dqn.input: states,
            main_dqn.expert_action: actions,
            main_dqn.expert_weights: weights
        })
    return loss
Пример #3
0
    def test(self, test_set, test_batch_size, **kwargs):
        """ Run forward process over the given whole test set.

    Parameters
    ----------
    test_set: dict
      dict of lists, including SGAs, drug types, DEGs, patient barcodes
    test_batch_size: int

    Returns
    -------

    """

        tgts, prds, msks, tmr, amtr = [], [], [], [], []

        for iter_test in range(0, len(self.rng_test), test_batch_size):
            batch_set = get_minibatch(test_set,
                                      self.rng_test,
                                      iter_test,
                                      test_batch_size,
                                      batch_type="test",
                                      use_cuda=self.use_cuda)

            hid_drg = self.forward(batch_set)

            batch_prds = torch.sigmoid(hid_drg)
            batch_tgts = batch_set["tgt"]
            batch_msks = batch_set["msk"]

            if self.use_attention:
                if self.use_cuda:
                    amtr.append(self.encoder.Amtr.data.cpu().numpy()
                                )  #(batch_size, num_drg, num_omc)
                else:
                    amtr.append(self.encoder.Amtr.data.numpy()
                                )  #(batch_size, num_drg, num_omc)

            if self.use_cuda:
                tgts.append(batch_tgts.data.cpu().numpy())
                msks.append(batch_msks.data.cpu().numpy())
                prds.append(batch_prds.data.cpu().numpy())
            else:
                tgts.append(batch_tgts.data.numpy())
                msks.append(batch_msks.data.numpy())
                prds.append(batch_prds.data.numpy())
            tmr = tmr + batch_set["tmr"]

        tgts = np.concatenate(tgts, axis=0)
        msks = np.concatenate(msks, axis=0)
        prds = np.concatenate(prds, axis=0)
        if self.use_attention:
            amtr = np.concatenate(amtr,
                                  axis=0)  #(sample_size, num_drg, num_omc)

        return tgts, msks, prds, tmr, amtr
Пример #4
0
    def test(self, test_set, test_batch_size, **kwargs):
        """ Run forward process over the given whole test set.

    Parameters
    ----------
    test_set: dict
      dict of lists, including SGAs, cancer types, DEGs, patient barcodes
    test_batch_size: int

    Returns
    -------
    labels: 2D array of 0/1
      groud truth of gene expression
    preds: 2D array of float in [0, 1]
      predicted gene expression
    hid_tmr: 2D array of float
      hidden layer of MLP
    emb_tmr: 2D array of float
      tumor embedding
    emb_sga: 2D array of float
      stratified tumor embedding
    attn_wt: 2D array of float
      attention weights of SGAs
    tmr: list of str
      barcodes of patients/tumors

    """

        labels, preds, hid_tmr, emb_tmr, emb_sga, attn_wt, tmr = [], [], [], [], [], [], []

        for iter_test in range(0, len(test_set["can"]), test_batch_size):
            batch_set = get_minibatch(test_set,
                                      iter_test,
                                      test_batch_size,
                                      batch_type="test")
            batch_preds, batch_hid_tmr, batch_emb_tmr, batch_emb_sga, batch_attn_wt = self.forward(
                batch_set["sga"], batch_set["can"])
            batch_labels = batch_set["deg"]

            labels.append(batch_labels.data.numpy())
            preds.append(batch_preds.data.numpy())
            hid_tmr.append(batch_hid_tmr.data.numpy())
            emb_tmr.append(batch_emb_tmr.data.numpy())
            emb_sga.append(batch_emb_sga.data.numpy())
            attn_wt.append(batch_attn_wt.data.numpy())
            tmr = tmr + batch_set["tmr"]

        labels = np.concatenate(labels, axis=0)
        preds = np.concatenate(preds, axis=0)
        hid_tmr = np.concatenate(hid_tmr, axis=0)
        emb_tmr = np.concatenate(emb_tmr, axis=0)
        emb_sga = np.concatenate(emb_sga, axis=0)
        attn_wt = np.concatenate(attn_wt, axis=0)

        return labels, preds, hid_tmr, emb_tmr, emb_sga, attn_wt, tmr
Пример #5
0
    def train(self,
              train_set,
              test_set,
              batch_size=None,
              test_batch_size=None,
              max_iter=None,
              max_fscore=None,
              test_inc_size=None,
              **kwargs):
        """ Train the model until max_iter or max_fscore reached.

    Parameters
    ----------
    train_set: dict
      dict of lists, including SGAs, cancer types, DEGs, patient barcodes
    test_set: dict
    batch_size: int
    test_batch_size: int
    max_iter: int
      max number of iterations that the training will run
    max_fscore: float
      max test F1 score that the model will continue to train itself
    test_inc_size: int
      interval of running a test/evaluation

    """

        for iter_train in range(0, max_iter + 1, batch_size):
            batch_set = get_minibatch(train_set,
                                      iter_train,
                                      batch_size,
                                      batch_type="train")
            preds, _, _, _, _ = self.forward(batch_set["sga"],
                                             batch_set["can"])
            labels = batch_set["deg"]

            self.optimizer.zero_grad()
            loss = -torch.log(self.epsilon + 1 -
                              torch.abs(preds - labels)).mean()
            loss.backward()
            self.optimizer.step()

            if test_inc_size and (iter_train % test_inc_size == 0):
                labels, preds, _, _, _, _, _ = self.test(
                    test_set, test_batch_size)
                precision, recall, f1score, accuracy = evaluate(
                    labels, preds, epsilon=self.epsilon)
                print("[%d,%d], f1_score: %.3f, acc: %.3f" %
                      (iter_train // len(train_set["can"]),
                       iter_train % len(train_set["can"]), f1score, accuracy))

                if f1score >= max_fscore:
                    break

        self.save_model(os.path.join(self.output_dir, "trained_model.pth"))
Пример #6
0
def sgd(model, x_train, y_train):
    mini_batches = get_minibatch(x_train, y_train)

    for i in range(n_iter):
        idx = np.random.randint(0, len(mini_batches))
        x_mini, y_mini = mini_batches[idx]

        grad = get_minibatch_grad(model, x_mini, y_mini)
        for layer in grad:
            model[layer] += 1e-3 * grad[layer]

    return model
Пример #7
0
def train_bc(session, dataset, replay_dataset, main_dqn, pretrain=False):
    states, actions, _, _, _, weights = utils.get_minibatch(dataset)
    gen_states_1, _, _, _, _ = replay_dataset.get_minibatch(
    )  #Generated trajectories
    gen_states_2, _, _, _, _ = replay_dataset.get_minibatch(
    )  #Generated trajectories
    gen_states = np.concatenate([gen_states_1, gen_states_2], axis=0)
    expert_loss, _ = session.run(
        [main_dqn.behavior_cloning_loss, main_dqn.bc_update],
        feed_dict={
            main_dqn.input: states,
            main_dqn.expert_action: actions,
            main_dqn.expert_weights: weights,
            main_dqn.generated_input: gen_states
        })

    return expert_loss
Пример #8
0
def momentum(model, X_train, y_train):
    velocity = {k: np.zeros_like(v) for k, v in model.items()}
    gamma = .9

    minibatches = get_minibatch(X_train, y_train)

    for iter in range(1, 100 + 1):
        idx = np.random.randint(0, len(minibatches))
        X_mini, y_mini = minibatches[idx]

        grad = get_minibatch_grad(model, X_mini, y_mini)

        for layer in grad:
            velocity[layer] = gamma * velocity[layer] + 1e-3 * grad[layer]
            model[layer] += velocity[layer]

    return model
Пример #9
0
    def evaluate(self, test, is_test_set=False):

        x_test = test[0]
        y_test = test[1]
        accs = []
        wrong_predictions = []
        lab_c = []
        pred_c = []
        """
        wrong_predictions is a list of tuples of type (sentence, fp_set, fn_set, lab, lab_pred)
        """
        correct_preds, total_correct, total_preds = 0., 0., 0.
        for sentences_batch, labels_batch in get_minibatch(
            (x_test, y_test), self.config.batch_size):
            labels_pred_batch, sequence_lengths_batch = self.predict_batch(
                sentences_batch)
            sentence_index = 0
            for lab, lab_pred, length in zip(labels_batch, labels_pred_batch,
                                             sequence_lengths_batch):
                lab = lab[:length]
                lab_pred = lab_pred[:length]
                accs += [a == b for (a, b) in zip(lab, lab_pred)]

                lab_chunks = set(
                    get_chunks(lab, self.config.vocab_tags, self.config))
                lab_pred_chunks = set(
                    get_chunks(lab_pred, self.config.vocab_tags, self.config))

                correct_preds += len(lab_chunks & lab_pred_chunks)
                total_preds += len(lab_pred_chunks)
                total_correct += len(lab_chunks)

                lab_c.append(lab)
                pred_c.append(lab_pred)

                fp_preds = lab_pred_chunks - lab_chunks
                fn_preds = lab_chunks - lab_pred_chunks

                if is_test_set and (len(fp_preds) != 0 or len(fn_preds) != 0):
                    wrong_pred = (sentences_batch[sentence_index], fp_preds,
                                  fn_preds, lab, lab_pred)
                    # print len(fp_preds) + len(lab_chunks & lab_pred_chunks)
                    # print len(fn_preds) + len(lab_chunks & lab_pred_chunks)
                    # print len(lab_pred_chunks)
                    # print len(lab_chunks)
                    # print fp_preds
                    # print fn_preds
                    wrong_predictions.append(wrong_pred)
                sentence_index += 1

        p = correct_preds / total_preds if correct_preds > 0 else 0
        r = correct_preds / total_correct if correct_preds > 0 else 0
        f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0
        acc = np.mean(accs)
        '''
        print "Correct: " + str(correct_preds)
        print "Total Pred: " + str(total_preds)
        print "Total Correct: " + str(total_correct)
        '''
        if is_test_set:
            print "Precision: " + str(p)
            print "Recall: " + str(r)
            print "F1: " + str(f1)

        #if is_test_set:
        #    write_wrong_predictions_to_file(wrong_predictions, self.config)
        pdir = self.config.dir_model
        import pickle
        pickle.dump(lab_c,
                    open(pdir + 'lab_.pkl', 'w'),
                    protocol=pickle.HIGHEST_PROTOCOL)
        pickle.dump(pred_c,
                    open(pdir + 'pred_.pkl', 'w'),
                    protocol=pickle.HIGHEST_PROTOCOL)

        return {
            "acc": 100 * acc,
            "f1": 100 * f1,
            "Precision": 100 * p,
            "Recall": 100 * r
        }
Пример #10
0
def learn(session, dataset, replay_memory, main_dqn, target_dqn, batch_size,
          gamma, args):
    """
    Args:states
        session: A tensorflow sesson object
        replay_memory: A ReplayMemory object
        main_dqn: A DQN object
        target_dqn: A DQN object
        batch_size: Integer, Batch size
        gamma: Float, discount factor for the Bellman equation
    Returns:
        loss: The loss of the minibatch, for tensorboard
    Draws a minibatch from the replay memory, calculates the
    target Q-value that the prediction Q-value is regressed to.
    Then a parameter update is performed on the main DQN.
    """
    # Draw a minibatch from the replay memory
    #weight = 1 - np.exp(policy_weight)/(np.exp(expert_weight) + np.exp(policy_weight))
    expert_states, expert_actions, expert_rewards, expert_new_states, expert_terminal_flags, weights = utils.get_minibatch(
        dataset)  #Expert trajectories
    generated_states, generated_actions, generated_rewards, generated_new_states, generated_terminal_flags = replay_memory.get_minibatch(
    )  #Generated trajectories

    # The main network estimates which action is best (in the next
    # state s', new_states is passed!)
    # for every transition in the minibatch

    # next_states = np.concatenate((expert_new_states, generated_new_states), axis=0)
    # combined_terminal_flags = np.concatenate((expert_terminal_flags, generated_terminal_flags), axis=0)
    # combined_rewards = np.concatenate((expert_rewards, generated_rewards), axis=0)
    # combined_actions = np.concatenate((expert_actions, generated_actions), axis=0)
    # combined_states = np.concatenate((expert_states, generated_states), axis=0)

    arg_q_max = session.run(main_dqn.best_action,
                            feed_dict={main_dqn.input: generated_new_states})
    q_vals = session.run(target_dqn.q_values,
                         feed_dict={target_dqn.input: generated_new_states})
    double_q = q_vals[range(batch_size), arg_q_max]

    # Bellman equation. Multiplication with (1-terminal_flags) makes sure that
    # if the game is over, targetQ=rewards
    target_q = generated_rewards + (gamma * double_q *
                                    (1 - generated_terminal_flags))

    # Gradient descend step to update the parameters of the main network
    for i in range(1):
        loss, _ = session.run(
            [main_dqn.loss, main_dqn.update],
            feed_dict={
                main_dqn.input: generated_states,
                main_dqn.target_q: target_q,
                main_dqn.action: generated_actions
            })

    expert_loss, _ = session.run(
        [main_dqn.expert_loss, main_dqn.expert_update],
        feed_dict={
            main_dqn.input: expert_states,
            main_dqn.generated_input: generated_states,
            main_dqn.expert_action: expert_actions,
            main_dqn.expert_weights: weights
        })
    return loss, expert_loss
Пример #11
0
def train_or_test(model,
                  data,
                  label,
                  data_show,
                  data_index,
                  batch_size,
                  train_type,
                  optimizer,
                  epoch,
                  lossF=nn.NLLLoss(reduction='sum')):
    #F.nll_loss nn.NLLLoss(reduction='sum')
    start_time = time.time()
    if train_type == "train":
        np.random.shuffle(data_index)
        model.train()
    else:
        model.eval()

    losses = []
    preds = []

    Y = []
    test_probs = []
    for i in range(0, len(data), batch_size):
        data_batch, label_batch, lens = utils.get_minibatch(
            data, label, data_index, i, batch_size, max_len)
        scores = model(data_batch, lens, None)
        if train_type != 'test':
            loss = lossF(scores, label_batch)
            losses.append(loss.item())

        scores = scores.data.cpu().numpy()
        if train_type == 'test':
            test_probs.extend(scores)
        pred = [np.argmax(s) for s in scores]
        preds.extend(pred)
        Y.extend(label_batch.data.cpu().numpy())  #valid 是label test是ids

        if train_type == "train":
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, 1)
            optimizer.step()

    if train_type == 'valid':
        mean_loss = np.mean(losses)
        P, F1 = utils.getF1(preds, Y)
        #utils.show_result_txt(preds,Y,data_index,data_show,"valid_epoch%d"%(epoch))
        print("Valid epoch:%d time:%.4f loss:%.4f  F1:%.4f  P:%.4f" %
              (epoch, time.time() - start_time, mean_loss, F1, P))
    elif train_type == 'train':
        mean_loss = np.mean(losses)
        P, F1 = utils.getF1(preds, Y)
        print("--Train epoch:%d time:%.4f loss:%.4f  F1:%.4f P:%.4f" %
              (epoch, time.time() - start_time, mean_loss, F1, P))
    else:
        print(np.shape(test_probs))
        #np.save("LSTM_test_scores_%d.npy"%(epoch),test_probs)
        #np.save("test_ids.npy",Y)
        utils.saveResult(Y, preds, 'LSTM_Result_%d' % (epoch))
        return test_probs, Y
Пример #12
0
    def find_lr(self,
                train_set,
                test_set,
                batch_size=None,
                test_batch_size=None,
                max_iter=None,
                max_fscore=None,
                test_inc_size=None,
                logs=None,
                **kwargs):
        """ Train the model until max_iter or max_fscore reached.

    Parameters
    ----------
    train_set: dict
      dict of lists, including mut, cnv, exp, met, drug sensitivity, patient barcodes
    test_set: dict
    batch_size: int
    test_batch_size: int
    max_iter: int
      max number of iterations that the training will run
    max_fscore: float
      max test F1 score that the model will continue to train itself
    test_inc_size: int
      interval of running a test/evaluation

    """

        record_epoch = 0

        clr = CLR(max_iter // batch_size)

        running_loss = 0.

        avg_beta = 0.98  # useful in calculating smoothed loss

        for iter_train in range(0, max_iter + 1, batch_size):
            if iter_train // len(self.rng_train) != record_epoch:
                record_epoch = iter_train // len(self.rng_train)
                random.shuffle(self.rng_train)
            batch_set = get_minibatch(train_set,
                                      self.rng_train,
                                      iter_train,
                                      batch_size,
                                      batch_type="train",
                                      use_cuda=self.use_cuda)
            lgt_drg = self.forward(batch_set)
            tgts = batch_set["tgt"]
            msks = batch_set["msk"]

            loss = self.loss_cross_entropy(lgt_drg, tgts, msks)

            if self.use_cuda:
                lc = loss.data.cpu().numpy().tolist()
            else:
                lc = loss.data.numpy().tolist()

            # calculate the smoothed loss
            running_loss = avg_beta * running_loss + (
                1.0 - avg_beta) * lc  # the running loss
            smoothed_loss = running_loss / (
                1.0 - avg_beta**(iter_train // batch_size + 1)
            )  # smoothening effect of the loss

            lr = clr.calc_lr(
                smoothed_loss)  # calculate learning rate using CLR

            if lr == -1:  # the stopping criteria
                break
            for pg in self.optimizer.param_groups:  # update learning rate
                pg['lr'] = lr

            # compute gradient and do parameter updates
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if test_inc_size and (iter_train % test_inc_size == 0):
                print(
                    "[%d,%d] | loss:" % (iter_train // len(self.rng_train),
                                         iter_train % len(self.rng_train)), lc,
                    lr)

        logs['lrs'] = clr.lrs
        logs['losses'] = clr.losses

        return logs
Пример #13
0
    def train(self,
              train_set,
              test_set,
              batch_size=None,
              test_batch_size=None,
              max_iter=None,
              max_fscore=None,
              test_inc_size=None,
              logs=None,
              **kwargs):
        """ Train the model until max_iter or max_fscore reached.

    Parameters
    ----------
    train_set: dict
      dict of lists, including mut, cnv, exp, met, drug sensitivity, patient barcodes
    test_set: dict
    batch_size: int
    test_batch_size: int
    max_iter: int
      max number of iterations that the training will run
    max_fscore: float
      max test F1 score that the model will continue to train itself
    test_inc_size: int
      interval of running a test/evaluation
      
    """

        ocp = OneCycle(max_iter // batch_size, self.learning_rate)

        tgts_train, prds_train, msks_train = [], [], []
        losses, losses_ent = [], []

        record_epoch = 0
        for iter_train in range(0, max_iter + 1, batch_size):
            if iter_train // len(self.rng_train) != record_epoch:
                record_epoch = iter_train // len(self.rng_train)
                random.shuffle(self.rng_train)

            batch_set = get_minibatch(train_set,
                                      self.rng_train,
                                      iter_train,
                                      batch_size,
                                      batch_type="train",
                                      use_cuda=self.use_cuda)

            lgt_drg = self.forward(batch_set)
            tgts = batch_set["tgt"]
            msks = batch_set["msk"]

            lr, mom = ocp.calc()  # calculate learning rate using CLR

            if lr == -1:  # the stopping criteria
                break
            for pg in self.optimizer.param_groups:  # update learning rate
                pg['lr'] = lr
                pg['momentum'] = mom

            self.optimizer.zero_grad()

            loss_ent = self.loss_cross_entropy(lgt_drg, tgts, msks)
            loss = loss_ent

            loss.backward()

            self.optimizer.step()

            if self.use_cuda:
                tgts_train.append(tgts.data.cpu().numpy())
                msks_train.append(msks.data.cpu().numpy())
                prds_train.append(torch.sigmoid(lgt_drg).data.cpu().numpy())
                losses.append(loss.data.cpu().numpy().tolist())
                losses_ent.append(loss_ent.data.cpu().numpy().tolist())
            else:
                tgts_train.append(tgts.data.numpy())
                msks_train.append(msks.data.numpy())
                prds_train.append(torch.sigmoid(lgt_drg).data.numpy())
                losses.append(loss.data.numpy().tolist())
                losses_ent.append(loss_ent.data.numpy().tolist())

            if test_inc_size and (iter_train % test_inc_size == 0):

                tgts_train = np.concatenate(tgts_train, axis=0)
                msks_train = np.concatenate(msks_train, axis=0)
                prds_train = np.concatenate(prds_train, axis=0)

                precision_train, recall_train, f1score_train, accuracy_train, auc_train = evaluate(
                    tgts_train, msks_train, prds_train, epsilon=self.epsilon)

                tgts, msks, prds, _, _ = self.test(test_set, test_batch_size)

                precision, recall, f1score, accuracy, auc = evaluate(
                    tgts, msks, prds, epsilon=self.epsilon)

                print(
                    "[%d,%d] | tst f1:%.1f, auc:%.1f | trn f1:%.1f, auc:%.1f, loss:%.3f"
                    %
                    (iter_train // len(self.rng_train), iter_train %
                     len(self.rng_train), 100.0 * f1score, 100.0 * auc, 100.0 *
                     f1score_train, 100.0 * auc_train, np.mean(losses)))

                logs["iter"].append(iter_train)
                logs["precision"].append(precision)
                logs["recall"].append(recall)
                logs["f1score"].append(f1score)
                logs["accuracy"].append(accuracy)
                logs["auc"].append(auc)

                logs["precision_train"].append(precision_train)
                logs["recall_train"].append(recall_train)
                logs["f1score_train"].append(f1score_train)
                logs["accuracy_train"].append(accuracy_train)
                logs["auc_train"].append(auc_train)

                logs['loss'].append(np.mean(losses))

                tgts_train, prds_train, msks_train = [], [], []
                losses, losses_ent = [], []

        #self.save_model(os.path.join(self.output_dir, "trained_model.pth"))

        return logs
Пример #14
0
def train_bc(session, dataset, replay_dataset, main_dqn, pretrain=False):
    states, actions, _, _, _, weights = utils.get_minibatch(dataset)
Пример #15
0
        patience = 0
        prv_miou = 0
        curr_miou = 0
        max_valid = 0
        print('About to try, patience = %d' % patience)

        try:
            step = sess.run(global_step)
            print(step)
            while not coord.should_stop(
            ) and step < args.max_steps and patience < args.max_patience:

                start_time = time.time()

                mb = get_minibatch(args.bs, t_lumList, t_alphaList, t_betaList,
                                   t_segList)

                my_mask = mb[:, :, :, 3]

                _, train_loss, train_summ = sess.run(
                    [train_op, total_loss, training_summary],
                    feed_dict={
                        keep_prob: 0.5,
                        p_cielab: mb[:, :, :, 0:3],
                        p_mask: my_mask
                    })

                duration = time.time() - start_time

                assert not np.isnan(
                    train_loss), 'Model diverged with loss = NaN'