Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
# -*- coding: utf-8 -*- r""" Discrete Gaussian Samplers over the Integers
This class realizes oracles which returns integers proportionally to `\exp(-(x-c)^2/(2σ^2))`. All oracles are implemented using rejection sampling. See :func:`DiscreteGaussianDistributionIntegerSampler.__init__` for which algorithms are available.
AUTHORS:
- Martin Albrecht (2014-06-28): initial version
EXAMPLES:
We construct a sampler for the distribution `D_{3,c}` with width `σ=3` and center `c=0`::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: sigma = 3.0 sage: D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma)
We ask for 100000 samples::
sage: from six.moves import range sage: n=100000; l = [D() for _ in range(n)]
These are sampled with a probability proportional to `\exp(-x^2/18)`. More precisely we have to normalise by dividing by the overall probability over all integers. We use the fact that hitting anything more than 6 standard deviations away is very unlikely and compute::
sage: norm_factor = sum([exp(-x^2/(2*sigma^2)) for x in range(-6*sigma,sigma*6+1)]) sage: norm_factor 7.519...
With this normalisation factor, we can now test if our samples follow the expected distribution::
sage: x=0; l.count(x), ZZ(round(n*exp(-x^2/(2*sigma^2))/norm_factor)) (13355, 13298) sage: x=4; l.count(x), ZZ(round(n*exp(-x^2/(2*sigma^2))/norm_factor)) (5479, 5467) sage: x=-10; l.count(x), ZZ(round(n*exp(-x^2/(2*sigma^2))/norm_factor)) (53, 51)
We construct an instance with a larger width::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: sigma = 127 sage: D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma, algorithm='uniform+online')
ask for 100000 samples::
sage: from six.moves import range sage: n=100000; l = [D() for _ in range(n)] # long time
and check if the proportions fit::
sage: x=0; y=1; float(l.count(x))/l.count(y), exp(-x^2/(2*sigma^2))/exp(-y^2/(2*sigma^2)).n() # long time (1.0, 1.00...) sage: x=0; y=-100; float(l.count(x))/l.count(y), exp(-x^2/(2*sigma^2))/exp(-y^2/(2*sigma^2)).n() # long time (1.32..., 1.36...)
We construct a sampler with `c\%1 != 0`::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: sigma = 3 sage: D = DiscreteGaussianDistributionIntegerSampler(sigma=sigma, c=1/2) sage: from six.moves import range sage: n=100000; l = [D() for _ in range(n)] # long time sage: mean(l).n() # long time 0.486650000000000
REFERENCES:
- [DDLL2013]_
""" #****************************************************************************** # # DGS - Discrete Gaussian Samplers # # Copyright (c) 2014, Martin Albrecht <martinralbrecht+dgs@googlemail.com> # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # The views and conclusions contained in the software and documentation are # those of the authors and should not be interpreted as representing official # policies, either expressed or implied, of the FreeBSD Project. #*****************************************************************************/
from __future__ import print_function
from cysignals.signals cimport sig_on, sig_off
from sage.rings.real_mpfr cimport RealNumber, RealField from sage.libs.mpfr cimport mpfr_set, MPFR_RNDN from sage.rings.integer cimport Integer from sage.misc.randstate cimport randstate, current_randstate
from .dgs cimport dgs_disc_gauss_mp_init, dgs_disc_gauss_mp_clear, dgs_disc_gauss_mp_flush_cache from .dgs cimport dgs_disc_gauss_dp_init, dgs_disc_gauss_dp_clear, dgs_disc_gauss_dp_flush_cache from .dgs cimport DGS_DISC_GAUSS_UNIFORM_TABLE, DGS_DISC_GAUSS_UNIFORM_ONLINE, DGS_DISC_GAUSS_UNIFORM_LOGTABLE, DGS_DISC_GAUSS_SIGMA2_LOGTABLE
cdef class DiscreteGaussianDistributionIntegerSampler(SageObject): r""" A Discrete Gaussian Sampler using rejection sampling.
.. automethod:: __init__ .. automethod:: __call__ """
# We use tables for σt ≤ table_cutoff
def __init__(self, sigma, c=0, tau=6, algorithm=None, precision="mp"): r""" Construct a new sampler for a discrete Gaussian distribution.
INPUT:
- ``sigma`` - samples `x` are accepted with probability proportional to `\exp(-(x-c)²/(2σ²))`
- ``c`` - the mean of the distribution. The value of ``c`` does not have to be an integer. However, some algorithms only support integer-valued ``c`` (default: ``0``)
- ``tau`` - samples outside the range `(⌊c⌉-⌈στ⌉,...,⌊c⌉+⌈στ⌉)` are considered to have probability zero. This bound applies to algorithms which sample from the uniform distribution (default: ``6``)
- ``algorithm`` - see list below (default: ``"uniform+table"`` for `σt` bounded by ``DiscreteGaussianDistributionIntegerSampler.table_cutoff`` and ``"uniform+online"`` for bigger `στ`)
- ``precision`` - either ``"mp"`` for multi-precision where the actual precision used is taken from sigma or ``"dp"`` for double precision. In the latter case results are not reproducible. (default: ``"mp"``)
ALGORITHMS:
- ``"uniform+table"`` - classical rejection sampling, sampling from the uniform distribution and accepted with probability proportional to `\exp(-(x-c)²/(2σ²))` where `\exp(-(x-c)²/(2σ²))` is precomputed and stored in a table. Any real-valued `c` is supported.
- ``"uniform+logtable"`` - samples are drawn from a uniform distribution and accepted with probability proportional to `\exp(-(x-c)²/(2σ²))` where `\exp(-(x-c)²/(2σ²))` is computed using logarithmically many calls to Bernoulli distributions. See [DDLL2013]_ for details. Only integer-valued `c` are supported.
- ``"uniform+online"`` - samples are drawn from a uniform distribution and accepted with probability proportional to `\exp(-(x-c)²/(2σ²))` where `\exp(-(x-c)²/(2σ²))` is computed in each invocation. Typically this is very slow. See [DDLL2013]_ for details. Any real-valued `c` is accepted.
- ``"sigma2+logtable"`` - samples are drawn from an easily samplable distribution with `σ = k·σ_2` with `σ_2 = \sqrt{1/(2\log 2)}` and accepted with probability proportional to `\exp(-(x-c)²/(2σ²))` where `\exp(-(x-c)²/(2σ²))` is computed using logarithmically many calls to Bernoulli distributions (but no calls to `\exp`). See [DDLL2013]_ for details. Note that this sampler adjusts `σ` to match `k·σ_2` for some integer `k`. Only integer-valued `c` are supported.
EXAMPLES::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm="uniform+online") Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 0 sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm="uniform+table") Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 0 sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm="uniform+logtable") Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 0
Note that ``"sigma2+logtable"`` adjusts `σ`::
sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm="sigma2+logtable") Discrete Gaussian sampler over the Integers with sigma = 3.397287 and c = 0
TESTS:
We are testing invalid inputs::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: DiscreteGaussianDistributionIntegerSampler(-3.0) Traceback (most recent call last): ... ValueError: sigma must be > 0.0 but got -3.000000
sage: DiscreteGaussianDistributionIntegerSampler(3.0, tau=-1) Traceback (most recent call last): ... ValueError: tau must be >= 1 but got -1
sage: DiscreteGaussianDistributionIntegerSampler(3.0, tau=2, algorithm="superfastalgorithmyouneverheardof") Traceback (most recent call last): ... ValueError: Algorithm 'superfastalgorithmyouneverheardof' not supported by class 'DiscreteGaussianDistributionIntegerSampler'
sage: DiscreteGaussianDistributionIntegerSampler(3.0, c=1.5, algorithm="sigma2+logtable") Traceback (most recent call last): ... ValueError: algorithm 'uniform+logtable' requires c%1 == 0
We are testing correctness for multi-precision::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=0, tau=2) sage: from six.moves import range sage: l = [D() for _ in range(2^16)] sage: min(l) == 0-2*1.0, max(l) == 0+2*1.0, abs(mean(l)) < 0.01 (True, True, True)
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=2) sage: from six.moves import range sage: l = [D() for _ in range(2^18)] sage: min(l)==2-2*1.0, max(l)==2+2*1.0, mean(l).n() (True, True, 2.45...)
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=6) sage: from six.moves import range sage: l = [D() for _ in range(2^18)] sage: min(l), max(l), abs(mean(l)-2.5) < 0.01 (-2, 7, True)
We are testing correctness for double precision::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=0, tau=2, precision="dp") sage: from six.moves import range sage: l = [D() for _ in range(2^16)] sage: min(l) == 0-2*1.0, max(l) == 0+2*1.0, abs(mean(l)) < 0.05 (True, True, True)
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=2, precision="dp") sage: from six.moves import range sage: l = [D() for _ in range(2^18)] sage: min(l)==2-2*1.0, max(l)==2+2*1.0, mean(l).n() (True, True, 2.4...)
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: D = DiscreteGaussianDistributionIntegerSampler(1.0, c=2.5, tau=6, precision="dp") sage: from six.moves import range sage: l = [D() for _ in range(2^18)] sage: min(l)<=-1, max(l)>=6, abs(mean(l)-2.5) < 0.1 (True, True, True) sage: tuple(l.count(i) for i in range(-2,8)) # output random (7, 242, 4519, 34120, 92714, 91700, 33925, 4666, 246, 5)
We plot a histogram::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: D = DiscreteGaussianDistributionIntegerSampler(17.0) sage: S = [D() for _ in range(2^16)] sage: list_plot([(v,S.count(v)) for v in set(S)]) # long time Graphics object consisting of 1 graphics primitive
These generators cache random bits for performance reasons. Hence, resetting the seed of the PRNG might not have the expected outcome. You can flush this cache with ``_flush_cache()``::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: D = DiscreteGaussianDistributionIntegerSampler(3.0) sage: sage.misc.randstate.set_random_seed(0); D() 3 sage: sage.misc.randstate.set_random_seed(0); D() 3 sage: sage.misc.randstate.set_random_seed(0); D._flush_cache(); D() 3
sage: D = DiscreteGaussianDistributionIntegerSampler(3.0) sage: sage.misc.randstate.set_random_seed(0); D() 3 sage: sage.misc.randstate.set_random_seed(0); D() 3 sage: sage.misc.randstate.set_random_seed(0); D() -3 """
else: algorithm = "uniform+online"
raise ValueError("algorithm 'uniform+logtable' requires c%1 == 0") else:
sigma = RR(sigma) else: raise ValueError("Parameter precision '%s' not supported."%precision)
def _flush_cache(self): r""" Flush the internal cache of random bits.
EXAMPLES::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler
sage: f = lambda: sage.misc.randstate.set_random_seed(0)
sage: f() sage: D = DiscreteGaussianDistributionIntegerSampler(30.0) sage: [D() for k in range(16)] [21, 23, 37, 6, -64, 29, 8, -22, -3, -10, 7, -43, 1, -29, 25, 38]
sage: f() sage: D = DiscreteGaussianDistributionIntegerSampler(30.0) sage: l = [] sage: for i in range(16): ....: f(); l.append(D()) sage: l [21, 21, 21, 21, -21, 21, 21, -21, -21, -21, 21, -21, 21, -21, 21, 21]
sage: f() sage: D = DiscreteGaussianDistributionIntegerSampler(30.0) sage: l = [] sage: for i in range(16): ....: f(); D._flush_cache(); l.append(D()) sage: l [21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21] """ dgs_disc_gauss_dp_flush_cache(self._gen_dp)
def __dealloc__(self): r""" TESTS::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: D = DiscreteGaussianDistributionIntegerSampler(3.0, algorithm="uniform+online") sage: del D """
def __call__(self): r""" Return a new sample.
EXAMPLES::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm="uniform+online")() -3 sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm="uniform+table")() 3
TESTS::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: DiscreteGaussianDistributionIntegerSampler(3.0, algorithm="uniform+logtable", precision="dp")() # random output 13 """ cdef randstate rstate cdef Integer rop else:
def _repr_(self): r""" TESTS::
sage: from sage.stats.distributions.discrete_gaussian_integer import DiscreteGaussianDistributionIntegerSampler sage: repr(DiscreteGaussianDistributionIntegerSampler(3.0, 2)) 'Discrete Gaussian sampler over the Integers with sigma = 3.000000 and c = 2' """ |