def train_model(model, sess):
    model.train_writer.add_graph(sess.graph)
    for epoch_index in range(model.settings.max_epoch):
        train_fetches = [
            model.loss, model.sigmoid_y_pred, model.train_op, model.global_step
        ]
        train_batch_generator = generate_batch_data('train', model.settings)
        prog = Progbar(target=model.settings.train_data_size //
                       model.settings.batch_size)
        for index, batch in enumerate(train_batch_generator):
            feed_dict = model.create_feed_dic(batch)
            loss, y_pred, _, global_step = sess.run(train_fetches, feed_dict)
            if global_step % 100 == 0:
                precision, recall, f1 = evaluate(
                    batch['ground_truth'],
                    get_top_5_id(y_pred, model.settings.batch_size))
                prog.update(index + 1, [("Loss", loss),
                                        ("precision", precision),
                                        ("recall", recall), ("F1", f1)])
                summary = tf.Summary(value=[
                    tf.Summary.Value(tag="Loss", simple_value=loss),
                    tf.Summary.Value(tag="precision", simple_value=precision),
                    tf.Summary.Value(tag="recall", simple_value=recall),
                    tf.Summary.Value(tag="F1", simple_value=f1),
                ])

                model.train_writer.add_summary(summary,
                                               global_step=global_step)
        test_model(model, sess)
Exemple #2
0
def train_population(population, x, y, batch_size, steps,
                     steps_save=100, validation_split=0.3):
    # Split data in train and validation. Set seed to get same splits in
    # consequent calls
    x_train, x_val, y_train, y_val = train_test_split(
        x, y, test_size=validation_split, random_state=42)

    population_size = len(population)
    batch_generator = BatchGenerator(x_train, y_train, batch_size)

    results = defaultdict(lambda: [])
    stateful_metrics = ['min_loss', 'max_loss', 'mean_loss']
    for metric, _ in population[0].eval_metrics:
        stateful_metrics.extend(
            [m.format(metric) for m in ['min_{}', 'max_{}', 'mean_{}']])
    progbar = Progbar(steps, stateful_metrics=stateful_metrics)

    for step in range(1, steps + 1):
        x, y = batch_generator.next()
        for idx, member in enumerate(population):
            # One step of optimisation using hyperparameters of 'member'
            member.step_on_batch(x, y)
            # Model evaluation
            loss = member.eval_on_batch(x_val, y_val)
            # If optimised for 'STEPS_READY' steps
            if member.ready():
                # Use the rest of population to find better solutions
                exploited = member.exploit(population)
                # If new weights != old weights
                if exploited:
                    # Produce new hyperparameters for 'member'
                    member.explore()
                    loss = member.eval_on_batch(x_val, y_val)

            if step % steps_save == 0 or step == steps:
                results['model_id'].append(str(member))
                results['step'].append(step)
                results['loss'].append(loss)
                results['loss_smoothed'].append(member.loss_smoothed())
                for metric, value in member.eval_metrics:
                    results[metric].append(value)
                for h, v in member.get_hyperparameter_config().items():
                    results[h].append(v)

        # Get recently added losses to show in the progress bar
        all_losses = results['loss']
        recent_losses = all_losses[-population_size:]
        if recent_losses:
            metrics = _statistics(recent_losses, 'loss')
            for metric, _ in population[0].eval_metrics:
                metrics.extend(
                    _statistics(results[metric][-population_size:], metric))
            progbar.update(step, metrics)

    return pd.DataFrame(results)
