Fast online estimates on the GPU

← Posts · 06/08/2021 · 4 minutes

Estimating moments is an important step of any statistical analysis of data. The mean, variance, skewness and kurtosis of a dataset can already tell a lot about the distribution of our data.

However, some datasets don’t quite fit in memory. If you have a dataset of N samples and C features where N is a lot bigger than C, you can benefit a lot by using online algorithms.

$$ \bar x = \frac 1N\sum_{i=1}^N x_i $$

$$ \sigma^2 = \frac N{N-1}\left(\sum_{i=1}^N x_i^2 - \bar x^2\right) $$

If you are using PyTorch on a GPU, it is easy to compute those moments. But since CUDA and multi-threading do not play well together, we have to optimize the batching of our samples.

import torch
from torch import Tensor

class Moments:
    """
    Online estimator of moments up to 4.

    >>> m = Moments(5)

    >>> m.fit(torch.zeros(5,))
    1

    >>> m.fit(torch.ones(5,))
    2

    >>> m.mean()
    tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000])

    >>> m.std()
    tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000])

    >>> m.var(corrected=False)
    tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500])

    >>> m.var(corrected=True)
    tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000])

    >>> m.skewness()
    tensor([0., 0., 0., 0., 0.])

    >>> m.kurtosis()
    tensor([-2., -2., -2., -2., -2.])

    """

    def __init__(self, n_channels: int, device="cpu") -> None:
        self.n = 0
        self.m = torch.zeros((n_channels, 4), device=device)  # moments

    def fit(self, new_obs: Tensor) -> int:
        if len(new_obs.shape) == 1:  # single obs
            new_obs = new_obs.unsqueeze(0)

        self.n += new_obs.size(0)

        y = new_obs.clone()
        for i in range(4):
            self.m[:, i] += y.sum(dim=0)
            y = y * new_obs

        return self.n

    def merge(self, other):
        self.m += other.m
        self.n += other.n
        return self

    def mean(self) -> Tensor:
        return self.m[:, 0] / self.n

    def std(self) -> Tensor:
        return self.var(corrected=False).sqrt()

    def var(self, corrected: bool = False) -> Tensor:
        var = self.m[:,1]/self.n - self.mean().pow(2)
        if corrected:
            var = var * self.n / (self.n - 1)
        return var
    
    def skewness(self) -> Tensor:
        var = self.var()
        mean = self.mean()
        return (self.m[:,2] / self.n - 3.0 * mean * var - mean.pow(3)) / var.pow(1.5)
    
    def kurtosis(self) -> Tensor:
        m1, m2, m3, m4 = tuple(self.m[:,i] / self.n for i in range(4))
        return (m4 - 4.0 * m1 * m3 + 6.0 * m1.pow(2) * m2 - 3.0 * m1.pow(4)) / self.var().pow(2) - 3.0

Example usage:

device = "cuda" if torch.cuda.is_available() else "cpu"
m = Moments(100, device)

def embed(imgs: Tensor) -> Tensor:
  # transform input images into embedding vectors
  embeddings = ...
  return embeddings

for imgs in tqdm(dataloader):
    embeddings = embed(imgs.to(device))

    m.fit(embeddings)

# You can then get the moments
m.mean()
m.var(corrected=True)
m.skewness()
m.kurtosis()

Since we are embedding images, the bottleneck is likely to be the disk reads to load the image files. Make sure to use a torch.utils.data.DataLoader with new_workers set to a high enough value to leverage asynchronous loading of files.

With enough num_workers, the fitting process of 127,000 images which each produce 1024 embedding vectors takes around 3 minutes. This process could take a lot more time with a more naive approach and it could even be undoable when saving all the vectors in memory (127,000 * 1024 * 100 * 32 bits -> 52.016 Gbytes).

Computing the CIFAR10 mean and standard deviation

Let’s first instantiate the CIFAR10 dataset provided by torchvision:

import torch
from torchvision.datasets import CIFAR10
import torchvision.transforms as T

cifar10 = CIFAR10(
    "./data/cifar10", train=True, download=True, transform=T.ToTensor()
)
loader = torch.utils.data.DataLoader(cifar10, batch_size=32, num_workers=16,)

Then we can use the Moments class to compute the mean and standard deviation for each color channel:

device = "cuda" if torch.cuda.is_available() else "cpu"
moments = Moments(n_channels=3, device=device)
for x, _ in tqdm(loader):
    moments.fit(x.to(device).permute(0, 2, 3, 1).reshape(-1, 3))

mean = moments.mean()
std = moments.std()

print(f">> mean = {mean},\nstd = {std}")

Which returns the following mean and std:

mean = tensor([0.4914, 0.4822, 0.4465], device='cuda:0'),
std = tensor([0.2470, 0.2435, 0.2616], device='cuda:0')

Helper Function

from tqdm import tqdm
import torch

def get_mean_and_std(dataset: torch.utils.data.Dataset):
    loader = torch.utils.data.DataLoader(dataset, batch_size = 32, num_workers = 4)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    moments = Moments(n_channels=3, device=device)
    for x, _ in tqdm(loader):
        moments.fit(x.to(device).permute(0, 2, 3, 1).reshape(-1, 3))

    mean = moments.mean()
    std = moments.std()

    return mean, std

An online multi-variate Gaussian estimator

$$ \mu = \frac 1N\sum_{i = 1}^Nx_i $$

$$ \Sigma = \frac 1{N-1}\left[\sum_{i=1}^N(x_ix_i^T) -N\mu\mu^T\right] $$

from typing import Tuple

import torch
from torch import Tensor

class OnlineGaussian:
    """
    Estimates the mean and covariance matrix of a multi-variate Gaussian 
    of `dim` variables.

    >>> gauss = OnlineGaussian(3)

    >>> gauss.fit(torch.ones((12, 3,)))
    12

    >>> gauss.value()
    (tensor([1., 1., 1.]), tensor([[0., 0., 0.],
            [0., 0., 0.],
            [0., 0., 0.]]))
    """

    def __init__(self, dim: int, device: str="cpu"):
        self.dim = dim
        self.sum1 = torch.zeros((dim,), device=device)
        self.sum2 = torch.zeros((dim, dim), device=device)
        self.N = 0

    @torch.no_grad()
    def fit(self, x: Tensor) -> int:
        x = x.view(-1, self.dim)
        assert x.size(-1) == self.dim

        self.N += x.size(0)
        self.sum1 += x.sum(0)
        self.sum2 += torch.einsum("ni,nj->ij", x, x)
        return self.N

    @torch.no_grad()
    def value(self, corrected: bool=True) -> Tuple[Tensor, Tensor]:
        means = self.sum1 / self.N
        covs = (self.sum2 - self.N * torch.outer(means, means)) / (self.N - corrected)
        return means, covs

References

  • OnlineStats.jl, a cool Julia package that implements a lot of online estimators.