Beispiel #1
0
    def start_path(self, ui, region):
        """Start new trajectory path."""
        # print("new direction:", self.scale, self.noutside_regions, self.nrejects, self.naccepts)

        v = self.generate_direction(ui, region, scale=self.scale)
        assert (v**2).sum() > 0, (v, self.scale)
        assert region.inside(ui.reshape((1, -1))).all(), ui
        self.path = ContourSamplingPath(SamplingPath(ui, v, 0.0), region)
        if self.grad_function is not None:
            self.path.gradient = self.grad_function

        if not (ui > 0).all() or not (ui < 1).all() or not region.inside(ui.reshape((1, -1))):
            assert False, ui

        self.direction = +1
        self.lasti = 0
        self.cache = {0: (True, ui, self.last[1])}
        self.deadends = set()
        # self.iresets += 1
        if self.log:
            print()
            print("starting new direction", v, 'from', ui)
Beispiel #2
0
    def start_direction(self, region):
        """Choose a new random direction."""
        if self.log:
            print("choosing random direction")
        ui, Li = self.last
        v = generate_random_direction(ui, region, scale=self.scale)
        # v = generate_region_random_direction(ui, region, scale=self.scale)

        self.nrestarts += 1

        if self.sampler is None or True:
            samplingpath = SamplingPath(ui, v, Li)
            contourpath = ContourSamplingPath(samplingpath, region)
            if self.samplername == 'steps':
                self.sampler = ClockedStepSampler(contourpath, log=self.log)
                self.stepper = DirectJumper(self.sampler, self.nsteps, log=self.log)
            elif self.samplername == 'bisect':
                self.sampler = ClockedBisectSampler(contourpath, log=self.log)
                self.stepper = DirectJumper(self.sampler, self.nsteps, log=self.log)
            elif self.samplername == 'nuts':
                self.sampler = ClockedNUTSSampler(contourpath, log=self.log)
                self.stepper = IntervalJumper(self.sampler, self.nsteps, log=self.log)
            else:
                assert False
