示例#1
0
    def trainLowMemory(self):

        print("TRAINING IS STARTING RIGHT NOW!")

        for epoch_num in range(1, MAX_EPOCHS + 1):

            epoch_start_time = time.time()
            currentSaveRate = SAVE_EVERY

            for step in range(self.dataset.num_batches_train):

                start_time = time.time()

                # get data
                batch_x, batch_y = self.sess.run(
                    [self.dataset.train_images, self.dataset.train_labels])

                # train towers logits
                logits = []
                for sequence_image in range(self.dataset.frames):
                    feed_dict = {self.towerImage: batch_x[:, sequence_image]}
                    eval_tensors = self.towerNet
                    logits.append(self.sess.run(eval_tensors, feed_dict))
                logits = np.transpose(np.array(logits), [1, 0, 2])

                feed_dict = {
                    self.towerLogits: logits,
                    self.Y: batch_y,
                    self.is_training: True
                }
                if (step + 1) * BATCH_SIZE % 5000 == 0:
                    # optimiza + rest of network + summaries
                    eval_tensors = [
                        self.loss, self.train_op, self.merged_summary_op
                    ]
                    loss_val, _, merged_ops = self.sess.run(
                        eval_tensors, feed_dict=feed_dict)
                    self.summary_train_train_writer.add_summary(
                        merged_ops, self.global_step.eval(session=self.sess))
                else:
                    # optimize + rest of network
                    eval_tensors = [self.loss, self.train_op]
                    loss_val, _ = self.sess.run(eval_tensors, feed_dict)

                duration = time.time() - start_time
                util.log_step(epoch_num, step, duration, loss_val, BATCH_SIZE,
                              self.dataset.num_train_examples, LOG_EVERY)

                if (step / self.dataset.num_batches_train) >= currentSaveRate:
                    self.saver.save(self.sess, self.ckptPrefix)
                    currentSaveRate += SAVE_EVERY

            epoch_time = time.time() - epoch_start_time
            print("Total epoch time training: {}".format(epoch_time))

            # self.startValidation(epoch_num, epoch_time)

        self.finishTraining()
示例#2
0
    def trainLowMemory(self):

        print("TRAINING IS STARTING RIGHT NOW!")

        for epoch_num in range(1, MAX_EPOCHS + 1):

            epoch_start_time = time.time()
            currentSaveRate = SAVE_EVERY

            for step in range(self.dataset.num_batches_train):

                start_time = time.time()

                # get data
                batch_x, batch_y = self.sess.run([self.dataset.train_images, self.dataset.train_labels])

                # train towers logits
                logits = []
                for sequence_image in range(self.dataset.frames):
                    feed_dict = {self.towerImage: batch_x[:, sequence_image], self.is_training: True}
                    eval_tensors = self.towerNet
                    logits.append(self.sess.run(eval_tensors, feed_dict))
                logits = np.transpose(np.array(logits), [1, 0, 2, 3, 4])

                # logits
                feed_dict = {self.towerLogits: logits, self.is_training: False}
                eval_tensors = self.logits
                logits = self.sess.run(eval_tensors, feed_dict)

                # optimize
                feed_dict = {self.optLogits: logits, self.Y: batch_y, self.is_training: True}
                if (step + 1) * BATCH_SIZE % 5000 == 0:
                    eval_tensors = [self.loss, self.train_op, self.merged_summary_op]
                    loss_val, _, merged_ops = self.sess.run(eval_tensors, feed_dict=feed_dict)
                    self.summary_train_train_writer.add_summary(merged_ops, self.global_step.eval(session=self.sess))
                else:
                    eval_tensors = [self.loss, self.train_op]
                    loss_val, _ = self.sess.run(eval_tensors, feed_dict)

                duration = time.time() - start_time
                util.log_step(epoch_num, step, duration, loss_val, BATCH_SIZE, self.dataset.num_train_examples,
                              LOG_EVERY)

                if (step / self.dataset.num_batches_train) >= currentSaveRate:
                    self.saver.save(self.sess, self.ckptPrefix)
                    currentSaveRate += SAVE_EVERY

            epoch_time = time.time() - epoch_start_time
            print("Total epoch time training: {}".format(epoch_time))

            self.startValidationLowMemory(epoch_num, epoch_time)

        self.finishTraining()
示例#3
0
    def train(self):

        print("TRAINING IS STARTING RIGHT NOW!")

        for epoch_num in range(1, MAX_EPOCHS + 1):

            epoch_start_time = time.time()
            currentSaveRate = SAVE_EVERY

            for step in range(self.dataset.num_batches_train):

                start_time = time.time()

                batch_x, batch_y = self.sess.run(
                    [self.dataset.train_images, self.dataset.train_labels])

                feed_dict = {
                    self.is_training: True,
                    self.X: batch_x,
                    self.Y: batch_y
                }
                eval_tensors = [self.loss, self.train_op]
                if (step + 1) * BATCH_SIZE % LOG_EVERY == 0:
                    eval_tensors += [self.merged_summary_op]

                eval_ret = self.sess.run(eval_tensors, feed_dict=feed_dict)
                eval_ret = dict(zip(eval_tensors, eval_ret))

                loss_val = eval_ret[self.loss]

                if self.merged_summary_op in eval_tensors:
                    self.summary_train_train_writer.add_summary(
                        eval_ret[self.merged_summary_op],
                        self.global_step.eval(session=self.sess))

                duration = time.time() - start_time
                util.log_step(epoch_num, step, duration, loss_val, BATCH_SIZE,
                              self.dataset.num_train_examples, LOG_EVERY)

                if (step /
                        self.dataset.num_batches_train) >= (currentSaveRate):
                    self.saver.save(self.sess, self.ckptPrefix)
                    currentSaveRate += SAVE_EVERY

            epoch_time = time.time() - epoch_start_time
            print("Total epoch time training: {}".format(epoch_time))

            self.startValidation(epoch_num, epoch_time)

        self.finishTraining()
示例#4
0
文件: ef.py 项目: Tiyanak/lip-reading
    def train(self):

        print("TRAINING IS STARTING RIGHT NOW!")

        for epoch_num in range(1, MAX_EPOCHS + 1):

            epoch_start_time = time.time()
            currentSaveRate = SAVE_EVERY

            for step in range(self.dataset.num_batches_train):

                start_time = time.time()

                batch_x, batch_y = self.sess.run([self.dataset.train_images, self.dataset.train_labels])
                feed_dict = {self.is_training: True, self.X: batch_x, self.Y: batch_y}
                eval_tensors = [self.loss, self.train_op]
                if (step + 1) * BATCH_SIZE % LOG_EVERY == 0:
                    eval_tensors += [self.merged_summary_op]

                eval_ret = self.sess.run(eval_tensors, feed_dict=feed_dict)
                eval_ret = dict(zip(eval_tensors, eval_ret))

                loss_val = eval_ret[self.loss]

                if self.merged_summary_op in eval_tensors:
                    self.summary_train_train_writer.add_summary(eval_ret[self.merged_summary_op],
                                                                self.global_step.eval(session=self.sess))

                duration = time.time() - start_time
                util.log_step(epoch_num, step, duration, loss_val, BATCH_SIZE, self.dataset.num_train_examples,
                              LOG_EVERY)

                if (step / self.dataset.num_batches_train) >= (currentSaveRate):
                    self.saver.save(self.sess, self.ckptPrefix)
                    currentSaveRate += SAVE_EVERY

            epoch_time = time.time() - epoch_start_time
            print("Total epoch time training: {}".format(epoch_time))

            self.startValidation(epoch_num, epoch_time)

        self.finishTraining()