Ejemplo n.º 1
0
def test_enum_discrete_parallel_iarange_ok():
    enum_discrete = "defined below"

    def model():
        p2 = torch.ones(2) / 2
        p34 = torch.ones(3, 4) / 4
        p536 = torch.ones(5, 3, 6) / 6

        x2 = pyro.sample("x2", dist.Categorical(p2))
        with pyro.iarange("outer", 3):
            x34 = pyro.sample("x34", dist.Categorical(p34))
            with pyro.iarange("inner", 5):
                x536 = pyro.sample("x536", dist.Categorical(p536))

        if enum_discrete == "sequential":
            # All dimensions are iarange dimensions.
            assert x2.shape == torch.Size([])
            assert x34.shape == torch.Size([3])
            assert x536.shape == torch.Size([5, 3])
        else:
            # Meaning of dimensions:    [ enum dims | iarange dims ]
            assert x2.shape == torch.Size([2, 1, 1])  # noqa: E201
            assert x34.shape == torch.Size([4, 1, 1, 3])  # noqa: E201
            assert x536.shape == torch.Size([6, 1, 1, 5, 3])  # noqa: E201

    enum_discrete = "sequential"
    assert_ok(model, config_enumerate(model, "sequential"),
              TraceEnum_ELBO(max_iarange_nesting=2))

    enum_discrete = "parallel"
    assert_ok(model, config_enumerate(model, "parallel"),
              TraceEnum_ELBO(max_iarange_nesting=2))
Ejemplo n.º 2
0
def test_enum_discrete_parallel_iarange_ok():
    enum_discrete = "defined below"

    def model():
        p2 = torch.ones(2) / 2
        p34 = torch.ones(3, 4) / 4
        p536 = torch.ones(5, 3, 6) / 6

        x2 = pyro.sample("x2", dist.Categorical(p2))
        with pyro.iarange("outer", 3):
            x34 = pyro.sample("x34", dist.Categorical(p34))
            with pyro.iarange("inner", 5):
                x536 = pyro.sample("x536", dist.Categorical(p536))

        if enum_discrete == "sequential":
            # All dimensions are iarange dimensions.
            assert x2.shape == torch.Size([])
            assert x34.shape == torch.Size([3])
            assert x536.shape == torch.Size([5, 3])
        else:
            # Meaning of dimensions:    [ enum dims | iarange dims ]
            assert x2.shape == torch.Size([        2, 1, 1])  # noqa: E201
            assert x34.shape == torch.Size([    4, 1, 1, 3])  # noqa: E201
            assert x536.shape == torch.Size([6, 1, 1, 5, 3])  # noqa: E201

    enum_discrete = "sequential"
    assert_ok(model, config_enumerate(model, "sequential"), TraceEnum_ELBO(max_iarange_nesting=2))

    enum_discrete = "parallel"
    assert_ok(model, config_enumerate(model, "parallel"), TraceEnum_ELBO(max_iarange_nesting=2))
Ejemplo n.º 3
0
 def _initialize_model_properties(self):
     if self.max_plate_nesting is None:
         self._guess_max_plate_nesting()
     # Wrap model in `poutine.enum` to enumerate over discrete latent sites.
     # No-op if model does not have any discrete latents.
     self.model = poutine.enum(config_enumerate(self.model),
                               first_available_dim=-1 -
                               self.max_plate_nesting)
     if self._automatic_transform_enabled:
         self.transforms = {}
     trace = poutine.trace(self.model).get_trace(*self._args,
                                                 **self._kwargs)
     for name, node in trace.iter_stochastic_nodes():
         if isinstance(node["fn"], _Subsample):
             continue
         if node["fn"].has_enumerate_support:
             self._has_enumerable_sites = True
             continue
         site_value = node["value"]
         if node["fn"].support is not constraints.real and self._automatic_transform_enabled:
             self.transforms[name] = biject_to(node["fn"].support).inv
             site_value = self.transforms[name](node["value"])
         self._r_shapes[name] = site_value.shape
         self._r_numels[name] = site_value.numel()
     self._trace_prob_evaluator = TraceEinsumEvaluator(
         trace, self._has_enumerable_sites, self.max_plate_nesting)
     mass_matrix_size = sum(self._r_numels.values())
     if self.full_mass:
         initial_mass_matrix = eye_like(site_value, mass_matrix_size)
     else:
         initial_mass_matrix = site_value.new_ones(mass_matrix_size)
     self._adapter.inverse_mass_matrix = initial_mass_matrix
Ejemplo n.º 4
0
def test_gmm_iter_discrete_traces(data_size, graph_type, model):
    pyro.clear_param_store()
    data = torch.arange(0, data_size)
    model = config_enumerate(model)
    traces = list(iter_discrete_traces(graph_type, model, data=data, verbose=True))
    # This non-vectorized version is exponential in data_size:
    assert len(traces) == 2**data_size
Ejemplo n.º 5
0
 def __init__(self,
              model: Type[torch.nn.Module],
              optimizer: Type[optim.PyroOptim] = None,
              loss: Type[infer.ELBO] = None,
              enumerate_parallel: bool = False,
              seed: int = 1,
              **kwargs: Union[str, float]) -> None:
     """
     Initializes the trainer's parameters
     """
     pyro.clear_param_store()
     set_deterministic_mode(seed)
     self.device = kwargs.get(
         "device", 'cuda' if torch.cuda.is_available() else 'cpu')
     if optimizer is None:
         lr = kwargs.get("lr", 1e-3)
         optimizer = optim.Adam({"lr": lr})
     if loss is None:
         if enumerate_parallel:
             loss = infer.TraceEnum_ELBO(max_plate_nesting=1,
                                         strict_enumeration_warning=False)
         else:
             loss = infer.Trace_ELBO()
     guide = model.guide
     if enumerate_parallel:
         guide = infer.config_enumerate(guide, "parallel", expand=True)
     self.svi = infer.SVI(model.model, guide, optimizer, loss=loss)
     self.loss_history = {"training_loss": [], "test_loss": []}
     self.current_epoch = 0