Beispiel #3
0
def test_reversible_gradient(plot=False):
    def loglike(x):
        x, y = x.transpose()
        return -0.5 * (x**2 + ((y - 0.5) / 0.2)**2)

    def transform(u):
        return u

    Lmin = -0.5

    for i in [84] + list(range(1, 100)):
        print("setting seed = %d" % i)
        np.random.seed(i)
        points = np.random.uniform(size=(10000, 2))
        L = loglike(points)
        mask = L > Lmin
        points = points[mask, :][:100, :]
        active_u = points
        active_values = L[mask][:100]

        transformLayer = AffineLayer(wrapped_dims=[])
        transformLayer.optimize(points, points)
        region = MLFriends(points, transformLayer)
        region.maxradiussq, region.enlarge = region.compute_enlargement(
            nbootstraps=30)
        region.create_ellipsoid()
        nclusters = transformLayer.nclusters
        assert nclusters == 1
        assert np.allclose(region.unormed,
                           region.transformLayer.transform(
                               points)), "transform should be reproducible"
        assert region.inside(
            points).all(), "live points should lie near live points"

        if i == 84:
            v = np.array([0.03477044, -0.01977415])
            reflpoint = np.array([0.09304075, 0.29114574])
        elif i == 4:
            v = np.array([0.03949306, -0.00634806])
            reflpoint = np.array([0.9934771, 0.55358031])

        else:
            v = np.random.normal(size=2)
            v /= (v**2).sum()**0.5
            v *= 0.04
            j = np.random.randint(len(active_u))
            reflpoint = np.random.normal(active_u[j, :], 0.04)
            if not (reflpoint < 1).all() and not (reflpoint > 0).all():
                continue

        bpts = region.transformLayer.transform(reflpoint).reshape((1, -1))
        tt = get_sphere_tangents(region.unormed, bpts)
        t = region.transformLayer.untransform(tt * 1e-3 +
                                              region.unormed) - region.u
        # compute new vector
        normal = t / norm(t, axis=1).reshape((-1, 1))
        print("reflecting at  ", reflpoint, "with direction", v)
        mask_forward1, angles, anglesnew = get_reflection_angles(normal, v)
        if mask_forward1.any():
            j = np.argmin(
                ((region.unormed[mask_forward1, :] - bpts)**2).sum(axis=1))
            k = np.arange(len(normal))[mask_forward1][j]
            angles_used = angles[k]
            normal_used = normal[k, :]
            print("chose normal", normal_used, k)
            #chosen_point = region.u[k,:]
            vnew = -(v - 2 * angles_used * normal_used)
            assert vnew.shape == v.shape

            mask_forward2, angles2, anglesnew2 = get_reflection_angles(
                normal, vnew)
            #j2 = np.argmin(((region.unormed[mask_forward2,:] - bpts)**2).sum(axis=1))
            #chosen_point2 = region.u[mask_forward2,:][0,:]
            #assert j2 == j, (j2, j)
            assert mask_forward2[k]
            #assert_allclose(chosen_point, chosen_point2)

            #for m, a, b, m2, a2, b2 in zip(mask_forward1, angles, anglesnew, mask_forward2, angles2, anglesnew2):
            #    if m != m2:
            #        print('  ', m, a, b, m2, a2, b2)

            #print("using normal", normal)
            #print("changed v from", v, "to", vnew)

            #angles2 = -(normal * (vnew / norm(vnew))).sum(axis=1)
            #mask_forward2 = angles < 0
            if plot:
                plt.figure(figsize=(5, 5))
                plt.title('%d' % mask_forward1.sum())
                plt.plot((reflpoint + v)[0], (reflpoint + v)[1],
                         '^',
                         color='orange')
                plt.plot((reflpoint + vnew)[:, 0], (reflpoint + vnew)[:, 1],
                         '^ ',
                         color='lime')
                plt.plot(reflpoint[0], reflpoint[1], '^ ', color='r')
                plt.plot(region.u[:, 0], region.u[:, 1], 'x ', ms=2, color='k')
                plt.plot(region.u[mask_forward1, 0],
                         region.u[mask_forward1, 1],
                         'o ',
                         ms=6,
                         mfc='None',
                         mec='b')
                plt.plot(region.u[mask_forward2, 0],
                         region.u[mask_forward2, 1],
                         's ',
                         ms=8,
                         mfc='None',
                         mec='g')
                plt.xlim(0, 1)
                plt.ylim(0, 1)
                plt.savefig('test_flatnuts_reversible_gradient_%d.png' % i,
                            bbox_inches='tight')
                plt.close()
            assert mask_forward1[k] == mask_forward2[k], (mask_forward1[k],
                                                          mask_forward2[k])

            print("reflecting at  ", reflpoint, "with direction", v)
            # make that step, then try to go back
            j = np.arange(len(normal))[mask_forward1][0]
            normal = normal[j, :]
            angles = (normal * (v / norm(v))).sum()
            v2 = v - 2 * angle(normal, v) * normal

            print("reflecting with", normal, "new direction", v2)

            #newpoint = reflpoint + v2
            #angles2 = (normal * (v2 / norm(v2))).sum()
            v3 = v2 - 2 * angle(normal, v2) * normal

            print("re-reflecting gives direction", v3)
            assert_allclose(v3, v)

            print()
            print("FORWARD:", v, reflpoint)
            samplingpath = SamplingPath(reflpoint - v, v, active_values[0])
            contourpath = ContourSamplingPath(samplingpath, region)
            normal = contourpath.gradient(reflpoint)
            if normal is not None:
                assert normal.shape == v.shape, (normal.shape, v.shape)

                print("BACKWARD:", v, reflpoint)
                v2 = -(v - 2 * angle(normal, v) * normal)
                normal2 = contourpath.gradient(reflpoint)
                assert_allclose(normal, normal2)
                normal2 = normal
                v3 = -(v2 - 2 * angle(normal2, v2) * normal2)
                assert_allclose(v3, v)
