コード例 #1
0
    def cross_validate(self, learner, x, labels):
        if not has_func(learner, "fit") or not has_func(learner, "predict"):
            raise ValueError("Learner doesn't have fit(x) or predict(x) functions implemented")
        train_agg = {"accuracy": 0.0, "precision": 0.0, "recall": 0.0, "f-1": 0.0}
        val_agg = {"accuracy": 0.0, "precision": 0.0, "recall": 0.0, "f-1": 0.0}
        train_scores, val_scores = [], []
        for fold in range(self.k):
            training, val = self._partition(x, fold)
            training_labels, val_labels = self._partition(labels, fold)
            learner.fit(training)
            training_predicted = learner.predict(training)
            val_predicted = learner.predict(val)

            # print("Training: {}\nVal: {}\nTLabels: {}\nVlabels: {}".format(training, val, training_labels, val_labels))
            # print("Training predicted", len(training_predicted), training_predicted)
            # print("Validation predicted", len(val_predicted), val_predicted)

            acc, (p, r, f1) = accuracy(training_labels, training_predicted), get_metrics(training_labels, training_predicted, class_label=1)
            scores = {"accuracy": acc, "precision": p, "recall": r, "f-1": f1}
            train_scores.append(scores)
            self._update_scores(scores, train_agg)

            acc, (p, r, f1) = accuracy(val_labels, val_predicted), get_metrics(val_labels, val_predicted, class_label=1)
            scores = {"accuracy": acc, "precision": p, "recall": r, "f-1": f1}
            val_scores.append(scores)
            self._update_scores(scores, val_agg)
        return self._aggregate_scores(train_agg, self.k), self._aggregate_scores(val_agg, self.k), train_scores, val_scores
コード例 #2
0
ファイル: model.py プロジェクト: chmod644/tgs_salt_model
def compile_model(model,
                  optimizer='adam',
                  loss='bce-dice',
                  threshold=0.5,
                  dice=False,
                  weight_decay=0.0,
                  exclude_bn=True,
                  deep_supervised=False):
    if loss == 'bce':
        _loss = weighted_binary_crossentropy
    elif loss == 'bce-dice':
        _loss = weighted_bce_dice_loss
    elif loss == 'lovasz':
        _loss = weighted_lovasz_hinge
    elif loss == 'lovasz-dice':
        _loss = weighted_lovasz_dice_loss
    elif loss == 'lovasz-inv':
        _loss = weighted_lovasz_hinge_inversed
    elif loss == 'lovasz-double':
        _loss = weighted_lovasz_hinge_double

    if weight_decay != 0.0:
        _l2_loss = l2_loss(weight_decay, exclude_bn)
        loss = lambda true, pred: _loss(true, pred) + _l2_loss
    else:
        loss = _loss

    if optimizer == ('msgd'):
        optimizer = optimizers.SGD(momentum=0.9)

    if not deep_supervised:
        model.compile(optimizer=optimizer,
                      loss=loss,
                      metrics=get_metrics(threshold))
    else:
        loss_pixel = loss_noempty(loss)
        losses = {
            'output_final': loss,
            'output_pixel': loss_pixel,
            'output_image': bce_with_logits
        }
        loss_weights = {
            'output_final': 1.0,
            'output_pixel': 0.5,
            'output_image': 0.1
        }
        metrics = {
            'output_final': get_metrics(threshold),
            'output_pixel': get_metrics(threshold),
            'output_image': accuracy_with_logits
        }
        model.compile(optimizer=optimizer,
                      loss=losses,
                      loss_weights=loss_weights,
                      metrics=metrics)
    return model
コード例 #3
0
ファイル: run_al.py プロジェクト: beckdaniel/uncertainty_qe
def train_al_and_report(model_name, kernel, warp, ard):
    dataset_dir = os.path.join(MODEL_DIR, DATASET)
    try: 
        os.makedirs(dataset_dir)
    except OSError:
        print "skipping output folder"
    for fold in xrange(1):
        fold_dir = os.path.join(SPLIT_DIR, DATASET, str(fold))
        train_data = np.loadtxt(os.path.join(fold_dir, 'train'))
        test_data = np.loadtxt(os.path.join(fold_dir, 'test'))
        params_file = None
        output_dir = os.path.join(dataset_dir, str(fold))
        try: 
            os.makedirs(output_dir)
        except OSError:
            print "skipping output folder"
        if ard:
            iso_dir = output_dir.replace('True', 'False')
            params_file = os.path.join(iso_dir, 'params')
            
        # Split train into train and pool
        pool_data = train_data[50:]
        train_data = train_data[:50]

        metrics_list = []
        while True:
            # Train gp
            gp = util.train_gp_model(train_data, kernel, warp, ard, params_file)
            
            # Get metrics on test
            metrics = util.get_metrics(gp, test_data)
            metrics_list.append([metrics[0], metrics[1], metrics[2][0],
                                 metrics[2][1], metrics[3]])
            
            # Predict pool and select instance in pool with higher variance
            new_instance, new_i = util.query(gp, pool_data)

            # Update train and pool
            train_data = np.append(train_data, [new_instance], axis=0)
            pool_data = np.delete(pool_data, (new_i), axis=0)
            
            #if pool_data.shape[0] == 0:
            #    break
            if train_data.shape[0] == 500:
                break

            print pool_data.shape
            gc.collect(2)

        # Final metrics on full train set (sanity check)
        gp = util.train_gp_model(train_data, kernel, warp, ard, params_file)
        metrics = util.get_metrics(gp, test_data)
        metrics_list.append([metrics[0], metrics[1], metrics[2][0],
                             metrics[2][1], metrics[3]])

        util.save_metrics_list(metrics_list, os.path.join(output_dir, 'metrics'))
