from utils import futils, Visualizer
import glob
import os

visual_dir = '/home/xueyan/antialias-cnn/data/output/resnet101_ori/visual'
files = glob.glob(os.path.join(visual_dir, '*/*.*'))

class_dict = {}
for pth in files:
    class_name = pth.split('/')[8]
    if class_name not in class_dict.keys():
        class_dict[class_name] = []
    class_dict[class_name].append(pth)

for key in class_dict.keys():
    visualizer = Visualizer(os.path.join(visual_dir, key),
                            demo_name='index.html')
    for pth in class_dict[key]:
        visual_pth = '/'.join(pth.split('/')[-1:])
        visual_name = pth.split('/')[-1]
        visualizer.insert(visual_pth, visual_name)
    visualizer.write()
def validate_shift(val_loader, model, args, visual):
    batch_time = AverageMeter()
    consist = []

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    visual_folder, top_visualizer, bottom_visualizer = visual

    # load json file for id to class
    idx_to_cls = json.load(
        open('../../data/imagenet-vid-robust/misc/imagenet_idx_to_name.json'))

    # when validate shift, batch size must be 1.
    assert args.batch_size == 1

    # switch to evaluate mode
    model.eval()
    visualizer_dict = {}

    # get evaluated class name
    src_pth = '/home/xueyan/antialias-cnn/data/output/resnet101_ori/visual'
    class_names = [x[0].split('/')[-1] for x in os.walk(src_pth)][1:]

    with torch.no_grad():
        end = time.time()

        video_pred_var = {}
        video_to_root_visual_pth = {}

        for i, ((input, target), pth) in enumerate(val_loader):
            hist_names = []

            img_visual_names = []
            img_visual_cap = []

            class_name, img_name = pth[0].split('/')[5], pth[0].split('/')[6]
            if i % 100 == 0:
                print('process ',
                      '[' + str(i) + '|' + str(len(val_loader)) + ']')
            if class_name in class_names:
                continue

            input_ = input.clone()
            n, c, h, w = input.shape
            h_c, w_c = h // 2, w // 2
            half_s = args.size // 2
            input = input[:, :, (h_c - half_s):(h_c + half_s),
                          (w_c - half_s):(w_c + half_s)]

            # save input center
            center_pth = os.path.join(args.out_dir, 'visual', class_name)

            if class_name in visualizer_dict.keys():
                visualizer = visualizer_dict[class_name]
            else:
                visualizer = Visualizer(os.path.join(args.out_dir, 'visual',
                                                     class_name),
                                        demo_name='index.html')
                visualizer_dict[class_name] = visualizer
            if not os.path.exists(center_pth):
                os.mkdir(center_pth)

            cur_img_name = img_name[:-5] + '_a_h_0_w_0.png'
            center_name = os.path.join(args.out_dir, 'visual', class_name,
                                       cur_img_name)
            hist_names.append(
                os.path.join(args.out_dir, 'visual', class_name,
                             img_name[:-5] + '_bar_h_0_w_0.png'))
            imsave(center_name,
                   ((input[0, :, :, :].permute(1, 2, 0).cpu().numpy()) *
                    np.array(std) + np.array(mean).reshape(1, 1, 3)) * 255)

            # append to html file
            visual_pth = os.path.join('/',
                                      args.out_dir.split('/')[-1], 'visual',
                                      class_name, cur_img_name)
            img_visual_names.append(visual_pth)
            img_visual_cap.append(cur_img_name)
            visualizer.insert(visual_pth, cur_img_name)

            offset = [-16, -10, -5, -2, -1, 1, 2, 5, 10, 16]
            for k in range(0, len(offset)):
                # shift h
                shift_h = input_[:, :, (h_c + offset[k] -
                                        half_s):(h_c + offset[k] + half_s),
                                 (w_c - half_s):(w_c + half_s)]
                input = torch.cat((input, shift_h), dim=0)

                cur_img_name = img_name[:-5] + '_a_' + 'h_' + str(
                    k) + '_w_0.png'
                h_name = os.path.join(args.out_dir, 'visual', class_name,
                                      cur_img_name)
                hist_names.append(
                    os.path.join(
                        args.out_dir, 'visual', class_name,
                        img_name[:-5] + '_bar_h_' + str(k) + '_w_0.png'))
                imsave(h_name,
                       ((shift_h[0, :, :, :].permute(1, 2, 0).cpu().numpy()) *
                        np.array(std) + np.array(mean).reshape(1, 1, 3)) * 255)
                # append to html file
                visual_pth = os.path.join('/',
                                          args.out_dir.split('/')[-1],
                                          'visual', class_name, cur_img_name)
                img_visual_names.append(visual_pth)
                img_visual_cap.append(cur_img_name)
                visualizer.insert(visual_pth, cur_img_name)

                # shift w
                shift_w = input_[:, :, (h_c - half_s):(h_c + half_s),
                                 (w_c + offset[k] - half_s):(w_c + offset[k] +
                                                             half_s)]
                input = torch.cat((input, shift_w), dim=0)

                cur_img_name = img_name[:-5] + '_a_h_0_w_' + str(k) + '.png'
                w_name = os.path.join(
                    args.out_dir, 'visual', class_name,
                    img_name[:-5] + '_a_h_0_w_' + str(k) + '.png')
                hist_names.append(
                    os.path.join(
                        args.out_dir, 'visual', class_name,
                        img_name[:-5] + '_bar_h_0_w_' + str(k) + '.png'))
                imsave(w_name,
                       ((shift_w[0, :, :, :].permute(1, 2, 0).cpu().numpy()) *
                        np.array(std) + np.array(mean).reshape(1, 1, 3)) * 255)
                # append to html file
                visual_pth = os.path.join('/',
                                          args.out_dir.split('/')[-1],
                                          'visual', class_name, cur_img_name)
                img_visual_names.append(visual_pth)
                img_visual_cap.append(cur_img_name)
                visualizer.insert(visual_pth, cur_img_name)

            if args.gpu is not None:
                input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            output = torch.nn.Softmax(dim=1)(model(input))
            consist.append(JS_Divergence(output))

            output_sort, output_argsort = torch.sort(output,
                                                     dim=1,
                                                     descending=True)

            for k in range(input.shape[0]):
                top5_arg_name = [
                    idx_to_cls[str(x.item())] for x in output_argsort[k, 0:5]
                ]
                top5_prob = output_sort[k, 0:5].cpu().numpy()

                ax = sns.barplot(x=top5_arg_name, y=top5_prob)
                plt.savefig(hist_names[k])
                plt.clf()

                cur_img_name = hist_names[k].split('/')[-1]
                visual_pth = os.path.join('/',
                                          args.out_dir.split('/')[-1],
                                          'visual', class_name, cur_img_name)
                img_visual_names.append(visual_pth)
                img_visual_cap.append(cur_img_name)
                visualizer.insert(visual_pth, cur_img_name)

            video_to_root_visual_pth[pth[0]] = [
                sorted(img_visual_names),
                sorted(img_visual_cap)
            ]
            video_pred_var[pth[0]] = JS_Divergence(output)

    for key in visualizer_dict:
        visualizer_dict[key].write()
