def val_epoch(epoch, data_loader, model, opt, epoch_logger, val_dataset):

    print('eval at epoch {}'.format(epoch))

    if val_dataset == 'ucf_aug':
        val_opt['min_cycles'] = 2
    else:
        val_opt['min_cycles'] = 4

    if val_dataset == 'yt_seg':
        val_opt['merge_w'] = 0.1

    model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    maes = AverageMeter()
    maeps = AverageMeter()
    maens = AverageMeter()
    oboas = AverageMeter()

    end_time = time.time()
    counts_oboa = []
    counts_all = []
    maes_all = []
    oboas_all = []

    cycle_length_dataset = np.zeros([150, pow(2, val_opt['merge_level'])],
                                    dtype=np.float)
    cycle_length_dataset_ptr = 0

    for i, (sample_inputs, _, _, label_counts,
            sample_len) in enumerate(data_loader):

        if val_opt['iter_terminal_num'] != -1 and i > val_opt[
                'iter_terminal_num']:
            break

        data_time.update(time.time() - end_time)
        end_time = time.time()

        batch_size = sample_inputs.size(0)

        # targets init
        label_counts = label_counts.numpy()
        sample_len = sample_len.numpy()
        level_pow = pow(2, val_opt['merge_level'])

        # track state init
        mp = np.zeros([batch_size, val_opt['merge_level'], level_pow],
                      dtype=np.int)
        lp_l = np.zeros([batch_size, val_opt['merge_level'], level_pow],
                        dtype=np.int)
        lp_r = np.zeros([batch_size, val_opt['merge_level'], level_pow],
                        dtype=np.int)
        rp_l = np.zeros([batch_size, val_opt['merge_level'], level_pow],
                        dtype=np.int)
        rp_r = np.zeros([batch_size, val_opt['merge_level'], level_pow],
                        dtype=np.int)

        load_lp = np.zeros(batch_size, dtype=np.int)
        load_mp = np.zeros(batch_size, dtype=np.int)
        load_rp = np.zeros(batch_size, dtype=np.int)
        save_lp = np.zeros(batch_size, dtype=np.int)
        save_mp = np.zeros(batch_size, dtype=np.int)
        save_rp = np.zeros(batch_size, dtype=np.int)

        load_ls = np.zeros(batch_size, dtype=np.float)
        load_rs = np.zeros(batch_size, dtype=np.float)
        save_ls = np.zeros(batch_size, dtype=np.float)
        save_rs = np.zeros(batch_size, dtype=np.float)

        counts = np.zeros(batch_size, dtype=np.float)

        # get the first estimation
        max_mp = np.zeros(batch_size, dtype=np.int)
        max_score = np.zeros(batch_size, dtype=np.float)
        for j in range(0, batch_size):
            max_score[j] = -1e6

        for k in range(0, val_opt['init_scale_num']):
            powers_level = (val_opt['max_scale'] / val_opt['min_scale'])**(
                float(k) / (val_opt['init_scale_num'] - 1))
            inputs = torch.zeros([
                batch_size, 3, opt.basic_duration, opt.sample_size,
                opt.sample_size
            ],
                                 dtype=torch.float).cuda()

            for j in range(0, batch_size):
                mp_k = sample_len[j] * val_opt['min_scale'] * powers_level
                mid_pt = sample_len[j] / 2
                inputs[j], _ = update_inputs_2stream(
                    sample_inputs[j],
                    [mid_pt - mp_k, mid_pt, mid_pt + mp_k + 1], sample_len[j],
                    opt)

            pred_cls, pred_box, _, _ = model(inputs)
            pred_box = torch.clamp(pred_box, min=-0.5, max=0.5)

            for j in range(0, batch_size):

                for p in range(3, 4):
                    box_exp = math.exp(pred_box[j][p])
                    pred_seg = box_exp * opt.anchors[p]
                    penalty = 1
                    score = F.softmax(pred_cls, dim=1)[j][1][p] * penalty
                    mp_k = sample_len[j] * val_opt[
                        'min_scale'] * powers_level * pred_seg
                    if score > max_score[j] and mp_k >= 4 and mp_k < sample_len[
                            j] / val_opt['min_cycles']:
                        max_score[j], max_mp[j] = score, mp_k

        for k in range(0, 4):
            inputs = torch.zeros([
                batch_size, 3, opt.basic_duration, opt.sample_size,
                opt.sample_size
            ],
                                 dtype=torch.float).cuda()
            for j in range(0, batch_size):
                mp_k = max_mp[j]
                mid_pt = sample_len[j] / 2
                inputs[j], _ = update_inputs_2stream(
                    sample_inputs[j],
                    [mid_pt - mp_k, mid_pt, mid_pt + mp_k + 1], sample_len[j],
                    opt)

            pred_cls, pred_box, _, _ = model(inputs)
            pred_box = torch.clamp(pred_box, min=-0.5, max=0.5)

            for j in range(0, batch_size):
                max_score[j] = -1e6
                tmp = max_mp[j]
                for p in range(3, 4):
                    box_exp = math.exp(pred_box[j][p])
                    pred_seg = box_exp * opt.anchors[p]
                    penalty = 1
                    score = F.softmax(pred_cls, dim=1)[j][1][p] * penalty
                    mp_k = tmp * pred_seg
                    if score > max_score[j] and mp_k >= 4 and mp_k < sample_len[
                            j] / val_opt['min_cycles']:
                        max_score[j], max_mp[j] = score, round(
                            float(max_mp[j] * (1 - val_opt['merge_w'])) +
                            float(mp_k * val_opt['merge_w']))

        for j in range(0, batch_size):
            for l2 in range(0, level_pow):
                mp[j, 0, l2] = int(
                    float(sample_len[j]) / float(level_pow + 1) * (l2 + 0.5))
                lp_l[j, 0, l2] = mp[j, 0, l2] - max_mp[j]
                rp_l[j, 0, l2] = mp[j, 0, l2] + max_mp[j] + 1
                lp_r[j, 0, l2] = lp_l[j, 0, l2]
                rp_r[j, 0, l2] = rp_l[j, 0, l2]

        total_steps = 0
        for l1 in range(1, val_opt['merge_level']):

            steps = pow(2, val_opt['merge_level'] - l1 - 1)
            pos = -steps
            for l2 in range(0, pow(2, l1)):
                pos = pos + 2 * steps

                if l1 == 1:
                    iters = 4
                elif l1 == 2:
                    iters = 2
                else:
                    iters = 1

                for l3 in range(0, iters):

                    total_steps = total_steps + 1

                    inputs = torch.zeros([
                        batch_size, 3, opt.basic_duration, opt.sample_size,
                        opt.sample_size
                    ],
                                         dtype=torch.float).cuda()
                    # network input initilization
                    for j in range(0, batch_size):

                        if l3 == 0:
                            load_mp[j] = mp[j, l1 - 1, pos]
                            load_lp[j] = round(
                                float(lp_l[j, l1 - 1, pos] +
                                      lp_r[j, l1 - 1, pos]) / 2)
                            load_rp[j] = round(
                                float(rp_l[j, l1 - 1, pos] +
                                      rp_r[j, l1 - 1, pos]) / 2)
                        else:
                            load_mp[j] = save_mp[j]
                            load_lp[j] = round(
                                float(save_lp[j]) * val_opt['merge_w'] +
                                float(load_lp[j]) * (1.0 - val_opt['merge_w']))
                            load_rp[j] = round(
                                float(save_rp[j]) * val_opt['merge_w'] +
                                float(load_rp[j]) * (1.0 - val_opt['merge_w']))

                        inputs[j], _ = update_inputs_2stream(
                            sample_inputs[j],
                            [load_lp[j], load_mp[j], load_rp[j]],
                            sample_len[j], opt)

                    # do the forward
                    inputs = Variable(inputs)
                    pred_cls_1, pred_box_1, pred_cls_2, pred_box_2 = model(
                        inputs)

                    pred_box_1 = torch.clamp(pred_box_1, min=-0.5, max=0.5)
                    pred_box_2 = torch.clamp(pred_box_2, min=-0.5, max=0.5)

                    # track state update
                    for j in range(0, batch_size):

                        max_score, action_1 = -1e6, -1
                        for k in range(0, opt.n_classes):
                            box_exp = math.exp(pred_box_1[j][k])
                            pred_seg = box_exp * opt.anchors[k]
                            penalty = 1
                            score = F.softmax(pred_cls_1,
                                              dim=1)[j][1][k] * penalty
                            if score > max_score:
                                max_score, action_1 = score, pred_seg
                                save_ls[j] = score

                        max_score, action_2 = -1e6, -1
                        for k in range(0, opt.n_classes):
                            box_exp = math.exp(pred_box_2[j][k])
                            pred_seg = box_exp * opt.anchors[k]
                            penalty = 1
                            score = F.softmax(pred_cls_2,
                                              dim=1)[j][1][k] * penalty
                            if score > max_score:
                                max_score, action_2 = score, pred_seg
                                save_rs[j] = score

                        if val_opt['abandon_second_box'] == True:
                            action_2 = action_1
                            save_rs[j] = save_ls[j]

                        new_state, done_flag, fail_flag = action_step(
                            [load_lp[j], load_mp[j], load_rp[j]], action_1,
                            action_2, 0, sample_len[j], opt, val_dataset)
                        save_lp[j], save_mp[j], save_rp[j] = new_state

                        if fail_flag:
                            save_lp[j] = load_lp[j]
                            save_rp[j] = load_rp[j]

                for j in range(0, batch_size):

                    l_segments = float(
                        save_lp[j]) * val_opt['merge_w'] + float(
                            load_lp[j]) * (1.0 - val_opt['merge_w'])
                    r_segments = float(
                        save_rp[j]) * val_opt['merge_w'] + float(
                            load_rp[j]) * (1.0 - val_opt['merge_w'])

                    for s in range(-steps, 0):
                        mp[j, l1, pos + s] = mp[j, l1 - 1, pos + s]
                        lp_r[j, l1, pos +
                             s] = mp[j, l1 - 1, pos + s] + (l_segments -
                                                            mp[j, l1 - 1, pos])
                        rp_r[j, l1, pos +
                             s] = mp[j, l1 - 1, pos + s] + (r_segments -
                                                            mp[j, l1 - 1, pos])

                        if l1 <= 2 or l1 == val_opt[
                                'merge_level'] - 1 or l2 == 0:
                            lp_l[j, l1, pos + s] = lp_r[j, l1, pos + s]
                            rp_l[j, l1, pos + s] = rp_r[j, l1, pos + s]
                        else:
                            lp_l[j, l1, pos + s] = lp_l[j, l1 - 1, pos + s]
                            rp_l[j, l1, pos + s] = rp_l[j, l1 - 1, pos + s]

                    for s in range(0, steps):
                        mp[j, l1, pos + s] = mp[j, l1 - 1, pos + s]
                        lp_l[j, l1, pos +
                             s] = mp[j, l1 - 1, pos + s] + (l_segments -
                                                            mp[j, l1 - 1, pos])
                        rp_l[j, l1, pos +
                             s] = mp[j, l1 - 1, pos + s] + (r_segments -
                                                            mp[j, l1 - 1, pos])
                        if l1 <= 2 or l1 == val_opt[
                                'merge_level'] - 1 or l2 == pow(2, l1) - 1:
                            lp_r[j, l1, pos + s] = lp_l[j, l1, pos + s]
                            rp_r[j, l1, pos + s] = rp_l[j, l1, pos + s]
                        else:
                            lp_r[j, l1, pos + s] = lp_r[j, l1 - 1, pos + s]
                            rp_r[j, l1, pos + s] = rp_r[j, l1 - 1, pos + s]

        for j in range(0, batch_size):
            left_avg = AverageMeter()
            right_avg = AverageMeter()
            for k in range(0, level_pow):
                last = val_opt['merge_level'] - 1

                lp_avg = round(float(lp_l[j, last, k] + lp_r[j, last, k]) / 2)
                rp_avg = round(float(rp_l[j, last, k] + rp_r[j, last, k]) / 2)

                pos1 = int(lp_avg -
                           (mp[j, last, k] - lp_avg + 1) * opt.l_context_ratio)
                pos2 = int(rp_avg + (rp_avg - mp[j, last, k] + 0) *
                           (opt.r_context_ratio - 1))
                if pos1 >= 0 and pos2 < sample_len[j]:

                    if val_dataset == 'quva' or val_dataset == 'yt_seg' or val_dataset == 'ucf_aug':
                        left_avg.update(1.0 /
                                        float(mp[j, last, k] - lp_avg + 1))
                        right_avg.update(1.0 / float(rp_avg - mp[j, last, k]))

                    else:
                        left_avg.update(float(mp[j, last, k] - lp_avg + 1))
                        right_avg.update(float(rp_avg - mp[j, last, k]))

                cycle_length_dataset[
                    cycle_length_dataset_ptr + j,
                    k] = 1.0 / float(mp[j, last, k] - lp_avg +
                                     1) + 1.0 / float(rp_avg - mp[j, last, k])

            if left_avg.avg == 0 or right_avg.avg == 0:
                counts[j] = float(sample_len[j]) / float(max_mp[j] + 1)
            else:
                if val_dataset == 'quva' or val_dataset == 'yt_seg' or val_dataset == 'ucf_aug':
                    counts[j] = float(
                        sample_len[j]) * float(left_avg.sum * 0.5 +
                                               right_avg.sum * 0.5) / float(
                                                   left_avg.count)

                else:
                    counts[j] = float(sample_len[j] +
                                      1e-6) / float(left_avg.avg * 0.5 +
                                                    right_avg.avg * 0.5)

            counts[j] = float(round(counts[j]))
            # print(sample_inputs.size(), sample_len[j], label_counts[j], counts[j], float(sample_len[j]) / float(max_mp[j]+1))

            counts_all.append(counts[j])

            mae = float(abs(counts[j] - label_counts[j])) / float(
                label_counts[j])
            if mae > 0.33:
                counts_oboa.append(i)

            if abs(counts[j] - label_counts[j]) > 1:
                oboa = 0.0
            else:
                oboa = 1.0

            maes_all.append(mae)
            oboas_all.append(oboa)

            maes.update(mae)
            if counts[j] > label_counts[j]:
                maeps.update(mae)
            elif counts[j] < label_counts[j]:
                maens.update(mae)
            oboas.update(oboa)

        batch_time.update(time.time() - end_time)
        cycle_length_dataset_ptr = cycle_length_dataset_ptr + batch_size

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'OBOA {oboa.val:.4f} ({oboa.avg:.4f})\t'
              'MAE {maes.val:.4f} ({maes.avg:.4f})\t'
              'MAEstd {maestd:.4f}\t'
              'MAEP {maeps.val:.4f} ({maeps.avg:.4f})\t'
              'MAEN {maens.val:.4f} ({maens.avg:.4f})\t'
              'total_steps {total_steps: d}\n'.format(epoch,
                                                      i + 1,
                                                      len(data_loader),
                                                      batch_time=batch_time,
                                                      oboa=oboas,
                                                      maes=maes,
                                                      maestd=maes.std(),
                                                      maeps=maeps,
                                                      maens=maens,
                                                      total_steps=total_steps))

    # np.save(val_dataset, cycle_length_dataset)

    epoch_logger.log({
        'epoch': epoch,
        'OBOA': oboas.avg,
        'MAE': maes.avg,
        'MAE_std': maes.std(),
        'MAEP': maeps.avg,
        'MAEN': maens.avg,
    })
    return maes.avg