示例#1
0
    def plot_params_over_time(self):
        # BETA, GAMMA, DELTA
        size = self.beta.shape[0]
        pl_x = list(range(size))  # list(range(len(beta)))
        beta_pl = Curve(pl_x, self.beta.detach().numpy(), '-g', "$\\beta$")
        gamma_pl = Curve(pl_x, [self.gamma.detach().numpy()] * size, '-r',
                         "$\gamma$")
        delta_pl = Curve(pl_x, [self.delta.detach().numpy()] * size, '-b',
                         "$\delta$")
        params_curves = [beta_pl, gamma_pl, delta_pl]

        if self.use_alpha:
            alpha = numpy.np.concatenate([
                self.alpha(self.get_policy_code(t)).detach().numpy().reshape(1)
                for t in range(size)
            ],
                                         axis=0)
            alpha_pl = Curve(pl_x, alpha, '-', "$\\alpha$")
            beta_alpha_pl = Curve(pl_x,
                                  alpha * self.beta.detach().numpy(), '-',
                                  "$\\alpha \cdot \\beta$")
            params_curves.append(alpha_pl)
            params_curves.append(beta_alpha_pl)

        bgd_pl_title = "beta, gamma, delta"
        return generic_plot(params_curves,
                            bgd_pl_title,
                            None,
                            formatter=format_xtick)
示例#2
0
    def plot_params_over_time(self):
        # BETA, GAMMA, DELTA
        size = self.beta.shape[0]
        pl_x = list(range(size))  # list(range(len(beta)))
        beta_pl = Curve(pl_x, self.beta.detach().numpy(), '-g', "$\\beta$")
        gamma_pl = Curve(pl_x, [self.gamma.detach().numpy()] * size, '-r', "$\\gamma$")
        delta_pl = Curve(pl_x, [self.delta.detach().numpy()] * size, '-b', "$\\delta$")
        params_curves = [beta_pl, gamma_pl, delta_pl]

        bgd_pl_title = "beta, gamma, delta"
        return generic_plot(params_curves, bgd_pl_title, None, formatter=format_xtick)
    def step(self, closure=None):
        for group in self.param_groups:
            lr = group["lr"]
            param_name = group["name"]
            for idx, parameter in enumerate(group["params"]):
                if parameter.grad is None:
                    continue

                d_p = parameter.grad
                if self.momentum:
                    times = torch.arange(parameter.shape[0],
                                         dtype=parameter.dtype)
                    mu = torch.sigmoid(self.m * times)
                    eta = lr / (1 + self.a * times)
                    update = [-eta[0] * d_p[0]]
                    for t in range(1, d_p.size(0)):
                        momentum_term = -eta[t] * d_p[t] + mu[t] * update[t -
                                                                          1]
                        update.append(momentum_term)
                    update = torch.tensor(update)
                else:
                    update = -lr * d_p

                if self.summary is not None and self.epoch % 50 == 0:
                    pl_x = range(0, update.shape[0])
                    before_m_curve = Curve(
                        pl_x, (-lr * d_p).detach().numpy(),
                        '.',
                        label=f"{param_name} before momentum",
                        color=None)
                    after_m_curve = Curve(pl_x,
                                          update.detach().numpy(),
                                          '.',
                                          label=f"{param_name} after momentum",
                                          color=None)
                    fig = generic_plot([before_m_curve, after_m_curve],
                                       f"{param_name} update over time", None)
                    self.summary.add_figure(f"updates/{param_name}",
                                            fig,
                                            global_step=self.epoch)
                parameter.data.add_(update)

        self.epoch += 1
示例#4
0
    def plot_r0(self, r0):
        pl_x = list(range(0, r0.shape[0]))
        hat_curve = Curve(pl_x,
                          r0.detach().numpy(),
                          '-',
                          label=f"Estimated R0",
                          color=None)
        curves = [hat_curve]

        if self.references is not None:
            r0_slice = slice(0, r0.shape[0], 1)
            ref_curve = Curve(pl_x,
                              self.references["r0"][r0_slice],
                              "--",
                              label=f"Reference R0",
                              color=None)
            curves.append(ref_curve)

        pl_title = f"Estimated R0"
        plot = generic_plot(curves,
                            pl_title,
                            None,
                            formatter=self.format_xtick)
        return (plot, pl_title)
示例#5
0
    def plot_params_over_time(self, n_days=None):
        param_plots = []

        if n_days is None:
            n_days = self.beta.shape[0]

        for param_group, param_keys in self.param_groups.items():
            params_subdict = {
                param_key: self.params[param_key]
                for param_key in param_keys
            }
            for param_key, param in params_subdict.items():
                param = self.extend_param(param, n_days)
                pl_x = list(range(n_days))
                pl_title = f"{param_group}/$\\{param_key}$ over time"
                param_curve = Curve(pl_x,
                                    param.detach().numpy(),
                                    '-',
                                    f"$\\{param_key}$",
                                    color=None)
                curves = [param_curve]

                if self.references is not None:
                    if param_key in self.references:
                        ref_curve = Curve(pl_x,
                                          self.references[param_key][:n_days],
                                          "--",
                                          f"$\\{param_key}$ reference",
                                          color=None)
                        curves.append(ref_curve)
                plot = generic_plot(curves,
                                    pl_title,
                                    None,
                                    formatter=self.format_xtick)
                param_plots.append((plot, pl_title))
        return param_plots
