# Fast online moments estimates on the GPU

← Posts · 06/08/2021 · 3 minutes

Estimating moments is an important step of any statistical analysis of data. The mean, variance, skewness and kurtosis of a dataset can already tell a lot about the distribution of our data.

However, some datasets don’t quite fit in memory. If you have a dataset of N samples and C features where N is a lot bigger than C, you can benefit a lot by using online algorithms.

$$\bar x = \frac 1N\sum_{i=1}^N x_i$$

$$\sigma^2 = \frac N{N-1}\left(\sum_{i=1}^N x_i^2 - \bar x^2\right)$$

If you are using PyTorch on a GPU, it is easy to compute those moments. But since CUDA and multi-threading do not play well together, we have to optimize the batching of our samples.

import torch
from torch import Tensor

class Moments:
"""
Online estimator of moments up to 4.

>>> m = Moments(5)

>>> m.fit(torch.zeros(5,))
1

>>> m.fit(torch.ones(5,))
2

>>> m.mean()
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000])

>>> m.var(corrected=True)
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000])

>>> m.skewness()
tensor([0., 0., 0., 0., 0.])

>>> m.kurtosis()
tensor([-2., -2., -2., -2., -2.])

"""

def __init__(self, size: int, device="cpu") -> None:
self.n = 0
self.m = torch.zeros((size, 4), device=device) # moments

def fit(self, new_obs: Tensor) -> int:
if len(new_obs.shape) == 1: # single obs
new_obs = new_obs.unsqueeze(0)

self.n += new_obs.size(0)

y = new_obs.clone()
for i in range(4):
self.m[:,i] += y.sum(dim=0)
y = y * new_obs

return self.n

def merge(self, other):
self.m += other.m
self.n += other.n
return self

def mean(self) -> Tensor:
return self.m[:,0] / self.n

def var(self, corrected: bool = False) -> Tensor:
var = self.m[:,1]/self.n - self.mean().pow(2)
if corrected:
var = var * self.n / (self.n - 1)
return var

def skewness(self) -> Tensor:
var = self.var()
mean = self.mean()
return (self.m[:,2] / self.n - 3.0 * mean * var - mean.pow(3)) / var.pow(1.5)

def kurtosis(self) -> Tensor:
m1, m2, m3, m4 = tuple(self.m[:,i] / self.n for i in range(4))
return (m4 - 4.0 * m1 * m3 + 6.0 * m1.pow(2) * m2 - 3.0 * m1.pow(4)) / self.var().pow(2) - 3.0


Example usage:

device = "cuda" if torch.cuda.is_available() else "cpu"
m = Moments(100, device)

def embed(imgs: Tensor) -> Tensor:
# transform input images into embedding vectors
embeddings = ...
return embeddings

embeddings = embed(imgs.to(device))

m.fit(embeddings)

# You can then get the moments
m.mean()
m.var(corrected=True)
m.skewness()
m.kurtosis()


Since we are embedding images, the bottleneck is likely to be the disk reads to load the image files. Make sure to use a torch.utils.data.DataLoader with new_workers set to a high enough value to leverage asynchronous loading of files.

With enough num_workers, the fitting process of 127,000 images which each produce 1024 embedding vectors takes around 3 minutes. This process could take a lot more time with a more naive approach and it could even be undoable when saving all the vectors in memory (127,000 * 1024 * 100 * 32 byte -> 416,154 Gbytes).

• OnlineStats.jl, a cool Julia package that implements a lot of online estimators.