Beispiel #4
0
def test_detailed_balance():
    def loglike(x):
        x, y = x.transpose()
        return -0.5 * (x**2 + ((y - 0.5) / 0.2)**2)

    def transform(u):
        return u

    Lmin = -0.5
    for i in range(1, 100):
        print()
        print("---- seed=%d ----" % i)
        print()
        np.random.seed(i)
        points = np.random.uniform(size=(10000, 2))
        L = loglike(points)
        mask = L > Lmin
        points = points[mask, :][:400, :]
        active_u = points
        active_values = L[mask][:400]

        transformLayer = AffineLayer(wrapped_dims=[])
        transformLayer.optimize(points, points)
        region = MLFriends(points, transformLayer)
        region.maxradiussq, region.enlarge = region.compute_enlargement(
            nbootstraps=30)
        region.create_ellipsoid()
        nclusters = transformLayer.nclusters
        assert nclusters == 1
        assert np.allclose(region.unormed,
                           region.transformLayer.transform(
                               points)), "transform should be reproducible"
        assert region.inside(
            points).all(), "live points should lie near live points"

        v = np.random.normal(size=2)
        v /= (v**2).sum()**0.5
        v *= 0.04

        print("StepSampler ----")
        print("FORWARD SAMPLING FROM", 0, active_u[0], v, active_values[0])
        samplingpath = SamplingPath(active_u[0], v, active_values[0])
        problem = dict(loglike=loglike, transform=transform, Lmin=Lmin)
        sampler = ClockedStepSampler(ContourSamplingPath(samplingpath, region))
        check_starting_point(sampler, active_u[0], active_values[0], **problem)
        sampler.expand_onestep(fwd=True, **problem)
        sampler.expand_onestep(fwd=True, **problem)
        sampler.expand_onestep(fwd=True, **problem)
        sampler.expand_onestep(fwd=True, **problem)
        sampler.expand_onestep(fwd=False, **problem)
        sampler.expand_to_step(4, **problem)
        sampler.expand_to_step(-4, **problem)
        check_starting_point(sampler, active_u[0], active_values[0], **problem)

        starti, startx, startv, startL = max(sampler.points)

        print()
        print("BACKWARD SAMPLING FROM", starti, startx, startv, startL)
        samplingpath2 = SamplingPath(startx, -startv, startL)
        sampler2 = ClockedStepSampler(
            ContourSamplingPath(samplingpath2, region))
        check_starting_point(sampler2, startx, startL, **problem)
        sampler2.expand_to_step(starti, **problem)
        check_starting_point(sampler2, startx, startL, **problem)

        starti2, startx2, startv2, startL2 = max(sampler2.points)
        assert_allclose(active_u[0], startx2)
        assert_allclose(v, -startv2)

        starti, startx, startv, startL = min(sampler.points)
        print()
        print("BACKWARD SAMPLING FROM", starti, startx, startv, startL)
        samplingpath3 = SamplingPath(startx, startv, startL)
        sampler3 = ClockedStepSampler(
            ContourSamplingPath(samplingpath3, region))
        check_starting_point(sampler3, startx, startL, **problem)
        sampler3.expand_to_step(-starti, **problem)
        check_starting_point(sampler3, startx, startL, **problem)

        starti3, startx3, startv3, startL3 = max(sampler3.points)
        assert_allclose(active_u[0], startx3)
        assert_allclose(v, startv3)
        print()

        print("BisectSampler ----")
        log = dict(log=True)
        print("FORWARD SAMPLING FROM", 0, active_u[0], v, active_values[0])
        samplingpath = SamplingPath(active_u[0], v, active_values[0])
        sampler = ClockedBisectSampler(
            ContourSamplingPath(samplingpath, region), **log)
        check_starting_point(sampler, active_u[0], active_values[0], **problem)
        sampler.expand_to_step(10, **problem)
        check_starting_point(sampler, active_u[0], active_values[0], **problem)

        starti, startx, startv, startL = max(sampler.points)
        print()
        print("BACKWARD SAMPLING FROM", starti, startx, startv, startL)
        samplingpath2 = SamplingPath(startx, -startv, startL)
        sampler2 = ClockedBisectSampler(
            ContourSamplingPath(samplingpath2, region), **log)
        check_starting_point(sampler2, startx, startL, **problem)
        sampler2.expand_to_step(starti, **problem)
        check_starting_point(sampler2, startx, startL, **problem)

        starti2, startx2, startv2, startL2 = max(sampler2.points)
        if gap_free_path(sampler, 0, starti, **problem) and gap_free_path(
                sampler2, 0, starti2, **problem):
            assert_allclose(active_u[0], startx2)
            assert_allclose(v, -startv2)

        starti, startx, startv, startL = min(sampler.points)
        print()
        print("BACKWARD SAMPLING FROM", starti, startx, startv, startL)
        samplingpath3 = SamplingPath(startx, -startv, startL)
        sampler3 = ClockedBisectSampler(
            ContourSamplingPath(samplingpath3, region), **log)
        check_starting_point(sampler3, startx, startL, **problem)
        sampler3.expand_to_step(starti, **problem)
        check_starting_point(sampler3, startx, startL, **problem)

        starti3, startx3, startv3, startL3 = min(sampler3.points)
        if gap_free_path(sampler, 0, starti, **problem) and gap_free_path(
                sampler3, 0, starti3, **problem):
            assert_allclose(active_u[0], startx3)
            assert_allclose(v, -startv3)
        print()

        print("NUTSSampler ----")
        print("FORWARD SAMPLING FROM", 0, active_u[0], v, active_values[0])
        samplingpath = SamplingPath(active_u[0], v, active_values[0])
        np.random.seed(i)
        sampler = ClockedNUTSSampler(ContourSamplingPath(samplingpath, region))
        sampler.get_independent_sample(**problem)