def imu_integration(imu_data, x_0_v, track_progress=True):

    # TODO: get a better comparison

    samples = len(imu_data)
    out = np.zeros((samples, 10))

    imu_v, t_diff_v = imu_data[:, :, :6], imu_data[:, :, 6:]

    # Convert time diff to seconds
    t_diff_v = np.squeeze(np.stack(t_diff_v / 1000), axis=2)

    bar = Progbar(samples)

    for sample in range(samples):
        if track_progress:
            bar.update(sample)

        t_diff = t_diff_v[sample, :]

        # Get initial states (world frame)
        x_i = x_0_v[sample, :3]
        v_i = x_0_v[sample, 3:6]
        q_i = Quaternion(x_0_v[sample, 6:]).unit

        for i in range(len(t_diff)):

            dt = t_diff[i]

            # Rotation body -> world
            w_R_b = q_i.inverse

            # Rotate angular velocity to world frame
            w_w = w_R_b.rotate(imu_v[sample, i, :3])
            # Rotate acceleration to world frame
            w_a = w_R_b.rotate(imu_v[sample, i, 3:]) + [0, 0, 9.81]

            # Integrate attitude (world frame)
            q_i.integrate(w_w, dt)

            # Integrate velocity
            v_i += w_a * dt

            # Integrate position
            x_i += v_i * dt + 1 / 2 * w_a * (dt**2)

        out[sample, 0:3] = x_i
        out[sample, 3:6] = v_i
        out[sample, 6:] = q_i.elements

    out[:, 6:] = correct_quaternion_flip(out[:, 6:])
    return out
 def train(self, is_record=True):
     self.config.logger.info('the batch_size: {}'.format(self.config.batch_size))
     if not is_record:
         filepaths = glob(
             self.config.train_png_dir + 'original/*.png')
         filepaths = sorted(filepaths)  # Order the list of files
         filepaths_noisy = glob(self.config.train_png_dir + 'noisy/*.png')
         filepaths_noisy = sorted(filepaths_noisy)
         ind = list(range(len(filepaths)))
         num_batch = int(len(filepaths / self.config.batch_size))
         train_data = Dataset_File(self.sess, self.config.batch_size, filepaths, filepaths_noisy, ind)
         # train_iterator = train_data.get_batch()
     else:
         train_data = DataGenerator(self.config.tf_record_dir, 'train', self.config.batch_size, self.config.logger)
         num_batch = int(train_data.get_sample_size() / self.config.batch_size)
         # train_iterator = train_data.generate()
     start_epoch = self.itr_num // num_batch
     start_time = time.time()
     for epoch in range(start_epoch, self.config.num_epoch):
         batch_id = 0
         prog = Progbar(num_batch)
         for batch_clean, batch_noisy in train_data.generate():
             # batch_noisy = batch_noisy[np.newaxis, ...]
             # batch_clean = batch_clean[np.newaxis, ...]
             batch_clean, batch_noisy = get_patch_batch(batch_clean, batch_noisy, self.config.patch_size)
             feed = {
                 self.X: batch_noisy,
                 self.Y: batch_clean,
             }
             _, loss, _summary = self.sess.run([self.optimizer, self.loss, self.merged], feed_dict=feed)
             self.itr_num += 1
             batch_id += 1
             self.writer.add_summary(summary=_summary, global_step=self.itr_num)
             prog.update(batch_id, [('epoch', int(epoch + 1)), ('loss', loss),
                                    ('global step', self.global_step.eval(self.sess)),
                                    ('itr_num', self.itr_num),
                                    ('time', time.time() - start_time)])
             # do save
             if self.itr_num % self.config.save_itr_size == 0 and self.itr_num != 0:
                 self.save(self.itr_num)
         self.save(self.itr_num)
     self.config.logger.info('train done!')
