예제 #1
0
def eval():
    g1 = tf.Graph()
    with g1.as_default():
        config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False)  # 如果有gpu,则优先用gpu,否则走cpu资源
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:

            data_object = DataProcessing(args)

            valid_batch_example = data_object.input_frame_data(
                frame_path=args.valid_path, batch_size=160, num_epoch=1)
            model = creat_model(sess, args, isTraining=False)  # 构建模型计算图

            # app_tag_model函数: 模型计算图的具体实现
            ckpt = tf.train.get_checkpoint_state(args.model_dir)

            # 判断是否存在模型
            if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
                # print("Reloading model parameters..")
                print("valid step: Reading model parameters from {}".format(
                    ckpt.model_checkpoint_path))
                saver = tf.train.Saver(tf.global_variables())
                saver.restore(sess=sess, save_path=ckpt.model_checkpoint_path
                              )  # 调用saver接口,将各个tensor变量的值赋给对应的tensor
                print(sess.run(model.global_step))
            else:
                if not os.path.exists(args.model_dir):
                    os.makedirs(args.model_dir)
                print("valid step: Created new model parameters..")
                sess.run(tf.global_variables_initializer())

            valid_cate2_perr_list = []
            valid_cate2_gap_list = []
            valid_total_loss_list = []

            try:
                while True:
                    context_parsed, sequence_parsed = sess.run(
                        valid_batch_example)
                    batch_origin_labels = [
                        np.nonzero(row)[0].tolist()
                        for row in context_parsed['labels']
                    ]
                    cate1_multilabel, cate2_multilabel, batch_origin_cate1, batch_origin_cate2 = data_object.get_cate1_cate2_label(
                        batch_origin_labels)
                    batch_vid_name = np.asarray(context_parsed['id'])
                    batch_cate1_label_multiHot = np.asarray(
                        cate1_multilabel)  # batch,cate1_nums
                    batch_cate2_label_multiHot = np.asarray(
                        cate2_multilabel)  # batch,cate2_nums
                    batch_rgb_fea_float_list = np.asarray(
                        sequence_parsed['rgb'])  # batch,max_frame,1024
                    batch_audio_fea_float_list = np.asarray(
                        sequence_parsed['audio'])  # batch,max_frame,128
                    batch_num_audio_rgb_true_frame = np.asarray(
                        context_parsed['num_audio_rgb_true_frame'])

                    feed = dict(
                        zip([
                            model.input_video_vidName,
                            model.input_cate1_multilabel,
                            model.input_cate2_multilabel,
                            model.input_video_RGB_feature,
                            model.input_video_Audio_feature,
                            model.input_rgb_audio_true_frame,
                            model.dropout_keep_prob
                        ], [
                            batch_vid_name, batch_cate1_label_multiHot,
                            batch_cate2_label_multiHot,
                            batch_rgb_fea_float_list,
                            batch_audio_fea_float_list,
                            batch_num_audio_rgb_true_frame, 1.0
                        ]))

                    cate2_probs, total_loss = sess.run(
                        [model.cate2_probs, model.total_loss], feed)

                    cate2_perr = eval_util.calculate_precision_at_equal_recall_rate(
                        cate2_probs, batch_cate2_label_multiHot)
                    cate2_gap = eval_util.calculate_gap(
                        cate2_probs, batch_cate2_label_multiHot)

                    valid_cate2_perr_list.append(cate2_perr)
                    valid_cate2_gap_list.append(cate2_gap)

                    valid_total_loss_list.append(total_loss)

            except tf.errors.OutOfRangeError:
                print("end!")

                valid_cate2_perr_aver_loss = 1.0 * np.sum(
                    valid_cate2_perr_list) / len(valid_cate2_perr_list)
                valid_cate2_gap_aver_loss = 1.0 * np.sum(
                    valid_cate2_gap_list) / len(valid_cate2_gap_list)

                valid_total_valid_aver_loss = 1.0 * np.sum(
                    valid_total_loss_list) / len(valid_total_loss_list)

                print('total valid cate2_perr_aver_loss: %0.4f' %
                      valid_cate2_perr_aver_loss)
                print('total valid cate2_gap_aver_loss: %0.4f' %
                      valid_cate2_gap_aver_loss)
                print('***********************')
                print('total valid total_valid_aver_loss: %0.4f' %
                      valid_total_valid_aver_loss)

            return valid_cate2_gap_aver_loss
