Beispiel #1
0
def save_mask(mask, label, label_prob, max_prob, max_label, save_path, ind,
              tot_iters, im_sz, f_time, model_name, **kwargs):
    # label is gt_category
    category_map_dict = eutils.imagenet_label_mappings()
    mask = get_blurred_img(255 * mask, 1)
    mask = 1 - mask
    aa = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_VIRIDIS)
    aa = cv2.resize(aa, (im_sz, im_sz))

    aa = add_text(
        aa, 'Target: {} {:.3f}'.format(category_map_dict[label].split(',')[0],
                                       label_prob), **kwargs)
    # x_pt=50, scale=1, size=0.35)
    aa = add_text(
        aa,
        'Top-1: {} {:.3f}'.format(category_map_dict[max_label].split(',')[0],
                                  max_prob), **kwargs)

    aa = add_text(aa, 'Index is: {:3d}/{}'.format(ind, tot_iters), **kwargs)

    temp_path = os.path.join(save_path,
                             f'evolution_mask_time_{f_time}/{model_name}')
    eutils.mkdir_p(temp_path)
    cv2.imwrite(
        os.path.join(
            temp_path,
            "Model_{}_{:03d}_mask_{}.png".format(model_name, ind, label)), aa)
    return out


########################################################################################################################
if __name__ == '__main__':
    base_img_dir = abs_path(settings.imagenet_val_path)
    # base_img_dir = '/home/naman/CS231n/heatmap_tests/images/ILSVRC2012_img_val'
    # text_file = f'/home/naman/CS231n/heatmap_tests/' \
    #             f'Madri/Madri_New/robustness_applications/img_name_files/' \
    #             f'time_15669152608009198_seed_0_' \
    #             f'common_correct_imgs_model_names_madry_ressnet50_googlenet.txt'
    s_time = time.time()
    f_time = ''.join(str(s_time).split('.'))
    args = get_arguments()
    im_label_map = eutils.imagenet_label_mappings()
    eutils.mkdir_p(args.out_path)

    img_filenames = os.listdir(args.input_dir_path)
    img_filenames = [
        i for i in img_filenames if 'ILSVRC2012_val_000' in i
        and int(i.split('_')[-1]) in range(1, 50001)
    ]
    if args.idx_flag == 1:
        img_filenames = img_filenames[0]

    # ## TODO: Chnages here
    # incorrect_img_list = np.load('/home/naman/CS231n/heatmap_tests/Madri/Madri_New/'
    #                              'robustness_applications/img_name_files/incorrect_img_names.npy').tolist()

    ##############################################################