def fit(self, tfrecord_path=TFRECORD_PATH, model_dir=MODEL_DIR, num_steps=-1, input_size=[CNN_FRAME_SIZE, CNN_VIDEO_HEIGHT, CNN_VIDEO_WIDTH, 3], batch_size=BATCH_SIZE, label_size=LABEL_SIZE, learning_rate=LEARNING_RATE, num_test_batches=NUM_TEST_BATCHES, display_train_loss_step=DISPLAY_TRAIN_LOSS_STEP, display_test_loss_step=DISPLAY_TEST_LOSS_STEP): """ Fit CNN model. """ # Initialize model paths. model_path = model_dir + "/model.ckpt" self.init_model_paths(model_path) self.phase = "train" self.batch_size = batch_size # Initialize model. tf.reset_default_graph() self.build_model(input_size, label_size) # Create loss. with tf.variable_scope("loss_error"): loss_function, cross_entropy_classes, cross_entropy_action = \ self.build_loss() # Create optimization function. optimizer = self.optimizer(loss_function, learning_rate) # Create summaries. train_writer, test_writer, loss_summary = self.create_summaries( loss_function) # Start train session. self.open_session() train_info = self.load_train_info() self.load_graph() # Create batch generators. train_generator = BatchGenerator( "train", self.sess, tfrecord_path, self.input_size[0], self.input_size[1], self.input_size[2], batch_size) test_generator = BatchGenerator( "validation", self.sess, tfrecord_path, self.input_size[0], self.input_size[1], self.input_size[2], batch_size) while train_info["step"] < num_steps or num_steps == -1: # Get train batch. forgd_samples, backd_samples, labels = train_generator.get_next() if train_info["step"] % display_train_loss_step == 0: train_loss_s, error_classes, error_action, loss_train_val, \ _opt_val = self.sess.run( [loss_summary, cross_entropy_classes, cross_entropy_action, loss_function, optimizer], feed_dict={ self.input["input_video"]: forgd_samples, self.input["input_background_video"]: backd_samples, self.input["input_label"]: labels}) train_writer.add_summary(train_loss_s, train_info["step"]) print('Step %i: train loss: %f,' ' classes loss: %f, action loss: %f' % (train_info["step"], loss_train_val, error_classes, error_action)) else: _opt_val, loss_train_val = self.sess.run( [optimizer, loss_function], feed_dict={self.input["input_video"]: forgd_samples, self.input["input_background_video"]: backd_samples, self.input["input_label"]: labels}) self.save_train_info(train_info) train_writer.flush() # Display test loss and input/output images. if train_info["step"] % display_test_loss_step == 0: test_loss_list = [] error_classes_list = [] error_action_list = [] batch_index = 0 while batch_index < num_test_batches: forgd_samples, backd_samples, labels = \ test_generator.get_next() batch_index += 1 if batch_index < num_test_batches: loss_test_val, error_classes, error_action = \ self.sess.run( [loss_function, cross_entropy_classes, cross_entropy_action], feed_dict={ self.input["input_video"]: forgd_samples, self.input["input_background_video"]: backd_samples, self.input["input_label"]: labels}) else: loss_s, loss_test_val, error_classes, error_action = \ self.sess.run( [loss_summary, loss_function, cross_entropy_classes, cross_entropy_action], feed_dict={ self.input["input_video"]: forgd_samples, self.input["input_background_video"]: backd_samples, self.input["input_label"]: labels}) test_loss_list.append(loss_test_val) error_classes_list.append(error_classes) error_action_list.append(error_action) loss_test_val = np.mean(test_loss_list) if loss_test_val < train_info["best_test_lost"]: train_info["best_test_lost"] = loss_test_val self.saver.save(self.sess, model_path, global_step=train_info["step"]) print('Step %i: validation loss: %f,' ' best validation loss: %f, classes loss: %f, ' 'action loss: %f' % (train_info["step"], loss_test_val, train_info["best_test_lost"], np.mean(error_classes_list), np.mean(error_action_list))) test_writer.add_summary(loss_s, train_info["step"]) test_writer.flush() self.save_train_info(train_info) train_info["step"] += 1 self.close_session()
args = parser.parse_args() frame_size = args.frame_size batch_size = args.batch_size height = args.height width = args.width tfrecord_path = args.tfrecord_path forgd_video_dir = args.foreground_video_dir labels = get_labels(forgd_video_dir) batch_generator = BatchGenerator("test", None, tfrecord_path, frame_size, height, width, 1) while True: # batch_forgd, batch_backd, batch_labels = \ # batch_generator.get_next() batch_forgd, batch_labels = batch_generator.get_next() action, probs = model.predict(batch_forgd[0]) forgd_frames = batch_forgd[-1] label = batch_labels[-1] if label[0] == 1: label = labels[np.argmax(label[1:])] else: continue # label = "Background" top_classes = get_top_classes(probs, labels) print(label in top_classes, label, top_classes) for i in range(forgd_frames.shape[0]):