예제 #1
0
    def train_infer(self, model, epoch):

        rawLabelDir = '/data4/mjx/gd/raw_data/amap_traffic_annotations_test.json'
        image_datasets = roadDatasetInfer(self.test_dir)
        dataset_loaders = torch.utils.data.DataLoader(
            image_datasets,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=self.num_workers)
        model.eval()
        pre_result = []
        pre_name = []
        pre_dict = {}
        for data in dataset_loaders:
            inputs, paths = data
            inputs = inputs.cuda()
            outputs = model(inputs)
            _, preds = torch.max(outputs.data, 1)
            pre_result += preds.cpu().numpy().tolist()
            for frame in paths:
                pre_name.append(frame.split('/')[-1])
        assert len(pre_name) == len(pre_result)
        for idx in range(len(pre_result)):
            pre_dict[pre_name[idx]] = pre_result[idx]

        count_result = {'畅通': 0, '缓行': 0, '拥堵': 0}
        with open(rawLabelDir) as f:
            submit = json.load(f)
        submit_annos = submit['annotations']
        submit_result = []
        for i in range(len(submit_annos)):
            submit_anno = submit_annos[i]
            imgId = submit_anno['id']
            frame_name = [
                imgId + '_' + i['frame_name'] for i in submit_anno['frames']
            ]
            status_all = [pre_dict[i] for i in frame_name]
            status = max(status_all, key=status_all.count)
            submit['annotations'][i]['status'] = status

        submit_json = '{}/{}/{}_{}.json'.format(self.sub_dir, self.model_name,
                                                self.model_name, epoch)
        json_data = json.dumps(submit)
        with open(submit_json, 'w') as w:
            w.write(json_data)
        f_class, score, P, R, real_f1 = compare(submit_json, self.real_json)
        count_result = count(submit_json)
        count_result_real = count(self.real_json)
        self.logger.info("{} 第{} epoch 预测结果:{}".format(self.model_name, epoch,
                                                       count_result))
        self.logger.info("{} 预测结果:{}".format(self.real_json,
                                             count_result_real))
        self.logger.info("{} 和 {} 的 f1:{} 加权f1:{} ".format(
            self.model_name, self.real_json, f_class, real_f1))
        self.logger.info("{} 和 {} 的 Acc:{} Precision:{} Recall:{}".format(
            self.model_name, self.real_json, score, P, R))
예제 #2
0
def get_friends(user, red):
    user_data_df = models.load_df(store_name="user_data")
    user_data = red.get_user_content(user)  #wiinkme #red.get_user(user)

    # user_data = pd.DataFrame()
    # user_data['user'] = ['Zmini12']
    # user_data['subs'] = [{'AskReddit': 0.48514851485148514, 'RocketLeague': 0.039603960396039604, 'RocketLeagueExchange': 0.019801980198019802, 'interestingasfuck': 0.0297029702970297, 'gifs': 0.019801980198019802, 'reflex': 0.019801980198019802, 'woahdude': 0.009900990099009901, 'teenagers': 0.16831683168316833, 'GlobalOffensive': 0.019801980198019802, 'CringeAnarchy': 0.04950495049504951, 'videos': 0.04950495049504951, 'Brogress': 0.009900990099009901, 'talesfromsecurity': 0.009900990099009901, 'longboarding': 0.019801980198019802, 'justneckbeardthings': 0.0297029702970297, 'facepalm': 0.019801980198019802}]
    # user_data['categories'] = [{'meme': 0, 'default': 3, 'p**n': 0, 'political': 0, 'cool': 0, 'educational': 0}]

    if user_data is not None and user_data_df is not None:
        # print(dict(pd.DataFrame(user_data)))
        friend, score, common_subreddits, common_categories = compare.compare(
            user_data_df, user_data.iloc[0])

        # print("You should be friends with %s because you guys had a score of %s. Your common subreddits are %s, your common categories are %s"%(friend['user'], score, common_subreddits, common_categories))
        # print(dict(friend))
        # print(dict(user_data))
        return friend, score, common_subreddits, common_categories

    else:
        print('user is bad')
        return None, None, None
import re
import socket
import datetime
# hack to allow import from parent dir
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.sys.path.insert(0,parent_dir)
from utils.compare import compare

# run backup script
os.system("python ../backup_to_swift.py -c config.json")

# get swift package, decrypt, and unpack it
os.system("python ../download_decrypt_and_upack_swift_files.py -n test-backups -f 'test_backup_files' -c 'config.json' -d ./decrypted_swift_files/")

# compare extracted fileset to original fileset
comp = compare()
if not comp.are_dirs_equal('./decrypted_swift_files/test_backup_files','./test_backup_files'):
    raise Exception('FAIL: Files downloaded from swift, decrypted, and unpacked, do not match what was uploaded.')

# cleanup test goobers from swift
os.system("python ./cleanup_test_files.py -c config.json")

# cleanup local test goobers
host = socket.gethostname()
date = datetime.datetime.today()
date = date.strftime("%Y-%m-%d")
cleanup_pattern = "backup." + date + "." + host
for f in os.listdir(os.path.dirname(__file__)):
    if re.search(cleanup_pattern, f):
        os.remove(os.path.join(os.path.dirname(__file__), f))
예제 #4
0
 def test_compare_diff(self):
     img1 = cv2.imread(TEST_PHOTO)
     img2 = cv2.imread(TEST_PHOTO2)
     hash1 = compare.calc_hash(img1)
     hash2 = compare.calc_hash(img2)
     self.assertEqual(compare.compare(hash1, hash2), DIFF)