Esempio n. 1
0
def uslm_eval_helper(
    expr,
    memo,
    ctrl,
    data_fraction,
    assume_promising,
    data_view,
    memmap_name_template,
    DataView,
    loss_fn,
    true_loss_fn,
):

    use_obj_for_literal_in_memo(expr, data_view, DataView, memo)
    versions = git_versions()
    logger.info('GIT VERSIONS: %s' % str(versions))

    def exception_thrower():
        argdict = pyll.rec_eval(expr, memo=memo, print_node_on_error=False)
        visitor = PrimalVisitor(
            pipeline=argdict['pipeline'],
            ctrl=argdict['ctrl'],
            data_view=argdict['data_view'],
            max_n_features=argdict['max_n_features'],
            # TODO: just pass memmap_name directly
            memmap_name=memmap_name_template %
            (os.getpid(), np.random.randint(10000)),
            thresh_rank=1,
            optimize_l2_reg=True,
            batched_lmap_speed_thresh=argdict['batched_lmap_speed_thresh'],
            badfit_thresh=None,
            batchsize=argdict['batchsize'],
        )

        protocol_iter = argdict['data_view'].protocol_iter(visitor)
        msg, model = protocol_iter.next()
        assert msg == 'model validation complete'

        # -- save the loss, but don't save attachments yet.
        rdict = visitor.hyperopt_rval()
        rdict['loss'] = loss_fn(visitor, argdict['bagging_fraction'])
        rdict['in_progress'] = True
        rdict['status'] = hyperopt.STATUS_OK
        argdict['ctrl'].checkpoint(rdict)

        if assume_promising:
            promising = True
        else:
            promising = view2_worth_calculating(loss=rdict['loss'],
                                                ctrl=argdict['ctrl'],
                                                thresh_loss=1.0,
                                                thresh_rank=1)

        logger.info('Promising: %s' % promising)
        if promising:
            msg, model2 = protocol_iter.next()
            assert msg == 'model testing complete'
            rdict = visitor.hyperopt_rval()
            rdict['loss'] = loss_fn(visitor, argdict['bagging_fraction'])
            rdict['true_loss'] = true_loss_fn(visitor)
            visitor.attach_obj_results()
        else:
            logger.warn('Not testing unpromising model %s' % str(model))
            del rdict['in_progress']
        return visitor, rdict

    try:
        visitor, rdict = call_catching_pipeline_errors(exception_thrower)
    except USLM_Exception, e:
        exc, rdict = e.args
        logger.info('job failed: %s: %s' % (type(e), exc))