예제 #2
0
def predict_test():
    if os.path.exists(args.predict_out_path):
        remove_file(args.predict_out_path)

    with open(args.predict_out_path, 'a+') as fout1:
        fout1.write("VideoId,LabelConfidencePairs" + '\n')

    data_object = DataProcessing(args)
    test_batch_example = data_object.input_frame_data(
        frame_path=args.test_path, batch_size=200, num_epoch=1)

    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False)  # 如果有gpu,则优先用gpu,否则走cpu资源
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = creat_model(sess, args, isTraining=False)  # 构建模型计算图

        num = 0
        try:
            while True:
                batch_vid, batch_index, batch_value = [], [], []
                context_parsed, sequence_parsed = sess.run(test_batch_example)
                batch_origin_labels = [
                    np.nonzero(row)[0].tolist()
                    for row in context_parsed['labels']
                ]
                cate1_multilabel, cate2_multilabel, batch_origin_cate1, batch_origin_cate2 = data_object.get_cate1_cate2_label(
                    batch_origin_labels)
                batch_vid_name = np.asarray(context_parsed['id'])
                batch_cate1_label_multiHot = np.asarray(
                    cate1_multilabel)  # batch,cate1_nums
                batch_cate2_label_multiHot = np.asarray(
                    cate2_multilabel)  # batch,cate2_nums
                batch_rgb_fea_float_list = np.asarray(
                    sequence_parsed['rgb'])  # batch,max_frame,1024
                batch_audio_fea_float_list = np.asarray(
                    sequence_parsed['audio'])  # batch,max_frame,128
                batch_num_audio_rgb_true_frame = np.asarray(
                    context_parsed['num_audio_rgb_true_frame'])

                feed = dict(
                    zip([
                        model.input_video_vidName,
                        model.input_cate1_multilabel,
                        model.input_cate2_multilabel,
                        model.input_video_RGB_feature,
                        model.input_video_Audio_feature,
                        model.input_rgb_audio_true_frame,
                        model.dropout_keep_prob
                    ], [
                        batch_vid_name, batch_cate1_label_multiHot,
                        batch_cate2_label_multiHot, batch_rgb_fea_float_list,
                        batch_audio_fea_float_list,
                        batch_num_audio_rgb_true_frame, 1.0
                    ]))

                cate2_top20_probs_index, cate2_top20_probs_value = sess.run([
                    model.cate2_top20_probs_index,
                    model.cate2_top20_probs_value
                ], feed)

                num += 1
                print("validSet step_num: ", num)
                for one_batch_index in cate2_top20_probs_index:
                    batch_index.append(list(one_batch_index))

                for one_batch_value in cate2_top20_probs_value:
                    batch_value.append(list(one_batch_value))

                for one_batch_vid in batch_vid_name:
                    batch_vid.append(list(one_batch_vid))

                with open(args.predict_out_path, 'a+') as fout2:

                    for vid, index, value in zip(
                            *[batch_vid, batch_index, batch_value]):
                        one_result = []
                        tmp_index_vale_list = []
                        vid = vid[0].decode(encoding='utf-8')
                        one_result.append(vid)
                        one_result.append(',')
                        for k, v in zip(*[index, value]):
                            tmp_index_vale_list.append(str(k))
                            tmp_index_vale_list.append(str("%.6f" % v))

                        one_result.append(' '.join(tmp_index_vale_list))

                        # print("one_result: ",''.join(one_result))
                        fout2.write(''.join(one_result) + '\n')

                batch_vid, batch_index, batch_value = [], [], []

        except tf.errors.OutOfRangeError:
            print("train processing is finished!")
