コード例 #1
0
ファイル: MAPO_workers.py プロジェクト: Jacobjohansen90/VQA
def MAPO_CPU(args, pg, sample_que, number):
    if args.info:
        print('MAPO process %s started' % str(number))
    dtype = torch.FloatTensor
    pg.type(dtype)
    while True:
        question = sample_que.get()
        q_name = '-'.join(str(e) for e in question if e != 0)
        question = torch.from_numpy(question).long()
        question = question.unsqueeze(0)
        bf_path = args.bf_load_path + q_name
        directory = args.high_reward_path + q_name + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)
        try:
            bf = CBF(filepath=bf_path)
        except:
            bf = CBF(est_elements=args.bf_est_ele,
                     false_positive_rate=args.bf_false_pos_rate)

        if args.multi_GPU and torch.cuda.device_count() > 1:
            program_preds, program_names, bf = pg.module.reinforce_MAPO_samples(
                question,
                bf,
                temperature=args.temperature,
                argmax=args.MAPO_sample_argmax)
        else:
            program_preds, program_names, bf = pg.reinforce_MAPO_samples(
                question,
                bf,
                temperature=args.temperature,
                argmax=args.MAPO_sample_argmax)
        bf.export(bf_path)
        for program, program_name in zip(program_preds, program_names):
            torch.save(program, directory + program_name + '_MAPO' + '.pt')
コード例 #2
0
    def test_cbf_load_file(self):
        """ test loading bloom filter from file """
        filename = "test.cbm"
        blm = CountingBloomFilter(est_elements=10, false_positive_rate=0.05)
        blm.add("this is a test")
        blm.export(filename)

        blm2 = CountingBloomFilter(filepath=filename)
        self.assertEqual("this is a test" in blm2, True)
        self.assertEqual("this is not a test" in blm2, False)
        os.remove(filename)
コード例 #3
0
    def test_cbf_load_file(self):
        ''' test loading bloom filter from file '''
        filename = 'test.cbm'
        blm = CountingBloomFilter(est_elements=10, false_positive_rate=0.05)
        blm.add('this is a test')
        blm.export(filename)

        blm2 = CountingBloomFilter(filepath=filename)
        self.assertEqual('this is a test' in blm2, True)
        self.assertEqual('this is not a test' in blm2, False)
        os.remove(filename)
コード例 #4
0
    def test_cbf_load_file(self):
        """test loading bloom filter from file"""
        with NamedTemporaryFile(dir=os.getcwd(),
                                suffix=".cbm",
                                delete=DELETE_TEMP_FILES) as fobj:
            blm = CountingBloomFilter(est_elements=10,
                                      false_positive_rate=0.05)
            blm.add("this is a test")
            blm.export(fobj.name)

            blm2 = CountingBloomFilter(filepath=fobj.name)
            self.assertEqual("this is a test" in blm2, True)
            self.assertEqual("this is not a test" in blm2, False)
コード例 #5
0
    def test_cbf_export_file(self):
        """ test exporting bloom filter to file """
        filename = "test.cbm"
        md5_val = "941b499746dd72d36658399b209d4869"
        blm = CountingBloomFilter(est_elements=10, false_positive_rate=0.01)
        blm.add("test")
        blm.add("out")
        blm.add("the")
        blm.add("counting")
        blm.add("bloom")
        blm.add("filter")

        blm.add("test")
        blm.add("Test")
        blm.add("out")
        blm.add("test")
        blm.export(filename)

        md5_out = calc_file_md5(filename)
        self.assertEqual(md5_out, md5_val)
        os.remove(filename)
コード例 #6
0
    def test_cbf_export_file(self):
        ''' test exporting bloom filter to file '''
        filename = 'test.cbm'
        md5_val = '941b499746dd72d36658399b209d4869'
        blm = CountingBloomFilter(est_elements=10, false_positive_rate=0.01)
        blm.add('test')
        blm.add('out')
        blm.add('the')
        blm.add('counting')
        blm.add('bloom')
        blm.add('filter')

        blm.add('test')
        blm.add('Test')
        blm.add('out')
        blm.add('test')
        blm.export(filename)

        md5_out = calc_file_md5(filename)
        self.assertEqual(md5_out, md5_val)
        os.remove(filename)
コード例 #7
0
    def test_cbf_export_file(self):
        """test exporting bloom filter to file"""
        md5_val = "0b83c837da30e25f768f0527c039d341"
        with NamedTemporaryFile(dir=os.getcwd(),
                                suffix=".cbm",
                                delete=DELETE_TEMP_FILES) as fobj:
            blm = CountingBloomFilter(est_elements=10,
                                      false_positive_rate=0.01)
            blm.add("test")
            blm.add("out")
            blm.add("the")
            blm.add("counting")
            blm.add("bloom")
            blm.add("filter")

            blm.add("test")
            blm.add("Test")
            blm.add("out")
            blm.add("test")
            blm.export(fobj.name)

            md5_out = calc_file_md5(fobj.name)
            self.assertEqual(md5_out, md5_val)