Ejemplo n.º 6
0
def test_gmm_batch_iter_discrete_traces(model, data_size, graph_type):
    pyro.clear_param_store()
    data = torch.arange(0, data_size)
    model = config_enumerate(model)
    traces = list(iter_discrete_traces(graph_type, model, data=data))
    # This vectorized version is independent of data_size:
    assert len(traces) == 2
Ejemplo n.º 7
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = config_enumerate(self.model)
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
            *args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._discrete_sites = []
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            if site["infer"].get("enumerate") != "sequential":
                raise NotImplementedError(
                    'Expected sample site "{}" to be discrete and '
                    'configured for sequential enumeration'.format(name))

            # collect discrete sample sites
            fn = site["fn"]
            Dist = type(fn)
            if Dist in (dist.Bernoulli, dist.Categorical,
                        dist.OneHotCategorical):
                params = [("probs", fn.probs.detach().clone(),
                           fn.arg_constraints["probs"])]
            else:
                raise NotImplementedError("{} is not supported".format(
                    Dist.__name__))
            self._discrete_sites.append((site, Dist, params))
Ejemplo n.º 8
0
def main(args):
    """
    run inference for CVAE
    :param args: arguments for CVAE
    :return: None
    """
    if args.seed is not None:
        set_seed(args.seed, args.cuda)

    if os.path.exists('cvae.model.pt'):
        print('Loading model %s' % 'cvae.model.pt')
        cvae = torch.load('cvae.model.pt')

    else:

        cvae = CVAE(z_dim=args.z_dim,
                    y_dim=8,
                    x_dim=32612,
                    hidden_dim=args.hidden_dimension,
                    use_cuda=args.cuda)

    print(cvae)

    # setup the optimizer
    adam_params = {
        "lr": args.learning_rate,
        "betas": (args.beta_1, 0.999),
        "clip_norm": 0.5
    }
    optimizer = ClippedAdam(adam_params)
    guide = config_enumerate(cvae.guide, args.enum_discrete)

    # set up the loss for inference.
    loss = SVI(cvae.model,
               guide,
               optimizer,
               loss=TraceEnum_ELBO(max_iarange_nesting=1))

    try:
        # setup the logger if a filename is provided
        logger = open(args.logfile, "w") if args.logfile else None

        data_loaders = setup_data_loaders(NHANES, args.cuda, args.batch_size)
        print(len(data_loaders['prediction']))

        #torch.save(cvae, 'cvae.model.pt')

        mu, sigma, actuals, lods, masks = get_predictions(
            data_loaders["prediction"], cvae.sim_measurements)

        torch.save((mu, sigma, actuals, lods, masks), 'cvae.predictions.pt')

    finally:
        # close the logger file object if we opened it earlier
        if args.logfile:
            logger.close()
Ejemplo n.º 9
0
def test_svi_step_smoke(model, guide, enumerate1):
    pyro.clear_param_store()
    data = torch.tensor([0.0, 1.0, 9.0])

    guide = config_enumerate(guide, default=enumerate1)
    optimizer = pyro.optim.Adam({"lr": .001})
    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))
    inference = SVI(model, guide, optimizer, loss=elbo)
    inference.step(data)
Ejemplo n.º 10
0
def get_enum_traces(model, x):
    guide_enum = EnumMessenger(first_available_dim=-2)
    model_enum = EnumMessenger()
    guide_ = guide_enum(
        infer.config_enumerate(model.guide, "parallel", expand=True))
    model_ = model_enum(model.model)
    guide_trace = poutine.trace(guide_, graph_type="flat").get_trace(x)
    model_trace = poutine.trace(pyro.poutine.replay(model_, trace=guide_trace),
                                graph_type="flat").get_trace(x)
    return guide_trace, model_trace
Ejemplo n.º 11
0
def test_nonnested_iarange_iarange_ok(Elbo):
    def model():
        p = torch.tensor(0.5, requires_grad=True)
        with pyro.iarange("iarange_0", 10, 5) as ind1:
            pyro.sample("x0", dist.Bernoulli(p).expand_by([len(ind1)]))
        with pyro.iarange("iarange_1", 11, 6) as ind2:
            pyro.sample("x1", dist.Bernoulli(p).expand_by([len(ind2)]))

    guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
    assert_ok(model, guide, Elbo())