コード例 #4
0
def main():
    args = setup_argparser().parse_args()

    filename = args.file
    max_depth = args.max_depth
    min_size = args.min_size
    validate = args.validate

    x, labels = import_file(filename)
    dt = DecisionTree(max_depth=max_depth, min_size=min_size)
    tree = dt.fit(x)
    TreeNode.show_tree(tree)
    predicted_labels = dt.predict(x)
    p, r, f1 = get_metrics(labels, predicted_labels, class_label=1)
    acc = accuracy(labels, predicted_labels)
    print("Naive results")
    print("Accuracy: {}, Precision: {}, Recall: {}, F-1: {}".format(
        acc, p, r, f1))

    if validate:
        ten_cv = CrossValidation(k=10)
        dt = DecisionTree(max_depth=max_depth, min_size=min_size)
        train_scores, val_scores, *_ = ten_cv.cross_validate(dt, x, labels)
        print("10-fold cross validation")
        print("Training scores: {0}\nValidation scores: {1}".format(
            train_scores, val_scores))
    return
コード例 #5
0
ファイル: pipeline.py プロジェクト: beckdaniel/uncertainty_qe
def train_and_report(model_name, kernel, warp, ard):
    dataset_dir = os.path.join(MODEL_DIR, DATASET)
    try: 
        os.makedirs(dataset_dir)
    except OSError:
        print "skipping output folder"
    for fold in xrange(10):
        fold_dir = os.path.join(SPLIT_DIR, DATASET, str(fold))
        train_data = np.loadtxt(os.path.join(fold_dir, 'train'))
        test_data = np.loadtxt(os.path.join(fold_dir, 'test'))
        params_file = None
        output_dir = os.path.join(dataset_dir, str(fold))
        try: 
            os.makedirs(output_dir)
        except OSError:
            print "skipping output folder"

        if ard:
            iso_dir = output_dir.replace('True', 'False')
            params_file = os.path.join(iso_dir, 'params')
        gp = util.train_gp_model(train_data, kernel, warp, ard, params_file)
        util.save_parameters(gp, os.path.join(output_dir, 'params'))
        util.save_gradients(gp, os.path.join(output_dir, 'grads'))
        metrics = util.get_metrics(gp, test_data)



        util.save_metrics(metrics, os.path.join(output_dir, 'metrics'))
        util.save_cautious_curves(gp, test_data, os.path.join(output_dir, 'curves'))
        util.save_predictions(gp, test_data, os.path.join(output_dir, 'preds'))

        asym_metrics = util.get_asym_metrics(gp, test_data)
        util.save_asym_metrics(asym_metrics, os.path.join(output_dir, 'asym_metrics'))
        gc.collect(2) # buggy GPy has allocation cycles...
コード例 #6
0
ファイル: pipeline.py プロジェクト: beckdaniel/uncertainty_qe
def train_and_report(model_name, kernel, warp, ard, likelihood='gaussian'):
    dataset_dir = os.path.join(MODEL_DIR, DATASET)
    try: 
        os.makedirs(dataset_dir)
    except OSError:
        print "skipping output folder"
    for fold in xrange(10):
        fold_dir = os.path.join(SPLIT_DIR, DATASET, str(fold))
        train_data = np.loadtxt(os.path.join(fold_dir, 'train'))
        test_data = np.loadtxt(os.path.join(fold_dir, 'test'))
        output_dir = os.path.join(dataset_dir, str(fold))
        params_file = None
        if ard:
            iso_dir = output_dir.replace('True', 'False')
            params_file = os.path.join(iso_dir, 'params')
        gp = util.train_gp_model(train_data, kernel, warp, ard, params_file, likelihood=likelihood)
        metrics = util.get_metrics(gp, test_data)

        try: 
            os.makedirs(output_dir)
        except OSError:
            print "skipping output folder"
        util.save_parameters(gp, os.path.join(output_dir, 'params'))
        util.save_metrics(metrics, os.path.join(output_dir, 'metrics'))
        #util.save_gradients(gp, os.path.join(output_dir, 'grads'))
        util.save_cautious_curves(gp, test_data, os.path.join(output_dir, 'curves'))
        util.save_predictions(gp, test_data, os.path.join(output_dir, 'preds'))