def validate_shift(val_loader, model, args, visual):
    batch_time = AverageMeter()
    consist = []

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    visual_folder, top_visualizer, bottom_visualizer = visual

    # load json file for id to class
    idx_to_cls = json.load(
        open('../../data/imagenet-vid-robust/misc/imagenet_idx_to_name.json'))

    # when validate shift, batch size must be 1.
    assert args.batch_size == 1

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()

        video_pred_var = {}
        video_to_root_visual_pth = {}

        for i, ((input, target), pth) in enumerate(val_loader):

            if i % 25 == 0:
                hist_names = []

                img_visual_names = []
                img_visual_cap = []

                class_name, img_name = pth[0].split('/')[5], pth[0].split(
                    '/')[6]

                input_ = input.clone()
                n, c, h, w = input.shape
                h_c, w_c = h // 2, w // 2
                half_s = args.size // 2
                input = input[:, :, (h_c - half_s):(h_c + half_s),
                              (w_c - half_s):(w_c + half_s)]

                # save input center
                center_pth = os.path.join(args.out_dir, 'visual', class_name)
                visualizer = Visualizer(os.path.join(args.out_dir, 'visual',
                                                     class_name),
                                        demo_name='index.html')
                if not os.path.exists(center_pth):
                    os.mkdir(center_pth)

                cur_img_name = 'a_' + img_name[:-5] + '_h_0_w_0.png'
                center_name = os.path.join(args.out_dir, 'visual', class_name,
                                           cur_img_name)
                hist_names.append(
                    os.path.join(args.out_dir, 'visual', class_name,
                                 'bar_' + img_name[:-5] + '_h_0_w_0.png'))
                imsave(center_name,
                       ((input[0, :, :, :].permute(1, 2, 0).cpu().numpy()) *
                        np.array(std) + np.array(mean).reshape(1, 1, 3)) * 255)

                # append to html file
                visual_pth = os.path.join('/',
                                          args.out_dir.split('/')[-1],
                                          'visual', class_name, cur_img_name)
                img_visual_names.append(visual_pth)
                img_visual_cap.append(cur_img_name)
                visualizer.insert(visual_pth, cur_img_name)

                offset = [-16, -10, -5, -2, -1, 1, 2, 5, 10, 16]
                for k in range(0, len(offset)):
                    # shift h
                    shift_h = input_[:, :, (h_c + offset[k] -
                                            half_s):(h_c + offset[k] + half_s),
                                     (w_c - half_s):(w_c + half_s)]
                    input = torch.cat((input, shift_h), dim=0)

                    cur_img_name = 'a_' + img_name[:-5] + '_h_' + str(
                        k) + '_w_0.png'
                    h_name = os.path.join(args.out_dir, 'visual', class_name,
                                          cur_img_name)
                    hist_names.append(
                        os.path.join(
                            args.out_dir, 'visual', class_name, 'bar_' +
                            img_name[:-5] + '_h_' + str(k) + '_w_0.png'))
                    imsave(
                        h_name,
                        ((shift_h[0, :, :, :].permute(1, 2, 0).cpu().numpy()) *
                         np.array(std) + np.array(mean).reshape(1, 1, 3)) *
                        255)
                    # append to html file
                    visual_pth = os.path.join('/',
                                              args.out_dir.split('/')[-1],
                                              'visual', class_name,
                                              cur_img_name)
                    img_visual_names.append(visual_pth)
                    img_visual_cap.append(cur_img_name)
                    visualizer.insert(visual_pth, cur_img_name)

                    # shift w
                    shift_w = input_[:, :, (h_c - half_s):(h_c + half_s),
                                     (w_c + offset[k] -
                                      half_s):(w_c + offset[k] + half_s)]
                    input = torch.cat((input, shift_w), dim=0)

                    cur_img_name = 'a_' + img_name[:-5] + '_h_0_w_' + str(
                        k) + '.png'
                    w_name = os.path.join(
                        args.out_dir, 'visual', class_name,
                        'a_' + img_name[:-5] + '_h_0_w_' + str(k) + '.png')
                    hist_names.append(
                        os.path.join(
                            args.out_dir, 'visual', class_name, 'bar_' +
                            img_name[:-5] + '_h_0_w_' + str(k) + '.png'))
                    imsave(
                        w_name,
                        ((shift_w[0, :, :, :].permute(1, 2, 0).cpu().numpy()) *
                         np.array(std) + np.array(mean).reshape(1, 1, 3)) *
                        255)
                    # append to html file
                    visual_pth = os.path.join('/',
                                              args.out_dir.split('/')[-1],
                                              'visual', class_name,
                                              cur_img_name)
                    img_visual_names.append(visual_pth)
                    img_visual_cap.append(cur_img_name)
                    visualizer.insert(visual_pth, cur_img_name)

                if args.gpu is not None:
                    input = input.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

                output = torch.nn.Softmax(dim=1)(model(input))
                consist.append(JS_Divergence(output))

                output_sort, output_argsort = torch.sort(output,
                                                         dim=1,
                                                         descending=True)

                for k in range(input.shape[0]):
                    top5_arg_name = [
                        idx_to_cls[str(x.item())]
                        for x in output_argsort[k, 0:5]
                    ]
                    top5_prob = output_sort[k, 0:5].cpu().numpy()

                    ax = sns.barplot(x=top5_arg_name, y=top5_prob)
                    plt.savefig(hist_names[k])
                    plt.clf()

                    cur_img_name = hist_names[k].split('/')[-1]
                    visual_pth = os.path.join('/',
                                              args.out_dir.split('/')[-1],
                                              'visual', class_name,
                                              cur_img_name)
                    img_visual_names.append(visual_pth)
                    img_visual_cap.append(cur_img_name)
                    visualizer.insert(visual_pth, cur_img_name)

                video_to_root_visual_pth[pth[0]] = [
                    sorted(img_visual_names),
                    sorted(img_visual_cap)
                ]
                video_pred_var[pth[0]] = JS_Divergence(output)

                visualizer.write()
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                print('process ',
                      '[' + str(i) + '|' + str(len(val_loader)) + ']')

        sorted_var_top = sorted(video_pred_var.items(),
                                key=lambda item: item[1],
                                reverse=True)
        sorted_var_back = sorted(video_pred_var.items(),
                                 key=lambda item: item[1],
                                 reverse=False)

        for i in range(0, 100):
            vid_name, var = sorted_var_top[i]
            video_visual_pths, video_captions = video_to_root_visual_pth[
                vid_name]
            for j in range(0, len(video_visual_pths)):
                top_visualizer.insert(
                    video_visual_pths[j],
                    video_captions[j][:-4] + '_' + str(var)[0:6])

            vid_name, var = sorted_var_back[i]
            video_visual_pths, video_captions = video_to_root_visual_pth[
                vid_name]
            for j in range(0, len(video_visual_pths)):
                bottom_visualizer.insert(
                    video_visual_pths[j],
                    video_captions[j][:-4] + '_' + str(var)[0:6])

        top_visualizer.write(_sorted=False)
        bottom_visualizer.write(_sorted=False)

        print(' * Consistency ', sum(consist) / len(consist))
    return consist