예제 #3
0
def train():
    '''
        模型,训练任务
        :return:
    '''
    valid_max_accuracy = -9999
    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False)  # 如果有gpu,则优先用gpu,否则走cpu资源
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:

        data_object = DataProcessing(args)
        train_batch_example = data_object.input_frame_data(
            frame_path=args.train_path,
            batch_size=args.batch_size,
            num_epoch=args.epoch)

        model = creat_model(sess, args, isTraining=True)  # 构建模型计算图
        saver = tf.train.Saver(tf.global_variables(),
                               max_to_keep=3)  # max_to_keep 表征只保留最好的3个模型
        print("Begin training..")

        train_cate1_perr_list = []
        train_cate1_gap_list = []
        train_cate2_perr_list = []
        train_cate2_gap_list = []
        train_total_loss_list = []

        try:
            while True:
                context_parsed, sequence_parsed = sess.run(train_batch_example)

                batch_origin_labels = [
                    np.nonzero(row)[0].tolist()
                    for row in context_parsed['labels']
                ]
                cate1_multilabel, cate2_multilabel, batch_origin_cate1, batch_origin_cate2 = data_object.get_cate1_cate2_label(
                    batch_origin_labels)
                batch_vid_name = np.asarray(context_parsed['id'])
                batch_num_audio_rgb_true_frame = np.asarray(
                    context_parsed['num_audio_rgb_true_frame'])
                batch_cate1_label_multiHot = np.asarray(
                    cate1_multilabel)  # batch,cate1_nums
                batch_cate2_label_multiHot = np.asarray(
                    cate2_multilabel)  # batch,cate2_nums
                batch_rgb_fea_float_list = np.asarray(
                    sequence_parsed['rgb'])  # batch,max_frame,1024
                batch_audio_fea_float_list = np.asarray(
                    sequence_parsed['audio'])  # batch,max_frame,1
                # print("batch_vid_name.shape:",batch_vid_name.shape)
                # print("batch_cate1_label_multiHot.shape: ",batch_cate1_label_multiHot.shape)
                # print("batch_cate2_label_multiHot.shape: ",batch_cate2_label_multiHot.shape)
                # print("batch_rgb_fea_float_list.shape: ",batch_rgb_fea_float_list.shape)
                # print("batch_audio_fea_float_list.shape: ",batch_audio_fea_float_list.shape)
                # print("batch_num_audio_rgb_true_frame: ",batch_num_audio_rgb_true_frame)
                # print("batch_num_audio_rgb_true_frame: ", np.asarray(batch_num_audio_rgb_true_frame).shape)
                # assert 1==2

                feed = dict(
                    zip([
                        model.input_video_vidName,
                        model.input_cate1_multilabel,
                        model.input_cate2_multilabel,
                        model.input_video_RGB_feature,
                        model.input_video_Audio_feature,
                        model.input_rgb_audio_true_frame,
                        model.dropout_keep_prob
                    ], [
                        batch_vid_name, batch_cate1_label_multiHot,
                        batch_cate2_label_multiHot, batch_rgb_fea_float_list,
                        batch_audio_fea_float_list,
                        batch_num_audio_rgb_true_frame, 0.5
                    ]))

                cate1_probs, cate2_probs, total_loss, _ = sess.run([
                    model.cate1_probs, model.cate2_probs, model.total_loss,
                    model.optimizer
                ], feed)

                train_cate1_perr = eval_util.calculate_precision_at_equal_recall_rate(
                    cate1_probs, batch_cate1_label_multiHot)
                train_cate1_gap = eval_util.calculate_gap(
                    cate1_probs, batch_cate1_label_multiHot)
                train_cate2_perr = eval_util.calculate_precision_at_equal_recall_rate(
                    cate2_probs, batch_cate2_label_multiHot)
                train_cate2_gap = eval_util.calculate_gap(
                    cate2_probs, batch_cate2_label_multiHot)

                train_cate1_perr_list.append(train_cate1_perr)
                train_cate1_gap_list.append(train_cate1_gap)
                train_cate2_perr_list.append(train_cate2_perr)
                train_cate2_gap_list.append(train_cate2_gap)

                train_total_loss_list.append(total_loss)

                if model.global_step.eval() % args.report_freq == 0:
                    print("report_freq: ", args.report_freq)

                    print(
                        'cate1_train: Step:{} ; aver_train_cate1_perr:{} ; aver_train_cate1_gap_list:{} ; aver_total_loss:{}'
                        .format(
                            model.global_step.eval(),
                            1.0 * np.sum(train_cate1_perr_list) /
                            len(train_cate1_perr_list),
                            1.0 * np.sum(train_cate1_gap_list) /
                            len(train_cate1_gap_list),
                            1.0 * np.sum(train_total_loss_list) /
                            len(train_total_loss_list)))

                    print(
                        'cate2_train: Step:{} ; aver_train_cate2_perr:{} ; aver_train_cate2_gap_list:{} ; aver_total_loss:{}'
                        .format(
                            model.global_step.eval(),
                            1.0 * np.sum(train_cate2_perr_list) /
                            len(train_cate2_perr_list),
                            1.0 * np.sum(train_cate2_gap_list) /
                            len(train_cate2_gap_list),
                            1.0 * np.sum(train_total_loss_list) /
                            len(train_total_loss_list)))

                    train_cate1_perr_list = []
                    train_cate1_gap_list = []
                    train_cate2_perr_list = []
                    train_cate2_gap_list = []
                    train_total_loss_list = []

                if model.global_step.eval(
                ) > 1 and model.global_step.eval() % args.valid_freq == 0:
                    # 统计验证集的准确率
                    print("valid infer is process  111!")
                    print('model.global_step.eval(): ',
                          model.global_step.eval())

                    valid_cate2_gap_aver_loss = eval()
                    # 保存当前验证集合准确率最高的模型
                    if valid_cate2_gap_aver_loss > valid_max_accuracy:
                        print("save the model, step= : ",
                              model.global_step.eval())
                        valid_max_accuracy = valid_cate2_gap_aver_loss
                        checkpoint_path = os.path.join(args.model_dir,
                                                       'model.ckpt')
                        saver.save(sess=sess,
                                   save_path=checkpoint_path,
                                   global_step=model.global_step.eval())

        except tf.errors.OutOfRangeError:
            print("train processing is finished!")