コード例 #7
0
ファイル: main.py プロジェクト: zsun1029/Image-Caption
def test(epoch, valtest_loader, tag):
    logging.info(time.strftime('time:%m.%d_%H:%M', time.localtime(time.time())))#写在cache下的log文件里
    adaptive.eval()
    prediction = [] 
    for batch_idx, (image, img_id) in enumerate(valtest_loader, 1):  
        # print((image))          #list   #torch.FloatTensor 
        # print((image[0]))       #None   #3x224x224
        if type(image) == tuple:   #list  
            continue   
        start_time = time.time() # image=[1, 3, 224, 224] <class 'torch.FloatTensor'> 
        source = Variable(image) # torch.LongTensor(image))#transform(image).unsqueeze(0))      
        if torch.cuda.is_available():  
            src = source.cuda()       
        pred = adaptive.sampler(opt, src)   # print(pred.size()) [1, 20] 
        # pred = adaptive.beam(opt, src)   # print(pred.size()) [1, 20] 
        pred_dict = dict()  # print(len(vocab)) 9956
        pred = [' '.join(map(lambda x: vocab.idx2word[x], p)) for p in pred]
        pred_dict['image_id'] = img_id[0]
        pred_dict['caption'] = pred[0] 
        prediction.append(pred_dict)  

        elapsed = time.time() - start_time
        if batch_idx % 1000 == 0:#opt.log_interval
            print(tag + ' Epoch: [{}/{} ({:.0f}%)]\ts/batch: {:.2f}'.format(
                batch_idx, len(valtest_loader), 100. * batch_idx / len(valtest_loader), elapsed) )

    name = ('epoch_%d_%s' % (epoch, tag)) + base_name # test/val生成的文件,放在generation下 
    
    predictfile = '%s%s'%(opt.save, name)
    json.dump(prediction, open(predictfile+'.json', 'w')) 
    metrics = get_metrics(tag, predictfile)
    logging.info("%s epoch %s metrics %s" %(tag, str(epoch), str(metrics)))
    print("%s epoch %s metrics %s" %(tag, str(epoch), str(metrics)))
コード例 #8
0
def main():
    result_path, today = create_result_directory()
    run_id = get_next_index(result_path + today + 'InfraRender_metrics')
    mixtures = MixedSpectra(ascii_spectra="input/test/mtes_kimmurray_rocks_full_tab.txt",
                            meta_csv="input/test/mtes_kimmurray_rocks_full_meta.csv")

    spectra = torch.from_numpy(mixtures.spectra[104:, :].T).type(dtype).to(device)
    abundances = torch.from_numpy(mixtures.abundances).type(dtype).to(device)
    wavenumbers = torch.from_numpy(mixtures.bands[104:]).type(dtype).to(device)

    model = load_model(model_params=param_path, dispersion_params=dispersion_params, wavenumbers=wavenumbers)
    metrics = []
    squared_error = 0
    j, count = 0, 0
    for cur_spectra, cur_abundances in zip(spectra, abundances):
        if isValidMixture(mixtures.category[j]):
            pred_spectra, pred_abundances = model.forward(cur_spectra.unsqueeze(0))
            pred_abundances = consolidate_feely(pred_abundances[0])
            metrics.append(get_metrics(modelSpectra=pred_spectra, trueSpectra=cur_spectra,
                                       modelAbundances=pred_abundances, trueAbundances=cur_abundances))
            squared_error += ((pred_abundances - cur_abundances) ** 2).sum().item()
            count += 1
        j += 1
    print('average error', squared_error/count)
    save_list_of_dicts(metrics, "InfraRender_metrics", result_path, today, run_id)
コード例 #9
0
def main():
    args = setup_argparser().parse_args()

    filename = args.file
    num_trees = args.num_trees
    sampling_ratio = args.sampling_ratio
    max_depth = args.max_depth
    min_size = args.min_size
    features_ratio = args.features_ratio

    x, labels = import_file(filename)
    rf = RandomForest(num_trees=num_trees,
                      sampling_ratio=sampling_ratio,
                      max_depth=max_depth,
                      min_size=min_size,
                      features_ratio=features_ratio)
    rf.fit(x)
    predictions = rf.predict(x)
    p, r, f1 = get_metrics(labels, predictions, class_label=1)
    acc = accuracy(labels, predictions)
    print("Naive results")
    print("Accuracy: {}, Precision: {}, Recall: {}, F-1: {}".format(
        acc, p, r, f1))

    ten_cv = CrossValidation(k=10)
    rf = RandomForest(num_trees=num_trees,
                      sampling_ratio=sampling_ratio,
                      max_depth=max_depth,
                      min_size=min_size,
                      features_ratio=features_ratio)
    train_scores, val_scores, *_ = ten_cv.cross_validate(rf, x, labels)
    print("10-fold cross validation")
    print("Training scores: {0}\nValidation scores: {1}".format(
        train_scores, val_scores))
    return