Esempio n. 2
0
def slm_visitor_lfw(expr, memo, ctrl,
    maybe_test_view2=True,
    max_n_per_class=None,
    comparison_names=('mult', 'absdiff', 'sqrtabsdiff', 'sqdiff'),
    assume_promising=False,
    foobar_trace=True,
    foobar_trace_target=None,
    ):
    # -- possibly enable computation tracing
    foobar.reset_trace()
    foobar.trace_enabled = foobar_trace
    if foobar_trace_target:
        foobar.trace_verify = True
        foobar.set_trace_target(foobar_trace_target)
    slm_visitor_esvc._curdb = dbname # XXX tids are only unique within db

    versions = git_versions()
    info('GIT VERSIONS: %s' % str(versions))

    data_view = lfw.view.Aligned(
            x_dtype='uint8',
            max_n_per_class=max_n_per_class,
            )

    use_obj_for_literal_in_memo(expr, data_view, DataViewPlaceHolder, memo)

    def loss_fn(s, rdct, bagging_fraction):
        """
        bagging_fraction - float
            If the function measures the loss within the ensemble (loss)
            as well as the loss without the ensemble (loss_last_member) then
            this value interpolates between boosting (0.0) and bagging (1.0).

        """
        # -- this is the criterion we minimize during model search
        norm_key = s.norm_key('devTrain')
        task_name = 'devTrain'
        dct = s._results['train_image_match_indexed'][norm_key][task_name]
        loss = (bagging_fraction * dct['valid_error_no_ensemble']
                + (1 - bagging_fraction) * dct['valid_error'])
        rdct['loss'] = loss
        rdct['status'] = STATUS_OK

    def foo():
        argdict = pyll.rec_eval(expr, memo=memo, print_node_on_error=False)
        visitor = ESVC_SLM_Visitor(pipeline=argdict['pipeline'],
                    ctrl=argdict['ctrl'],
                    data_view=argdict['data_view'],
                    max_n_features=argdict['max_n_features'],
                    memmap_name='%s_%i' % (__name__, os.getpid()),
                    svm_crossvalid_max_evals=50,
                    optimize_l2_reg=True,
                    batched_lmap_speed_thresh=argdict[
                        'batched_lmap_speed_thresh'],
                    comparison_names=comparison_names,
                    batchsize=argdict['batchsize'],
                    )
        # -- drive the visitor according to the protocol of the data set
        protocol_iter = argdict['data_view'].protocol_iter(visitor)
        msg, model = protocol_iter.next()
        assert msg == 'model validation complete'

        # -- save the loss, but don't save attachments yet.
        rdict = visitor.hyperopt_rval(save_grams=False)
        rdict['in_progress'] = True
        loss_fn(visitor, rdict, argdict['bagging_fraction'])
        argdict['ctrl'].checkpoint(rdict)

        if assume_promising:
            promising = True
        else:
            promising = view2_worth_calculating(
                loss=rdict['loss'],
                ctrl=argdict['ctrl'],
                thresh_loss=1.0,
                thresh_rank=1)


        info('Promising: %s' % promising)

        if maybe_test_view2:
            if promising:
                info('Disabling trace verification for view2')
                foobar.trace_verify = False
                msg = protocol_iter.next()
                assert msg == 'model testing complete'
            else:
                warn('Not testing unpromising model %s' % str(model))
        else:
            warn('Skipping view2 stuff for model %s' % str(model))
        rdict = visitor.hyperopt_rval(save_grams=promising)
        loss_fn(visitor, rdict, argdict['bagging_fraction'])
        return visitor, rdict

    try:
        visitor, rdict = call_catching_pipeline_errors(foo)
    except USLM_Exception, e:
        exc, rdict = e.args
        print ('job failed: %s: %s' % (type(e), exc))
def uslm_eval_helper(
    expr,
    memo,
    ctrl,
    data_fraction,
    assume_promising,
    data_view,
    memmap_name_template,
    DataView,
    loss_fn,
    true_loss_fn,
    ):

    use_obj_for_literal_in_memo(expr, data_view, DataView, memo)
    versions = git_versions()
    logger.info('GIT VERSIONS: %s' % str(versions))

    def exception_thrower():
        argdict = pyll.rec_eval(expr, memo=memo, print_node_on_error=False)
        visitor = PrimalVisitor(
            pipeline=argdict['pipeline'],
            ctrl=argdict['ctrl'],
            data_view=argdict['data_view'],
            max_n_features=argdict['max_n_features'],
            # TODO: just pass memmap_name directly
            memmap_name=memmap_name_template % (os.getpid(),
                                           np.random.randint(10000)),
            thresh_rank=1,
            optimize_l2_reg=True,
            batched_lmap_speed_thresh=argdict[
                'batched_lmap_speed_thresh'],
            badfit_thresh=None,
            batchsize=argdict['batchsize'],
            )

        protocol_iter = argdict['data_view'].protocol_iter(visitor)
        msg, model = protocol_iter.next()
        assert msg == 'model validation complete'

        # -- save the loss, but don't save attachments yet.
        rdict = visitor.hyperopt_rval()
        rdict['loss'] = loss_fn(visitor, argdict['bagging_fraction'])
        rdict['in_progress'] = True
        rdict['status'] = hyperopt.STATUS_OK
        argdict['ctrl'].checkpoint(rdict)

        if assume_promising:
            promising = True
        else:
            promising = view2_worth_calculating(
                loss=rdict['loss'],
                ctrl=argdict['ctrl'],
                thresh_loss=1.0,
                thresh_rank=1)

        logger.info('Promising: %s' % promising)
        if promising:
            msg, model2 = protocol_iter.next()
            assert msg == 'model testing complete'
            rdict = visitor.hyperopt_rval()
            rdict['loss'] = loss_fn(visitor, argdict['bagging_fraction'])
            rdict['true_loss'] = true_loss_fn(visitor)
            visitor.attach_obj_results()
        else:
            logger.warn('Not testing unpromising model %s' % str(model))
            del rdict['in_progress']
        return visitor, rdict

    try:
        visitor, rdict = call_catching_pipeline_errors(exception_thrower)
    except USLM_Exception, e:
        exc, rdict = e.args
        logger.info('job failed: %s: %s' % (type(e), exc))