Welcome to jax-flows’s documentation!

Transformations

flows.ActNorm()[source]

An implementation of an actnorm layer from Glow: Generative Flow with Invertible 1x1 Convolutions (https://arxiv.org/abs/1807.03039).

Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.
flows.AffineCoupling(transform)[source]

An implementation of a coupling layer from Density Estimation Using RealNVP (https://arxiv.org/abs/1605.08803).

Parameters:net – An (params, apply_fun) pair characterizing a trainable translation function
Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.
flows.AffineCouplingSplit(scale, translate)[source]

An implementation of a coupling layer from Density Estimation Using RealNVP (https://arxiv.org/abs/1605.08803).

Parameters:
  • scale – An (params, apply_fun) pair characterizing a trainable scaling function
  • translate – An (params, apply_fun) pair characterizing a trainable translation function
Returns:

An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.

flows.BatchNorm(momentum=0.9)[source]

An implementation of a batch normalization layer from Density Estimation Using RealNVP (https://arxiv.org/abs/1605.08803).

Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.
flows.FixedInvertibleLinear()[source]

An implementation of an invertible linear layer from Glow: Generative Flow with Invertible 1x1 Convolutions (https://arxiv.org/abs/1605.08803).

Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.
flows.Invert(bijection)[source]

Inverts a tranformation so that its direct_fun is its inverse_fun and vice versa.

Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.
flows.InvertibleLinear()[source]

An implementation of an invertible linear layer from Glow: Generative Flow with Invertible 1x1 Convolutions (https://arxiv.org/abs/1605.08803).

Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.
flows.Logit(clip_before_logit=True)[source]

Computes the logit function on a set of inputs, with sigmoid function being its inverse.

Important note: Values passed through this layer are clipped to be within a range computable using 32 bits. This was done in “Cubic-Spline Flows” by Durkan et al. Technically this breaks invertibility, but it avoids inevitable NaNs.

Parameters:clip_before_logit – Whether to clip values to range [1e-5, 1 - 1e-5] before being passed through logit.
Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.
flows.MADE(transform)[source]

An implementation of MADE: Masked Autoencoder for Distribution Estimation (https://arxiv.org/abs/1502.03509).

Parameters:transform – maps inputs of dimension num_inputs to 2 * num_inputs
Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.
flows.Reverse()[source]

An implementation of a reversing layer from Density Estimation Using RealNVP (https://arxiv.org/abs/1605.08803).

Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.

Examples

>>> num_examples, input_dim, tol = 20, 3, 1e-4
>>> layer_rng, input_rng = random.split(random.PRNGKey(0))
>>> inputs = random.uniform(input_rng, (num_examples, input_dim))
>>> init_fun = Reverse()
>>> params, direct_fun, inverse_fun = init_fun(layer_rng, input_dim)
>>> mapped_inputs = direct_fun(params, inputs)[0]
>>> reconstructed_inputs = inverse_fun(params, mapped_inputs)[0]
>>> np.allclose(inputs, reconstructed_inputs).item()
True
flows.Serial(*init_funs)[source]
Parameters:*init_funs – Multiple bijections in sequence
Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.

Examples

>>> num_examples, input_dim, tol = 20, 3, 1e-4
>>> layer_rng, input_rng = random.split(random.PRNGKey(0))
>>> inputs = random.uniform(input_rng, (num_examples, input_dim))
>>> init_fun = Serial(Shuffle(), Shuffle())
>>> params, direct_fun, inverse_fun = init_fun(layer_rng, input_dim)
>>> mapped_inputs = direct_fun(params, inputs)[0]
>>> reconstructed_inputs = inverse_fun(params, mapped_inputs)[0]
>>> np.allclose(inputs, reconstructed_inputs).item()
True
flows.Shuffle()[source]

An implementation of a shuffling layer from Density Estimation Using RealNVP (https://arxiv.org/abs/1605.08803).

Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.

Examples

>>> num_examples, input_dim, tol = 20, 3, 1e-4
>>> layer_rng, input_rng = random.split(random.PRNGKey(0))
>>> inputs = random.uniform(input_rng, (num_examples, input_dim))
>>> init_fun = Shuffle()
>>> params, direct_fun, inverse_fun = init_fun(layer_rng, input_dim)
>>> mapped_inputs = direct_fun(params, inputs)[0]
>>> reconstructed_inputs = inverse_fun(params, mapped_inputs)[0]
>>> np.allclose(inputs, reconstructed_inputs).item()
True
flows.Sigmoid(clip_before_logit=True)[source]

Computes the sigmoid function on a set of inputs, with the logit function being its inverse.

Important note: Values passed through this layer are clipped to be within a range computable using 32 bits. This was done in “Cubic-Spline Flows” by Durkan et al. Technically this breaks invertibility, but it avoids inevitable NaNs.

Parameters:clip_before_logit – Whether to clip values to range [1e-5, 1 - 1e-5] before being passed through logit.
Returns:An init_fun mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet.

Distributions

flows.Flow(transformation, prior=<function Normal.<locals>.init_fun>)[source]
Parameters:
  • transformation – a function mapping (rng, input_dim) to a (params, direct_fun, inverse_fun) triplet
  • prior – a function mapping (rng, input_dim) to a (params, log_pdf, sample) triplet
Returns:

A function mapping (rng, input_dim) to a (params, log_pdf, sample) triplet.

Examples

>>> import flows
>>> input_dim, rng = 3, random.PRNGKey(0)
>>> transformation = flows.Serial(
...     flows.Reverse(),
...     flows.Reverse()
... )
>>> init_fun = flows.Flow(transformation, Normal())
>>> params, log_pdf, sample = init_fun(rng, input_dim)
flows.GMM(means, covariances, weights)[source]
flows.Normal()[source]
Returns:A function mapping (rng, input_dim) to a (params, log_pdf, sample) triplet.

Indices and tables