コード例 #10
0
def run_test(
    dataset,
    clf_type,
    epochs,
    true_rh1,
    downsample_ratio,
    ordered_models_keys,
    list_of_images=range(10),
    suppress_error=False,
    verbose=False,
    pi1=0.0,
    one_vs_rest=True,
    cv_n_folds=3,
    early_stopping=True,
    pulearning=None,
):

    # Cast types to ensure consistency for 1 and 1.0, 0 and 0.0
    true_rh1 = float(true_rh1)
    downsample_ratio = float(downsample_ratio)
    pi1 = float(pi1)

    # Load MNIST or CIFAR data
    (X_train_original,
     y_train_original), (X_test_original,
                         y_test_original) = get_dataset(dataset=dataset)
    X_train_original, y_train_original = downsample(X_train_original,
                                                    y_train_original,
                                                    downsample_ratio)

    # Initialize models and result storage
    metrics = {key: [] for key in ordered_models_keys}
    data_all = {"metrics": metrics, "calculated": {}, "errors": {}}
    start_time = dt.now()

    # Run through the ten images class of 0, 1, ..., 9
    for image in list_of_images:
        if one_vs_rest:
            # X_train and X_test will not be modified. All data will be used. Adjust pointers.
            X_train = X_train_original
            X_test = X_test_original

            # Relabel the image data. Make label 1 only for given image.
            y_train = np.array(y_train_original == image, dtype=int)
            y_test = np.array(y_test_original == image, dtype=int)
        else:  # one_vs_other
            # Reducing the dataset to just contain our image and image = 4
            other_image = 4 if image != 4 else 7
            X_train = X_train_original[(y_train_original == image) |
                                       (y_train_original == other_image)]
            y_train = y_train_original[(y_train_original == image) |
                                       (y_train_original == other_image)]
            X_test = X_test_original[(y_test_original == image) |
                                     (y_test_original == other_image)]
            y_test = y_test_original[(y_test_original == image) |
                                     (y_test_original == other_image)]

            # Relabel the data. Make label 1 only for given image.
            y_train = np.array(y_train == image, dtype=int)
            y_test = np.array(y_test == image, dtype=int)

        print()
        print("Evaluating image:", image)
        print("Number of positives in y:", sum(y_train))
        print()
        sys.stdout.flush()

        s = y_train * (np.cumsum(y_train) < (1 - true_rh1) * sum(y_train))
        # In the presence of mislabeled negative (negative incorrectly labeled positive):
        # pi1 is the fraction of mislabeled negative in the labeled set:
        num_mislabeled = int(sum(y_train) * (1 - true_rh1) * pi1 / (1 - pi1))
        if num_mislabeled > 0:
            negative_set = s[y_train == 0]
            mislabeled = np.random.choice(len(negative_set),
                                          num_mislabeled,
                                          replace=False)
            negative_set[mislabeled] = 1
            s[y_train == 0] = negative_set

        print("image = {0}".format(image))
        print(
            "Training set: total = {0}, positives = {1}, negatives = {2}, P_noisy = {3}, N_noisy = {4}"
            .format(len(X_train), sum(y_train),
                    len(y_train) - sum(y_train), sum(s),
                    len(s) - sum(s)))
        print("Testing set:  total = {0}, positives = {1}, negatives = {2}".
              format(len(X_test), sum(y_test),
                     len(y_test) - sum(y_test)))

        # Fit different models for PU learning
        for key in ordered_models_keys:
            fit_start_time = dt.now()
            print("\n\nFitting {0} classifier. Default classifier is {1}.".
                  format(key, clf_type))

            if clf_type == "logreg":
                clf = LogisticRegression()
            elif clf_type == "cnn":
                from classifier_cnn import CNN
                from keras import backend as K
                K.clear_session()
                clf = CNN(
                    dataset_name=dataset,
                    num_category=2,
                    epochs=epochs,
                    early_stopping=early_stopping,
                    verbose=1,
                )
            else:
                raise ValueError(
                    "clf_type must be either logreg or cnn for this testing file."
                )

            ps1 = sum(s) / float(len(s))
            py1 = sum(y_train) / float(len(y_train))
            true_rh0 = pi1 * ps1 / float(1 - py1)

            model = get_model(
                key=key,
                rh1=true_rh1,
                rh0=true_rh0,
                clf=clf,
            )

            try:
                if key == "True Classifier":
                    model.fit(X_train, y_train)
                elif key in [
                        "Rank Pruning", "Rank Pruning (noise rates given)",
                        "Liu16 (noise rates given)"
                ]:
                    model.fit(X_train,
                              s,
                              pulearning=pulearning,
                              cv_n_folds=cv_n_folds)
                elif key in ["Nat13 (noise rates given)"]:
                    model.fit(X_train, s, pulearning=pulearning)
                else:  # Elk08, Baseline
                    model.fit(X_train, s)

                pred = model.predict(X_test)
                # Produces only P(y=1|x) for pulearning models because they are binary
                pred_prob = model.predict_proba(X_test)
                pred_prob = pred_prob[:,
                                      1] if key == "True Classifier" else pred_prob

                # Compute metrics
                metrics_dict = get_metrics(pred, pred_prob, y_test)
                elapsed = (dt.now() - fit_start_time).total_seconds()

                if verbose:
                    print(
                        "\n{0} Model Performance at image {1}:\n=================\n"
                        .format(key, image))
                    print("Time Required", elapsed)
                    print("AUC:", metrics_dict["AUC"])
                    print("Error:", metrics_dict["Error"])
                    print("Precision:", metrics_dict["Precision"])
                    print("Recall:", metrics_dict["Recall"])
                    print("F1 score:", metrics_dict["F1 score"])
                    print("rh1:", model.rh1 if hasattr(model, 'rh1') else None)
                    print("rh0:", model.rh0 if hasattr(model, 'rh0') else None)
                    print()

                metrics_dict["image"] = image
                metrics_dict["time_seconds"] = elapsed
                metrics_dict["rh1"] = model.rh1 if hasattr(model,
                                                           'rh1') else None
                metrics_dict["rh0"] = model.rh0 if hasattr(model,
                                                           'rh0') else None

                # Append dictionary of error and loss metrics
                if key not in data_all["metrics"]:
                    data_all["metrics"][key] = [metrics_dict]
                else:
                    data_all["metrics"][key].append(metrics_dict)
                data_all["calculated"][(key, image)] = True

            except Exception as e:
                msg = "Error in {0}, image {1}, rh1 {2}, m {3}: {4}\n".format(
                    key, image, true_rh1, pi1, e)
                print(msg)
                make_sure_path_exists("failed_models/")
                with open("failed_models/" + key + ".txt", "ab") as f:
                    f.write(msg)
                if suppress_error:
                    continue
                else:
                    raise
    return data_all
