def test_chip(test_set, rebuilder, transform, save_dir):
    _t = Timer()
    cost_time = list()
    for type in test_set.test_dict:
        img_list = test_set.test_dict[type]
        if not os.path.exists(os.path.join(save_dir, type)):
            os.mkdir(os.path.join(save_dir, type))
        for k, path in enumerate(img_list):
            image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            _t.tic()
            ori_img, input_tensor = transform(image)
            out = rebuilder.inference(input_tensor)
            re_img = out[0]
            s_map = ssim_seg(ori_img, re_img, win_size=11, gaussian_weights=True)
            mask = seg_mask(s_map, threshold=32)
            inference_time = _t.toc()
            cat_img = np.concatenate((ori_img, re_img, mask), axis=1)
            cv2.imwrite(os.path.join(save_dir, type, '{:d}.png'.format(k)), cat_img)
            cost_time.append(inference_time)
            if (k+1) % 20 == 0:
                print('{}th image, cost time: {:.1f}'.format(k+1, inference_time*1000))
            _t.clear()
    # calculate mean time
    cost_time = np.array(cost_time)
    cost_time = np.sort(cost_time)
    num = cost_time.shape[0]
    num90 = int(num*0.9)
    cost_time = cost_time[0:num90]
    mean_time = np.mean(cost_time)
    print('Mean_time: {:.1f}ms'.format(mean_time*1000))
def test_chip(test_set, rebuilder, transform, save_dir, configs):
    _t = Timer()
    cost_time = list()
    iou_list={}
    s_map_list=list()
    for type in test_set.test_dict:
        img_list = test_set.test_dict[type]
        if not os.path.exists(os.path.join(save_dir, type)):
            os.mkdir(os.path.join(save_dir, type))
            os.mkdir(os.path.join(save_dir, type, 'ori'))
            os.mkdir(os.path.join(save_dir, type, 'gen'))
            os.mkdir(os.path.join(save_dir, type, 'mask'))
        if not os.path.exists(os.path.join(save_dir, type,'ROC_curve')):
            os.mkdir(os.path.join(save_dir, type, 'ROC_curve'))
        for k, path in enumerate(img_list):
            name= path.split('/')[-1]
            image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            _t.tic()
            ori_img, input_tensor = transform(image)
            out = rebuilder.inference(input_tensor)
            re_img = out[0]
            s_map = ssim_seg(ori_img, re_img, win_size=11, gaussian_weights=True)
            _h, _w = image.shape
            s_map_save = cv2.resize(s_map, (_w, _h))
            s_map_list.append(s_map_save.reshape(-1,1))
            mask = seg_mask(s_map, threshold=128)
            inference_time = _t.toc()
            if configs['db']['resize'] == [832, 832]:
                #cat_img = np.concatenate((ori_img[32:-32,32:-32], re_img[32:-32,32:-32], mask[32:-32,32:-32]), axis=1)
                cv2.imwrite(os.path.join(save_dir, type, 'ori', 'mask{:d}.png'.format(k)), ori_img[32:-32,32:-32])
                cv2.imwrite(os.path.join(save_dir, type, 'gen', 'mask{:d}.png'.format(k)), re_img[32:-32,32:-32])
                cv2.imwrite(os.path.join(save_dir, type, 'mask', 'mask{:d}.png'.format(k)), mask[32:-32,32:-32])
            elif configs['db']['resize'] == [768, 768]:
                cv2.imwrite(os.path.join(save_dir, type, 'ori', 'mask{:d}.png'.format(k)), ori_img)
                cv2.imwrite(os.path.join(save_dir, type, 'gen', 'mask{:d}.png'.format(k)), re_img)
                cv2.imwrite(os.path.join(save_dir, type, 'mask', 'mask{:d}.png'.format(k)), mask)
            elif configs['db']['resize'] == [256, 256]:
                cv2.imwrite(os.path.join(save_dir, type, 'ori', name), ori_img)
                cv2.imwrite(os.path.join(save_dir, type, 'gen', name), re_img)
                cv2.imwrite(os.path.join(save_dir, type, 'mask', name), mask)
            else:
                raise Exception("invaild image size")
            #cv2.imwrite(os.path.join(save_dir, type, '{:d}.png'.format(k)), cat_img)
            cost_time.append(inference_time)
            if (k+1) % 20 == 0:
                print('{}th image, cost time: {:.1f}'.format(k+1, inference_time*1000))
            _t.clear()
        torch.save(s_map_list,os.path.join(save_dir) + '/s_map.pth')
    # calculate mean time
    cost_time = np.array(cost_time)
    cost_time = np.sort(cost_time)
    num = cost_time.shape[0]
    num90 = int(num*0.9)
    cost_time = cost_time[0:num90]
    mean_time = np.mean(cost_time)
    print('Mean_time: {:.1f}ms'.format(mean_time*1000))
    test_set.eval(save_dir)
                break
        lr = adjust_learning_rate(trainer, learning_rate, decay_rate, epoch,
                                  step_index, iteration, epoch_size)

        # load data
        _t.tic()
        images = next(batch_iterator)
        if configs['model']['type'] == 'AutoEncoder':
            trainer.train(images)

        else:
            raise Exception("Wrong model type!")
        batch_time = _t.toc()

        # print message
        if iteration % 10 == 0:
            _t.clear()
            mes = 'Epoch:' + repr(epoch) + '||epochiter: ' + repr(
                iteration % epoch_size) + '/' + repr(epoch_size)
            mes += '||Totel iter: ' + repr(iteration)
            mes += '||{}'.format(trainer.get_loss_message())
            mes += '||LR: %.8f' % (lr)
            mes += '||Batch time: %.4f sec.' % batch_time
            log.wr_mes(mes)
            print(mes)
    save_name = '{}-{:d}.pth'.format(args.cfg, epoch)
    save_path = os.path.join(save_dir, save_name)
    trainer.save_params(save_path)
    log.close()
    exit(0)
