def model_train(self): """ Train a TF graph :param sess: TF session to use when training the graph :param x: input placeholder :param y: output placeholder (for labels) :param predictions: model output predictions :param X_train: numpy array with training inputs :param Y_train: numpy array with training outputs :param hparams.save: boolean controlling the save operation :param predictions_adv: if set with the adversarial example tensor, will run adversarial training :param evaluate: function that is run after each training iteration (typically to display the test/validation accuracy). """ assert self.runner is not None, ( """Runner is not initialized. TrainerSingleGPU or TrainerMultiGPU instantiate a Runner object at initialization time.""") hparams = self.hparams batch_size = hparams.batch_size nb_epochs = hparams.nb_epochs train_dir = hparams.save_dir filename = 'model.ckpt' X_train = self.X_train Y_train = self.Y_train sess = self.sess with sess.as_default(): X_batch = X_train[:batch_size] Y_batch = Y_train[:batch_size] self._init_tf(X_batch, Y_batch) for epoch in six.moves.xrange(nb_epochs): logging.info("Epoch " + str(epoch)) # Compute number of batches nb_batches = int(math.ceil(float(len(X_train)) / batch_size)) assert nb_batches * batch_size >= len(X_train) # Indices to shuffle training set index_shuf = list(range(len(X_train))) self.rng.shuffle(index_shuf) prev = time.time() for batch in range(nb_batches): # Compute batch start and end indices start, end = batch_indices( batch, len(X_train), batch_size) # Perform one training step self._update_learning_params() # Train step X_batch = X_train[index_shuf[start:end]] Y_batch = Y_train[index_shuf[start:end]] self._run({'x_pre': X_batch, 'y': Y_batch}) self._sync_params() # Clean up the queue while not self.runner.is_finished(): self._run() self._sync_params(forced=True) assert end >= len(X_train), ( 'Not all training examples are used.') cur = time.time() logging.info("\tEpoch took " + str(cur - prev) + " seconds") prev = cur self.eval() # Save model cond = ((epoch+1) % hparams.save_steps == 0 or epoch == nb_epochs) if hparams.save and cond: save_path = os.path.join(train_dir, filename) saver = tf.train.Saver() saver.save(sess, save_path) logging.info("Model saved at: " + str(save_path)) logging.info("Completed model training.")