Exemple #5
0
def validation_loop(sess,
                    g,
                    n_batches,
                    chars=None,
                    val_gen=None,
                    tb_writer=None):

    Loss = []
    Cer = []
    Wer = []

    progbar = Progbar(target=n_batches, verbose=1, stateful_metrics=['t'])
    print('Strating validation Loop')

    for i in range(n_batches):

        x, y = val_gen.next()
        if len(x) == 1: x = x[0]
        if len(y) == 1: y = y[0]

        # -- Autoregressive inference
        preds = np.zeros((config.batch_size, config.maxlen), np.int32)

        tile_preds = config.test_aug_times
        # -- For train graph feed in the previous step's predictions manually for the next
        if not 'infer' in config.graph_type:
            prev_inp = np.tile(
                preds, [config.test_aug_times, 1]) if tile_preds else preds
            feed_dict = {g.x: x, g.prev: prev_inp, g.y: y}

            enc = sess.run(g.enc, feed_dict)
            if type(enc) is list:
                for enc_tens, enc_val in zip(g.enc, enc):
                    feed_dict[enc_tens] = enc_val
            else:
                feed_dict[g.enc] = enc
            for j in range(config.maxlen):
                _preds, loss, cer = sess.run([g.preds, g.mean_loss, g.cer],
                                             feed_dict)
                preds[:, j] = _preds[:, j]
                prev_inp = np.tile(
                    preds, [config.test_aug_times, 1]) if tile_preds else preds
                feed_dict[g.prev] = prev_inp
                # if all samples in batch predict the pad symbol (char_id==0)
                if np.sign(preds[:, j]).sum() == 0:
                    if g.tb_sum is not None:
                        tb_sum = sess.run(g.tb_sum, {
                            g.x: x,
                            g.prev: prev_inp,
                            g.y: y
                        })
                    break

        # -- Autoregression loop is built into the beam search graph
        else:
            feed_dict = {g.x: x, g.y: y}
            enc = sess.run(g.enc, feed_dict)
            if type(enc) is list:
                for enc_tens, enc_val in zip(g.enc, enc):
                    feed_dict[enc_tens] = enc_val
            else:
                feed_dict[g.enc] = enc
            _preds, loss, cer = sess.run([g.preds, g.mean_loss, g.cer],
                                         feed_dict)
            preds = _preds

        # use last loss
        gt_sents = [''.join([chars[cid] for cid in prr]).strip() for prr in y]
        gt_words = [sent.split('-') for sent in gt_sents]

        def decode_preds_to_chars(decoding):
            return ''.join([chars[cid] for cid in decoding]).strip()

        pred_sentences = [decode_preds_to_chars(prr) for prr in preds]

        pred_words = [sent.split('-') for sent in pred_sentences]

        edists = [
            rel_edist(gt, dec_str)
            for gt, dec_str in zip(gt_words, pred_words)
        ]
        wer = np.mean(edists)

        # -- Write tb_summaries if any
        if g.tb_sum is not None:
            if wer == 0:
                tb_writer.add_summary(tb_sum, i)

        if config.print_predictions:
            print()
            for gts, prs, wr in zip(gt_sents, pred_sentences, edists):
                f = open("demofile3.txt", "w")
                f.write(prs)
                f.close()
                print('(wer={:.1f}) {} --> {}'.format(wr * 100, gts, prs))

        progbar.update(i + 1, [('cer', cer), ('wer', wer)])
        Wer.append(wer)

        Cer.append(cer)
        Loss.append(loss)

    return np.average(Loss), np.average(Cer), np.average(Wer)
    def iterate_model_output(self, datasets, dataset_options,
                             experiment_options, experiment_name):

        gt = {k: [] for k in experiment_options["plot_data"].keys()}
        predictions = {k: [] for k in experiment_options["plot_data"].keys()}
        comparisons = {k: [] for k in experiment_options["plot_data"].keys()}

        predictions_x_axis = None
        comparisons_x_axis = None
        gt_x_axis = None

        max_n_predictions = min(
            [len(datasets[i][0]["imu_input"]) for i in range(len(datasets))])
        n_pred = None
        if "iterations" in experiment_options.keys():
            n_pred = experiment_options["iterations"]
            assert n_pred * self.window_len - 1 < max_n_predictions, \
                "The maximum number of iterations are {0} for the current window length of {1}".format(
                    int(np.floor(max_n_predictions / self.window_len)), self.window_len)
        assert len(experiment_options["plot_data"].keys()) == 1, \
            "Currently this experiment only supports one output. Got {0} instead: {1}".format(
                len(experiment_options["plot_data"].keys()), experiment_options["plot_data"].keys())
        assert "state_in" in experiment_options.keys()
        assert "state_out" in experiment_options.keys()

        output_name = list(experiment_options["plot_data"].keys())[0]
        output_size = self.output_type_vars[experiment_options["plot_data"]
                                            [output_name]["type"]]["shape"]

        for i, dataset in enumerate(datasets):
            d_len = len(dataset[0]["imu_input"])

            if n_pred is None:
                n_predictions_p = int(
                    np.floor((d_len - self.window_len) /
                             (self.window_len - 1)) + 2)
                n_predictions_c = int(
                    np.floor((d_len - self.window_len) / self.window_len) + 2)
            else:
                n_predictions_p = n_pred
                n_predictions_c = n_pred

            state_out_name = experiment_options["state_out"]["name"]
            state_in_name = experiment_options["state_in"]["name"]

            for option in dataset_options[i]:
                if option == "predict":
                    predictions_x_axis = np.zeros(n_predictions_p,
                                                  dtype=np.int)
                    model = self.model_loader()
                    model_predictions = np.zeros((n_predictions_p, ) +
                                                 output_size)
                    progress_bar = Progbar(n_predictions_p - 1)
                    model_out = {}
                    ds_i = 0
                    model_predictions[0] = dataset[1][output_name][0]
                    for it in range(n_predictions_p - 1):
                        progress_bar.update(it + 1)
                        model_in = {
                            k: np.expand_dims(dataset[0][k][ds_i], axis=0)
                            for k in dataset[0].keys()
                        }
                        if it > 0:
                            past_pred = model_out[state_out_name]
                            if experiment_options["state_out"]["lie"]:
                                past_pred = np.concatenate(
                                    (past_pred[:, :6],
                                     exp_mapping(past_pred[:, 6:])),
                                    axis=1)
                            model_in[state_in_name] = past_pred
                        model_out = model.predict(model_in, verbose=0)
                        model_out = create_predictions_dict(model_out, model)
                        model_predictions[it + 1, :] = model_out[output_name]
                        ds_i += self.window_len - 1
                        predictions_x_axis[it + 1] = int(ds_i)
                        predictions_x_axis = predictions_x_axis.astype(np.int)

                    predictions[output_name] = model_predictions

                elif option == "compare_prediction":
                    model_predictions = np.zeros((n_predictions_c, 10))
                    comparisons_x_axis = np.zeros(n_predictions_c,
                                                  dtype=np.int)
                    progress_bar = Progbar(n_predictions_c)
                    state_in = np.expand_dims(dataset[0][state_in_name][0],
                                              axis=0)
                    model_predictions[0, :] = state_in
                    ds_i = 0
                    for it in range(n_predictions_c):
                        progress_bar.update(it + 1)
                        model_out = self.alt_prediction_algo(
                            np.squeeze(np.expand_dims(
                                dataset[0]["imu_input"][ds_i], axis=0),
                                       axis=-1), state_in, False)
                        model_predictions[it, :] = model_out
                        state_in = model_out
                        comparisons_x_axis[it] = int(ds_i)
                        comparisons_x_axis = comparisons_x_axis.astype(np.int)
                        ds_i += self.window_len - (1 if it == 0 else 0)

                    comparisons[output_name] = model_predictions

                elif option == "ground_truth":
                    state_in = np.expand_dims(dataset[0][state_in_name][0],
                                              axis=0)
                    if experiment_options["state_out"]["lie"]:
                        state_in = np.concatenate(
                            (state_in[:, :6], log_mapping(state_in[:, 6:])),
                            axis=1)
                    state_in = np.tile(state_in, (self.window_len - 1, 1))
                    gt = {
                        k: dataset[1][k]
                        for k in experiment_options["plot_data"].keys()
                    }
                    gt[state_out_name] = np.concatenate(
                        (state_in, gt[state_out_name]), axis=0)
                    gt_x_axis = np.arange(0, len(gt[state_out_name]))

        fig = self.draw_predictions(
            ground_truth=gt,
            model_prediction=predictions,
            comp_prediction=comparisons,
            plot_options=experiment_options["plot_data"],
            gt_x=gt_x_axis.astype(np.int),
            model_x=predictions_x_axis,
            comp_x=comparisons_x_axis)
        self.experiment_plot(fig,
                             experiment_options,
                             experiment_name=experiment_name)