def test_mvtec(test_set, rebuilder, transform, save_dir, threshold_seg_dict, val_index):
    _t = Timer()
    cost_time = list()
    threshold_dict = dict()
    if not os.path.exists(os.path.join(save_dir, 'ROC_curve')):
        os.mkdir(os.path.join(save_dir, 'ROC_curve'))
    for item in test_set.test_dict:
        threshold_list = list()
        item_dict = test_set.test_dict[item]

        if not os.path.exists(os.path.join(save_dir, item)):
            os.mkdir(os.path.join(save_dir, item))
            os.mkdir(os.path.join(save_dir, item, 'ori'))
            os.mkdir(os.path.join(save_dir, item, 'gen'))
            os.mkdir(os.path.join(save_dir, item, 'mask'))
        for type in item_dict:
            if not os.path.exists(os.path.join(save_dir, item, 'ori', type)):
                os.mkdir(os.path.join(save_dir, item, 'ori', type))
            if not os.path.exists(os.path.join(save_dir, item, 'gen', type)):
                os.mkdir(os.path.join(save_dir, item, 'gen', type))
            if not os.path.exists(os.path.join(save_dir, item, 'mask', type)):
                os.mkdir(os.path.join(save_dir, item, 'mask', type))
            _time = list()
            img_list = item_dict[type]
            for path in img_list:
                image = cv2.imread(path, cv2.IMREAD_COLOR)
                ori_h, ori_w, _ = image.shape
                _t.tic()
                ori_img, input_tensor = transform(image)
                out = rebuilder.inference(input_tensor)
                re_img = out.transpose((1, 2, 0))
                s_map = ssim_seg(ori_img, re_img, win_size=11, gaussian_weights=True)
                s_map = cv2.resize(s_map, (ori_w, ori_h))
                if val_index == 1:
                    mask = seg_mask(s_map, threshold=threshold_seg_dict[item])
                elif val_index == 0:
                    mask = seg_mask(s_map, threshold=threshold_seg_dict)
                else:
                    raise Exception("Invalid val_index")

                inference_time = _t.toc()
                img_id = path.split('.')[0][-3:]
                cv2.imwrite(os.path.join(save_dir, item, 'ori', type, '{}.png'.format(img_id)), ori_img)
                cv2.imwrite(os.path.join(save_dir, item, 'gen', type, '{}.png'.format(img_id)), re_img)
                cv2.imwrite(os.path.join(save_dir, item, 'mask', type, '{}.png'.format(img_id)), mask)
                _time.append(inference_time)

                if type != 'good':
                    threshold_list.append(s_map)
                else:
                    pass

            cost_time += _time
            mean_time = np.array(_time).mean()
            print('Evaluate: Item:{}; Type:{}; Mean time:{:.1f}ms'.format(item, type, mean_time*1000))
            _t.clear()
        threshold_dict[item] = threshold_list
    # calculate mean time
    cost_time = np.array(cost_time)
    cost_time = np.sort(cost_time)
    num = cost_time.shape[0]
    num90 = int(num*0.9)
    cost_time = cost_time[0:num90]
    mean_time = np.mean(cost_time)
    print('Mean_time: {:.1f}ms'.format(mean_time*1000))

    # evaluate results
    print('Evaluating...')
    test_set.eval(save_dir,threshold_dict)