Ejemplo n.º 12
0
def test_no_iarange_enum_discrete_batch_error():
    def model():
        p = torch.tensor(0.5)
        pyro.sample("x", dist.Bernoulli(p).expand_by([5]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        pyro.sample("x", dist.Bernoulli(p).expand_by([5]))

    assert_error(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 13
0
def test_enum_discrete_single_ok():
    def model():
        p = torch.tensor(0.5)
        pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        pyro.sample("x", dist.Bernoulli(p))

    assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 14
0
def test_no_iarange_enum_discrete_batch_error():

    def model():
        p = torch.tensor(0.5)
        pyro.sample("x", dist.Bernoulli(p).expand_by([5]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        pyro.sample("x", dist.Bernoulli(p).expand_by([5]))

    assert_error(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 15
0
def test_enum_discrete_single_ok():

    def model():
        p = torch.tensor(0.5)
        pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        pyro.sample("x", dist.Bernoulli(p))

    assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 16
0
def test_nonnested_iarange_iarange_ok(Elbo):

    def model():
        p = torch.tensor(0.5, requires_grad=True)
        with pyro.iarange("iarange_0", 10, 5) as ind1:
            pyro.sample("x0", dist.Bernoulli(p).expand_by([len(ind1)]))
        with pyro.iarange("iarange_1", 11, 6) as ind2:
            pyro.sample("x1", dist.Bernoulli(p).expand_by([len(ind2)]))

    guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
    assert_ok(model, guide, Elbo())
Ejemplo n.º 17
0
def test_enum_discrete_irange_single_ok():
    def model():
        p = torch.tensor(0.5)
        for i in pyro.irange("irange", 10, 5):
            pyro.sample("x_{}".format(i), dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        for i in pyro.irange("irange", 10, 5):
            pyro.sample("x_{}".format(i), dist.Bernoulli(p))

    assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 18
0
def test_iarange_enum_discrete_batch_ok():
    def model():
        p = torch.tensor(0.5)
        with pyro.iarange("iarange", 10, 5) as ind:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        with pyro.iarange("iarange", 10, 5) as ind:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)]))

    assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 19
0
def test_nested_iarange_iarange_dim_error_4(Elbo):

    def model():
        p = torch.tensor([0.5], requires_grad=True)
        with pyro.iarange("iarange_outer", 10, 5) as ind_outer:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer), 1]))
            with pyro.iarange("iarange_inner", 11, 6) as ind_inner:
                pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner)]))
                pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind_outer), len(ind_outer)]))  # error here

    guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
    assert_error(model, guide, Elbo())
Ejemplo n.º 20
0
def test_subsample_gradient(Elbo, reparameterized, has_rsample, subsample, local_samples, scale):
    pyro.clear_param_store()
    data = torch.tensor([-0.5, 2.0])
    subsample_size = 1 if subsample else len(data)
    precision = 0.06 * scale
    Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

    def model(subsample):
        with pyro.plate("data", len(data), subsample_size, subsample) as ind:
            x = data[ind]
            z = pyro.sample("z", Normal(0, 1))
            pyro.sample("x", Normal(z, 1), obs=x)

    def guide(subsample):
        scale = pyro.param("scale", lambda: torch.tensor([1.0]))
        with pyro.plate("data", len(data), subsample_size, subsample):
            loc = pyro.param("loc", lambda: torch.zeros(len(data)), event_dim=0)
            z_dist = Normal(loc, scale)
            if has_rsample is not None:
                z_dist.has_rsample_(has_rsample)
            pyro.sample("z", z_dist)

    if scale != 1.0:
        model = poutine.scale(model, scale=scale)
        guide = poutine.scale(guide, scale=scale)

    num_particles = 50000
    if local_samples:
        guide = config_enumerate(guide, num_samples=num_particles)
        num_particles = 1

    optim = Adam({"lr": 0.1})
    elbo = Elbo(max_plate_nesting=1,  # set this to ensure rng agrees across runs
                num_particles=num_particles,
                vectorize_particles=True,
                strict_enumeration_warning=False)
    inference = SVI(model, guide, optim, loss=elbo)
    with xfail_if_not_implemented():
        if subsample_size == 1:
            inference.loss_and_grads(model, guide, subsample=torch.tensor([0], dtype=torch.long))
            inference.loss_and_grads(model, guide, subsample=torch.tensor([1], dtype=torch.long))
        else:
            inference.loss_and_grads(model, guide, subsample=torch.tensor([0, 1], dtype=torch.long))
    params = dict(pyro.get_param_store().named_parameters())
    normalizer = 2 if subsample else 1
    actual_grads = {name: param.grad.detach().cpu().numpy() / normalizer for name, param in params.items()}

    expected_grads = {'loc': scale * np.array([0.5, -2.0]), 'scale': scale * np.array([2.0])}
    for name in sorted(params):
        logger.info('expected {} = {}'.format(name, expected_grads[name]))
        logger.info('actual   {} = {}'.format(name, actual_grads[name]))
    assert_equal(actual_grads, expected_grads, prec=precision)
Ejemplo n.º 21
0
def test_iarange_enum_discrete_batch_ok():

    def model():
        p = torch.tensor(0.5)
        with pyro.iarange("iarange", 10, 5) as ind:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        with pyro.iarange("iarange", 10, 5) as ind:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind)]))

    assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 22
0
def test_enum_discrete_irange_single_ok():

    def model():
        p = torch.tensor(0.5)
        for i in pyro.irange("irange", 10, 5):
            pyro.sample("x_{}".format(i), dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        for i in pyro.irange("irange", 10, 5):
            pyro.sample("x_{}".format(i), dist.Bernoulli(p))

    assert_ok(model, config_enumerate(guide), TraceEnum_ELBO())
Ejemplo n.º 23
0
def test_enum_discrete_parallel_nested_ok(max_iarange_nesting):
    iarange_shape = torch.Size([1] * max_iarange_nesting)

    def model():
        p2 = torch.tensor(torch.ones(2) / 2)
        p3 = torch.tensor(torch.ones(3) / 3)
        x2 = pyro.sample("x2", dist.OneHotCategorical(p2))
        x3 = pyro.sample("x3", dist.OneHotCategorical(p3))
        assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape
        assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape

    assert_ok(model, config_enumerate(model, "parallel"),
              TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting))
Ejemplo n.º 24
0
def test_nested_iarange_iarange_dim_error_3(Elbo):
    def model():
        p = torch.tensor([0.5], requires_grad=True)
        with pyro.iarange("iarange_outer", 10, 5) as ind_outer:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer), 1]))
            with pyro.iarange("iarange_inner", 11, 6) as ind_inner:
                pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner)]))
                pyro.sample("z",
                            dist.Bernoulli(p).expand_by([len(ind_inner),
                                                         1]))  # error here

    guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
    assert_error(model, guide, Elbo())