Beispiel #5
0
def test_directjumper():
    Lmin = -1.0
    us = 0.5 + np.zeros((100, 2))
    #Ls = np.zeros(100)
    region = make_region(2)

    def transform(x):
        return x

    def loglike(x):
        return 0.0

    def gradient(x, plot=False):
        j = np.argmax(np.abs(x - 0.5))
        v = np.zeros(len(x))
        v[j] = -1 if x[j] > 0.5 else 1
        return v

    def nocall(x):
        assert False

    ui = us[np.random.randint(len(us)), :]
    v = np.array([0.01, 0.01])
    path = ContourSamplingPath(SamplingPath(ui, v, 0.0), region)
    path.gradient = nocall
    sampler = ClockedBisectSampler(path)
    stepper = DirectJumper(sampler, 4)

    assert (stepper.naccepts, stepper.nrejects) == (0, 0), (stepper.naccepts,
                                                            stepper.nrejects)
    assert stepper.isteps == 0, stepper.isteps
    x, L = makejump(stepper, sampler, transform, loglike, Lmin)
    assert_allclose(x, [0.54, 0.54])
    assert (stepper.naccepts, stepper.nrejects) == (4, 0), (stepper.naccepts,
                                                            stepper.nrejects)
    assert stepper.isteps == 4, stepper.isteps

    print()
    print("make reflect")
    print()

    def loglike(x):
        return 0.0 if x[0] < 0.505 else -100

    path = ContourSamplingPath(SamplingPath(ui, v, 0.0), region)
    path.gradient = gradient
    sampler = ClockedBisectSampler(path)
    stepper = DirectJumper(sampler, 4)
    assert (stepper.naccepts, stepper.nrejects) == (0, 0), (stepper.naccepts,
                                                            stepper.nrejects)
    assert stepper.isteps == 0, stepper.isteps
    x, L = makejump(stepper, sampler, transform, loglike, Lmin)
    assert_allclose(x, [0.47, 0.55])
    assert (stepper.naccepts, stepper.nrejects) == (4, 0), (stepper.naccepts,
                                                            stepper.nrejects)
    assert stepper.isteps == 4, stepper.isteps

    print()
    print("make stuck")
    print()

    # make stuck
    def loglike(x):
        return -100

    path = ContourSamplingPath(SamplingPath(ui, v, 0.0), region)
    path.gradient = gradient
    sampler = ClockedBisectSampler(path)
    stepper = DirectJumper(sampler, 4)
    assert (stepper.naccepts, stepper.nrejects) == (0, 0), (stepper.naccepts,
                                                            stepper.nrejects)
    assert stepper.isteps == 0, stepper.isteps
    x, L = makejump(stepper, sampler, transform, loglike, Lmin)
    assert_allclose(x, [0.50, 0.50])
    assert (stepper.naccepts, stepper.nrejects) == (0, 4), (stepper.naccepts,
                                                            stepper.nrejects)
    assert stepper.isteps == 4, stepper.isteps
Beispiel #6
0
    def move(self, ui, region, ndraw=1, plot=False):
        """Advance by slice sampling on the path."""
        if self.interval is None:
            v = self.generate_direction(ui, region, scale=self.scale)
            self.path = ContourSamplingPath(
                SamplingPath(ui, v, 0.0), region)

            if not (ui > 0).all() or not (ui < 1).all() or not region.inside(ui.reshape((1, -1))):
                assert False, ui

            # unit hypercube diagonal gives a reasonable maximum path length
            maxlength = len(ui)**0.5

            # expand direction until it is surely outside
            left = -1
            right = +1
            while abs(left * self.scale) < maxlength:
                xj, vj = self.path.extrapolate(left)
                if not (xj > 0).all() or not (xj < 1).all() or not region.inside(xj.reshape((1, -1))):
                    break
                # self.path.add(left, xj, vj, 0.0)
                left *= 2

            while abs(right * self.scale) < maxlength:
                xj, _ = self.path.extrapolate(right)
                if not (xj > 0).all() or not (xj < 1).all() or not region.inside(xj.reshape((1, -1))):
                    break
                # self.path.add(right, xj, vj, 0.0)
                right *= 2

            scale = max(-left, right)
            # print("scale %f gave %d %d " % (self.scale, left, right))
            if scale < 5:
                self.scale /= 1.1
            # if scale > 100:
            #     self.scale *= 1.1

            assert self.scale > 1e-10, self.scale
            self.interval = (left, right, None)
        else:
            left, right, mid = self.interval
            # we rejected mid, and shrink corresponding side
            if mid < 0:
                left = mid
            elif mid > 0:
                right = mid

        # shrink direction if outside
        while True:
            mid = np.random.randint(left, right + 1)
            # print("interpolating %d - %d - %d" % (left, mid, right),
            #     self.path.points)
            if mid == 0:
                _, xj, _, _ = self.path.points[0]
            else:
                xj, _ = self.path.extrapolate(mid)

            if region.inside(xj.reshape((1, -1))):
                self.interval = (left, right, mid)
                return xj.reshape((1, -1))
            else:
                if mid < 0:
                    left = mid
                else:
                    right = mid
                self.interval = (left, right, mid)