Exemple #7
0
    def windowed_imu_preintegration_dataset(self, args):
        """
        :param args: extra arguments for dataset generation

        Generates a dataset that aims at performing IMU integration.
            Input 1: one initial 10-dimensional state consisting on initial position (x,y,z), velocity (x,y,z) and
            orientation (w,x,y,z)
            Input 2: a window of imu samples of dimensions <imu_len x 7>, where 7 are the 6 dimensions of the IMU
            readings (3 gyro + 3 acc) plus the time differences between imu acquisitions, and the number of rows are the
            number of used imu samples.
            Output 1: the final 9-dimensional state consisting on final position (x,y,z), velocity (x,y,z) and
            orientation (x,y,z), in so(3) representation
            Output 2: the pre-integrated rotation for each window element, with shape <n_samples, imu_len, 3> in so(3)
            Output 3: the pre-integrated velocity for each window element, with shape <n_samples, imu_len, 3> in R(3)
            Output 4: the pre-integrated position for each window element, with shape <n_samples, imu_len, 3> in R(3)

        """

        window_len = args[0]

        # TODO: get as a parameter of the dataset
        g_val = -9.81

        n_samples = len(self.imu_raw) - window_len - 1

        self.windowed_imu_for_state_prediction(args)

        # Adjustments for the pre-integration dataset
        self.x_ds["imu_input"] = self.x_ds["imu_input"][1:]
        self.x_ds["state_input"] = self.x_ds["state_input"][1:]
        self.y_ds["state_output"] = self.y_ds["state_output"][:-1]

        gt_augmented = self.x_ds["state_input"]
        gt_augmented = np.concatenate(
            (gt_augmented, self.y_ds["state_output"][-window_len:, :]), axis=0)

        imu_window = self.x_ds["imu_input"]

        # Define the pre-integrated rotation, velocity and position vectors
        pre_int_rot = np.zeros((n_samples, window_len, 3))
        pre_int_v = np.zeros((n_samples, window_len, 3))
        pre_int_p = np.zeros((n_samples, window_len, 3))

        print("Generating pre-integration dataset. This may take a while...")
        prog_bar = Progbar(n_samples)

        for i in range(n_samples):
            pi = np.tile(gt_augmented[i, 0:3], [window_len, 1])
            vi = np.tile(gt_augmented[i, 3:6], [window_len, 1])
            qi = np.tile(gt_augmented[i, 6:], [window_len, 1])

            # imu_window[i, :, -1, 0] is a <1, window_len> vector containing all the dt between two consecutive samples
            # of the imu. We compute the cumulative sum to get the total time for every sample in the window since the
            # beginning of the window itself. We divide by 1000 to transform from ms to s
            cum_dt_vec = np.cumsum(imu_window[i, :, -1, 0]) / 1000

            # We calculate the quaternion that rotates q(i) to q(i+t) for all t in [0, window_len], and map it to so(3)
            pre_int_q = np.array([
                q.elements
                for q in quaternion_error(qi, gt_augmented[i:i + window_len,
                                                           6:])
            ])
            pre_int_rot[i, :, :] = log_mapping(
                correct_quaternion_flip(pre_int_q))

            g_contrib = np.expand_dims(cum_dt_vec * g_val, axis=1) * np.array(
                [0, 0, 1])
            pre_int_v[i, :, :] = rotate_vec(
                gt_augmented[i:i + window_len, 3:6] - vi - g_contrib,
                q_inv(qi))

            v_contrib = np.multiply(np.expand_dims(cum_dt_vec, axis=1), vi)
            g_contrib = 1 / 2 * np.expand_dims(cum_dt_vec**2 * g_val,
                                               axis=1) * np.array([0, 0, 1])
            pre_int_p[i, :, :] = rotate_vec(
                gt_augmented[i:i + window_len, 0:3] - pi - v_contrib -
                g_contrib, q_inv(qi))

            prog_bar.update(i + 1)

        self.set_outputs(
            ["pre_integrated_R", "pre_integrated_v", "pre_integrated_p"],
            [pre_int_rot, pre_int_v, pre_int_p])