Ejemplo n.º 25
0
def test_enum_discrete_parallel_nested_ok(max_iarange_nesting):
    iarange_shape = torch.Size([1] * max_iarange_nesting)

    def model():
        p2 = torch.tensor(torch.ones(2) / 2)
        p3 = torch.tensor(torch.ones(3) / 3)
        x2 = pyro.sample("x2", dist.OneHotCategorical(p2))
        x3 = pyro.sample("x3", dist.OneHotCategorical(p3))
        assert x2.shape == torch.Size([2]) + iarange_shape + p2.shape
        assert x3.shape == torch.Size([3, 1]) + iarange_shape + p3.shape

    assert_ok(model, config_enumerate(model, "parallel"),
              TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting))
Ejemplo n.º 26
0
def test_iarange_no_size_ok(Elbo):
    def model():
        p = torch.tensor(0.5)
        with pyro.iarange("iarange"):
            pyro.sample("x", dist.Bernoulli(p).expand_by([10]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        with pyro.iarange("iarange"):
            pyro.sample("x", dist.Bernoulli(p).expand_by([10]))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide)

    assert_ok(model, guide, Elbo())
Ejemplo n.º 27
0
def test_enum_discrete_parallel_ok(max_iarange_nesting):
    iarange_shape = torch.Size([1] * max_iarange_nesting)

    def model():
        p = torch.tensor(0.5)
        x = pyro.sample("x", dist.Bernoulli(p))
        assert x.shape == torch.Size([2]) + iarange_shape

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        x = pyro.sample("x", dist.Bernoulli(p))
        assert x.shape == torch.Size([2]) + iarange_shape

    assert_ok(model, config_enumerate(guide, "parallel"),
              TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting))
Ejemplo n.º 28
0
def test_iarange_reuse_ok(Elbo):

    def model():
        p = torch.tensor(0.5, requires_grad=True)
        iarange_outer = pyro.iarange("iarange_outer", 10, 5, dim=-1)
        iarange_inner = pyro.iarange("iarange_inner", 11, 6, dim=-2)
        with iarange_outer as ind_outer:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer)]))
        with iarange_inner as ind_inner:
            pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner), 1]))
        with iarange_outer as ind_outer, iarange_inner as ind_inner:
            pyro.sample("z", dist.Bernoulli(p).expand_by([len(ind_inner), len(ind_outer)]))

    guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
    assert_ok(model, guide, Elbo())
Ejemplo n.º 29
0
def test_irange_in_model_not_guide_ok(subsample_size, Elbo):
    def model():
        p = torch.tensor(0.5)
        for i in pyro.irange("irange", 10, subsample_size):
            pass
        pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        pyro.sample("x", dist.Bernoulli(p))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide)

    assert_ok(model, guide, Elbo())
Ejemplo n.º 30
0
def test_enum_discrete_parallel_ok(max_iarange_nesting):
    iarange_shape = torch.Size([1] * max_iarange_nesting)

    def model():
        p = torch.tensor(0.5)
        x = pyro.sample("x", dist.Bernoulli(p))
        assert x.shape == torch.Size([2]) + iarange_shape

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        x = pyro.sample("x", dist.Bernoulli(p))
        assert x.shape == torch.Size([2]) + iarange_shape

    assert_ok(model, config_enumerate(guide, "parallel"),
              TraceEnum_ELBO(max_iarange_nesting=max_iarange_nesting))
Ejemplo n.º 31
0
def test_iarange_no_size_ok(Elbo):

    def model():
        p = torch.tensor(0.5)
        with pyro.iarange("iarange"):
            pyro.sample("x", dist.Bernoulli(p).expand_by([10]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        with pyro.iarange("iarange"):
            pyro.sample("x", dist.Bernoulli(p).expand_by([10]))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide)

    assert_ok(model, guide, Elbo())
Ejemplo n.º 32
0
def test_irange_in_model_not_guide_ok(subsample_size, Elbo):

    def model():
        p = torch.tensor(0.5)
        for i in pyro.irange("irange", 10, subsample_size):
            pass
        pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        pyro.sample("x", dist.Bernoulli(p))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide)

    assert_ok(model, guide, Elbo())
Ejemplo n.º 33
0
def test_irange_variable_clash_error(Elbo):
    def model():
        p = torch.tensor(0.5)
        for i in pyro.irange("irange", 2):
            # Each loop iteration should give the sample site a different name.
            pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        for i in pyro.irange("irange", 2):
            # Each loop iteration should give the sample site a different name.
            pyro.sample("x", dist.Bernoulli(p))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide)

    assert_error(model, guide, Elbo())
Ejemplo n.º 34
0
def test_iarange_reuse_ok(Elbo):
    def model():
        p = torch.tensor(0.5, requires_grad=True)
        iarange_outer = pyro.iarange("iarange_outer", 10, 5, dim=-1)
        iarange_inner = pyro.iarange("iarange_inner", 11, 6, dim=-2)
        with iarange_outer as ind_outer:
            pyro.sample("x", dist.Bernoulli(p).expand_by([len(ind_outer)]))
        with iarange_inner as ind_inner:
            pyro.sample("y", dist.Bernoulli(p).expand_by([len(ind_inner), 1]))
        with iarange_outer as ind_outer, iarange_inner as ind_inner:
            pyro.sample(
                "z",
                dist.Bernoulli(p).expand_by([len(ind_inner),
                                             len(ind_outer)]))

    guide = config_enumerate(model) if Elbo is TraceEnum_ELBO else model
    assert_ok(model, guide, Elbo())