Beispiel #7
0
class SamplingPathSliceSampler(StepSampler):
    """Slice sampler, respecting the region, on the sampling path.

    This first builds up a complete trajectory, respecting reflections.
    Then, from the trajectory a new point is drawn with slice sampling.

    The trajectory is built by doubling the length to each side and
    checking if the point is still inside. If not, reflection is
    attempted with the gradient (either provided or region-based estimate).
    """

    def __init__(self, nsteps):
        """Initialise sampler.

        Parameters
        -----------
        nsteps: int
            number of accepted steps until the sample is considered independent.

        """
        StepSampler.__init__(self, nsteps=nsteps)
        self.interval = None
        self.path = None

    def generate_direction(self, ui, region, scale=1):
        """Choose new initial direction according to region.transformLayer axes."""
        return generate_region_oriented_direction(ui, region, tscale=1, scale=scale)

    def adjust_accept(self, accepted, unew, pnew, Lnew, nc):
        """Adjust proposal given that we have been *accepted* at a new point after *nc* calls."""
        if accepted:
            # start with a new interval next time
            self.interval = None

            self.last = unew, Lnew
            self.history.append((unew, Lnew))
        else:
            self.nrejects += 1
            # continue on current interval
            pass
        self.logstat.append([accepted, self.scale])

    def adjust_outside_region(self):
        """Adjust proposal given that we have stepped out of region."""
        self.logstat.append([False, self.scale])

    def move(self, ui, region, ndraw=1, plot=False):
        """Advance by slice sampling on the path."""
        if self.interval is None:
            v = self.generate_direction(ui, region, scale=self.scale)
            self.path = ContourSamplingPath(
                SamplingPath(ui, v, 0.0), region)

            if not (ui > 0).all() or not (ui < 1).all() or not region.inside(ui.reshape((1, -1))):
                assert False, ui

            # unit hypercube diagonal gives a reasonable maximum path length
            maxlength = len(ui)**0.5

            # expand direction until it is surely outside
            left = -1
            right = +1
            while abs(left * self.scale) < maxlength:
                xj, vj = self.path.extrapolate(left)
                if not (xj > 0).all() or not (xj < 1).all() or not region.inside(xj.reshape((1, -1))):
                    break
                # self.path.add(left, xj, vj, 0.0)
                left *= 2

            while abs(right * self.scale) < maxlength:
                xj, _ = self.path.extrapolate(right)
                if not (xj > 0).all() or not (xj < 1).all() or not region.inside(xj.reshape((1, -1))):
                    break
                # self.path.add(right, xj, vj, 0.0)
                right *= 2

            scale = max(-left, right)
            # print("scale %f gave %d %d " % (self.scale, left, right))
            if scale < 5:
                self.scale /= 1.1
            # if scale > 100:
            #     self.scale *= 1.1

            assert self.scale > 1e-10, self.scale
            self.interval = (left, right, None)
        else:
            left, right, mid = self.interval
            # we rejected mid, and shrink corresponding side
            if mid < 0:
                left = mid
            elif mid > 0:
                right = mid

        # shrink direction if outside
        while True:
            mid = np.random.randint(left, right + 1)
            # print("interpolating %d - %d - %d" % (left, mid, right),
            #     self.path.points)
            if mid == 0:
                _, xj, _, _ = self.path.points[0]
            else:
                xj, _ = self.path.extrapolate(mid)

            if region.inside(xj.reshape((1, -1))):
                self.interval = (left, right, mid)
                return xj.reshape((1, -1))
            else:
                if mid < 0:
                    left = mid
                else:
                    right = mid
                self.interval = (left, right, mid)