Exemple #8
0
class ProgressBarHook(session_run_hook.SessionRunHook):
  """Monitors training progress. This hook uses `tf.keras.utils.ProgBar` to
  write messages to stdout.

  Example:
    ```python
    estimator = tf.estimator.DNNClassifier(hidden_units=256, feature_columns=64)
    estimator.train(
      input_fn=lambda: input_fn,
      hooks=[ProgressBar(
        epochs=3,
        steps_per_epoch=4,
        tensors_to_log=['loss', 'acc'])])
    ```
    # output
    ```
    Epoch 1/5:
    4/4 [======================]4/4 - 13s 3s/step - acc: 0.7 - loss: 0.4124

    Epoch 2/5:
    4/4 [======================]4/4 - 1s 175ms/step - acc: 0.7235 - loss: 0.2313

    Epoch 3/5:
    4/4 [======================]4/4 - 1s 168ms/step - acc: 0.7814 - loss: 0.1951
    ```

  """
  def __init__(self,
               epochs,
               steps_per_epoch,
               tensors_to_log=None):
    """Initializes `ProgressBarHook` instance

    Args:
      epochs: `int`, Total number of expected epochs. It is usually calcuated
        by dividing number of training steps to `steps_per_epoch`.
      steps_per_epoch: `int`, numbers of expected iterations per epoch
      tensors_to_log: - optional - can be:
          `dict` maps string-valued tags to tensors/tensor names,
          or `iterable` of tensors/tensor names.

    Raise:
      ValueError: `tensors_to_log` is not a list or a dictionary.
    """
    self._epochs = epochs
    self._step_per_epoch = steps_per_epoch

    if tensors_to_log is not None:
      if not isinstance(tensors_to_log, dict):
        self._tag_order = tensors_to_log
        tensors_to_log = {item: item for item in tensors_to_log}
      else:
        self._tag_order = tensors_to_log.keys()
      self._tensors = tensors_to_log
    else:
      self._tensors = None

  def begin(self):
    self._global_step_tensor = training_util._get_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use ProgressBarHook")

    # Convert names to tensors if given
    if self._tensors:
      self._current_tensors = {tag: _as_graph_element(tensor)
                               for (tag, tensor) in self._tensors.items()}

  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    # Init current_epoch and current_step
    self._curr_step = session.run(self._global_step_tensor)
    if self._curr_step != 0:
      print('Resuming training from global step(s): %s...\n' % self._curr_step)

    self._curr_epoch = int(np.floor(self._curr_step / self._step_per_epoch))
    self._curr_step -= self._curr_epoch * self._step_per_epoch
    self._first_run = True

  def before_run(self, run_context):  # pylint: disable=unused-argument
    if self._first_run is  True:
      self._curr_epoch += 1
      print('Epoch %s/%s:' % (self._curr_epoch, self._epochs))
      self.progbar = Progbar(target=self._step_per_epoch)
      self._first_run = False

    elif self._curr_step % self._step_per_epoch == 0:
      self._curr_epoch += 1
      self._curr_step = 0
      print('Epoch %s/%s:' % (self._curr_epoch, self._epochs))
      self.progbar = Progbar(target=self._step_per_epoch)

    if self._tensors:
      return SessionRunArgs(self._current_tensors)

    return None

  def after_run(self,
                run_context,  # pylint: disable=unused-argument
                run_values):
    if self._tensors:
      values = self._extract_tensors_info(run_values.results)
    else:
      values = None
    self._curr_step += 1
    self.progbar.update(self._curr_step, values=values)

  def _extract_tensors_info(self, tensor_values):
    stats = []
    for tag in self._tag_order:
      stats.append((tag, tensor_values[tag]))
    return stats