Ejemplo n.º 35
0
def test_iarange_irange_ok(Elbo):

    def model():
        p = torch.tensor(0.5)
        with pyro.iarange("iarange", 3, 2) as ind:
            for i in pyro.irange("irange", 3, 2):
                pyro.sample("x_{}".format(i), dist.Bernoulli(p).expand_by([len(ind)]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        with pyro.iarange("iarange", 3, 2) as ind:
            for i in pyro.irange("irange", 3, 2):
                pyro.sample("x_{}".format(i), dist.Bernoulli(p).expand_by([len(ind)]))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide)

    assert_ok(model, guide, Elbo())
Ejemplo n.º 36
0
def test_irange_variable_clash_error(Elbo):

    def model():
        p = torch.tensor(0.5)
        for i in pyro.irange("irange", 2):
            # Each loop iteration should give the sample site a different name.
            pyro.sample("x", dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        for i in pyro.irange("irange", 2):
            # Each loop iteration should give the sample site a different name.
            pyro.sample("x", dist.Bernoulli(p))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide)

    assert_error(model, guide, Elbo())
Ejemplo n.º 37
0
def test_iarange_irange_ok(Elbo):
    def model():
        p = torch.tensor(0.5)
        with pyro.iarange("iarange", 3, 2) as ind:
            for i in pyro.irange("irange", 3, 2):
                pyro.sample("x_{}".format(i),
                            dist.Bernoulli(p).expand_by([len(ind)]))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        with pyro.iarange("iarange", 3, 2) as ind:
            for i in pyro.irange("irange", 3, 2):
                pyro.sample("x_{}".format(i),
                            dist.Bernoulli(p).expand_by([len(ind)]))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide)

    assert_ok(model, guide, Elbo())
Ejemplo n.º 38
0
def test_non_mean_field_bern_bern_elbo_gradient(enumerate1, pi1, pi2):
    pyro.clear_param_store()
    num_particles = 1 if enumerate1 else 20000

    def model():
        with pyro.iarange("particles", num_particles):
            y = pyro.sample("y", dist.Bernoulli(0.33).expand_by([num_particles]))
            pyro.sample("z", dist.Bernoulli(0.55 * y + 0.10))

    def guide():
        q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
        q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
        with pyro.iarange("particles", num_particles):
            y = pyro.sample("y", dist.Bernoulli(q1).expand_by([num_particles]))
            pyro.sample("z", dist.Bernoulli(q2 * y + 0.10))

    logger.info("Computing gradients using surrogate loss")
    elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                          strict_enumeration_warning=any([enumerate1]))
    elbo.loss_and_grads(model, config_enumerate(guide, default=enumerate1))
    actual_grad_q1 = pyro.param('q1').grad / num_particles
    actual_grad_q2 = pyro.param('q2').grad / num_particles

    logger.info("Computing analytic gradients")
    q1 = torch.tensor(pi1, requires_grad=True)
    q2 = torch.tensor(pi2, requires_grad=True)
    elbo = kl_divergence(dist.Bernoulli(q1), dist.Bernoulli(0.33))
    elbo = elbo + q1 * kl_divergence(dist.Bernoulli(q2 + 0.10), dist.Bernoulli(0.65))
    elbo = elbo + (1.0 - q1) * kl_divergence(dist.Bernoulli(0.10), dist.Bernoulli(0.10))
    expected_grad_q1, expected_grad_q2 = grad(elbo, [q1, q2])

    prec = 0.03 if enumerate1 is None else 0.001

    assert_equal(actual_grad_q1, expected_grad_q1, prec=prec, msg="".join([
        "\nq1 expected = {}".format(expected_grad_q1.data.cpu().numpy()),
        "\nq1  actual = {}".format(actual_grad_q1.data.cpu().numpy()),
    ]))
    assert_equal(actual_grad_q2, expected_grad_q2, prec=prec, msg="".join([
        "\nq2 expected = {}".format(expected_grad_q2.data.cpu().numpy()),
        "\nq2   actual = {}".format(actual_grad_q2.data.cpu().numpy()),
    ]))
Ejemplo n.º 39
0
def main(args):
    pyro.set_rng_seed(0)
    pyro.clear_param_store()
    K = 2

    data = torch.tensor([0.0, 1.0, 2.0, 20.0, 30.0, 40.0])
    optim = pyro.optim.Adam({'lr': 0.1})
    inference = SVI(model, config_enumerate(guide), optim,
                    loss=TraceEnum_ELBO(max_plate_nesting=1))

    print('Step\tLoss')
    loss = 0.0
    for step in range(args.num_epochs):
        if step and step % 10 == 0:
            print('{}\t{:0.5g}'.format(step, loss))
            loss = 0.0
        loss += inference.step(K, data)

    print('Parameters:')
    for name, value in sorted(pyro.get_param_store().items()):
        print('{} = {}'.format(name, value.detach().cpu().numpy()))
Ejemplo n.º 40
0
def test_irange_irange_swap_ok(subsample_size, Elbo):
    def model():
        p = torch.tensor(0.5)
        outer_irange = pyro.irange("irange_0", 3, subsample_size)
        inner_irange = pyro.irange("irange_1", 3, subsample_size)
        for i in outer_irange:
            for j in inner_irange:
                pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        outer_irange = pyro.irange("irange_0", 3, subsample_size)
        inner_irange = pyro.irange("irange_1", 3, subsample_size)
        for j in inner_irange:
            for i in outer_irange:
                pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide, "parallel")

    assert_ok(model, guide, Elbo(max_iarange_nesting=0))