Beispiel #8
0
class SamplingPathStepSampler(StepSampler):
    """Step sampler on a sampling path."""

    def __init__(self, nresets, nsteps, scale=1.0, balance=0.01, nudge=1.1, log=False):
        """Initialise sampler.

        Parameters
        ------------
        nresets: int
            after this many iterations, select a new direction
        nsteps: int
            how many steps to make in total
        scale: float
            initial step size
        balance: float
            acceptance rate to target
            if below, scale is increased, if above, scale is decreased
        nudge: float
            factor for increasing scale (must be >=1)
            nudge=1 implies no step size adaptation.

        """
        StepSampler.__init__(self, nsteps=nsteps)
        # self.lasti = None
        self.path = None
        self.nresets = nresets
        # initial step scale in transformed space
        self.scale = scale
        # fraction of times a reject is expected
        self.balance = balance
        # relative increase in step scale
        self.nudge = nudge
        assert nudge >= 1
        self.log = log
        self.grad_function = None
        self.istep = 0
        self.iresets = 0
        self.start()
        self.terminate_path()
        self.logstat_labels = ['acceptance rate', 'reflection rate', 'scale', 'nstuck']

    def __str__(self):
        """Get string representation."""
        return '(nsteps=%d, nresets=%d, AR=%d%%)' % (
            type(self).__name__, self.nsteps, self.nresets, (1 - self.balance) * 100)

    def start(self):
        """Start sampler, reset all counters."""
        if hasattr(self, 'naccepts') and self.nrejects + self.naccepts > 0:
            nr, na = self.nrejects, self.naccepts
            self.logstat.append([
                self.naccepts / (self.nrejects + self.naccepts),
                self.nreflects / (self.nreflects + self.nrejects + self.naccepts),
                self.scale, self.nstuck])
        self.nrejects = 0
        self.naccepts = 0
        self.nreflects = 0
        self.nstuck = 0
        self.istep = 0
        self.iresets = 0
        self.noutside_regions = 0
        self.last = None, None
        self.history = []

        self.direction = +1
        self.deadends = set()
        self.path = None

    def start_path(self, ui, region):
        """Start new trajectory path."""
        # print("new direction:", self.scale, self.noutside_regions, self.nrejects, self.naccepts)

        v = self.generate_direction(ui, region, scale=self.scale)
        assert (v**2).sum() > 0, (v, self.scale)
        assert region.inside(ui.reshape((1, -1))).all(), ui
        self.path = ContourSamplingPath(SamplingPath(ui, v, 0.0), region)
        if self.grad_function is not None:
            self.path.gradient = self.grad_function

        if not (ui > 0).all() or not (ui < 1).all() or not region.inside(ui.reshape((1, -1))):
            assert False, ui

        self.direction = +1
        self.lasti = 0
        self.cache = {0: (True, ui, self.last[1])}
        self.deadends = set()
        # self.iresets += 1
        if self.log:
            print()
            print("starting new direction", v, 'from', ui)

    def terminate_path(self):
        """Terminate current path, and reset path counting variable."""
        # check if we went anywhere:
        if -1 in self.deadends and +1 in self.deadends:
            # self.scale /= self.nudge
            self.nstuck += 1

        # self.nrejects = 0
        # self.naccepts = 0
        # self.istep = 0
        # self.noutside_regions = 0
        self.direction = +1
        self.deadends = set()
        self.path = None
        self.iresets += 1
        if self.log:
            print("reset %d" % self.iresets)

    def set_gradient(self, grad_function):
        """Set gradient function."""
        print("set gradient function to %s" % grad_function.__name__)

        def plot_gradient_wrapper(x, plot=False):
            """wrapper that makes plots (when desired)"""
            v = grad_function(x)
            if plot:
                plt.plot(x[0], x[1], '+ ', color='k', ms=10)
                plt.plot([x[0], v[0] * 1e-2 + x[0]],
                         [x[1], v[1] * 1e-2 + x[1]], color='gray')
            return v
        self.grad_function = plot_gradient_wrapper

    def generate_direction(self, ui, region, scale):
        """Choose a random axis from region.transformLayer."""
        return generate_region_random_direction(ui, region, scale=scale)
        # return generate_random_direction(ui, region, scale=scale)

    def adjust_accept(self, accepted, unew, pnew, Lnew, nc):
        """Adjust proposal given that we have been *accepted* at a new point after *nc* calls."""
        self.cache[self.nexti] = (accepted, unew, Lnew)
        if accepted:
            # start at new point next time
            self.lasti = self.nexti
            self.last = unew, Lnew
            self.history.append((unew, Lnew))
            self.naccepts += 1
        else:
            # continue on current point, do not update self.last
            self.nrejects += 1
            self.history.append((unew, Lnew))
            assert self.scale > 1e-10, (self.scale, self.istep, self.nrejects)

    def adjust_outside_region(self):
        """Adjust proposal given that we landed outside region."""
        self.noutside_regions += 1
        self.nrejects += 1

    def adjust_scale(self, maxlength):
        """Adjust scale, but not above maxlength."""
        # print("%2d | %2d | %2d | %2d %2d %2d %2d | %f"  % (self.iresets, self.istep,
        #     len(self.history), self.naccepts, self.nrejects,
        #     self.noutside_regions, self.nstuck, self.scale))
        assert len(self.history) > 1

        if self.naccepts < (self.nrejects + self.naccepts) * self.balance:
            if self.log:
                print("adjusting scale %f down: istep=%d inside=%d outside=%d region=%d nstuck=%d" % (
                    self.scale, len(self.history), self.naccepts, self.nrejects, self.noutside_regions, self.nstuck))
            self.scale /= self.nudge
        else:
            if self.scale < maxlength or True:
                if self.log:
                    print("adjusting scale %f up: istep=%d inside=%d outside=%d region=%d nstuck=%d" % (
                        self.scale, len(self.history), self.naccepts, self.nrejects, self.noutside_regions, self.nstuck))
                self.scale *= self.nudge
        assert self.scale > 1e-10, self.scale

    def movei(self, ui, region, ndraw=1, plot=False):
        """Make a move and return the proposed index."""
        if self.path is not None:
            if self.lasti - 1 in self.deadends and self.lasti + 1 in self.deadends:
                # stuck, cannot go anywhere. Stay.
                self.nexti = self.lasti
                return self.nexti

        if self.path is None:
            self.start_path(ui, region)

        assert not (self.lasti - 1 in self.deadends and self.lasti + 1 in self.deadends), \
            (self.deadends, self.lasti)
        if self.lasti + self.direction in self.deadends:
            self.direction *= -1

        self.nexti = self.lasti + self.direction
        # print("movei", self.nexti)
        # self.nexti = self.lasti + np.random.randint(0, 2) * 2 - 1
        return self.nexti

    def move(self, ui, region, ndraw=1, plot=False):
        """Advance move."""
        u, v = self.get_point(self.movei(ui, region=region, ndraw=ndraw, plot=plot))
        return u.reshape((1, -1))

    def reflect(self, reflpoint, v, region, plot=False):
        """Reflect at *reflpoint* going in direction *v*. Return new direction."""
        normal = self.path.gradient(reflpoint, plot=plot)
        if normal is None:
            return -v
        return v - 2 * (normal * v).sum() * normal

    def get_point(self, inew):
        """Get point corresponding to index *inew*."""
        ipoints = [(u, v) for i, u, p, v in self.path.points if i == inew]
        if len(ipoints) == 0:
            # print("getting point %d" % inew, self.path.points) #, "->", self.path.extrapolate(self.nexti))
            return self.path.extrapolate(inew)
        else:
            return ipoints[0]

    def __next__(self, region, Lmin, us, Ls, transform, loglike, ndraw=40, plot=False):
        """Get next point.

        Parameters
        ----------
        region: MLFriends
            region.
        Lmin: float
            loglikelihood threshold
        us: array of vectors
            current live points
        Ls: array of floats
            current live point likelihoods
        transform: function
            transform function
        loglike: function
            loglikelihood function
        ndraw: int
            number of draws to attempt simultaneously.
        plot: bool
            whether to produce debug plots.

        """
        # find most recent point in history conforming to current Lmin
        ui, Li = self.last
        if Li is not None and not Li >= Lmin:
            if self.log:
                print("wandered out of L constraint; resetting", ui[0])
            ui, Li = None, None

        if Li is not None and not region.inside(ui.reshape((1,-1))):
            # region was updated and we are not inside anymore
            # so reset
            if self.log:
                print("region change; resetting")
            ui, Li = None, None

        if Li is None and self.history:
            # try to resume from a previous point above the current contour
            for uj, Lj in self.history[::-1]:
                if Lj >= Lmin and region.inside(uj.reshape((1,-1))):
                    ui, Li = uj, Lj
                    if self.log:
                        print("recovered using history", ui)
                    break

        # select starting point
        if Li is None:
            # choose a new random starting point
            mask = region.inside(us)
            assert mask.any(), (
                "None of the live points satisfies the current region!",
                region.maxradiussq, region.u, region.unormed, us)
            i = np.random.randint(mask.sum())
            self.starti = i
            ui = us[mask,:][i]
            if self.log:
                print("starting at", ui)
            assert np.logical_and(ui > 0, ui < 1).all(), ui
            Li = Ls[mask][i]
            self.start()
            self.history.append((ui, Li))
            self.last = (ui, Li)

        inew = self.movei(ui, region, ndraw=ndraw)
        if self.log:
            print("i: %d->%d (step %d)" % (self.lasti, inew, self.istep))

        # uold, _ = self.get_point(self.lasti)
        _, uold, Lold = self.cache[self.lasti]
        if plot:
            plt.plot(uold[0], uold[1], 'd', color='brown', ms=4)

        uret, pret, Lret = uold, transform(uold), Lold

        nc = 0
        if inew != self.lasti:
            accept = False
            if inew not in self.cache:
                unew, _ = self.get_point(inew)
                if plot:
                    plt.plot(unew[0], unew[1], 'x', color='k', ms=4)
                accept = np.logical_and(unew > 0, unew < 1).all() and region.inside(unew.reshape((1, -1)))
                if accept:
                    if plot:
                        plt.plot(unew[0], unew[1], '+', color='orange', ms=4)
                    pnew = transform(unew)
                    Lnew = loglike(pnew.reshape((1, -1)))
                    nc = 1
                else:
                    Lnew = -np.inf
                    if self.log:
                        print("outside region: ", unew, "from", ui)
                    self.deadends.add(inew)
                    self.adjust_outside_region()
            else:
                _, unew, Lnew = self.cache[self.nexti]
                # if plot:
                #    plt.plot(unew[0], unew[1], 's', color='r', ms=2)

            if self.log:
                print("   suggested point:", unew)
            pnew = transform(unew)

            if Lnew >= Lmin:
                if self.log:
                    print(" -> inside.")
                if plot:
                    plt.plot(unew[0], unew[1], 'o', color='g', ms=4)
                self.adjust_accept(True, unew, pnew, Lnew, nc)
                uret, pret, Lret = unew, pnew, Lnew
            else:
                if plot:
                    plt.plot(unew[0], unew[1], '+', color='k', ms=2, alpha=0.3)
                if self.log:
                    print(" -> outside.")
                jump_successful = False
                if inew not in self.cache and inew not in self.deadends:
                    # first time we try to go beyond
                    # try to reflect:
                    reflpoint, v = self.get_point(inew)
                    if self.log:
                        print("    trying to reflect at", reflpoint)
                    self.nreflects += 1

                    sign = -1 if inew < 0 else +1
                    vnew = self.reflect(reflpoint, v * sign, region=region) * sign

                    xk, vk = extrapolate_ahead(sign, reflpoint, vnew, contourpath=self.path)

                    if plot:
                        plt.plot([reflpoint[0], (-v + reflpoint)[0]], [reflpoint[1], (-v + reflpoint)[1]], '-', color='k', lw=0.5, alpha=0.5)
                        plt.plot([reflpoint[0], (vnew + reflpoint)[0]], [reflpoint[1], (vnew + reflpoint)[1]], '-', color='k', lw=1)

                    if self.log:
                        print("    trying", xk)
                    accept = np.logical_and(xk > 0, xk < 1).all() and region.inside(xk.reshape((1, -1)))
                    if accept:
                        pk = transform(xk)
                        Lk = loglike(pk.reshape((1, -1)))[0]
                        nc += 1
                        if Lk >= Lmin:
                            jump_successful = True
                            uret, pret, Lret = xk, pk, Lk
                            if self.log:
                                print("successful reflect!")
                            self.path.add(inew, xk, vk, Lk)
                            self.adjust_accept(True, xk, pk, Lk, nc)
                        else:
                            if self.log:
                                print("unsuccessful reflect")
                            self.adjust_accept(False, xk, pk, Lk, nc)
                    else:
                        if self.log:
                            print("unsuccessful reflect out of region")
                        self.adjust_outside_region()

                    if plot:
                        plt.plot(xk[0], xk[1], 'x', color='g' if jump_successful else 'r', ms=8)

                    if not jump_successful:
                        # unsuccessful. mark as deadend
                        self.deadends.add(inew)
                        # print("deadends:", self.deadends)
                else:
                    self.adjust_accept(False, uret, pret, Lret, nc)

                # self.adjust_accept(False, unew, pnew, Lnew, nc)
                assert inew in self.cache or inew in self.deadends, (inew in self.cache, inew in self.deadends)
        else:
            # stuck, proposal did not move us
            self.nstuck += 1
            self.adjust_accept(False, uret, pret, Lret, nc)

        # increase step count
        self.istep += 1
        if self.istep == self.nsteps:
            if self.log:
                print("triggering re-orientation")
                # reset path so we go in a new direction
            self.terminate_path()
            self.istep = 0

        # if had enough resets, return final point
        if self.iresets >= self.nresets:
            if self.log:
                print("walked %d paths; returning sample" % self.iresets)
            self.adjust_scale(maxlength=len(uret)**0.5)
            self.start()
            self.last = None, None
            return uret, pret, Lret, nc

        # do not have a independent sample yet
        return None, None, None, nc