Exemple #9
0
    def train(self, config):
        """High level train function.
        Args:
            config: Configuration dictionary
            descriptors_database_dict: dictionary of numpy arrays of descriptors, one array per track
            actions_database_dict: dictionary of numpy arrays of actions associated to descriptors, one array per track
            pos_database_dict: dictionary of numpy arrays of positions associated to descriptors, one array per track
        Returns:
            None
        """

        self.config = config
        self.build_train_graph()
        self.collect_summaries()
        with tf.name_scope("parameter_count"):
            parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \
                                             for v in tf.trainable_variables()])
        self.saver = tf.train.Saver([var for var in \
                                     tf.trainable_variables()] + [self.global_step], max_to_keep=20)

        sv = tf.train.Supervisor(logdir=config.checkpoint_dir,
                                 save_summaries_secs=0,
                                 saver=None)

        with sv.managed_session() as sess:
            print("Number of params: {}".format(sess.run(parameter_count)))
            if config.resume_train:
                assert os.path.isdir(self.config.checkpoint_dir)
                print("Resume training from previous checkpoint")
                checkpoint = tf.train.latest_checkpoint(
                    self.config.checkpoint_dir)
                assert checkpoint, "Found no checkpoint in the given dir!"
                print("Restoring checkpoint: ")
                print(checkpoint)
                self.saver.restore(sess, checkpoint)

            progbar = Progbar(target=self.train_steps_per_epoch)

            # What to train?
            trainables = {
                'CNN': True,
                'Mean_Prediction': True,
                'Variance_Prediction': False
            }

            n_epochs = 0

            # (Re-)Initialize the iterator
            sess.run(self.training_init_iter)

            for step in count(start=1):
                if sv.should_stop():
                    break
                start_time = time.time()
                fetches = {
                    "global_step": self.global_step,
                    "incr_global_step": self.incr_global_step
                }
                for key, trainable in trainables.items():
                    # print(key + " is trainable: " + str(trainable))
                    if trainable:
                        # fetches[key] = self.train_op[key]
                        fetches[key] = self.train_op[key]

                if step % config.summary_freq == 0:
                    fetches["train_loss"] = self.train_loss
                    fetches["summary"] = self.step_sum_op

                # Runs a series of operations
                results = sess.run(fetches, feed_dict={self.is_training: True})

                progbar.update(step % self.train_steps_per_epoch)

                gs = results["global_step"]

                if step % config.summary_freq == 0:
                    sv.summary_writer.add_summary(results["summary"], gs)
                    self.completed_epochs = int(gs /
                                                self.train_steps_per_epoch)
                    train_step = gs - (self.completed_epochs -
                                       1) * self.train_steps_per_epoch
                    print("Epoch: [%2d] [%5d/%5d] time: %4.4f/it train_loss: %.3f "
                          % (self.completed_epochs, train_step, self.train_steps_per_epoch, \
                             time.time() - start_time, results["train_loss"]))

                if step % self.train_steps_per_epoch == 0:
                    n_epochs += 1
                    self.completed_epochs = int(gs /
                                                self.train_steps_per_epoch)
                    progbar = Progbar(target=self.train_steps_per_epoch)
                    done = self._epoch_end_callback(sess, sv, n_epochs)

                    # (Re-)Initialize the iterator
                    sess.run(self.training_init_iter)

                    if done:
                        break