Ejemplo n.º 41
0
    def __init__(self,
                 model: Type[nn.Module],
                 task: str = "classification",
                 optimizer: Type[optim.PyroOptim] = None,
                 seed: int = 1,
                 **kwargs: Union[str, float]) -> None:
        """
        Initializes trainer parameters
        """
        pyro.clear_param_store()
        set_deterministic_mode(seed)
        if task not in ["classification", "regression"]:
            raise ValueError(
                "Choose between 'classification' and 'regression' tasks")
        self.task = task
        self.device = kwargs.get(
            "device", 'cuda' if torch.cuda.is_available() else 'cpu')
        if optimizer is None:
            lr = kwargs.get("lr", 5e-4)
            optimizer = optim.Adam({"lr": lr})
        if self.task == "classification":
            guide = infer.config_enumerate(model.guide,
                                           "parallel",
                                           expand=True)
            loss = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1,
                                             strict_enumeration_warning=False)
        else:
            guide = model.guide
            loss = pyro.infer.Trace_ELBO()

        self.loss_basic = infer.SVI(model.model, guide, optimizer, loss=loss)
        self.loss_aux = infer.SVI(model.model_aux,
                                  model.guide_aux,
                                  optimizer,
                                  loss=pyro.infer.Trace_ELBO())
        self.model = model

        self.history = {"training_loss": [], "test": []}
        self.current_epoch = 0
        self.running_weights = {}
Ejemplo n.º 42
0
def test_non_mean_field_normal_bern_elbo_gradient(pi1, pi2, pi3):

    def model(num_particles):
        with pyro.iarange("particles", num_particles):
            q3 = pyro.param("q3", torch.tensor(pi3, requires_grad=True))
            q4 = pyro.param("q4", torch.tensor(0.5 * (pi1 + pi2), requires_grad=True))
            z = pyro.sample("z", dist.Normal(q3, 1.0).expand_by([num_particles]))
            zz = torch.exp(z) / (1.0 + torch.exp(z))
            pyro.sample("y", dist.Bernoulli(q4 * zz))

    def guide(num_particles):
        q1 = pyro.param("q1", torch.tensor(pi1, requires_grad=True))
        q2 = pyro.param("q2", torch.tensor(pi2, requires_grad=True))
        with pyro.iarange("particles", num_particles):
            z = pyro.sample("z", dist.Normal(q2, 1.0).expand_by([num_particles]))
            zz = torch.exp(z) / (1.0 + torch.exp(z))
            pyro.sample("y", dist.Bernoulli(q1 * zz))

    qs = ['q1', 'q2', 'q3', 'q4']
    results = {}

    for ed, num_particles in zip([None, 'parallel', 'sequential'], [30000, 20000, 20000]):
        pyro.clear_param_store()
        elbo = TraceEnum_ELBO(max_iarange_nesting=1,
                              strict_enumeration_warning=any([ed]))
        elbo.loss_and_grads(model, config_enumerate(guide, default=ed), num_particles)
        results[str(ed)] = {}
        for q in qs:
            results[str(ed)]['actual_grad_%s' % q] = pyro.param(q).grad.detach().cpu().numpy() / num_particles

    prec = 0.03
    for ed in ['parallel', 'sequential']:
        logger.info('\n*** {} ***'.format(ed))
        for q in qs:
            logger.info("[{}] actual: {}".format(q, results[ed]['actual_grad_%s' % q]))
            assert_equal(results[ed]['actual_grad_%s' % q], results['None']['actual_grad_%s' % q], prec=prec,
                         msg="".join([
                             "\nexpected (MC estimate) = {}".format(results['None']['actual_grad_%s' % q]),
                             "\n  actual ({} estimate) = {}".format(ed, results[ed]['actual_grad_%s' % q]),
                         ]))
Ejemplo n.º 43
0
def test_irange_irange_swap_ok(subsample_size, Elbo):

    def model():
        p = torch.tensor(0.5)
        outer_irange = pyro.irange("irange_0", 3, subsample_size)
        inner_irange = pyro.irange("irange_1", 3, subsample_size)
        for i in outer_irange:
            for j in inner_irange:
                pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))

    def guide():
        p = pyro.param("p", torch.tensor(0.5, requires_grad=True))
        outer_irange = pyro.irange("irange_0", 3, subsample_size)
        inner_irange = pyro.irange("irange_1", 3, subsample_size)
        for j in inner_irange:
            for i in outer_irange:
                pyro.sample("x_{}_{}".format(i, j), dist.Bernoulli(p))

    if Elbo is TraceEnum_ELBO:
        guide = config_enumerate(guide, "parallel")

    assert_ok(model, guide, Elbo(max_iarange_nesting=0))
Ejemplo n.º 44
0
    # Save figure
    fig.savefig(figname)


if __name__ == "__main__":
    pyro.enable_validation(True)
    pyro.set_rng_seed(42)

    # Create our model with a fixed number of components
    K = 2

    data = get_samples()

    global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scales']))
    global_guide = config_enumerate(global_guide, 'parallel')
    _, svi = initialize(data)

    true_colors = [0] * 100 + [1] * 100
    plot(data, colors=true_colors, figname='pyro_init.png')

    for i in range(151):
        svi.step(data)

        if i % 50 == 0:
            locs = pyro.param('locs')
            scales = pyro.param('scales')
            weights = pyro.param('weights')
            assignment_probs = pyro.param('assignment_probs')

            print("locs: {}".format(locs))
