Exemple #1
0
    def get_from_generator_with_offsets(self,
                                        generator,
                                        centroid_indices,
                                        adjust_last=False):
        from drnet.data_access.generator import get_last_id_set

        centroids_tmp, current_idx = [], 0
        while len(centroid_indices) != 0:
            x, _ = next(generator)
            ids = get_last_id_set()

            while len(centroid_indices) != 0 and centroid_indices[
                    0] <= current_idx + len(x[0]):
                next_index = centroid_indices[0]
                del centroid_indices[0]

                is_last_treatment = len(centroid_indices) == 0
                if is_last_treatment and adjust_last:
                    # Last treatment is control = worse expected outcomes.
                    response_mean_of_mean = 1 - self.response_mean_of_mean
                else:
                    response_mean_of_mean = self.response_mean_of_mean

                response_mean = clip_percentage(
                    self.random_generator.normal(response_mean_of_mean,
                                                 self.response_std_of_mean))
                response_std = clip_percentage(
                    self.random_generator.normal(
                        self.response_mean_of_std,
                        self.response_std_of_std)) + 0.025
                centroids_tmp.append(
                    (x[0][next_index], response_mean, response_std))
            current_idx += len(x[0])
        return centroids_tmp
Exemple #2
0
    def evaluate_counterfactual(model,
                                generator,
                                num_steps,
                                benchmark,
                                set_name="Test set",
                                with_print=True,
                                selected_slice=-1,
                                stateful_benchmark=True,
                                output_directory=""):
        if stateful_benchmark:
            benchmark.set_assign_counterfactuals(True)

        is_ihdp = isinstance(benchmark, IHDPBenchmark)
        is_jobs = isinstance(benchmark, JobsBenchmark)
        is_news = isinstance(benchmark, NewsBenchmark)
        is_icu = isinstance(benchmark, ICUBenchmark)
        has_exposure = benchmark.has_exposure()
        num_treatments = benchmark.get_num_treatments()

        all_outputs, all_ids, all_x, all_treatments, all_mu0, all_mu1, all_e, all_z = [], [], [], [], [], [], [], []
        for _ in range(num_steps):
            generator_outputs = next(generator)
            if len(generator_outputs) == 3:
                batch_input, labels_batch, sample_weight = generator_outputs
            else:
                batch_input, labels_batch = generator_outputs

            if isinstance(labels_batch, list):
                labels_batch = labels_batch[selected_slice]

            id_set = get_last_id_set()
            all_ids.append(id_set)

            if is_ihdp:
                result = np.array(
                    benchmark.data_access.get_rows(id_set, columns="mu0, mu1"))
                all_mu0.append(result[:, 0])
                all_mu1.append(result[:, 1])
            elif is_jobs:
                result = np.array(
                    benchmark.data_access.get_rows(id_set, columns="e"))
                all_e.append(result[:, 0])
            elif is_news and has_exposure:
                result = np.array(
                    benchmark.data_access.get_rows(id_set, columns="z")[0])
                all_z.append(result)

            treatment_outputs = []
            for treatment_idx in range(num_treatments):
                if labels_batch.ndim == 1:
                    labels_batch = np.expand_dims(labels_batch, axis=-1)

                not_none_indices = np.where(
                    np.not_equal(labels_batch[:, treatment_idx], None))[0]
                if len(not_none_indices) == 0:
                    continue

                original_treatment = batch_input[1][not_none_indices]
                current_batch_input = [
                    np.copy(batch_input[0]),
                    np.ones_like(batch_input[1]) * treatment_idx
                ]

                if has_exposure:
                    # We eval all potential outcomes at the same dosage points for the multiple treatment metrics.
                    current_batch_input += [batch_input[2][:, treatment_idx]]

                model_output = model.predict(current_batch_input)
                if isinstance(model_output, list):
                    model_output = model_output[selected_slice]

                none_indices = np.where(
                    np.equal(labels_batch[:, treatment_idx], None))[0]

                if len(none_indices) != 0:
                    full_length = len(labels_batch)
                    inferred_labels = np.array([None] * full_length)
                    inferred_labels[not_none_indices] = labels_batch[
                        not_none_indices, treatment_idx]
                    result = (model_output, inferred_labels)
                else:
                    result = (model_output, labels_batch[not_none_indices,
                                                         treatment_idx])
                treatment_outputs.append(result)

            y_pred = np.column_stack(map(lambda x: x[0], treatment_outputs))
            y_true = np.column_stack(map(lambda x: x[1], treatment_outputs))

            all_outputs.append((y_pred, y_true))
            all_x.append(batch_input[0])
            all_treatments.append(batch_input[1])

        all_ids = np.concatenate(all_ids, axis=0)
        all_x = np.concatenate(all_x, axis=0)
        all_treatments = np.concatenate(all_treatments, axis=0)

        if is_ihdp:
            all_mu0 = np.concatenate(all_mu0, axis=0)
            all_mu1 = np.concatenate(all_mu1, axis=0)
        elif is_jobs:
            all_e = np.concatenate(all_e, axis=0)
        elif is_news and has_exposure:
            all_z = np.concatenate(all_z, axis=0)

        y_pred, y_true, _ = ModelEvaluation.get_y_from_outputs(
            model, all_outputs, num_steps, selected_slice=-1, selected_index=0)

        y_pred_f, y_true_f = y_pred[np.arange(len(y_pred)), all_treatments], \
                             y_true[np.arange(len(y_true)), all_treatments]

        num_treatments = benchmark.get_num_treatments()
        y_pred_cf, y_true_cf = np.zeros((len(y_pred_f) * (num_treatments-1))), \
                               np.zeros((len(y_pred_f) * (num_treatments-1)))

        for treatment in range(num_treatments - 1):
            for idx in range(len(y_pred_f)):
                cf_indices = np.arange(num_treatments)
                cf_indices = np.delete(cf_indices, all_treatments[idx])
                y_pred_cf[idx + len(y_pred_f) *
                          treatment] = y_pred[idx, cf_indices[treatment]]
                y_true_cf[idx + len(y_pred_f) *
                          treatment] = y_true[idx, cf_indices[treatment]]

        score_dict_f = ModelEvaluation.calculate_statistics_counterfactual(
            y_true_f, y_pred_f, set_name + "_f", with_print, prefix="f_")
        score_dict_cf = ModelEvaluation.calculate_statistics_counterfactual(
            y_true_cf, y_pred_cf, set_name + "_cf", with_print, prefix="cf_")

        y_true_w = np.concatenate([y_true_f, y_true_cf], axis=0)
        y_pred_w = np.concatenate([y_pred_f, y_pred_cf], axis=0)
        score_dict_w = ModelEvaluation.calculate_statistics_counterfactual(
            y_true_w, y_pred_w, set_name + "_w", with_print, prefix="w_")

        score_dict_f.update(score_dict_cf)
        score_dict_f.update(score_dict_w)

        if has_exposure:
            from drnet.models.exposure_metrics import calculate_exposure_metrics
            score_dict_exp = calculate_exposure_metrics(
                model,
                benchmark,
                all_ids,
                all_x,
                all_z,
                y_true_f,
                num_treatments,
                output_directory=output_directory)
            score_dict_f.update(score_dict_exp)

        if num_treatments == 2:
            if is_ihdp:
                score_dict_pehe = ModelEvaluation.calculate_pehe(
                    y_true_f,
                    y_pred_f,
                    y_true_cf,
                    y_pred_cf,
                    all_treatments,
                    all_mu0,
                    all_mu1,
                    all_x,
                    set_name=set_name + "_pehe",
                    prefix="cf_",
                    with_print=with_print)
            else:
                score_dict_pehe = ModelEvaluation.calculate_est_pehe(
                    y_true_f,
                    y_pred_f,
                    y_true_cf,
                    y_pred_cf,
                    all_treatments,
                    all_x,
                    all_e,
                    set_name=set_name + "_pehe",
                    prefix="cf_",
                    with_print=with_print,
                    is_jobs=is_jobs)

            if with_print:
                print(
                    "INFO: Performance on",
                    set_name,
                    "MISE =",
                    None
                    if "mise" not in score_dict_f else score_dict_f["mise"],
                    "+-",
                    None if "mise_std" not in score_dict_f else
                    score_dict_f["mise_std"],
                    "RMISE =",
                    None
                    if "rmise" not in score_dict_f else score_dict_f["rmise"],
                    "+-",
                    None if "rmise_std" not in score_dict_f else
                    score_dict_f["rmise_std"],
                    "PE =",
                    None if "pe" not in score_dict_f else score_dict_f["pe"],
                    "+-",
                    None if "pe_std" not in score_dict_f else
                    score_dict_f["pe_std"],
                    "DPE =",
                    None if "dpe" not in score_dict_f else score_dict_f["dpe"],
                    "+-",
                    None if "dpe_std" not in score_dict_f else
                    score_dict_f["dpe_std"],
                    "AAMISE =",
                    None if "aamise" not in score_dict_f else
                    score_dict_f["aamise"],
                    "+-",
                    None if "aamise_std" not in score_dict_f else
                    score_dict_f["aamise_std"],
                    file=sys.stderr)

            score_dict_f.update(score_dict_pehe)
        else:
            list_score_dicts_pehe = []
            for i in range(num_treatments):
                for j in range(num_treatments):
                    if j >= i:
                        continue

                    # i = t0, j = t1
                    t1_indices = np.where(all_treatments == i)[0].tolist()
                    t2_indices = np.where(all_treatments == j)[0].tolist()

                    these_x = np.concatenate(
                        [all_x[t1_indices], all_x[t2_indices]], axis=0)
                    y_pred_these_treatments = np.concatenate(
                        [y_pred[t1_indices], y_pred[t2_indices]], axis=0)
                    y_true_these_treatments = np.concatenate(
                        [y_true[t1_indices], y_true[t2_indices]], axis=0)

                    these_treatments = np.concatenate([
                        np.ones((len(t1_indices), ), dtype=int) * i,
                        np.ones((len(t2_indices), ), dtype=int) * j
                    ],
                                                      axis=0)

                    these_y_pred_f = y_pred_these_treatments[
                        np.arange(len(y_pred_these_treatments)),
                        these_treatments]
                    these_y_true_f = y_true_these_treatments[
                        np.arange(len(y_pred_these_treatments)),
                        these_treatments]

                    inverse_treatments = np.concatenate([
                        np.ones((len(t1_indices), ), dtype=int) * j,
                        np.ones((len(t2_indices), ), dtype=int) * i
                    ],
                                                        axis=0)

                    these_y_pred_cf = y_pred_these_treatments[
                        np.arange(len(y_pred_these_treatments)),
                        inverse_treatments]
                    these_y_true_cf = y_true_these_treatments[
                        np.arange(len(y_pred_these_treatments)),
                        inverse_treatments]

                    these_treatments = np.concatenate([
                        np.zeros((len(t1_indices), ), dtype=int),
                        np.ones((len(t2_indices), ), dtype=int)
                    ],
                                                      axis=0)

                    score_dict_pehe = ModelEvaluation.calculate_est_pehe(
                        these_y_true_f,
                        these_y_pred_f,
                        these_y_true_cf,
                        these_y_pred_cf,
                        these_treatments,
                        these_x,
                        all_e,
                        set_name=set_name + "_pehe",
                        prefix="cf_",
                        with_print=False)
                    list_score_dicts_pehe.append(score_dict_pehe)

            score_dict_pehe = {}
            if len(list_score_dicts_pehe) != 0:
                for key in list_score_dicts_pehe[0].keys():
                    all_values = [
                        list_score_dicts_pehe[i][key]
                        for i in range(len(list_score_dicts_pehe))
                    ]
                    score_dict_pehe[key] = np.mean(all_values)
                    score_dict_pehe[key + "_std"] = np.std(all_values)
                score_dict_f.update(score_dict_pehe)

            if with_print:
                print(
                    "INFO: Performance on",
                    set_name,
                    "RPEHE =",
                    None if "cf_pehe" not in score_dict_pehe else
                    score_dict_pehe["cf_pehe"],
                    "+-",
                    None if "cf_pehe_std" not in score_dict_pehe else
                    score_dict_pehe["cf_pehe_std"],
                    "PEHE_NN =",
                    None if "cf_pehe_nn" not in score_dict_pehe else
                    score_dict_pehe["cf_pehe_nn"],
                    "+-",
                    None if "cf_pehe_nn_std" not in score_dict_pehe else
                    score_dict_pehe["cf_pehe_nn_std"],
                    "ATE =",
                    None if "cf_ate" not in score_dict_pehe else
                    score_dict_pehe["cf_ate"],
                    "+-",
                    None if "cf_ate_std" not in score_dict_pehe else
                    score_dict_pehe["cf_ate_std"],
                    "MISE =",
                    None
                    if "mise" not in score_dict_f else score_dict_f["mise"],
                    "+-",
                    None if "mise_std" not in score_dict_f else
                    score_dict_f["mise_std"],
                    "RMISE =",
                    None
                    if "rmise" not in score_dict_f else score_dict_f["rmise"],
                    "+-",
                    None if "rmise_std" not in score_dict_f else
                    score_dict_f["rmise_std"],
                    "NN_RMISE =",
                    None if "nn_rmise" not in score_dict_f else
                    score_dict_f["nn_rmise"],
                    "+-",
                    None if "nn_rmise_std" not in score_dict_f else
                    score_dict_f["nn_rmise_std"],
                    "PE =",
                    None if "pe" not in score_dict_f else score_dict_f["pe"],
                    "+-",
                    None if "pe_std" not in score_dict_f else
                    score_dict_f["pe_std"],
                    "DPE =",
                    None if "dpe" not in score_dict_f else score_dict_f["dpe"],
                    "+-",
                    None if "dpe_std" not in score_dict_f else
                    score_dict_f["dpe_std"],
                    "AAMISE =",
                    None if "aamise" not in score_dict_f else
                    score_dict_f["aamise"],
                    "+-",
                    None if "aamise_std" not in score_dict_f else
                    score_dict_f["aamise_std"],
                    file=sys.stderr)

        if stateful_benchmark:
            benchmark.set_assign_counterfactuals(False)
        return score_dict_f
