5 minute read

Introduction

Following my investigation into bootstrap confidence intervals, I set out to run some simulation experiments to observe how the coverage of these confidence intervals approached their nominal levels. I noticed that there didn’t appear to exist a popular JAX implementation for this functionality, so I took it upon myself to implement these methods inside a package I am calling statax. I think JAX is uniquely well suited for bootstrapping due to its ease of vectorisation via vmap and its jit capabilities, which are especially useful for repeated computation of a supplied statistic function.

While, it is true that the scipy.bootstrap function will allow you to enable vectorisation, this functionality is not as easily utilised since the burden is on you to pass scipy.bootstrap an efficiently vectorised implementation of your statistic. Nonetheless, even in the scenario where we pass scipy.bootstrap a vectorised statistic there is still utility in using a JAX implementation. This is is shown below where JAX allows us to easily vectorise a coverage experiment built on top of the bootstrap procedure, and as such achieve a >4x speedup relative to the same experiment using scipy.bootstrap. Obviously, this speed up is even more drastic when we are not able to pass scipy.bootstrap an easily pre-vectorised statistic.


N_SIM = 1000
N_SAMPLES = 100
N_RESAMPLES = 2_000


def scipy_bootstrap(seed: int = 42):
    rng = np.random.default_rng(seed)
    data = rng.normal(size=N_SAMPLES)

    res = bootstrap(
      (data,), 
      np.median, 
      confidence_level=0.95, 
      n_resamples=N_RESAMPLES, 
      vectorized=True, 
      method="BCa",
    )

    low, high = res.confidence_interval
    contains = low <= 0 and high >= 0

    return contains


def simulate_scipy():
    res = 0
    for i in range(N_SIM):
        res += scipy_bootstrap(i)
    return res / N_SIM


@jit
@vmap
def statax_bootstrap(key: jax.Array):
    data = random.normal(key, shape=(N_SAMPLES,))

    bootstrapper = BCaBootstrapper(jnp.median)
    bootstrapper.resample(data, n_resamples=N_RESAMPLES)

    ci = bootstrapper.ci(confidence_level=0.95)
    contains = jnp.logical_and(ci[0] <= 0, ci[1] >= 0)

    return contains


def simulate_statax():
    key = random.key(0)
    res = jnp.sum(statax_bootstrap(random.split(key, N_SIM)))
    return res / N_SIM


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    scipy_time = timeit(simulate_scipy, number=100)
    logging.info(f"scipy time: {scipy_time:.2f}")

    statax_time = timeit(simulate_statax, number=100)
    logging.info(f"statax time: {statax_time:.2f}")

    logging.info(f"{scipy_time / statax_time:.2f} factor speedup.")
INFO:root:scipy time: 316.69
INFO:root:statax time: 77.94
INFO:root:4.06 factor speedup.

Coverage Experiment

To test the coverage of different bootstrap confidence interval methods, I simulated n_sim bootstrap confidence interval generations for each bootstrap confidence interval generation procedure. I then calculated what percentage of the time these confidence intervals contained the true statistic. This ratio was then used as an estimate for the coverage of that confidence interval generation procedure for a given $n$.

class SamplingDistribution(Protocol):
    def __call__(self, key: jax.Array, n: int) -> jax.Array: ...


class StatisticFn(Protocol):
    def __call__(self, data: jax.Array) -> jax.Array: ...

def estimate_coverage(
    bootstrapper_constructor: Callable[..., Bootstrapper],
    sampling_distribution: SamplingDistribution,
    statistic_fn: StatisticFn,
    confidence_level: float = 0.95,
    n_sim: int = 10_000,
    n_samples: int = 100,
    n_boot: int = 2_000,
    batch_size: int = 1000,
    seed: int = 0,
):
    key = random.key(seed)

    key, subkey = random.split(key)
    true_statistic_value = statistic_fn(sampling_distribution(subkey, 1_000_000))

    bootstrapper = bootstrapper_constructor(statistic_fn)

    @jax.vmap
    @jax.jit
    def confidence_interval_simulation(key: jax.Array) -> tuple[jax.Array, jax.Array]:
        key, subkey = random.split(key)
        data = sampling_distribution(subkey, n_samples)

        key, subkey = random.split(key)
        bootstrapper.resample(data=data, n_resamples=n_boot, key=subkey)

        ci_low, ci_high = bootstrapper.ci(confidence_level=confidence_level)
        return (true_statistic_value >= ci_low) & (true_statistic_value <= ci_high), ci_high - ci_low

    covered_count = 0
    total_length = 0
    i = 0
    while i < n_sim:
        logging.debug(f"Batch: i / {n_sim}")

        current_batch_size = min(batch_size, n_sim - i)
        key, subkey = random.split(key)

        is_covered, length = confidence_interval_simulation(random.split(subkey, current_batch_size))
        covered_count += jnp.sum(is_covered)
        total_length += jnp.sum(length)

        i += batch_size
    coverage = covered_count / n_sim
    average_length = total_length / n_sim

    return coverage, average_length

