Joeyonng
  • Notebook
  • Pages
  • About
  • Backyard
  1. Research Notes
  2. L0 Regularization
  • Research Notes
    • Birkhoff+
    • CG Decision Rules
    • Confident Learning
    • L0 Regularization
    • ML Q & A
    • MLIC IMLI
    • Mobile Nets
    • Quantization Survey
    • RIPPER
    • SGD Warm Restarts
    • SSGD
    • Traversing Diagonals

On this page

  • Problem formulation
  • General recipe of L_{0} regularization
  • Hard concrete distribution
    • Binary Concrete distribution
    • From binary concrete to hard concrete

L0 Regularization

Published

June 2, 2022

This page contains my reading notes on

  • Learning Sparse Neural Networks through L_{0} Regularization
Code
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
plt.style.use("ggplot")

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def hard_sigmoid(x):
    return min(1, max(0, x))

def logit_dist():
    u = np.random.random()
    logit = np.log(u) - np.log(1 - u)
    
    return logit

def binary_concrete(loc, temp):
    logit = logit_dist()
    bc = sigmoid((logit + loc) / temp) 
    
    return bc

def stretch_binary_concrete(loc, temp, gamma=-0.1, zeta=1.1):
    bc = binary_concrete(loc, temp)
    stretch_bc = bc * (zeta - gamma) + gamma
    
    return stretch_bc

def hard_concrete(loc, temp, gamma=-0.1, zeta=1.1):
    stretch_bc = stretch_binary_concrete(loc, temp, gamma, zeta)
    hc = hard_sigmoid(stretch_bc) 

    return hc
    
def plot_probability(list_samples, bins=100, **kwargs):
    plt.figure(figsize=(16, 8))
    for samples in list_samples:
        weights = np.ones_like(samples) / len(samples)
        plt.hist(samples, weights=weights, bins=bins, alpha=0.5, **kwargs)

Problem formulation

Given a vector x of length n (matrix can also be seen as a vector by stacking up the rows/cols), the common vector norms are:

  • L_{0} norm:

    \sum_{i=1}^{n} \mathbb{1}[x_{i} \neq 0]

  • L_{1} norm:

    \sum_{i=1}^{n} \lvert x_{i} \rvert

    which is also called ridge regularization in neural network.

  • L_{2} norm:

    \sum_{i=1}^{n} x_{i}^2

    which is also called lasso regularization in neural network.

  • L_{\infty} norm:

    \max_{i=1}^{n} \lvert x_i \rvert

The normal way to prune the edges of the neural network is to use L_{1} or L_{2} regularization to drive weights to near 0 (not exactly 0), and then directly set all weights that are less than threshold to 0.

  • L_{0} is not used because the operation of counting the number of 0s is not differentiable.

  • However, L_{0} regularization is still desired because it won’t affect the magnitude of the weights in the pruning process.

General recipe of L_{0} regularization

The loss function used to train a neural network with L_{0} regularization is:

\mathcal{L}(f(x, \theta), y) + \lambda \sum_{i=1}^{\lvert \theta \rvert} \mathbb{1}[\theta_{i} \neq 0]

where

  • \mathcal{L} is a standard loss function (cross-entropy loss, softmax)
  • \theta are the parameters in the network
  • x, y are training instances
  • \lambda is a hyper-parameter that balance loss and the regularization.

If we attach a trainable binary random variable z_{i} to each element of the model parameter \theta_{i}, then the weights used in the feed-forward operation of the neural network can be replaced by \theta \odot z. The loss function then becomes:

\mathcal{L}(f(x, \theta \odot z), y) + \lambda \sum_{i=1}^{\lvert \theta \rvert} \mathbb{1}[z_{i} \neq 0]

where

  • z \in \{0, 1\}^{\lvert \theta \rvert} is randomly sampled in each forward propagation according to some distribution.
  • \odot corresponds to the elementwise product.

If we assume each z_{i} as a binary random variable with a Bernoulli distribution that has a parameter \pi_{i}, i.e. z_{i} = \mathrm{Bern}(\pi_{i}), then the loss function becomes:

\mathbb{E}_{z=\mathrm{Bern}(\pi)} \big[ \mathcal{L}(f(x, \theta \odot z), y) \big] + \lambda \sum_{i=1}^{\lvert \theta \rvert} \pi_{i}

where \mathbb{E}_{z=\mathrm{Bern}(\pi)} [\cdot] gives an expectation value of a function that has a Bernoulli distribution z as the input.

The reformulation of the above loss function can be established because

  1. Since the minimum of a function is upper bounded by the expectation of the function, minimizing \mathbb{E}_{z=\mathrm{Bern}(\pi)} \big[ \mathcal{L}(f(x, \theta \odot z), y) \big] is the same as minimizing the upper bound of \mathcal{L}(f(x, \theta \odot z), y).

  2. According to the definition of the Bernoulli distribution, \pi gives the probability of z being 1 (non zero). Thus, minimizing \pi is to increase the probability of z being 0.