コード例 #11
0
for trueSpectra, trueAbundances in zip(spectra, mixtures.abundances):
    if mixtures.category[j] == 'invalid':
        pass
    else:
        trueSpectra = torch.from_numpy(trueSpectra).type(dtype).to(device)
        trueAbundances = torch.from_numpy(trueAbundances).type(dtype).to(device)
        model = AnalysisBySynthesis(paramFile=model_file, p=p, lam=p_lambda,
                            wavenumbers=wavenumbers, dtype=dtype, device=device)
        for name, endmember in model.endmemberModels.items():
            endmember.set_constraint_tolerance(freq_tolerance=freq_tol,
                                               gamma_tolerance=gamma_tol,
                                               epsilon_tolerance=eps_tol,
                                               rho_tolerance=rho_tol,
                                               mode_weight_tolerance=mode_weight_tol)
        model.fit(trueSpectra, epochs=epochs, learningRate=lr, betas=betas)
        modelAbundances = consolidate_feely(model.abundances)
        metrics.append(get_metrics(modelSpectra=model.predictedSpectra, trueSpectra=trueSpectra, thresh=min_endmemb,
                                   modelAbundances=modelAbundances, trueAbundances=trueAbundances))
        model.write_results(experiment_path + today + '_labmix_results%d.hdf' % run_id, group_name=str(j))
        j += 1


# save metrics as a csv file
save_list_of_dicts(metrics, "labmix", experiment_path, today, run_id)

# convert metrics to a pandas dataframe and display average results
metrics =  pd.DataFrame(metrics)
print('squared error:')
print(metrics["squared_error"].mean())
print('\n')
コード例 #12
0
    # Step 3: perform the evaluation

    # By now if a problem occur save the computed metrics
    atexit.register(signal_handler)
    current_writing_file = None

    # Deal with cacheing things
    for d, r, c, m in tqdm(entries, disable=not args.verbose):
        documents_name = '.'.join(os.path.basename(d).split('.')[:-2])
        reference_name = '.'.join(os.path.basename(r).split('.')[1:-1])
        model_name = '.'.join(os.path.basename(c).split('.')[1:-1])

        logging.info(documents_name + reference_name + model_name)

        try:
            metrics = get_metrics(d, r, c, needed_metrics=m)
        except:
            continue

        for m_name, m_val in metrics:
            # TODO: Remove old entry if existing
            output = output.append(
                {
                    'model': model_name,
                    'dataset': documents_name,
                    'reference': reference_name,
                    'metric': m_name,
                    'value': round(m_val, 4),
                    'time': int(time.time())
                },
                ignore_index=True)