Ejemplo n.º 45
0
def main(args):
    # Fix random number seed
    pyro.util.set_rng_seed(args.seed)
    # Enable optional validation warnings
    pyro.enable_validation(True)

    # Load and pre-process data
    dataloader, num_genes, l_mean, l_scale, anndata = get_data(dataset=args.dataset, batch_size=args.batch_size,
                                                               cuda=args.cuda)

    # Instantiate instance of model/guide and various neural networks
    scanvi = SCANVI(num_genes=num_genes, num_labels=4, l_loc=l_mean, l_scale=l_scale,
                    scale_factor=1.0 / (args.batch_size * num_genes))

    if args.cuda:
        scanvi.cuda()

    # Setup an optimizer (Adam) and learning rate scheduler.
    # By default we start with a moderately high learning rate (0.005)
    # and reduce by a factor of 5 after 20 epochs.
    scheduler = MultiStepLR({'optimizer': Adam,
                             'optim_args': {'lr': args.learning_rate},
                             'milestones': [20],
                             'gamma': 0.2})

    # Tell Pyro to enumerate out y when y is unobserved
    guide = config_enumerate(scanvi.guide, "parallel", expand=True)

    # Setup a variational objective for gradient-based learning.
    # Note we use TraceEnum_ELBO in order to leverage Pyro's machinery
    # for automatic enumeration of the discrete latent variable y.
    elbo = TraceEnum_ELBO(strict_enumeration_warning=False)
    svi = SVI(scanvi.model, guide, scheduler, elbo)

    # Training loop
    for epoch in range(args.num_epochs):
        losses = []

        for x, y in dataloader:
            if y is not None:
                y = y.type_as(x)
            loss = svi.step(x, y)
            losses.append(loss)

        # Tell the scheduler we've done one epoch.
        scheduler.step()

        print("[Epoch %04d]  Loss: %.5f" % (epoch, np.mean(losses)))

    # Put neural networks in eval mode (needed for batchnorm)
    scanvi.eval()

    # Now that we're done training we'll inspect the latent representations we've learned
    if args.plot and args.dataset == 'pbmc':
        import scanpy as sc
        # Compute latent representation (z2_loc) for each cell in the dataset
        latent_rep = scanvi.z2l_encoder(dataloader.data_x)[0]

        # Compute inferred cell type probabilities for each cell
        y_logits = scanvi.classifier(latent_rep)
        y_probs = softmax(y_logits, dim=-1).data.cpu().numpy()

        # Use scanpy to compute 2-dimensional UMAP coordinates using our
        # learned 10-dimensional latent representation z2
        anndata.obsm["X_scANVI"] = latent_rep.data.cpu().numpy()
        sc.pp.neighbors(anndata, use_rep="X_scANVI")
        sc.tl.umap(anndata)
        umap1, umap2 = anndata.obsm['X_umap'][:, 0], anndata.obsm['X_umap'][:, 1]

        # Construct plots; all plots are scatterplots depicting the two-dimensional UMAP embedding
        # and only differ in how points are colored

        # The topmost plot depicts the 200 hand-curated seed labels in our dataset
        fig, axes = plt.subplots(3, 2)
        seed_marker_sizes = anndata.obs['seed_marker_sizes']
        axes[0, 0].scatter(umap1, umap2, s=seed_marker_sizes, c=anndata.obs['seed_colors'], marker='.', alpha=0.7)
        axes[0, 0].set_title('Hand-Curated Seed Labels')
        patch1 = Patch(color='lightcoral', label='CD8-Naive')
        patch2 = Patch(color='limegreen', label='CD4-Naive')
        patch3 = Patch(color='deepskyblue', label='CD4-Memory')
        patch4 = Patch(color='mediumorchid', label='CD4-Regulatory')
        axes[0, 1].legend(loc='center left', handles=[patch1, patch2, patch3, patch4])
        axes[0, 1].get_xaxis().set_visible(False)
        axes[0, 1].get_yaxis().set_visible(False)
        axes[0, 1].set_frame_on(False)

        # The remaining plots depict the inferred cell type probability for each of the four cell types
        s10 = axes[1, 0].scatter(umap1, umap2, s=1, c=y_probs[:, 0], marker='.', alpha=0.7)
        axes[1, 0].set_title('Inferred CD8-Naive probability')
        fig.colorbar(s10, ax=axes[1, 0])
        s11 = axes[1, 1].scatter(umap1, umap2, s=1, c=y_probs[:, 1], marker='.', alpha=0.7)
        axes[1, 1].set_title('Inferred CD4-Naive probability')
        fig.colorbar(s11, ax=axes[1, 1])
        s20 = axes[2, 0].scatter(umap1, umap2, s=1, c=y_probs[:, 2], marker='.', alpha=0.7)
        axes[2, 0].set_title('Inferred CD4-Memory probability')
        fig.colorbar(s20, ax=axes[2, 0])
        s21 = axes[2, 1].scatter(umap1, umap2, s=1, c=y_probs[:, 3], marker='.', alpha=0.7)
        axes[2, 1].set_title('Inferred CD4-Regulatory probability')
        fig.colorbar(s21, ax=axes[2, 1])

        fig.tight_layout()
        plt.savefig('scanvi.pdf')
