Source code for flows.bijections.bijections

import jax.numpy as np
import jax.scipy.special as spys
from jax import random, scipy
from jax.nn.initializers import orthogonal
from jax.scipy import linalg

# Each layer constructor function returns an init_fun where...
#
# init_fun: a function taking an rng key and an input shape and returns...
#     params: a pytree
#     direct_fun: a function taking params and inputs and returns...
#         outputs: the mapped inputs when the layer is applied
#         log_det_jacobian: the log of the determinant of the jacobian
#     inverse_fun: a function taking params and inputs and returns...
#         outputs: the mapped inputs when the layer is applied
#         log_det_jacobian: the log of the determinant of the jacobian


[docs]def ActNorm(): """An implementation of an actnorm layer from `Glow: Generative Flow with Invertible 1x1 Convolutions` (https://arxiv.org/abs/1807.03039). Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """ def init_fun(rng, input_dim, **kwargs): init_inputs = kwargs.pop("init_inputs", None) if not (init_inputs is None): log_weight = np.log(1.0 / (init_inputs.std(0) + 1e-6)) bias = init_inputs.mean(0) else: log_weight = np.zeros(input_dim) bias = np.zeros(input_dim) def direct_fun(params, inputs, **kwargs): log_weight, bias = params outputs = (inputs - bias) * np.exp(log_weight) log_det_jacobian = np.full(inputs.shape[:1], log_weight.sum()) return outputs, log_det_jacobian def inverse_fun(params, inputs, **kwargs): log_weight, bias = params outputs = inputs * np.exp(-log_weight) + bias log_det_jacobian = np.full(inputs.shape[:1], -log_weight.sum()) return outputs, log_det_jacobian return (log_weight, bias), direct_fun, inverse_fun return init_fun
[docs]def AffineCouplingSplit(scale, translate): """An implementation of a coupling layer from `Density Estimation Using RealNVP` (https://arxiv.org/abs/1605.08803). Args: scale: An ``(params, apply_fun)`` pair characterizing a trainable scaling function translate: An ``(params, apply_fun)`` pair characterizing a trainable translation function Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """ def init_fun(rng, input_dim, **kwargs): cutoff = input_dim // 2 scale_rng, rng = random.split(rng) scale_params, scale_apply_fun = scale(scale_rng, cutoff, input_dim - cutoff) translate_rng, rng = random.split(rng) translate_params, translate_apply_fun = translate(translate_rng, cutoff, input_dim - cutoff) def direct_fun(params, inputs, **kwargs): scale_params, translate_params = params lower, upper = inputs[:, :cutoff], inputs[:, cutoff:] log_weight = scale_apply_fun(scale_params, lower) bias = translate_apply_fun(translate_params, lower) upper = upper * np.exp(log_weight) + bias outputs = np.concatenate([lower, upper], axis=1) log_det_jacobian = log_weight.sum(-1) return outputs, log_det_jacobian def inverse_fun(params, inputs, **kwargs): scale_params, translate_params = params lower, upper = inputs[:, :cutoff], inputs[:, cutoff:] log_weight = scale_apply_fun(scale_params, lower) bias = translate_apply_fun(translate_params, lower) upper = (upper - bias) * np.exp(-log_weight) outputs = np.concatenate([lower, upper], axis=1) log_det_jacobian = log_weight.sum(-1) return outputs, log_det_jacobian return (scale_params, translate_params), direct_fun, inverse_fun return init_fun
[docs]def AffineCoupling(transform): """An implementation of a coupling layer from `Density Estimation Using RealNVP` (https://arxiv.org/abs/1605.08803). Args: net: An ``(params, apply_fun)`` pair characterizing a trainable translation function Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """ def init_fun(rng, input_dim, **kwargs): cutoff = input_dim // 2 params, apply_fun = transform(rng, cutoff, 2 * (input_dim - cutoff)) def direct_fun(params, inputs, **kwargs): lower, upper = inputs[:, :cutoff], inputs[:, cutoff:] log_weight, bias = apply_fun(params, lower).split(2, axis=1) upper = upper * np.exp(log_weight) + bias outputs = np.concatenate([lower, upper], axis=1) log_det_jacobian = log_weight.sum(-1) return outputs, log_det_jacobian def inverse_fun(params, inputs, **kwargs): lower, upper = inputs[:, :cutoff], inputs[:, cutoff:] log_weight, bias = apply_fun(params, lower).split(2, axis=1) upper = (upper - bias) * np.exp(-log_weight) outputs = np.concatenate([lower, upper], axis=1) log_det_jacobian = log_weight.sum(-1) return outputs, log_det_jacobian return params, direct_fun, inverse_fun return init_fun
[docs]def BatchNorm(momentum=0.9): """An implementation of a batch normalization layer from `Density Estimation Using RealNVP` (https://arxiv.org/abs/1605.08803). Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """ def init_fun(rng, input_dim, **kwargs): log_weight = np.zeros(input_dim) bias = np.zeros(input_dim) eps = 1e-5 # Which is better, keeping track of state as an edge case for batchnorm or changing # the entire library's interface to return a layer state and keep functions pure? state = {} def direct_fun(params, inputs, **kwargs): evaluation = kwargs.pop("evaluation", None) log_weight, bias = params if "running_mean" not in state: state["running_mean"] = np.zeros(input_dim) state["running_var"] = np.ones(input_dim) running_mean, running_var = state["running_mean"], state["running_var"] if evaluation: mean = running_mean var = running_var else: batch_mean = inputs.mean(0) batch_var = ((inputs - batch_mean) ** 2.0).mean(0) + eps state["batch_mean"] = batch_mean state["batch_var"] = batch_var running_mean = (running_mean * momentum) + (batch_mean * (1 - momentum)) running_var = (running_var * momentum) + (batch_var * (1 - momentum)) mean = batch_mean var = batch_var x_hat = (inputs - mean) / np.sqrt(var) outputs = x_hat * np.exp(log_weight) + bias log_det_jacobian = np.full((inputs.shape[0],), (log_weight - 0.5 * np.log(var)).sum()) return outputs, log_det_jacobian def inverse_fun(params, inputs, **kwargs): evaluation = kwargs.pop("evaluation", None) log_weight, bias = params if "running_mean" not in state: state["running_mean"] = np.zeros(input_dim) state["running_var"] = np.ones(input_dim) running_mean, running_var = state["running_mean"], state["running_var"] if evaluation: mean = running_mean var = running_var else: mean = state["batch_mean"] var = state["batch_var"] x_hat = (inputs - bias) * np.exp(-log_weight) outputs = x_hat * np.sqrt(var) + mean log_det_jacobian = np.full((inputs.shape[0],), (-log_weight + 0.5 * np.log(var)).sum()) return outputs, log_det_jacobian return (log_weight, bias), direct_fun, inverse_fun return init_fun
[docs]def Invert(bijection): """Inverts a tranformation so that its ``direct_fun`` is its ``inverse_fun`` and vice versa. Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """ def init_fun(rng, input_dim, **kwargs): params, direct_fun, inverse_fun = bijection(rng, input_dim) return params, inverse_fun, direct_fun return init_fun
[docs]def FixedInvertibleLinear(): """An implementation of an invertible linear layer from `Glow: Generative Flow with Invertible 1x1 Convolutions` (https://arxiv.org/abs/1605.08803). Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """ def init_fun(rng, input_dim, **kwargs): W = orthogonal()(rng, (input_dim, input_dim)) W_inv = linalg.inv(W) W_log_det = np.linalg.slogdet(W)[-1] def direct_fun(params, inputs, **kwargs): outputs = inputs @ W log_det_jacobian = np.full(inputs.shape[:1], W_log_det) return outputs, log_det_jacobian def inverse_fun(params, inputs, **kwargs): outputs = inputs @ W_inv log_det_jacobian = np.full(inputs.shape[:1], -W_log_det) return outputs, log_det_jacobian return (), direct_fun, inverse_fun return init_fun
[docs]def InvertibleLinear(): """An implementation of an invertible linear layer from `Glow: Generative Flow with Invertible 1x1 Convolutions` (https://arxiv.org/abs/1605.08803). Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """ def init_fun(rng, input_dim, **kwargs): W = orthogonal()(rng, (input_dim, input_dim)) P, L, U = scipy.linalg.lu(W) S = np.diag(U) U = np.triu(U, 1) identity = np.eye(input_dim) def direct_fun(params, inputs, **kwargs): L, U, S = params L = np.tril(L, -1) + identity U = np.triu(U, 1) W = P @ L @ (U + np.diag(S)) outputs = inputs @ W log_det_jacobian = np.full(inputs.shape[:1], np.log(np.abs(S)).sum()) return outputs, log_det_jacobian def inverse_fun(params, inputs, **kwargs): L, U, S = params L = np.tril(L, -1) + identity U = np.triu(U, 1) W = P @ L @ (U + np.diag(S)) outputs = inputs @ linalg.inv(W) log_det_jacobian = np.full(inputs.shape[:1], -np.log(np.abs(S)).sum()) return outputs, log_det_jacobian return (L, U, S), direct_fun, inverse_fun return init_fun
[docs]def Logit(clip_before_logit=True): """Computes the logit function on a set of inputs, with sigmoid function being its inverse. Important note: Values passed through this layer are clipped to be within a range computable using 32 bits. This was done in "Cubic-Spline Flows" by Durkan et al. Technically this breaks invertibility, but it avoids inevitable NaNs. Args: clip_before_logit: Whether to clip values to range [1e-5, 1 - 1e-5] before being passed through logit. Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """ return Invert(Sigmoid(clip_before_logit))
[docs]def Reverse(): """An implementation of a reversing layer from `Density Estimation Using RealNVP` (https://arxiv.org/abs/1605.08803). Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. Examples: >>> num_examples, input_dim, tol = 20, 3, 1e-4 >>> layer_rng, input_rng = random.split(random.PRNGKey(0)) >>> inputs = random.uniform(input_rng, (num_examples, input_dim)) >>> init_fun = Reverse() >>> params, direct_fun, inverse_fun = init_fun(layer_rng, input_dim) >>> mapped_inputs = direct_fun(params, inputs)[0] >>> reconstructed_inputs = inverse_fun(params, mapped_inputs)[0] >>> np.allclose(inputs, reconstructed_inputs).item() True """ def init_fun(rng, input_dim, **kwargs): perm = np.arange(input_dim)[::-1] def direct_fun(params, inputs, **kwargs): return inputs[:, perm], np.zeros(inputs.shape[:1]) def inverse_fun(params, inputs, **kwargs): return inputs[:, perm], np.zeros(inputs.shape[:1]) return (), direct_fun, inverse_fun return init_fun
[docs]def Shuffle(): """An implementation of a shuffling layer from `Density Estimation Using RealNVP` (https://arxiv.org/abs/1605.08803). Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. Examples: >>> num_examples, input_dim, tol = 20, 3, 1e-4 >>> layer_rng, input_rng = random.split(random.PRNGKey(0)) >>> inputs = random.uniform(input_rng, (num_examples, input_dim)) >>> init_fun = Shuffle() >>> params, direct_fun, inverse_fun = init_fun(layer_rng, input_dim) >>> mapped_inputs = direct_fun(params, inputs)[0] >>> reconstructed_inputs = inverse_fun(params, mapped_inputs)[0] >>> np.allclose(inputs, reconstructed_inputs).item() True """ def init_fun(rng, input_dim, **kwargs): perm = random.permutation(rng, np.arange(input_dim)) inv_perm = np.argsort(perm) def direct_fun(params, inputs, **kwargs): return inputs[:, perm], np.zeros(inputs.shape[:1]) def inverse_fun(params, inputs, **kwargs): return inputs[:, inv_perm], np.zeros(inputs.shape[:1]) return (), direct_fun, inverse_fun return init_fun
[docs]def Sigmoid(clip_before_logit=True): """Computes the sigmoid function on a set of inputs, with the logit function being its inverse. Important note: Values passed through this layer are clipped to be within a range computable using 32 bits. This was done in "Cubic-Spline Flows" by Durkan et al. Technically this breaks invertibility, but it avoids inevitable NaNs. Args: clip_before_logit: Whether to clip values to range [1e-5, 1 - 1e-5] before being passed through logit. Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. """ def init_fun(rng, input_dim, **kwargs): def direct_fun(params, inputs, **kwargs): outputs = spys.expit(inputs) log_det_jacobian = np.log(spys.expit(inputs) * (1 - spys.expit(inputs))).sum(-1) return outputs, log_det_jacobian def inverse_fun(params, inputs, **kwargs): if clip_before_logit: inputs = np.clip(inputs, 1e-5, 1 - 1e-5) outputs = spys.logit(inputs) log_det_jacobian = -np.log(inputs - np.square(inputs)).sum(-1) return outputs, log_det_jacobian return (), direct_fun, inverse_fun return init_fun
[docs]def Serial(*init_funs): """ Args: *init_funs: Multiple bijections in sequence Returns: An ``init_fun`` mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet. Examples: >>> num_examples, input_dim, tol = 20, 3, 1e-4 >>> layer_rng, input_rng = random.split(random.PRNGKey(0)) >>> inputs = random.uniform(input_rng, (num_examples, input_dim)) >>> init_fun = Serial(Shuffle(), Shuffle()) >>> params, direct_fun, inverse_fun = init_fun(layer_rng, input_dim) >>> mapped_inputs = direct_fun(params, inputs)[0] >>> reconstructed_inputs = inverse_fun(params, mapped_inputs)[0] >>> np.allclose(inputs, reconstructed_inputs).item() True """ def init_fun(rng, input_dim, **kwargs): init_inputs = kwargs.pop("init_inputs", None) all_params, direct_funs, inverse_funs = [], [], [] for init_fun in init_funs: rng, layer_rng = random.split(rng) param, direct_fun, inverse_fun = init_fun(layer_rng, input_dim, init_inputs=init_inputs) all_params.append(param) direct_funs.append(direct_fun) inverse_funs.append(inverse_fun) if not (init_inputs is None): init_inputs = direct_fun(param, init_inputs)[0] def feed_forward(params, apply_funs, inputs): log_det_jacobians = np.zeros(inputs.shape[:1]) for apply_fun, param in zip(apply_funs, params): inputs, log_det_jacobian = apply_fun(param, inputs, **kwargs) log_det_jacobians += log_det_jacobian return inputs, log_det_jacobians def direct_fun(params, inputs, **kwargs): return feed_forward(params, direct_funs, inputs) def inverse_fun(params, inputs, **kwargs): return feed_forward(reversed(params), reversed(inverse_funs), inputs) return all_params, direct_fun, inverse_fun return init_fun