In the equation above,

  • \pi_{i} is a parameter that we want to be learned using gradient descent.

  • Thus, the second term \lambda \sum_{i=1}^{\lvert \theta \rvert} \pi_{i} can be directly minimized to regularize \pi because gradient of the second term w.r.t to pi_{i} can be easily calculated.

  • However, the first term \mathbb{E}_{z=\mathrm{Bern}(\pi)} \big[ \mathcal{L}(f(x, \theta \odot z), y) \big] is still problematic because z as a categorical random variable cannot be differentiated with respect to \pi.

Hard concrete distribution

Binary Concrete distribution

The binary concrete distribution can be seen as a continuous approximation of the Bernoulli distribution. It has 2 parameters:

  • location \alpha: similar to the probability parameter of the Bernoulli distribution.

  • temperature \beta: it controls how similar the binary concrete distribution is with the Bernoulli distribution.

Using the the reparametrization trick, the binary concrete distribution s can be represented as:

s = \operatorname{sigmoid} \left( \frac{\alpha + l}{\beta} \right)

where l is a sample from the logistic distribution.

Code
def plot_logit(size=100000, **kwargs):
    logit_samples = [logit_dist() for _ in range(size)]
    plot_probability([logit_samples])
    
interact(plot_logit);
Code
def plot_bc(size=100000, **kwargs):
    ber_samples = np.random.binomial(1, 0.5, size=size)
    bc_samples = [binary_concrete(kwargs['loc'], kwargs['temp']) for _ in range(size)]
    plot_probability([ber_samples, bc_samples])
    
interact(plot_bc, loc=(-3, 3), temp=(0.001, 1));

From binary concrete to hard concrete

We cannot use s (binary concrete distribution) to directly replace z (Bernoulli distribution) - The range of s is (0, 1) and never touches 0 or 1. - However, we want z to be either 0 or 1.

A simple trick to solve this problem is

  1. First “stretch” the binary concrete distribution from interval (0, 1) to interval (\gamma, \zeta) with \gamma < 0 and \zeta > 1

    \bar{s} = s(\zeta - \gamma) + \gamma

  2. Then clip the stretch binary concrete distribution into the range [0, 1]

    \bar{z} = \mathrm{clip}(\bar{s}, 0, 1)

    where

    • \mathrm{clip}(x, \mathrm{min}, \mathrm{max}) means to clip x between the range [\mathrm{min}, \mathrm{max}].

    • \bar{z} is a random variable that follows the hard concrete distribution and it can be used to replace z.

Code
def plot_sbc(size=100000, **kwargs):
    bc_samples = [binary_concrete(kwargs['loc'], kwargs['temp']) for _ in range(size)]
    sbc_samples = [stretch_binary_concrete(kwargs['loc'], kwargs['temp']) for _ in range(size)]
    plot_probability([bc_samples, sbc_samples])
    
interact(plot_sbc, loc=(-3, 3), temp=(0.001, 1));
Code
def plot_hc(size=100000, **kwargs):
    hc_samples = [hard_concrete(kwargs['loc'], kwargs['temp']) for _ in range(size)]
    sbc_samples = [stretch_binary_concrete(kwargs['loc'], kwargs['temp']) for _ in range(size)]
    plot_probability([hc_samples, sbc_samples])
    
interact(plot_hc, loc=(-3, 3), temp=(0.001, 1));

Also, we cannot use \alpha (location parameter of binary concrete distribution) to replace \pi (probability parameter of Bernoulli distribution) in the regularization term of the loss function. - Remember that the regularization term above measures the sum of the probabilities of the z (Bernoulli distribution) being non-zero.

$$ \sum_{i=1}^{\lvert \theta \rvert} \pi_{i} $$

Measuring the probability of \bar{z} of being non-zero is the same as measuring the probability of \bar{s} being positive.

  • Because all negative values of \bar{s} is clipped to be 0.

  • Since we know that the total probability of a random variable being all values is 1, the probability of \bar{s} being positive is written as:

    \textrm{q}(\bar{s} > 0 | \phi) = 1 - \mathrm{Q}(\bar{s} \leq 0 | \phi)

    where

    • \textrm{Q}(\cdot) is the cumulative density function (CDF)
    • \mathrm{Q}(s \leq 0 | \phi) gives the probability of s being negative.

Thus, the loss function above can be rewritten using the hard concrete distribution by:

\mathbb{E}_{\bar{s}} \big[ \mathcal{L}(f(x, \theta \odot \bar{z}), y) \big] + \lambda \sum_{i=1}^{\lvert \theta \rvert} (1 - \mathrm{Q}(\bar{s} \leq 0 | \phi))

The loss function is “fully” differentiable with respect to \alpha

  • \bar{z} = \mathrm{clip}(\bar{s}, 0, 1) and \bar{s} is a differentiable function with respect to \alpha.

  • \mathrm{Q}(\bar{s} \leq 0 | \phi) can be written as a differentiable function with respect to \alpha.