# Train on all datasets in succession
        # Print settings headers to raw results file
        printF("# " + str(parameters), experimentId, currentIteration)

        # Compute batching variables
        repetition_size = dataset.lengths[dataset.TRAIN]
        next_testing_threshold = parameters['test_interval'] * repetition_size

        dataset_data = None
        label_index = None
        if (parameters['simple_data_loading']):
            dataset_data, label_index = load_data(parameters, processor,
                                                  dataset)

        for r in range(parameters['repetitions']):
            stats = set_up_statistics(dataset.output_dim, model.n_max_digits,
                                      dataset.oneHot.keys())
            total_error = 0.0
            # Print repetition progress and save to raw results file
            printF("Batch %d (repetition %d of %d, dataset 1 of 1) (samples processed after batch: %d)" % \
                    (r+1,r+1,parameters['repetitions'],(r+1)*repetition_size), experimentId, currentIteration)
            currentIteration = r + 1
            currentDataset = 1

            # Train model per minibatch
            k = 0
            printedProgress = -1
            while k < repetition_size:
                profiler.start('train batch')
                profiler.start('get train batch')
                data, target, test_labels, target_expressions, nrSamples, health = \
                    get_batch(0, dataset, model, dataset_data, label_index,
def test(model,
         dataset,
         parameters,
         max_length,
         print_samples=False,
         sample_size=False,
         returnTestSamples=False):
    # Test
    printF("Testing...", experimentId, currentIteration)

    total = dataset.lengths[dataset.TEST]
    printing_interval = 1000
    if (parameters['max_dataset_size'] is not False):
        printing_interval = 100
    elif (sample_size != False):
        total = sample_size

    # Set up statistics
    stats = set_up_statistics(dataset.output_dim, model.n_max_digits,
                              dataset.oneHot.keys())
    total_labels_used = {}

    # Predict
    printed_samples = False
    totalError = 0.0
    k = 0
    testSamples = []
    while k < total:
        # Get data from batch
        test_data, test_targets, test_labels, test_expressions, \
            nrSamples, health = get_batch(1, dataset, model, dataset_data, label_index, debug=parameters['debug'])

        predictions, other = model.predict(test_data,
                                           test_targets,
                                           nrSamples=nrSamples)
        totalError += other['summed_error']

        profiler.start("test batch stats")
        stats, _, _ = model.batch_statistics(stats,
                                             predictions,
                                             test_labels,
                                             None,
                                             other,
                                             nrSamples,
                                             dataset,
                                             None,
                                             None,
                                             parameters,
                                             data=test_data)

        for j in range(nrSamples):
            if (test_labels[j] not in total_labels_used):
                total_labels_used[test_labels[j]] = True

            # Save predictions to testSamples
            if (returnTestSamples):
                strData = map(
                    lambda x: dataset.findSymbol[x],
                    np.argmax(test_data[j, :, :model.data_dim],
                              len(test_data.shape) - 2))
                strPrediction = dataset.findSymbol[predictions[j]]
                testSamples.append((strData, strPrediction))

        # Print samples
        if (print_samples and not printed_samples):
            for i in range(nrSamples):
                prefix = "# "
                printF(
                    prefix + "Data          1: %s" % "".join(
                        (map(lambda x: dataset.findSymbol[x],
                             np.argmax(test_data[i],
                                       len(test_data.shape) - 2)))),
                    experimentId, currentIteration)
                printF(
                    prefix +
                    "Prediction    1: %s" % dataset.findSymbol[predictions[i]],
                    experimentId, currentIteration)
                printF(
                    prefix +
                    "Used label    1: %s" % dataset.findSymbol[test_labels[i]],
                    experimentId, currentIteration)
            printed_samples = True

        if (stats['prediction_size'] % printing_interval == 0):
            printF("# %d / %d" % (stats['prediction_size'], total),
                   experimentId, currentIteration)
        profiler.stop("test batch stats")

        k += nrSamples

    profiler.profile()

    print("Test: %d" % k)
    printF("Total testing error: %.2f" % totalError, experimentId,
           currentIteration)
    printF("Mean testing error: %.8f" % (totalError / float(k)), experimentId,
           currentIteration)

    stats = model.total_statistics(stats,
                                   dataset,
                                   parameters,
                                   total_labels_used=total_labels_used,
                                   digits=False)
    print_stats(stats, parameters)

    if (returnTestSamples):
        return stats, testSamples
    else:
        return stats
def test(model,
         dataset,
         parameters,
         max_length,
         print_samples=False,
         sample_size=False,
         returnTestSamples=False):
    # Test
    print("Testing...")

    total = dataset.lengths[dataset.TEST]
    printing_interval = 1000
    if (parameters['max_testing_size'] is not False):
        total = parameters['max_testing_size']
        printing_interval = 100
    elif (sample_size != False):
        total = sample_size

    # Set up statistics
    stats = set_up_statistics(dataset.output_dim, model.n_max_digits)
    total_labels_used = {k: 0
                         for k in range(30)}

    # Predict
    printed_samples = False
    totalError = 0.0
    k = 0
    testSamples = []
    while k < total:
        # Get data from batch
        test_data, test_targets, _, test_expressions, \
            nrSamples = get_batch(False, dataset, model, debug=parameters['debug'])

        predictions, other = model.predict(test_data,
                                           test_targets,
                                           nrSamples=nrSamples)
        totalError += other['error']

        profiler.start("test batch stats")
        stats, _ = model.batch_statistics(
            stats,
            predictions,
            test_expressions,
            other,
            nrSamples,
            dataset,
            testInDataset=parameters['test_in_dataset'])

        for j in range(nrSamples):
            total_labels_used[test_expressions[j]] = True

            # Save predictions to testSamples
            if (returnTestSamples):
                strData = map(
                    lambda x: dataset.findSymbol[x],
                    np.argmax(test_targets[j, :, :model.data_dim],
                              len(test_targets.shape) - 2))
                strPrediction = map(lambda x: dataset.findSymbol[x],
                                    predictions[j])
                testSamples.append((strData, strPrediction))

        # Print samples
        if (print_samples and not printed_samples):
            for i in range(nrSamples):
                prefix = ""
                print(prefix + "Data          1: %s" % "".join((map(
                    lambda x: dataset.findSymbol[x],
                    np.argmax(test_targets[i, :, :model.data_dim],
                              len(test_data.shape) - 2)))))
                print(prefix + "Prediction    1: %s" % "".join(
                    map(lambda x: dataset.findSymbol[x], predictions[i])))
                print(prefix + "Used label    1: %s" % test_expressions[i])
            printed_samples = True

        if (stats['prediction_size'] % printing_interval == 0):
            print("# %d / %d" % (stats['prediction_size'], total))
        profiler.stop("test batch stats")

        k += nrSamples

    profiler.profile()

    print("Total testing error: %.2f" % totalError)

    stats = model.total_statistics(stats, total_labels_used=total_labels_used)
    print_stats(stats, parameters)

    if (returnTestSamples):
        return stats, testSamples
    else:
        return stats
Beispiel #4
0
def predictInterventionSample():
    """
    Takes a POST variable 'sample' containing a data sample for this model,
    without the '='-symbol.
    Provide intervention as a string symbol.
    """
    response = {
        'success': False
    }
    if ('sample1' in request.form):
        sample1 = request.form['sample1']
        sample2 = request.form['sample2']
        response['sample1'] = sample1
        response['sample2'] = sample2

        interventionLocations = np.zeros((2, data['rnn'].minibatch_size),
                                         dtype='int32')
        interventionLocations[0, 0] = int(request.form['interventionLocation'])
        interventionLocations[1, 0] = interventionLocations[0, 0] + 1
        intervention = request.form['intervention']

        if (data['dataset'].dataset_type ==
                GeneratedExpressionDataset.DATASET_SEQ2NDMARKOV):
            datasample, _, _, _, _ = data['dataset'].processor(
                ";".join([sample1, sample2, "1"]), [], [], [], [])
        else:
            datasample, _, _, _, _ = data['dataset'].processor(
                ";".join([sample1, sample2]), [], [], [], [])
        response['data'] = datasample[0].tolist()

        datasample = data['dataset'].fill_ndarray(datasample, 1).reshape(
            (1, datasample[0].shape[0], datasample[0].shape[1]))
        label = copy.deepcopy(datasample)
        label[0, interventionLocations[0, 0]] = np.zeros(
            (datasample[0].shape[1]), dtype='float32')
        # Only supports interventions on the first sample
        label[0, interventionLocations[0, 0],
              data['dataset'].oneHot[intervention]] = 1.0
        if (len(datasample) < data['rnn'].minibatch_size):
            missing_datapoints = data['rnn'].minibatch_size - datasample.shape[
                0]
            datasample = np.concatenate(
                (datasample,
                 np.zeros((missing_datapoints, datasample.shape[1],
                           datasample.shape[2]),
                          dtype='float32')),
                axis=0)
            label = np.concatenate(
                (label,
                 np.zeros((missing_datapoints, datasample.shape[1],
                           datasample.shape[2]),
                          dtype='float32')),
                axis=0)
        prediction, _ = data['rnn'].predict(
            datasample,
            label=label,
            interventionLocations=interventionLocations)

        if (not data['rnn'].only_cause_expression):
            response['prediction1'] = prediction[0][0].tolist()
            response['prediction2'] = prediction[1][0].tolist()
        else:
            response['prediction1'] = prediction[0].tolist()

        response['prediction1Pretty'] = ""
        for index in response['prediction1']:
            if (index == data['dataset'].EOS_symbol_index):
                response['prediction1Pretty'] += "_"
            else:
                response['prediction1Pretty'] += data['dataset'].findSymbol[
                    index]

        if (not data['rnn'].only_cause_expression):
            response['prediction2Pretty'] = ""
            for index in response['prediction2']:
                if (index == data['dataset'].EOS_symbol_index):
                    response['prediction2Pretty'] += "_"
                else:
                    response['prediction2Pretty'] += data[
                        'dataset'].findSymbol[index]

        if (data['rnn'].only_cause_expression is not False):
            prediction = [prediction]

        response['success'] = True

        test_n = 1
        stats, _ = data['rnn'].batch_statistics(
            set_up_statistics(data['rnn'].decoding_output_dim,
                              data['rnn'].n_max_digits),
            prediction, [(sample1, sample2)],
            interventionLocations, {},
            test_n,
            data['dataset'],
            labels_to_use=[(sample1, sample2)])
        response['stats'] = {}
        response['stats']['correct'] = stats['correct']
        response['stats']['valid'] = stats['valid']

    return flask.jsonify(response)
        else:
            raise ValueError(
                "Loading pretrained model failed: wrong variables supplied!")

    # Train on all datasets in succession
    # Print settings headers to raw results file
    print("# " + str(parameters))

    # Compute batching variables
    repetition_size = dataset.lengths[dataset.TRAIN]
    if (parameters['max_training_size'] is not False):
        repetition_size = min(parameters['max_training_size'], repetition_size)
    next_testing_threshold = parameters['test_interval'] * repetition_size

    for r in range(parameters['repetitions']):
        stats = set_up_statistics(dataset.output_dim, model.n_max_digits)
        total_error = 0.0
        # Print repetition progress and save to raw results file
        print("Batch %d (repetition %d of %d, dataset 1 of 1) (samples processed after batch: %d)" % \
                (r+1,r+1,parameters['repetitions'],(r+1)*repetition_size))

        # Train model per minibatch
        k = 0
        printedProgress = -1
        while k < repetition_size:
            profiler.start('train batch')
            profiler.start('get train batch')
            data, target, _, target_expressions, nrSamples = \
                get_batch(True, dataset, model,
                          debug=parameters['debug'])
            profiler.stop('get train batch')
Beispiel #6
0
    def testFsubsBatchStats(self):
        # Label, prediction, first_error, recovery_index, no_recovery_index, errors
        samples = [('131', '121', 1, 1, None, 1),
                   ('131', '131', -1, None, None, 0),
                   ('131', '221', 0, 0, None, 2),
                   ('131', '223', 0, None, 0, 3),
                   ('131', '134', 2, None, 2, 1)]

        for i, (label, pred, tFirstError, tRecovIndex, tNoRecovIndex,
                tErrors) in enumerate(samples):
            stats = set_up_statistics(10, 20, [])
            comparison = map(lambda (p, l): p == l, zip(pred, label))
            recovery_index = None
            no_recovery_index = None

            i = 0
            errors = 0
            first_error = -1
            correct_after_first_error = False
            for i, v in enumerate(comparison):
                if (v):
                    stats['digit_2_correct'][i] += 1.0
                    if (first_error != -1):
                        correct_after_first_error = True
                else:
                    errors += 1
                    if (first_error == -1):
                        first_error = i
                    elif (correct_after_first_error):
                        correct_after_first_error = False
                stats['digit_2_prediction_size'][i] += 1

            if (first_error < 8):
                stats['first_error'][first_error] += 1.0
            else:
                stats['first_error'][8] += 1.0

            if (first_error != -1):
                if (correct_after_first_error
                        and first_error < len(comparison) - 1):
                    stats['recovery'][first_error] += 1.0
                    recovery_index = first_error
                else:
                    stats['no_recovery'][first_error] += 1.0
                    no_recovery_index = first_error

            if (errors > 8):
                errors = 8
            stats['error_size'][errors] += 1.0
            stats['prediction_size'] += 1.0

            self.assertEqual(
                tFirstError, first_error,
                "(%d) first_error: is %d, should be %d" %
                (i, first_error, tFirstError))
            self.assertEqual(
                tRecovIndex, recovery_index,
                "(%d) recovery_index: is %s, should be %s" %
                (i, str(recovery_index), str(tRecovIndex)))
            self.assertEqual(
                tNoRecovIndex, no_recovery_index,
                "(%d) no_recovery_index: is %s, should be %s" %
                (i, str(no_recovery_index), str(tNoRecovIndex)))
            self.assertEqual(
                tErrors, errors,
                "(%d) errors: is %d, should be %d" % (i, errors, tErrors))
Beispiel #7
0
def test(model, dataset, parameters, max_length, print_samples=False, sample_size=False):
    # Test
    print("Testing...");
        
    total = dataset.lengths[dataset.TEST];
    printing_interval = 100;
    if (parameters['max_testing_size'] is not False):
        total = parameters['max_testing_size'];
    elif (sample_size != False):
        total = sample_size;
    
    # Set up statistics
    stats = set_up_statistics(dataset.output_dim);
    
    # Predict
    printed_samples = False;
    batch_range = range(0,total,model.minibatch_size);
    for _ in batch_range:
        # Get data from batch
        test_data, test_targets, test_labels, test_expressions, \
            possibleInterventions, interventionLocation = get_batch(False, dataset, model, 
                                                                    parameters['intervention_range'], 
                                                                    max_length, debug=parameters['debug'],
                                                                    base_offset=parameters['intervention_base_offset']);
        test_n = model.minibatch_size;
        
        test_targets, _, interventionLocation, _ = \
            dataset.insertInterventions(test_targets, test_expressions, 
                                        interventionLocation, 
                                        possibleInterventions);
        
        prediction, other = model.predict(test_data, label=test_targets, 
                                          interventionLocation=interventionLocation);
        
        # Print samples
        if (print_samples and not printed_samples):
            for i in range(prediction.shape[0]):
                print("# Input: %s" % "".join((map(lambda x: dataset.findSymbol[x], np.argmax(test_data[i],len(test_data.shape)-2)))));
                print("# Label: %s" % "".join((map(lambda x: dataset.findSymbol[x], np.argmax(test_targets[i],len(test_data.shape)-2)))));
                print("# Output: %s" % "".join(map(lambda x: dataset.findSymbol[x], prediction[i])));
            printed_samples = True;
        
        stats = model.batch_statistics(stats, prediction, test_labels, 
                                       test_targets, expressions, 
                                       other,
                                       test_n, dataset, 
                                       eos_symbol_index=dataset.EOS_symbol_index);
    
        if (stats['prediction_size'] % printing_interval == 0):
            print("# %d / %d" % (stats['prediction_size'], total));
    
    stats = model.total_statistics(stats);
    
    # Print statistics
    stats_str = str_statistics(0, stats['score'], 
                               digit_score=stats['digit_score'], 
                               prediction_size_histogram=\
                                stats['prediction_size_histogram']);
    print(stats_str);
    
    return stats;
def test(model, dataset, parameters, max_length, base_offset, intervention_range, print_samples=False, 
         sample_size=False, homogeneous=False, returnTestSamples=False):
    # Test
    print("Testing...");
        
    total = dataset.lengths[dataset.TEST];
    printing_interval = 1000;
    if (parameters['max_testing_size'] is not False):
        total = parameters['max_testing_size'];
        printing_interval = 100;
    elif (sample_size != False):
        total = sample_size;
    
    # Set up statistics
    stats = set_up_statistics(dataset.data_dim, model.n_max_digits);
    
    # Predict
    printed_samples = False;
    totalError = 0.0;
    k = 0;
    testSamples = [];
    precisions = [];
    digit_precisions = [];
    while k < total:
        # Get data from batch
        test_data, test_expressions = get_batch(False, dataset, model, 
                                                intervention_range, 
                                                max_length, debug=parameters['debug'],
                                                base_offset=base_offset);
        
        predictions, precision, digit_precision, error = model.predict(test_data); 
        precisions.append(precision);
        digit_precisions.append(digit_precision);
        totalError += error;
        
        if (parameters['only_cause_expression']):
            prediction_1 = predictions;
            predictions = [predictions];
        else:
            prediction_1 = predictions[0];
            prediction_2 = predictions[1];
        
        # Print samples
        if (print_samples and not printed_samples):
            for i in range(model.minibatch_size):
                prefix = "# ";
                if (parameters['only_cause_expression'] is not False):
                    print(prefix + "Data      : %s" % "".join((map(lambda x: dataset.findSymbol[x], 
                                                       np.argmax(test_data[i],len(test_data.shape)-2)))));
                    print(prefix + "Prediction: %s" % "".join(map(lambda x: dataset.findSymbol[x], prediction_1[i])));
                else:
                    print(prefix + "Data       1: %s" % "".join((map(lambda x: dataset.findSymbol[x], 
                                                       np.argmax(test_data[i,:,:model.data_dim/2],len(test_data.shape)-2)))));
                    print(prefix + "Prediction 1: %s" % "".join(map(lambda x: dataset.findSymbol[x], prediction_1[i])));
                    print(prefix + "Data       2: %s" % "".join((map(lambda x: dataset.findSymbol[x], 
                                                       np.argmax(test_data[i,:,model.data_dim/2:],len(test_data.shape)-2)))));
                    print(prefix + "Prediction 2: %s" % "".join(map(lambda x: dataset.findSymbol[x], prediction_2[i])));
            printed_samples = True;

        if (k % printing_interval == 0):
            print("# %d / %d" % (stats['prediction_size'], total));
        
        k += model.minibatch_size;
    
    profiler.profile();
    
    print("Total testing error: %.2f" % totalError);
    
    print_stats(np.mean(precisions), np.mean(digit_precisions));
    
    if (returnTestSamples):
        return stats, testSamples;
    else:
        return stats;
 
 # Train on all datasets in succession
 # Print settings headers to raw results file
 print("# " + str(parameters));
 
 # Compute batching variables
 repetition_size = dataset.lengths[dataset.TRAIN];
 if (parameters['max_training_size'] is not False):
     repetition_size = min(parameters['max_training_size'],repetition_size);
 next_testing_threshold = parameters['test_interval'] * repetition_size;
 
 
 
 intervention_locations_train = {k: 0 for k in range(model.n_max_digits)};
 for r in range(parameters['repetitions']):
     stats = set_up_statistics(dataset.data_dim, model.n_max_digits);
     total_error = 0.0;
     # Print repetition progress and save to raw results file
     print("Batch %d (repetition %d of %d, dataset 1 of 1) (samples processed after batch: %d)" % \
             (r+1,r+1,parameters['repetitions'],(r+1)*repetition_size));
     
     # Train model per minibatch
     k = 0;
     printedProgress = -1;
     while k < repetition_size:
         profiler.start('train batch');
         profiler.start('get train batch');
         data, target_expressions = \
             get_batch(True, dataset, model, 
                       parameters['intervention_range'], model.n_max_digits, 
                       debug=parameters['debug'],