################ # Training ##### ################ if args.action == "train": model = None batch_gen = Base_batch_generator() beta_frame = int(math.ceil(float(args.beta) * 30)) S_enc_frame = (int(args.S_enc) * beta_frame) S_ant_frame = (int(args.S_ant) * beta_frame) if args.model == "rnn": model = ModelRNN(nClasses, args.rnn_size, args.max_seq_sz, args.num_layers) batch_gen = RNN_batch_generator(nClasses, args.n_iterations, args.max_seq_sz, actions_dict, args.alpha, S_enc_frame, S_ant_frame, beta_frame) elif args.model == "cnn": model = ModelCNN(args.nRows, nClasses) batch_gen = CNN_batch_generator(args.nRows, nClasses, actions_dict) batch_gen.read_data(list_of_videos) with tf.Session() as sess: model.train(sess, args.model_save_path, batch_gen, args.nEpochs, args.save_freq, args.batch_size) ################## # Prediction ##### ################## elif args.action == "predict": pred_percentages = [.1, .2, .3, .5]
actions_dict = read_mapping_dict(args.mapping_file) nClasses = len(actions_dict) file_ptr = open(args.vid_list_file, 'r') list_of_videos = file_ptr.read().split('\n')[1:-1] ################ # Training ##### ################ if args.action == "train": model = None batch_gen = Base_batch_generator() if args.model == "rnn": model = ModelRNN(nClasses, args.rnn_size, args.max_seq_sz, args.num_layers) batch_gen = RNN_batch_generator(nClasses, args.n_iterations, args.max_seq_sz, actions_dict, args.alpha) elif args.model == "cnn": model = ModelCNN(args.nRows, nClasses) batch_gen = CNN_batch_generator(args.nRows, nClasses, actions_dict) batch_gen.read_data(list_of_videos) with tf.Session() as sess: model.train(sess, args.model_save_path, batch_gen, args.nEpochs, args.save_freq, args.batch_size) ################## # Prediction ##### ################## elif args.action == "predict": pred_percentages = [.1, .2, .3, .5] obs_percentages = [.2, .3, .5] model_restore_path = args.model_save_path+"/epoch-"+str(args.eval_epoch)+"/model.ckpt"