Ejemplo n.º 46
0
def main(args):
    ## ドレミ
    def easyTones():
        training_seq_lengths = torch.tensor([8]*1)
        training_data_sequences = torch.zeros(1,8,88)
        for i in range(1):
            training_data_sequences[i][0][int(70-i*10)  ] = 1
            training_data_sequences[i][1][int(70-i*10)+2] = 1
            training_data_sequences[i][2][int(70-i*10)+4] = 1
            training_data_sequences[i][3][int(70-i*10)+5] = 1
            training_data_sequences[i][4][int(70-i*10)+7] = 1
            training_data_sequences[i][5][int(70-i*10)+9] = 1
            training_data_sequences[i][6][int(70-i*10)+11] = 1
            training_data_sequences[i][7][int(70-i*10)+12] = 1
        return training_seq_lengths, training_data_sequences

    def superEasyTones():
        training_seq_lengths = torch.tensor([8]*10)
        training_data_sequences = torch.zeros(10,8,88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(30+i*5)] = 1
        return training_seq_lengths, training_data_sequences

    ## ドドド、ドドド、ドドド
    def easiestTones():
        training_seq_lengths = torch.tensor([8]*10)
        training_data_sequences = torch.zeros(10,8,88)
        for i in range(10):
            for j in range(8):
                training_data_sequences[i][j][int(70)] = 1
        return training_seq_lengths, training_data_sequences

    # setup logging
    logging.basicConfig(level=logging.DEBUG, format='%(message)s', filename=args.log, filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    logging.info(args)

    data = poly.load_data(poly.JSB_CHORALES)
    training_seq_lengths = data['train']['sequence_lengths']
    training_data_sequences = data['train']['sequences']
    training_seq_lengths, training_data_sequences = easiestTones()
    test_seq_lengths = data['test']['sequence_lengths']
    test_data_sequences = data['test']['sequences']
    test_seq_lengths, test_data_sequences = easiestTones()
    val_seq_lengths = data['valid']['sequence_lengths']
    val_data_sequences = data['valid']['sequences']
    val_seq_lengths, val_data_sequences = easiestTones()
    N_train_data = len(training_seq_lengths)
    N_train_time_slices = float(torch.sum(training_seq_lengths))
    N_mini_batches = int(N_train_data / args.mini_batch_size +
                         int(N_train_data % args.mini_batch_size > 0))

    logging.info("N_train_data: %d     avg. training seq. length: %.2f    N_mini_batches: %d" %
                 (N_train_data, training_seq_lengths.float().mean(), N_mini_batches))

    # how often we do validation/test evaluation during training
    val_test_frequency = 50
    # the number of samples we use to do the evaluation
    n_eval_samples = 1

    # package repeated copies of val/test data for faster evaluation
    # (i.e. set us up for vectorization)
    def rep(x):
        rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
        repeat_dims = [1] * len(x.size())
        repeat_dims[0] = n_eval_samples
        return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape)

    # get the validation/test data ready for the dmm: pack into sequences, etc.
    val_seq_lengths = rep(val_seq_lengths)
    test_seq_lengths = rep(test_seq_lengths)
    val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
        val_seq_lengths, cuda=args.cuda)
    test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
        torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
        test_seq_lengths, cuda=args.cuda)

    # instantiate the dmm
    dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs,
              iaf_dim=args.iaf_dim, use_cuda=args.cuda)

    # setup optimizer
    adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
                   "clip_norm": args.clip_norm, "lrd": args.lr_decay,
                   "weight_decay": args.weight_decay}
    adam = ClippedAdam(adam_params)

    # setup inference algorithm
    if args.tmc:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC")
        tmc_loss = TraceTMC_ELBO()
        dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss)
    elif args.tmcelbo:
        if args.jit:
            raise NotImplementedError("no JIT support yet for TMC ELBO")
        elbo = TraceEnum_ELBO()
        dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
        svi = SVI(dmm.model, dmm_guide, adam, loss=elbo)
    else:
        elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
        svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # now we're going to define some functions we need to form the main training loop

    # saves the model and optimizer states to disk
    def save_checkpoint():
        logging.info("saving model to %s..." % args.save_model)
        torch.save(dmm.state_dict(), args.save_model)
        # logging.info("saving optimizer states to %s..." % args.save_opt)
        # adam.save(args.save_opt)
        logging.info("done saving model and optimizer checkpoints to disk.")

    # loads the model and optimizer states from disk
    def load_checkpoint():
        assert exists(args.load_opt) and exists(args.load_model), \
            "--load-model and/or --load-opt misspecified"
        logging.info("loading model from %s..." % args.load_model)
        dmm.load_state_dict(torch.load(args.load_model))
        logging.info("loading optimizer states from %s..." % args.load_opt)
        adam.load(args.load_opt)
        logging.info("done loading model and optimizer states.")

    # prepare a mini-batch and take a gradient step to minimize -elbo
    def process_minibatch(epoch, which_mini_batch, shuffled_indices):
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor approriate for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                 float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0

        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * args.mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
        mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
        # grab a fully prepped mini-batch using the helper function in the data loader
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, training_data_sequences,
                                  training_seq_lengths, cuda=args.cuda)
        # do an actual gradient step
        loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths, annealing_factor)
        # keep track of the training loss
        return loss

    # helper function for doing evaluation
    def do_evaluation():
        # put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
        dmm.rnn.eval()

        # compute the validation and test loss n_samples many times
        val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
                                    val_seq_lengths) / float(torch.sum(val_seq_lengths))
        test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
                                     test_seq_lengths) / float(torch.sum(test_seq_lengths))

        # put the RNN back into training mode (i.e. turn on drop-out if applicable)
        dmm.rnn.train()
        return val_nll, test_nll

    # if checkpoint files provided, load model and optimizer states from disk before we start training
    if args.load_opt != '' and args.load_model != '':
        load_checkpoint()

    #################
    # TRAINING LOOP #
    #################
    times = [time.time()]
    for epoch in range(args.num_epochs):
        # if specified, save model and optimizer states to disk every checkpoint_freq epochs
        if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
            save_checkpoint()

        # accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
        epoch_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = torch.randperm(N_train_data)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)

        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        logging.info("[training epoch %04d]  %.4f \t\t\t\t(dt = %.3f sec)" %
                     (epoch, epoch_nll / N_train_time_slices, epoch_time))

        # do evaluation on test and validation data and report results
        if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
            val_nll, test_nll = do_evaluation()
            logging.info("[val/test epoch %04d]  %.4f  %.4f" % (epoch, val_nll, test_nll))