示例#6
0
 def plot_sir_fit(self, w_hat, w_target):
     pl_x = list(range(0, w_hat.shape[0]))
     hat_curve = Curve(pl_x, w_hat.detach().numpy(), '-', label="Estimated Deaths")
     target_curve = Curve(pl_x, w_target, '.r', label="Actual Deaths")
     pl_title = "Estimated Deaths on fit"
     return generic_plot([hat_curve, target_curve], pl_title, None, formatter=format_xtick)
示例#7
0
    def plot_fits(self):
        fit_plots = []
        with torch.no_grad():

            targets = self.targets
            dataset_size = len(self.targets["d"])
            t_grid = torch.linspace(0, dataset_size, dataset_size + 1)

            inferences = self.inference(t_grid)
            norm_inferences = self.normalize_values(inferences,
                                                    self.population)

            t_inc = self.time_step
            train_size = self.train_size
            val_size = self.val_size + train_size
            train_range = range(0, train_size)
            val_range = range(train_size, val_size)
            test_range = range(val_size, dataset_size)
            dataset_range = range(0, dataset_size)

            t_start = 0
            train_hat_slice = slice(t_start, int(train_size / t_inc),
                                    int(1 / t_inc))
            val_hat_slice = slice(int(train_size / t_inc),
                                  int(val_size / t_inc), int(1 / t_inc))
            test_hat_slice = slice(int(val_size / t_inc),
                                   int(dataset_size / t_inc), int(1 / t_inc))

            train_target_slice = slice(t_start, train_size, 1)
            val_target_slice = slice(train_size, val_size, 1)
            test_target_slice = slice(val_size, dataset_size, 1)
            dataset_target_slice = slice(t_start, dataset_size, 1)

            hat_train = self.slice_values(inferences, train_hat_slice)
            hat_val = self.slice_values(inferences, val_hat_slice)
            hat_test = self.slice_values(inferences, test_hat_slice)

            target_train = self.slice_values(targets, train_target_slice)
            target_val = self.slice_values(targets, val_target_slice)
            target_test = self.slice_values(targets, test_target_slice)

            norm_hat_train = self.normalize_values(hat_train, self.population)
            norm_hat_val = self.normalize_values(hat_val, self.population)
            norm_hat_test = self.normalize_values(hat_test, self.population)

            norm_target_train = self.normalize_values(target_train,
                                                      self.population)
            norm_target_val = self.normalize_values(target_val,
                                                    self.population)
            norm_target_test = self.normalize_values(target_test,
                                                     self.population)

            for key in inferences.keys():
                if key in ["sol"]:
                    continue

                if key not in ["r0"]:
                    curr_hat_train = norm_hat_train[key]
                    curr_hat_val = norm_hat_val[key]
                    curr_hat_test = norm_hat_test[key]
                else:
                    curr_hat_train = hat_train[key]
                    curr_hat_val = hat_val[key]
                    curr_hat_test = hat_test[key]

                if key in norm_target_train:
                    target_train = norm_target_train[key]
                    target_val = norm_target_val[key]
                    target_test = norm_target_test[key]
                else:
                    target_train = None
                    target_val = None
                    target_test = None

                train_curves = self.get_curves(train_range, curr_hat_train,
                                               target_train, key, 'r')
                val_curves = self.get_curves(val_range, curr_hat_val,
                                             target_val, key, 'b')
                test_curves = self.get_curves(test_range, curr_hat_test,
                                              target_test, key, 'g')

                tot_curves = train_curves + val_curves + test_curves

                if self.references is not None:
                    reference_curve = Curve(
                        list(dataset_range),
                        self.references[key][dataset_target_slice],
                        "--",
                        label="Reference (Nature)")
                    tot_curves = tot_curves + [reference_curve]

                pl_title = f"{key.upper()} - train/validation/test/reference"
                fig = generic_plot(tot_curves,
                                   pl_title,
                                   None,
                                   formatter=self.format_xtick)
                pl_title = f"Estimated {key.upper()} on fit"
                fit_plots.append((fig, pl_title))

                if target_train is not None:
                    # add error plots
                    pl_title = f"{key.upper()} - errors"
                    fig_name = f"Error {key.upper()} on fit"
                    curr_errors_train = curr_hat_train - np.array(target_train)

                    curr_errors_val = curr_hat_val - np.array(target_val)

                    curr_errors_test = curr_hat_test - np.array(target_test)

                    train_curves = self.get_curves(train_range,
                                                   curr_errors_train, None,
                                                   key, 'r')
                    val_curves = self.get_curves(val_range, curr_errors_val,
                                                 None, key, 'b')
                    test_curves = self.get_curves(test_range, curr_errors_test,
                                                  None, key, 'g')
                    tot_curves = train_curves + val_curves + test_curves

                    fig = generic_plot(tot_curves,
                                       pl_title,
                                       None,
                                       formatter=self.format_xtick)
                    fit_plots.append((fig, fig_name))

        return fit_plots