Simple Statistic

As an end-to-end test for my different bootstrap confidence interval estimation procedures, my first experiment was on estimating the median of a normal distribution. As can be seen below, all methods start with a true coverage of $\approx 90\%$ which converges to the nominal level of 95% as $n$ increases.

    def normal_distribution(key: jax.Array, n: int) -> jax.Array:
        key, subkey = random.split(key)
        sample = random.normal(subkey, shape=(n,))
        return sample

    def median_statistic(data: jax.Array) -> jax.Array:
        return jnp.mean(data)

bootstrap_adversarial

Complex Statistic

Next, I looked at a more complex statistic over a more complex distribution. This time the observed coverage was far less than that promised by the nominal rate for small $n$, with all methods starting at around $\approx 70\%$. Interestingly, the performance of each approach is roughly in line with the order that they were presented in my original exploration, with the difference in coverage between the worst and best performing methods being within $\approx 5\%$. Additionally, all methods did eventually converge to the nominal rate as $n \rightarrow \infty$.

    def complex_distribution(key: jax.Array, n: int) -> jax.Array:
        # Log-normal distribution (heavily skewed)
        key, subkey = random.split(key)
        log_normal = jnp.exp(random.normal(subkey, shape=(n,)))

        # Add some contamination for extra complexity
        key, subkey = random.split(key)
        contamination = random.exponential(subkey, shape=(n,))

        return 0.9 * log_normal + 0.1 * contamination

    def complex_statistic(data: jax.Array) -> jax.Array:
        return jnp.mean(data) / (1 + jnp.median(data))

bootstrap_adversarial

Adversarial Percentile Setup

Thus far, I was surprised to see that the percentile method was on par with the other more complex methods. After all, my original exploration set out to challenge my default approach of reaching for the percentile method. As such, I wanted to devise an adversarial example that the percentile method would struggle on. Specifically, I modified my bootstrap resampling code to add a constant offset, with scale set by the statistic variance (for the sake of numerical stability), to the estimated statistic and twice this offset to the replicates. This perturbation ensured that the difference between the estimators and true statistics would be offset by some constant amount. This is a scenario that the percentile method will struggle with, since its estimated percentiles will be off by twice this offset. Meanwhile, the other methods make pivotal assumptions about statistic differences and as such will be insulated from this perturbation.

\[\hat{\theta}_n - \theta \sim C + F\] \[\hat{\theta}^* - \hat{\theta} \sim C + F\]
def resample(self, data: jax.Array, n_resamples: int = 2000, key: jax.Array = random.key(42)) -> None:
    key, subkey = random.split(key)

    self._theta_hat = self._statistic(data)

    @jax.vmap
    @jax.jit
    def _generate_bootstrap_replicate(rng_key: jax.Array) -> jax.Array:
        data_resampled = self._resample_data(data, rng_key)
        theta_boot = self._statistic(data_resampled)
        return theta_boot

    self._bootstrap_replicates = _generate_bootstrap_replicate(random.split(subkey, n_resamples))

    # Modification
    bias_factor = self.variance() * 2
    self._theta_hat = self._theta_hat + bias_factor
    self._bootstrap_replicates = self._bootstrap_replicates + 2 * bias_factor

bootstrap_adversarial

No Free Lunch

People often point to bootstrapping as an “assumption-free” panacea for confidence interval estimation in the age of computer-aided statistics. Hopefully, this analysis has shown that there is indeed a price for using bootstrap confidence interval estimation, namely that the true coverage can be considerably less than the nominal level, especially for $n \leq 100$. Furthermore, the success of different bootstrapping approaches, in terms of rapidly converging to their promised asymptotic results, very much depends on whether the supplied data conforms to that bootstrap method’s optimal convergence assumptions.

However, we should not be too harsh on bootstrapping. Classical statistics often makes equally dubious data generation and asymptotic assumptions. In summary, bootstrapping is a powerful tool; however, professional judgment should be exercised before blindly accepting the confidence intervals that these procedures produce. Furthermore, if you have good reason to believe that your data conforms to assumptions that would make it amenable to exact confidence interval calculation, then such procedures should not be overlooked.

Updated: