コード例 #1
0
def test(ckpt_path,
         model_type='wesup',
         input_size=None,
         scales=(0.5, ),
         device='cpu'):

    ckpt_path = Path(ckpt_path)
    trainer = initialize_trainer(model_type, device=device)
    trainer.load_checkpoint(ckpt_path)

    record_dir = ckpt_path.parent.parent

    if input_size is not None:
        results_dir = record_dir / 'results'
    else:
        results_dir = record_dir / f'results-{len(scales)}scale'

    if not results_dir.exists():
        results_dir.mkdir()

    try:
        print('\nTesting on test set A ...')
        data_dir = Path.home() / 'data' / 'GLAS_all' / 'testA'
        output_dir = results_dir / 'testA'
        infer(trainer, data_dir, output_dir, input_size, scales, device=device)

        print('\nTesting on test set B ...')
        data_dir = Path.home() / 'data' / 'GLAS_all' / 'testB'
        output_dir = results_dir / 'testB'
        infer(trainer, data_dir, output_dir, input_size, scales, device=device)
    finally:
        rmtree('models_ckpt', ignore_errors=True)
コード例 #2
0
ファイル: train.py プロジェクト: weifj0212/PointRend-PyTorch
def train(C, save_dir, loader, val_loader, net, optim, device):
    for e in range(C.epochs):
        loss = step(e, loader, net, optim, device)
        if is_main_process() and (e % 10) == 0:
            torch.save(net.state_dict(),
                       f"{save_dir}/epoch_{e:04d}_loss_{loss:.5f}.pth")
        infer(val_loader, net, device)
コード例 #3
0
ファイル: train.py プロジェクト: labimage/PointRend-PyTorch
def train(C, save_dir, device, loader, val_loader, net, optim):
    for e in range(C.epochs):
        loss = step(e, loader, net, optim)
        if (e % 10) == 0:
            torch.save(net.state_dict(),
                       f"{save_dir}/epoch_{e:04d}_loss_{loss:.5f}.pth")
        infer(device, val_loader, net)
コード例 #4
0
def count_detection_score_fasterrcnn(img_file_dir, bb_json_name, output_dir):
    config = './mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
    checkpoint = './models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
    infer(config=config,
          checkpoint=checkpoint,
          img_file_dir=img_file_dir + '/',
          output_dir=output_dir,
          json_name=bb_json_name)
    return
コード例 #5
0
ファイル: main.py プロジェクト: GKarmakar/oreilly-pytorch
def main(FLAGS):
    """
    """

    if FLAGS.mode == "train":
        train(FLAGS)
    elif FLAGS.mode == "infer":
        infer(FLAGS)
    else:
        raise Exception("Choose --mode=<train|infer>")
コード例 #6
0
ファイル: main.py プロジェクト: winnchow/oreilly-pytorch
def main(FLAGS):
    """
    """

    if FLAGS.mode == "train":
        train(FLAGS)
    elif FLAGS.mode == "infer":
        infer(FLAGS)
    else:
        raise Exception("Choose --mode=<train|infer>")
コード例 #7
0
def main():
    args = parse()
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    model, state = load_model(args)
    if args.command == 'train':
        assert args.val_date is not None or args.val_freq == 0, 'Specify date for validation or set val-freq to 0'
        assert args.band_size % 2, 'Band-size must be 2*k+1 where k is the number of bands'
        train(args, model, state)
    else:
        infer(args, model)
コード例 #8
0
def run(args):
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    logger = logging.getLogger("nmt_zh")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(os.path.join(args.out_dir, "log"))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    default_hparams = create_hparams(args)
    # Load hparams.
    hparams = create_or_load_hparams(default_hparams.out_dir, default_hparams)

    utils.log('Running with hparams : {}'.format(hparams))

    random_seed = hparams.random_seed
    if random_seed is not None and random_seed > 0:
        utils.log('Set random seed to {}'.format(random_seed))
        random.seed(random_seed)
        np.random.seed(random_seed)
        tf.set_random_seed(random_seed)

    if hparams.inference_input_file:
        utils.log('Inferring ...')
        # infer
        trans_file = hparams.inference_output_file
        ckpt = hparams.ckpt
        if not ckpt:
            ckpt = tf.train.latest_checkpoint(hparams.out_dir)
        utils.log('Use checkpoint: {}'.format(ckpt))
        utils.log('Start infer sentence in {}, output saved to {} ...'.format(
            hparams.inference_input_file, trans_file))
        infer.infer(ckpt, hparams.inference_input_file, trans_file, hparams)

        # eval
        ref_file = hparams.inference_ref_file
        if ref_file and os.path.exists(trans_file):
            utils.log(
                'Evaluating infer output with reference in {} ...'.format(
                    ref_file))
            score = evaluation_utils.evaluate(ref_file, trans_file, 'BLEU')
            utils.log("BLEU: %.1f" % (score, ))
    else:
        utils.log('Training ...')
        train.train(hparams)
コード例 #9
0
def main(FLAGS):
    """
    """

    if FLAGS.mode == 'train':

        # Process the data
        train_data, test_data = process_data(
            data_dir=FLAGS.data_dir,
            split_ratio=FLAGS.split_ratio,
        )

        # Sample
        sample(
            data=train_data,
            data_dir=FLAGS.data_dir,
        )

        # Load components
        with open(os.path.join(basedir, FLAGS.data_dir, 'char2index.json'),
                  'r') as f:
            char2index = json.load(f)

        # Training
        train(
            data_dir=FLAGS.data_dir,
            char2index=char2index,
            train_data=train_data,
            test_data=test_data,
            num_epochs=FLAGS.num_epochs,
            batch_size=FLAGS.batch_size,
            num_filters=FLAGS.num_filters,
            learning_rate=FLAGS.lr,
            decay_rate=FLAGS.decay_rate,
            max_grad_norm=FLAGS.max_grad_norm,
            dropout_p=FLAGS.dropout_p,
        )

    elif FLAGS.mode == 'infer':

        # Inference
        infer(
            data_dir=FLAGS.data_dir,
            model_name=FLAGS.model_name,
            sentence=FLAGS.sentence,
        )

    else:
        raise Exception('Choose --mode train|infer')
コード例 #10
0
ファイル: main.py プロジェクト: GKarmakar/oreilly-pytorch
def main(FLAGS):
    """
    """

    if FLAGS.mode == 'train':

        # Process the data
        train_data, test_data = process_data(
            data_dir=FLAGS.data_dir,
            split_ratio=FLAGS.split_ratio,
            )

        # Sample
        sample(
            data=train_data,
            data_dir=FLAGS.data_dir,
            )

        # Load components
        with open(os.path.join(basedir, FLAGS.data_dir, 'char2index.json'), 'r') as f:
            char2index = json.load(f)

        # Training
        train(
            data_dir=FLAGS.data_dir,
            char2index=char2index,
            train_data=train_data,
            test_data=test_data,
            num_epochs=FLAGS.num_epochs,
            batch_size=FLAGS.batch_size,
            num_filters=FLAGS.num_filters,
            learning_rate=FLAGS.lr,
            decay_rate=FLAGS.decay_rate,
            max_grad_norm=FLAGS.max_grad_norm,
            dropout_p=FLAGS.dropout_p,
            )

    elif FLAGS.mode == 'infer':

        # Inference
        infer(
            data_dir=FLAGS.data_dir,
            model_name=FLAGS.model_name,
            sentence=FLAGS.sentence,
            )

    else:
        raise Exception('Choose --mode train|infer')