示例#8
0
    def _plot_final_inferences(self, hat_t, target_t, dataset_target_slice):
        """
        Plot inferences
        :param hat_t: a tuple with train val and test hat
        :param target_t: a tuple with train val and test target_t
        :param dataset_target_slice: data slice
        :return:
        """

        hat_train, hat_val, hat_test = hat_t
        target_train, target_val, target_test = target_t

        # get normalized values
        population = self.model_params["population"]
        norm_hat_train = self.normalize_values(hat_train, population)
        norm_hat_val = self.normalize_values(hat_val, population)
        norm_hat_test = self.normalize_values(hat_test, population)
        norm_target_train = self.normalize_values(target_train, population)
        norm_target_val = self.normalize_values(target_val, population)
        norm_target_test = self.normalize_values(target_test, population)

        # ranges for train/val/test
        dataset_size = len(self.dataset.inputs)
        # validation on the next val_len days (or less if we have less data)
        train_size, val_len = self.dataset.train_size, self.dataset.val_len
        val_size = min(train_size + val_len, dataset_size - 5)

        train_range = range(0, train_size)
        val_range = range(train_size, val_size)
        test_range = range(val_size, dataset_size)
        dataset_range = range(0, dataset_size)

        def get_curves(x_range, hat, target, key, color=None):
            pl_x = list(x_range)
            hat_curve = Curve(pl_x,
                              hat,
                              '-',
                              label=f"Estimated {key.upper()}",
                              color=color)
            if target is not None:
                target_curve = Curve(pl_x,
                                     target,
                                     '.',
                                     label=f"Actual {key.upper()}",
                                     color=color)
                return [hat_curve, target_curve]
            else:
                return [hat_curve]

        for key in self.inferences.keys():

            # skippable keys
            if key in ["sol"]:
                continue

            # separate keys that should be normalized to 1
            if key not in ["r0"]:
                curr_hat_train = norm_hat_train[key]
                curr_hat_val = norm_hat_val[key]
                curr_hat_test = norm_hat_test[key]
            else:
                curr_hat_train = hat_train[key]
                curr_hat_val = hat_val[key]
                curr_hat_test = hat_test[key]

            if key in self.dataset.targets:
                # plot inf and target_t
                target_train = norm_target_train[key]
                target_val = norm_target_val[key]
                target_test = norm_target_test[key]
                pass
            else:
                target_train = None
                target_val = None
                target_test = None
                pass

            train_curves = get_curves(train_range, curr_hat_train,
                                      target_train, key, 'r')
            val_curves = get_curves(val_range, curr_hat_val, target_val, key,
                                    'b')
            test_curves = get_curves(test_range, curr_hat_test, target_test,
                                     key, 'g')

            tot_curves = train_curves + val_curves + test_curves

            # get reference in range of interest
            if self.references is not None:
                ref_y = self.references[key][dataset_target_slice]
                reference_curve = Curve(dataset_range,
                                        ref_y,
                                        "--",
                                        label="Reference (Nature)")
                tot_curves = tot_curves + [reference_curve]

            pl_title = f"{key.upper()} - train/validation/test/reference"
            fig = generic_plot(tot_curves,
                               pl_title,
                               None,
                               formatter=self.model.format_xtick)
            self.summary.add_figure(f"final/{key}_global", fig)
示例#9
0
            param_key = param_keys
            param_hat_legend = f"$\\{param_key}$"


        param = model.extend_param(model.params[param_key], max_len)
        param_hat_curve = Curve(pl_x, param[:max_len].detach().numpy(), '-', param_hat_legend, color) 
        
        param_ref_legend = f"{param_hat_legend} reference"
        param_ref_curve = Curve(pl_x, references[param_key][:max_len].numpy(), '--', param_ref_legend, color) 
        curves = curves + [param_hat_curve, param_ref_curve]
        #print(len(param_hat_curve.x))
        #print(len(param_ref_curve.x))
        
    filename = filename_from_title(title)
    save_path = os.path.join(base_figures_path, filename)
    plot = generic_plot(curves, title, save_path, formatter=format_xtick, yaxis_sci=True, close=False)


# %%
# plot fits

normalized_inferences = model.normalize_values(inferences, model.population)
norm_hat_train = slice_values(normalized_inferences, train_hat_slice)
norm_hat_val = slice_values(normalized_inferences, val_hat_slice)
norm_hat_test = slice_values(normalized_inferences, test_hat_slice)

normalized_targets = model.normalize_values(targets, model.population)
norm_target_train = slice_values(normalized_targets, train_target_slice)
norm_target_val = slice_values(normalized_targets, val_target_slice)
norm_target_test = slice_values(normalized_targets, test_target_slice)