Exemple #3
0
    def get_matched_generator(self, train_generator, train_steps):
        all_x, all_y, all_ids = [], [], []
        for _ in range(train_steps):
            x, y = next(train_generator)
            all_x.append(x)
            all_y.append(y)
            all_ids.append(get_last_id_set())

        x, t, y, ids = np.concatenate(map(lambda x: x[0], all_x), axis=0), \
                       np.concatenate(map(lambda x: x[1], all_x), axis=0), \
                       np.concatenate(all_y, axis=0), \
                       np.concatenate(all_ids, axis=0)

        if self.with_exposure:
            ts = np.concatenate(map(lambda x: x[2], all_x), axis=0)

        t_indices = map(lambda t_idx: np.where(t == t_idx)[0],
                        range(self.num_treatments))
        t_lens = map(lambda x: len(x), t_indices)

        self.batch_augmentation.make_propensity_lists(ids, self.benchmark,
                                                      **self.args)

        undersample = True
        base_treatment_idx = np.argmin(t_lens) if undersample else np.argmax(
            t_lens)
        base_indices = t_indices[base_treatment_idx]
        inner_x, inner_t, inner_y = x[base_indices], t[base_indices], y[
            base_indices]

        if self.with_exposure:
            inner_ts = ts[base_indices]
        else:
            inner_ts = None

        outer_x, outer_t, outer_y, outer_ts = \
            self.batch_augmentation.enhance_batch_with_propensity_matches(self.args,
                                                                          self.benchmark,
                                                                          inner_t,
                                                                          inner_x,
                                                                          inner_y,
                                                                          inner_ts,
                                                                          self.propensity_batch_probability,
                                                                          self.num_randomised_neighbours)

        def outer_generator():
            while True:
                indices = np.random.permutation(outer_x.shape[0])
                for idx in indices:
                    if self.with_exposure:
                        yield outer_x[idx], outer_t[idx], outer_ts[
                            idx], outer_y[idx]
                    else:
                        yield outer_x[idx], outer_t[idx], outer_y[idx]

        def inner_generator(wrapped_generator):
            while True:
                batch_data = zip(*map(lambda _: next(wrapped_generator),
                                      range(self.batch_size)))
                yield [np.array(data)
                       for data in batch_data[:-1]], np.array(batch_data[-1])

        new_generator = inner_generator(outer_generator())
        train_steps = max(outer_x.shape[0] // self.batch_size, 1)

        return new_generator, train_steps