コード例 #13
0
def experiment(model_dir='.',  # pylint: disable=dangerous-default-value
               imagenet_subset_dir=None,
               dataset='cifar10',
               batch_size=256,
               eval_batch_size=1024,
               num_epochs=200,
               learning_rate=0.1,
               aug_imagenet_apply_colour_jitter=False,
               aug_imagenet_greyscale_prob=0.0,
               sgd_momentum=0.9,
               sgd_nesterov=True,
               lr_schedule='stepped',
               lr_sched_steps=[[60, 0.2], [120, 0.04], [160, 0.008]],
               lr_sched_halfcoslength=400.0,
               lr_sched_warmup=5.0,
               l2_reg=0.0005,
               weight_decay=0.0,
               architecture='wrn22_10',
               n_val=5000,
               n_sup=1000,
               teacher_alpha=0.999,
               anneal_teacher_alpha=False,
               unsupervised_regularizer='none',
               cons_weight=1.0,
               conf_thresh=0.97,
               conf_avg=False,
               cut_backg_noise=1.0,
               cut_prob=1.0,
               box_reg_scale_mode='fixed',
               box_reg_scale=0.25,
               box_reg_random_aspect_ratio=False,
               cow_sigma_range=(4.0, 8.0),
               cow_prop_range=(0.25, 1.0),
               mix_regularizer='none',
               mix_aug_separately=False,
               mix_logits=True,
               mix_weight=1.0,
               mix_conf_thresh=0.97,
               mix_conf_avg=True,
               mix_conf_mode='mix_prob',
               ict_alpha=0.1,
               mix_box_reg_scale_mode='fixed',
               mix_box_reg_scale=0.25,
               mix_box_reg_random_aspect_ratio=False,
               mix_cow_sigma_range=(4.0, 8.0),
               mix_cow_prop_range=(0.0, 1.0),
               subset_seed=12345,
               val_seed=131,
               run_seed=None,
               log_fn=print,
               checkpoints='on',
               debug=False):
  """Run experiment."""
  if checkpoints not in {'none', 'on', 'retain'}:
    raise ValueError('checkpoints should be one of (none|on|retain)')

  if checkpoints != 'none':
    checkpoint_path = os.path.join(model_dir, 'checkpoint.pkl')
    checkpoint_new_path = os.path.join(model_dir, 'checkpoint.pkl.new')
  else:
    checkpoint_path = None
    checkpoint_new_path = None

  if dataset not in {'svhn', 'cifar10', 'cifar100', 'imagenet'}:
    raise ValueError('Unknown dataset \'{}\''.format(dataset))

  if architecture not in {'wrn20_10', 'wrn26_10', 'wrn26_2',
                          'wrn20_6_shakeshake', 'wrn26_6_shakeshake',
                          'wrn26_2_shakeshake', 'pyramid',
                          'resnet50', 'resnet101', 'resnet152',
                          'resnet50x2', 'resnet101x2', 'resnet152x2',
                          'resnet50x4', 'resnet101x4', 'resnet152x4',
                          'resnext50_32x4d', 'resnext101_32x8d',
                          'resnext152_32x4d'}:
    raise ValueError('Unknown architecture \'{}\''.format(architecture))

  if lr_schedule not in {'constant', 'stepped', 'cosine'}:
    raise ValueError('Unknown LR schedule \'{}\''.format(lr_schedule))

  if mix_conf_mode not in {'mix_prob', 'mix_conf'}:
    raise ValueError('Unknown mix_conf_mode \'{}\''.format(mix_conf_mode))

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(model_dir)
  else:
    summary_writer = None

  unsup_reg, augment_twice = build_pert_reg(
      unsupervised_regularizer, cut_backg_noise=cut_backg_noise,
      cut_prob=cut_prob, box_reg_scale_mode=box_reg_scale_mode,
      box_reg_scale=box_reg_scale,
      box_reg_random_aspect_ratio=box_reg_random_aspect_ratio,
      cow_sigma_range=cow_sigma_range, cow_prop_range=cow_prop_range)

  mix_reg = build_mix_reg(
      mix_regularizer, ict_alpha=ict_alpha,
      box_reg_scale_mode=mix_box_reg_scale_mode,
      box_reg_scale=mix_box_reg_scale,
      box_reg_random_aspect_ratio=mix_box_reg_random_aspect_ratio,
      cow_sigma_range=mix_cow_sigma_range, cow_prop_range=mix_cow_prop_range)

  if run_seed is None:
    run_seed = subset_seed << 32 | n_val
  train_rng = jax.random.PRNGKey(run_seed)
  init_rng, train_rng = jax.random.split(train_rng)

  if batch_size % jax.device_count() > 0:
    raise ValueError('Train batch size must be divisible by the number of '
                     'devices')
  if eval_batch_size % jax.device_count() > 0:
    raise ValueError('Eval batch size must be divisible by the number of '
                     'devices')
  local_batch_size = batch_size // jax.host_count()
  local_eval_batch_size = eval_batch_size // jax.host_count()
  device_batch_size = batch_size // jax.device_count()

  if dataset == 'svhn':
    image_size = 32
    top5_err_required = False
    data_source = small_image_data_source.SVHNDataSource(
        n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size,
        eval_batch_size=local_eval_batch_size,
        augment_twice=augment_twice, subset_seed=subset_seed,
        val_seed=val_seed)
  elif dataset == 'cifar10':
    image_size = 32
    top5_err_required = False
    data_source = small_image_data_source.CIFAR10DataSource(
        n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size,
        eval_batch_size=local_eval_batch_size, augment_twice=augment_twice,
        subset_seed=subset_seed, val_seed=val_seed)
  elif dataset == 'cifar100':
    image_size = 32
    top5_err_required = False
    data_source = small_image_data_source.CIFAR100DataSource(
        n_val=n_val, n_sup=n_sup, train_batch_size=local_batch_size,
        eval_batch_size=local_eval_batch_size, augment_twice=augment_twice,
        subset_seed=subset_seed, val_seed=val_seed)
  elif dataset == 'imagenet':
    image_size = 224
    top5_err_required = True
    if imagenet_subset_dir is None:
      raise ValueError('Please provide a directory to the imagenet_subset_dir '
                       'command line arg to specify where the ImageNet '
                       'subsets are stored')
    data_source = imagenet_data_source.ImageNetDataSource(
        imagenet_subset_dir, n_val, n_sup, local_batch_size,
        local_eval_batch_size, augment_twice,
        apply_colour_jitter=aug_imagenet_apply_colour_jitter,
        greyscale_prob=aug_imagenet_greyscale_prob,
        load_test_set=(n_val == 0), image_size=image_size,
        subset_seed=subset_seed, val_seed=val_seed)
  else:
    raise RuntimeError

  n_train = data_source.n_train
  train_ds = data_source.train_semisup_ds

  if n_val == 0:
    eval_ds = data_source.test_ds
    n_eval = data_source.n_test
  else:
    eval_ds = data_source.val_ds
    n_eval = data_source.n_val

  log_fn('DATA: |train|={}, |sup|={}, |eval|={}, (|val|={}, |test|={})'.format(
      data_source.n_train, data_source.n_sup, n_eval, data_source.n_val,
      data_source.n_test))

  log_fn('Loaded dataset')

  steps_per_epoch = n_train // batch_size
  steps_per_eval = n_eval // eval_batch_size
  if n_eval % eval_batch_size > 0:
    steps_per_eval += 1
  num_steps = steps_per_epoch * num_epochs

  # Create model
  model_stu, state_stu = create_model(
      init_rng, architecture, device_batch_size, image_size,
      data_source.n_classes)
  state_stu = jax_utils.replicate(state_stu)
  log_fn('Built model')

  # Create optimizer
  optimizer_def = optim.Momentum(learning_rate=learning_rate,
                                 beta=sgd_momentum,
                                 nesterov=sgd_nesterov)

  optimizer_stu = optimizer_def.create(model_stu)
  optimizer_stu = optimizer_stu.replicate()
  del model_stu  # don't keep a copy of the initial model

  # Create learning rate function
  base_learning_rate = learning_rate * batch_size / 256.
  if lr_schedule == 'constant':
    learning_rate_fn = create_constant_learning_rate_fn(base_learning_rate)
  elif lr_schedule == 'stepped':
    learning_rate_fn = create_stepped_learning_rate_fn(
        base_learning_rate, steps_per_epoch, lr_sched_steps=lr_sched_steps,
        warmup_length=lr_sched_warmup)
  elif lr_schedule == 'cosine':
    learning_rate_fn = create_cosine_learning_rate_fn(
        base_learning_rate, steps_per_epoch,
        halfcoslength_epochs=lr_sched_halfcoslength,
        warmup_length=lr_sched_warmup)
  else:
    raise RuntimeError

  if anneal_teacher_alpha:
    if lr_schedule == 'constant':
      one_minus_alpha_fn = create_constant_learning_rate_fn(1.0 - teacher_alpha)
    elif lr_schedule == 'stepped':
      one_minus_alpha_fn = create_stepped_learning_rate_fn(
          1.0 - teacher_alpha, steps_per_epoch, lr_sched_steps=lr_sched_steps)
    elif lr_schedule == 'cosine':
      one_minus_alpha_fn = create_cosine_learning_rate_fn(
          1.0 - teacher_alpha, steps_per_epoch,
          halfcoslength_epochs=lr_sched_halfcoslength)
    else:
      raise RuntimeError
    teacher_alpha_fn = lambda step: 1.0 - one_minus_alpha_fn(step)
  else:
    teacher_alpha_fn = lambda step: teacher_alpha

  log_fn('Built optimizer')

  # Teacher model is just the student as we duplicate it when we modify it
  model_tea = optimizer_stu.target
  # Replicate batch stats
  state_tea = jax.tree_map(lambda x: x, state_stu)

  # Set up epoch and step counter
  # Load existing checkpoint if available
  epoch = 1
  step = 0

  if checkpoints != 'none':
    if tf.io.gfile.exists(checkpoint_path):
      with tf.io.gfile.GFile(checkpoint_path, 'rb') as f_in:
        check = pickle.load(f_in)

        # Student optimizer and batch stats
        optimizer_stu = util.restore_state_list(
            optimizer_stu, check['optimizer_stu'])

        state_stu = util.restore_state_list(
            state_stu, check['state_stu'])

        # Teacher model and batch stats
        model_tea = util.restore_state_list(
            model_tea, check['model_tea'])

        state_tea = util.restore_state_list(
            state_tea, check['state_tea'])

        epoch = check['epoch']
        step = check['step']

        log_fn('Loaded checkpoint from {}'.format(checkpoint_path))

  #
  # Training and evaluation step functions
  #
  p_train_step = jax.pmap(
      functools.partial(train_step, learning_rate_fn=learning_rate_fn,
                        l2_reg=l2_reg, weight_decay=weight_decay,
                        teacher_alpha_fn=teacher_alpha_fn,
                        unsup_reg=unsup_reg, cons_weight=cons_weight,
                        conf_thresh=conf_thresh,
                        conf_avg=conf_avg,
                        mix_reg=mix_reg, mix_aug_separately=mix_aug_separately,
                        mix_logits=mix_logits, mix_weight=mix_weight,
                        mix_conf_thresh=mix_conf_thresh,
                        mix_conf_avg=mix_conf_avg,
                        mix_conf_mode=mix_conf_mode),
      axis_name='batch')
  p_eval_step = jax.pmap(
      functools.partial(eval_step, eval_top_5=top5_err_required),
      axis_name='batch')

  # Create dataset batch iterators
  train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)

  #
  # Training loop
  #

  log_fn('Training...')
  epoch_metrics_stu = []
  t1 = time.time()
  while step < num_steps:
    train_rng, iter_rng = jax.random.split(train_rng)
    batch = next(train_iter)
    batch = jax.tree_map(lambda x: x._numpy(), batch)  # pylint: disable=protected-access
    batch = shard(batch, iter_rng)

    optimizer_stu, state_stu, metrics_stu, model_tea, state_tea = p_train_step(
        optimizer_stu, state_stu, model_tea, state_tea, batch)

    if debug:
      log_fn('Step {} time {}'.format(step, time.time()-t1))

    epoch_metrics_stu.append(metrics_stu)
    if (step + 1) % steps_per_epoch == 0:
      epoch_metrics_stu = util.get_metrics(epoch_metrics_stu)
      train_epoch_metrics = jax.tree_map(lambda x: x.mean(), epoch_metrics_stu)
      if summary_writer is not None:
        for key, vals in epoch_metrics_stu.items():
          tag = 'train_%s' % key
          for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)

      epoch_metrics_stu = []
      eval_stu_metrics = []
      eval_tea_metrics = []
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        # TF to NumPy
        eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
        # Pad short batches
        eval_batch = util.pad_classification_batch(
            eval_batch, local_eval_batch_size)
        # Shard across local devices
        eval_batch = shard(eval_batch)
        metrics_stu = p_eval_step(optimizer_stu.target, state_stu, eval_batch)
        metrics_tea = p_eval_step(model_tea, state_tea, eval_batch)
        eval_stu_metrics.append(metrics_stu)
        eval_tea_metrics.append(metrics_tea)
      eval_stu_metrics = util.get_metrics(eval_stu_metrics)
      eval_tea_metrics = util.get_metrics(eval_tea_metrics)
      eval_stu_epoch_metrics = jax.tree_map(lambda x: x.sum(), eval_stu_metrics)
      eval_tea_epoch_metrics = jax.tree_map(lambda x: x.sum(), eval_tea_metrics)
      eval_stu_epoch_metrics = avg_eval_metrics(eval_stu_epoch_metrics)
      eval_tea_epoch_metrics = avg_eval_metrics(eval_tea_epoch_metrics)

      t2 = time.time()

      if top5_err_required:
        log_fn('EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, '
               'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, '
               'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, '
               'top-5-err={:.3%}, TEA Eval loss={:.6f}, err={:.3%}, '
               'top-5-err={:.3%}'.format(
                   epoch, t2 - t1, train_epoch_metrics['loss'],
                   train_epoch_metrics['error_rate'],
                   train_epoch_metrics['cons_loss'],
                   train_epoch_metrics['conf_rate'],
                   train_epoch_metrics['mix_loss'],
                   train_epoch_metrics['mix_conf_rate'],
                   eval_stu_epoch_metrics['loss'],
                   eval_stu_epoch_metrics['error_rate'],
                   eval_stu_epoch_metrics['top5_error_rate'],
                   eval_tea_epoch_metrics['loss'],
                   eval_tea_epoch_metrics['error_rate'],
                   eval_tea_epoch_metrics['top5_error_rate'],))
      else:
        log_fn('EPOCH {} (took {:.3f}s): Train loss={:.6f}, err={:.3%}, '
               'cons loss={:.6f}, conf rate={:.3%}, mix loss={:.6f}, '
               'mix conf rate={:.3%}; STU Eval loss={:.6f}, err={:.3%}, '
               'TEA Eval loss={:.6f}, err={:.3%}'.format(
                   epoch, t2 - t1, train_epoch_metrics['loss'],
                   train_epoch_metrics['error_rate'],
                   train_epoch_metrics['cons_loss'],
                   train_epoch_metrics['conf_rate'],
                   train_epoch_metrics['mix_loss'],
                   train_epoch_metrics['mix_conf_rate'],
                   eval_stu_epoch_metrics['loss'],
                   eval_stu_epoch_metrics['error_rate'],
                   eval_tea_epoch_metrics['loss'],
                   eval_tea_epoch_metrics['error_rate'],))

      t1 = t2

      if summary_writer is not None:
        summary_writer.scalar(
            'eval_stu_loss', eval_stu_epoch_metrics['loss'], epoch)
        summary_writer.scalar(
            'eval_stu_error_rate', eval_stu_epoch_metrics['error_rate'], epoch)
        summary_writer.scalar(
            'eval_tea_loss', eval_tea_epoch_metrics['loss'], epoch)
        summary_writer.scalar(
            'eval_tea_error_rate', eval_tea_epoch_metrics['error_rate'], epoch)
        if top5_err_required:
          summary_writer.scalar(
              'eval_stu_top5_error_rate',
              eval_stu_epoch_metrics['top5_error_rate'], epoch)
          summary_writer.scalar(
              'eval_tea_top5_error_rate',
              eval_tea_epoch_metrics['top5_error_rate'], epoch)
        summary_writer.flush()

        epoch += 1

        if checkpoints != 'none':
          if jax.host_id() == 0:
            # Write to new checkpoint file so that we don't immediately
            # overwrite the old one
            with tf.io.gfile.GFile(checkpoint_new_path, 'wb') as f_out:
              check = dict(
                  optimizer_stu=util.to_state_list(optimizer_stu),
                  state_stu=util.to_state_list(state_stu),
                  model_tea=util.to_state_list(model_tea),
                  state_tea=util.to_state_list(state_tea),
                  epoch=epoch,
                  step=step + 1,
              )
              pickle.dump(check, f_out)
              del check
            # Remove old checkpoint and rename
            if tf.io.gfile.exists(checkpoint_path):
              tf.io.gfile.remove(checkpoint_path)
            tf.io.gfile.rename(checkpoint_new_path, checkpoint_path)

    step += 1

  if checkpoints == 'on':
    if jax.host_id() == 0:
      if tf.io.gfile.exists(checkpoint_path):
        tf.io.gfile.remove(checkpoint_path)