def test_mvtec(test_set, rebuilder, transform, save_dir, threshold_seg_dict, configs):
    _t = Timer()
    cost_time = list()
    if not os.path.exists(os.path.join(save_dir, 'ROC_curve')):
        os.mkdir(os.path.join(save_dir, 'ROC_curve'))
    for item in test_set.test_dict:
        s_map_list = list()
        s_map_good_list=list()
        item_dict = test_set.test_dict[item]

        if not os.path.exists(os.path.join(save_dir, item)):
            os.mkdir(os.path.join(save_dir, item))
            os.mkdir(os.path.join(save_dir, item, 'ori'))
            os.mkdir(os.path.join(save_dir, item, 'gen'))
            os.mkdir(os.path.join(save_dir, item, 'mask'))
            #os.mkdir(os.path.join(save_dir, item))
        for type in item_dict:
            if not os.path.exists(os.path.join(save_dir, item, 'ori', type)):
                os.mkdir(os.path.join(save_dir, item, 'ori', type))
            if not os.path.exists(os.path.join(save_dir, item, 'gen', type)):
                os.mkdir(os.path.join(save_dir, item, 'gen', type))
            if not os.path.exists(os.path.join(save_dir, item, 'mask', type)):
                os.mkdir(os.path.join(save_dir, item, 'mask', type))
            _time = list()
            img_list = item_dict[type]
            for path in img_list:
                image = cv2.imread(path, cv2.IMREAD_COLOR)
                ori_h, ori_w, _ = image.shape
                _t.tic()
                ori_img, input_tensor = transform(image)
                out = rebuilder.inference(input_tensor)
                re_img = out.transpose((1, 2, 0))
                s_map = ssim_seg_mvtec(ori_img, re_img, configs, win_size=3, gaussian_weights=True)
                if threshold_seg_dict: # dict is not empty
                    mask = seg_mask_mvtec(s_map, threshold_seg_dict[item],configs)
                else:
                    mask = seg_mask_mvtec(s_map, 64, configs)
                inference_time = _t.toc()
                img_id = path.split('.')[0][-3:]
                cv2.imwrite(os.path.join(save_dir, item, 'ori', type, '{}.png'.format(img_id)), ori_img)
                cv2.imwrite(os.path.join(save_dir, item, 'gen', type, '{}.png'.format(img_id)), re_img)
                cv2.imwrite(os.path.join(save_dir, item, 'mask', type, '{}.png'.format(img_id)), mask)
                _time.append(inference_time)
                if type != 'good':
                    s_map_bad=s_map.reshape(-1,1)
                    s_map_list.append(s_map_bad)
                else:
                    s_map_good = s_map.reshape(-1, 1)
                    s_map_good_list.append(s_map_good)
            cost_time += _time
            mean_time = np.array(_time).mean()
            print('Evaluate: Item:{}; Type:{}; Mean time:{:.1f}ms'.format(item, type, mean_time*1000))
            _t.clear()
        torch.save(s_map_list, os.path.join(save_dir, item) + '/s_map.pth')
        torch.save(s_map_good_list, os.path.join(save_dir, item) + '/s_map_good.pth')

    # calculate mean time
    cost_time = np.array(cost_time)
    cost_time = np.sort(cost_time)
    num = cost_time.shape[0]
    num90 = int(num*0.9)
    cost_time = cost_time[0:num90]
    mean_time = np.mean(cost_time)
    print('Mean_time: {:.1f}ms'.format(mean_time*1000))

    # evaluate results
    print('Evaluating...')
    test_set.eval(save_dir)