コード例 #11
0
def video_parse(videoPath, exerciseName, idName):
    cap = cv2.VideoCapture(videoPath)
    vidData = []
    fps = cap.get(cv2.CAP_PROP_FPS)  # OpenCV2 version 2 used "CV_CAP_PROP_FPS"
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    framecounter = 0
    duration = frame_count / fps
    vidMeta = {
        'total_frames':
        frame_count,
        'length':
        duration,
        'size':
        (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    }
    print(vidMeta)
    while (cap.isOpened()):
        ret, frame = cap.read()
        if ret:
            _, key_coor, key_conf = infer(frame)
            vidData.append((framecounter, key_coor, key_conf))
            framecounter += 1
            print("%s/%s" % (framecounter, frame_count))
        else:
            break
    utils.generateGT(vidData, exerciseName, idName, meta=vidMeta)
コード例 #12
0
 def testAmbg2(self):
     # Note: as described in Issue #5 (https://github.com/jeffreystarr/dateinfer/issues/5), the result
     # should be %d/%m/%Y as the more likely choice. However, at this point, we will allow %m/%d/%Y.
     self.assertIn(
         infer.infer(
             ['04/12/2012', '05/12/2012', '06/12/2012', '07/12/2012']),
         ['%d/%m/%Y', '%m/%d/%Y'])
コード例 #13
0
def evaluate(c):
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32, (None, ) + c.image_shape,
                           name="input-x")
        y = tf.placeholder(tf.float32, (None, c.classes), name="input-y")
        feed = {x: c.x_valid, y: c.y_valid}
        z = infer.infer(c, x, False, None)
        acc_count = tf.equal(tf.argmax(z, 1), tf.argmax(y, 1))
        acc = tf.reduce_mean(tf.cast(acc_count, tf.float32))
        name_map = tf.train.ExponentialMovingAverage(
            c.moving_average_decay).variables_to_restore()
        for k, v in name_map.items():
            print("\trestore: " + str(k) + " -> " + str(v))
        saver = tf.train.Saver(name_map)  #use moving average parameter
        #saver = tf.train.Saver()
        while True:
            with tf.Session() as s:
                ckpt = tf.train.get_checkpoint_state(c.model_savepath)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(s, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split(
                        '/')[-1].split('-')[-1]
                    acc_score = s.run(acc, feed_dict=feed)
                    print("epoch[%s] validation acc=%f" %
                          (global_step, acc_score))
                else:
                    print("no checkpoint file found.")
                time.sleep(interval)
コード例 #14
0
ファイル: rest.py プロジェクト: norman00/stacks-usecase-OLD
async def pred():
    issue = list()
    issue.append(quart.request.json["issue"])
    if not quart.request.json or not "issue" in quart.request.json:
        quart.abort(400)
    labels = infer(issue)
    return quart.jsonify({"label": labels}), 201
コード例 #15
0
def transcribe_file(rec_path):

    initialize()

    wav_obj = wave.open(rec_path)
    #if not audio_utils.has_speech(wav_obj):
    #print "no speech"
    #    return ""

    tmp_dir_path = os.path.join(os.getcwd(), "tmp")
    # filter
    # normalize volume
    #audio_wav_volume_normalized_path = rec_path+"_normalized.wav"
    #print("Normalizing volume... %s" % (audio_wav_path))
    #audio_utils.loud_norm(rec_path, audio_wav_volume_normalized_path)

    # correct volume
    audio_wav_volume_corrected_path = rec_path + "_volume_corrected.wav"
    #print("Correcting volume...")
    audio_utils.correct_volume(rec_path, audio_wav_volume_corrected_path)

    # apply bandpass filter
    audio_wav_filtered_path = rec_path + "_filtered.wav"
    #print("Applying bandpass filter...")
    audio_utils.apply_bandpass_filter(audio_wav_volume_corrected_path,
                                      audio_wav_filtered_path)

    return infer.infer(audio_wav_filtered_path, session)
コード例 #16
0
    def execute(self):
        """执行业务逻辑"""
        utils.info(
            'API REQUEST INFO[' + self.request.path + '][' +
            self.request.method + '][' + self.request.remote_ip + '][' +
            str(self.request.arguments) + ']', ApiImageDigit)
        img_file = self.get_argument('img_file', '')
        model_path = self.get_argument('model_path', '')
        if img_file == '':
            return {'code': 2, 'msg': 'img_file不能为空'}
        if model_path == '':
            model_path = MODEL_PATH
        res = {}

        try:
            ret, msg, res = infer.infer(img_file, model_path)
            if ret != 0:
                utils.error('execute fail [' + img_file + '] ' + msg,
                            ApiImageDigit)
                return {'code': 4, 'msg': '查询失败'}
        except:
            utils.error('execute fail [' + img_file + '] ' + utils.get_trace(),
                        ApiImageDigit)
            return {'code': 5, 'msg': '查询失败'}

        # 组织返回格式
        return {'code': 0, 'msg': 'success', 'data': res}
コード例 #17
0
ファイル: rest.py プロジェクト: mak-454/stacks-usecase
def pred():
    issue = list()
    issue.append(flask.request.json["issue"])
    if not flask.request.json or not "issue" in flask.request.json:
        flask.abort(400)
    labels = infer(issue)
    return flask.jsonify({"label": labels}), 201
コード例 #18
0
    def infer(self, **kwargs):
        with suppress_stdout():
            args = self._default_args(**kwargs)

            ans = infer(args)

        print(ans)
コード例 #19
0
    def execute(self):
        """执行业务逻辑"""
        logger.info(
            'API REQUEST INFO[' + self.request.path + '][' +
            self.request.method + '][' + self.request.remote_ip + '][' +
            str(self.request.arguments) + ']', ApiImageClassification)
        img_file = self.get_argument('img_file', '')
        if img_file == '':
            return {'code': 2, 'msg': 'img_file不能为空'}
        res = {}

        try:
            ret, msg, res = classification_infer.infer(img_file)
            if ret != 0:
                logger.error('execute fail [' + img_file + '] ' + msg,
                             ApiImageClassification)
                return {'code': 4, 'msg': '查询失败'}
        except:
            logger.error(
                'execute fail [' + img_file + '] ' + logger.get_trace(),
                ApiImageClassification)
            return {'code': 5, 'msg': '查询失败'}

        # 组织返回格式
        return {'code': 0, 'msg': 'success', 'data': res}
コード例 #20
0
        def transcriber_worker(rec_path):
            
            wav_obj = wave.open(rec_path)
            if not audio_utils.has_speech(wav_obj):
                #print "no speech"
                return
                
            print "got speech"
            

            # filter
            # normalize volume
            audio_wav_volume_normalized_path = rec_path+"_normalized.wav"        
            #print("Normalizing volume... %s" % (audio_wav_path))
            audio_utils.loud_norm(rec_path, audio_wav_volume_normalized_path)

            # correct volume
            audio_wav_volume_corrected_path = rec_path+"_volume_corrected.wav"        
            #print("Correcting volume...")
            audio_utils.correct_volume(audio_wav_volume_normalized_path, audio_wav_volume_corrected_path)

            # apply bandpass filter
            audio_wav_filtered_path = rec_path+"_filtered.wav"   
            #print("Applying bandpass filter...")
            audio_utils.apply_bandpass_filter(audio_wav_volume_corrected_path, audio_wav_filtered_path)


            start_time = time.time()
            print "t: " + infer.infer(audio_wav_filtered_path, session)
            print "infer took: %.2f sec" % (time.time()-start_time)
コード例 #21
0
def handle_input(json):

    try:
        token = json['token']
    except:
        return emit('receive', '400 : Token Absent')

    if token is not None:
        try:
            token = bytes(token, 'utf-8')
            payload = jwt.decode(token, jwt_secret, algorithms=alg)
            idx = ''.join(
                secrets.choice(string.ascii_uppercase + string.digits)
                for i in range(N))

            inp = 'input/' + idx + '_noisy.wav'
            op_clean = 'output/' + idx + '_clean.wav'
            op_noise = 'output/' + idx + '_noise.wav'
            file_list = [inp, op_clean, op_noise]

            w = open(inp, 'wb')
            w.write(json['blob'])
            w.close()

            infer(model, file_list)

            r = open(op_clean, 'rb')
            data = r.read()
            r.close()

            emit('receive', data)

            if args['prod']:
                os.remove(inp)
                os.remove(op_clean)
                os.remove(op_noise)

        except jwt.ExpiredSignatureError:
            return emit('receive',
                        '401 : Signature expired. Please log in again.')
        except jwt.InvalidTokenError:
            emit('receive', '402 : Invalid token. Please log in again.')
            disconnect()
コード例 #22
0
ファイル: tests.py プロジェクト: CENDARI/editorsnotes
        def testFormat(self):
            # verify initial conditions
            self.assertTrue(hasattr(self, 'test_data'), 'testdata field not set on test object')

            expected = self.test_data['format']
            actual = infer.infer(self.test_data['examples'])

            self.assertEqual(expected,
                             actual,
                             '{0}: Inferred `{1}`!=`{2}`'.format(self.test_data['name'], actual, expected))
コード例 #23
0
ファイル: vqa_main.py プロジェクト: snuspl/vip_pipeline
    def predict(self, input_data):
        data = self.data
        question = input_data["question"]
        vid = input_data["vid"]

        data[0].question = question
        data[0].vid = vid + '_153'  # for alignment

        answer = infer(data)

        return answer
コード例 #24
0
ファイル: experiment.py プロジェクト: mohd1012/WSAT
def case(query, var_map, problem):
    success = 0
    fail = 0
    for i in xrange(0, REPEAT):
        problem.seed = os.urandom(4)
        if infer(query, var_map, problem):
            success += 1
        else:
            fail += 1
    print "{}, {}, {}, {}, {}".format(query[1:], problem.p, problem.max_flips,
                                      success, fail)
コード例 #25
0
def main(args):
    input_img_paths = glob.glob(args.input_dir + "/*")
    sub_img_paths = []
    Path(args.output_dir).mkdir(exist_ok=True)

    for input_img_path in input_img_paths:
        sub_img_path = Path(args.output_dir) / Path(input_img_path).name
        sub_img_path = str(sub_img_path)
        sub_img_paths.append(sub_img_path)

        substract_args = SimpleNamespace(template_dir=args.template_dir,
                                         base_img_path=args.base_img_path,
                                         target_img_path=input_img_path,
                                         sub_path=sub_img_path)
        subtract(substract_args)
    for sub_img_path in sub_img_paths:
        infer_args = SimpleNamespace(target_img_path=sub_img_path,
                                     output_dir=args.output_dir)
        infer(infer_args)
    print("process was completed!")
コード例 #26
0
        def testFormat(self):
            # verify initial conditions
            self.assertTrue(hasattr(self, 'test_data'),
                            'testdata field not set on test object')

            expected = self.test_data['format']
            actual = infer.infer(self.test_data['examples'])

            self.assertEqual(
                expected, actual,
                '{0}: Inferred `{1}`!=`{2}`'.format(self.test_data['name'],
                                                    actual, expected))
コード例 #27
0
ファイル: vqa_main.py プロジェクト: snuspl/vip_pipeline
    def infer(self, data, question, vid):
        #question="Who are the actors in the Friends?"
        #vid="s01e22_02"
        data[0].question = question
        data[0].vid = vid + '_122'  # for alignment

        #print("Question :", question)
        #print("VideoID :", vid)

        ans = infer(data)

        return ans
コード例 #28
0
def main():

    run_type = sys.argv[1]
    print('*' * 120)
    print('*' * 120)
    print('*' * 120)
    print(run_type)

    if run_type == 'train':
        import train
        # os.system('python3 train.py')
        train.run_train()

    elif run_type == 'test':
        import infer
        infer.infer(sys.argv[2])
        infer.get_activations(sys.argv[2])

    else:
        print(
            "To run this script please enter either: 'train' or 'test <x>.png'"
        )
コード例 #29
0
def image_parse(imagePath, exerciseName, idName):
    img = cv2.imread(imagePath)
    imgData = []
    frame_count = 1
    duration = 1
    imgMeta = {
        'total_frames': frame_count,
        'length': duration,
        'size': (img.shape[1], img.shape[0])
    }
    print(imgMeta)
    _, key_coor, key_conf = infer(img)
    imgData.append((1, key_coor, key_conf))
    utils.generateGT(imgData, "%s_img" % exerciseName, idName, meta=imgMeta)
コード例 #30
0
ファイル: app.py プロジェクト: rmori320/gitpages_otameshi
def send():
    if request.method == 'POST':
        img_file = request.files['img_file']
        img = Image.open(img_file)

        img_url = os.path.join(app.config['UPLOAD_FOLDER'], img_file.filename)
        img.save(img_url)

        img_array = np.array(img, dtype=np.float32)
        ans = infer.infer(img_array)
        return render_template('index.html', message=ans, img_url=img_url)

    else:
        return redirect(url_for('index'))
コード例 #31
0
def main():
    args = argparser()
    # json 格式的 ground truth
    gt_dict = utils.read_json(args.gt)
    # gt_dict = utils.read_url_list(args.gt)
    infered_dict_list = infer.infer(gt_dict, args.tool, ak=args.ak, sk=args.sk)
    if args.log:
        utils.logs(infered_dict_list, args.log)
    y_true, y_pred = utils.get_true_pred(gt_dict, infered_dict_list)
    metric = utils.Metrics(y_true, y_pred)

    conf_matrix = metric.confusion_matrix()
    acc = metric.accuracy()
    pulp_recall = metric.pulp_recall()
    pulp_precision = metric.pulp_precision()
    sexy_recall = metric.sexy_recall()
    sexy_precision = metric.sexy_precision()
    normal_recall = metric.normal_recall()
    normal_precision = metric.normal_precision()

    print('\n')
    print('【%s】剑皇测试' % __NAME[args.tool])
    print('~' * 50)
    print('Ground Truth: ')
    print('总样本:      %d' % len(gt_dict))
    print('有效识别样本: %d' % np.sum(conf_matrix))

    print('~' * 50)
    print('测试集分布: ')
    print('%d 个色情样本' % (np.sum(conf_matrix, axis=1)[0]))
    print('%d 个性感样本' % (np.sum(conf_matrix, axis=1)[1]))
    print('%d 个正常样本' % (np.sum(conf_matrix, axis=1)[2]))

    print('~' * 50)
    print('模型指标: ')
    print('accuracy:         %f ' % acc)
    print('pulp_recall:      %f ' % pulp_recall)
    print('pulp_precision:   %f ' % pulp_precision)
    print('sexy_recall:      %f ' % sexy_recall)
    print('sexy_precision:   %f ' % sexy_precision)
    print('normal_recall:    %f ' % normal_recall)
    print('normal_precision: %f ' % normal_precision)

    print('~' * 50)
    print('Confusion Matrix: ')
    print(conf_matrix)
    print('\n')

    if args.vis:
        metric.plot_confusion_matrix()
コード例 #32
0
ファイル: launch.py プロジェクト: divelab/GPT
def infer_single_image(
        gitapp: controller.GetInputTargetAndPredictedParameters):
    """Predicts the labels for a single image."""
    if not gfile.Exists(output_directory()):
        gfile.MakeDirs(output_directory())

    if FLAGS.infer_channel_whitelist is not None:
        infer_channel_whitelist = FLAGS.infer_channel_whitelist.split(',')
    else:
        infer_channel_whitelist = None

    while True:
        infer.infer(
            gitapp=gitapp,
            restore_directory=FLAGS.restore_directory or train_directory(),
            output_directory=output_directory(),
            extract_patch_size=CONCORDANCE_EXTRACT_PATCH_SIZE,
            stitch_stride=CONCORDANCE_STITCH_STRIDE,
            infer_size=FLAGS.infer_size,
            channel_whitelist=infer_channel_whitelist,
            simplify_error_panels=FLAGS.infer_simplify_error_panels,
        )
        if not FLAGS.infer_continuously:
            break
コード例 #33
0
ファイル: func.py プロジェクト: msmauck/stacks-usecase
def handler(ctx, data: io.BytesIO = None):
    outputImage = None
    try:
        # Get the input image as base64
        imageAsBase64String = data.getvalue()
        # Turn the base64 input into a file object
        inputImage = io.BytesIO(imageAsBase64String)
        # Run inference on the image
        outputImage = infer(inputImage)

    except (Exception, ValueError) as ex:
        print(str(ex))

    # Return the output image as an ascii string encoded in base64
    return response.Response(ctx,
                             response_data=str(outputImage, 'ascii'),
                             headers={"Content-Type": "text/html"})
コード例 #34
0
ファイル: main.py プロジェクト: mohd1012/WSAT
def main():
    max_flips = arguments.max_flips if arguments.max_flips else 10000
    seed = os.urandom(4)
    p = arguments.p if arguments.p else 0.5
    (num_lit, clauses, var_map) = parse_cnf(arguments.input)
    problem = Wsat(clauses, num_lit, p, max_flips, seed)
    if arguments.verbose:
        problem.print_info()
    if arguments.infer:
        if problem.solve():
            print "KB is valid."
            if infer(arguments.infer, var_map, problem):
                print "Query is entailed."
            else:
                print "Query is not entailed."
        else:
            print "KB is not valid!"
    else:
        print problem.solve()
コード例 #35
0
ファイル: infer.py プロジェクト: wkentaro/fcn
#!/usr/bin/env python

import os.path as osp
import sys


if __name__ == '__main__':
    here = osp.dirname(osp.abspath(__file__))
    sys.path.insert(0, osp.join(here, '../voc'))

    import infer
    infer.infer(n_class=40)
コード例 #36
0
ファイル: infer.py プロジェクト: wkentaro/fcn
#!/usr/bin/env python

import os.path as osp
import sys


if __name__ == '__main__':
    here = osp.dirname(osp.abspath(__file__))
    sys.path.insert(0, osp.join(here, '../voc'))

    import infer
    infer.infer(n_class=26)
コード例 #37
0
ファイル: tests.py プロジェクト: CENDARI/editorsnotes
 def testAmbg2(self):
     # Note: as described in Issue #5 (https://github.com/jeffreystarr/dateinfer/issues/5), the result
     # should be %d/%m/%Y as the more likely choice. However, at this point, we will allow %m/%d/%Y.
     self.assertIn(infer.infer(['04/12/2012', '05/12/2012', '06/12/2012', '07/12/2012']),
                   ['%d/%m/%Y', '%m/%d/%Y'])
コード例 #38
0
ファイル: tests.py プロジェクト: CENDARI/editorsnotes
 def testAmbg1(self):
     self.assertIn(infer.infer(['1/1/2012']), ['%m/%d/%Y', '%d/%m/%Y'])