Source code for flows.distributions

import jax.numpy as np
from jax import random
from jax.scipy.special import logsumexp
from jax.scipy.stats import norm, multivariate_normal


[docs]def Normal(): """ Returns: A function mapping ``(rng, input_dim)`` to a ``(params, log_pdf, sample)`` triplet. """ def init_fun(rng, input_dim): def log_pdf(params, inputs): return norm.logpdf(inputs).sum(1) def sample(rng, params, num_samples=1): return random.normal(rng, (num_samples, input_dim)) return (), log_pdf, sample return init_fun
[docs]def GMM(means, covariances, weights): def init_fun(rng, input_dim): def log_pdf(params, inputs): cluster_lls = [] for log_weight, mean, cov in zip(np.log(weights), means, covariances): cluster_lls.append(log_weight + multivariate_normal.logpdf(inputs, mean, cov)) return logsumexp(np.vstack(cluster_lls), axis=0) def sample(rng, params, num_samples=1): cluster_samples = [] for mean, cov in zip(means, covariances): rng, temp_rng = random.split(rng) cluster_sample = random.multivariate_normal(temp_rng, mean, cov, (num_samples,)) cluster_samples.append(cluster_sample) samples = np.dstack(cluster_samples) idx = random.categorical(rng, weights, shape=(num_samples, 1, 1)) return np.squeeze(np.take_along_axis(samples, idx, -1)) return (), log_pdf, sample return init_fun
[docs]def Flow(transformation, prior=Normal()): """ Args: transformation: a function mapping ``(rng, input_dim)`` to a ``(params, direct_fun, inverse_fun)`` triplet prior: a function mapping ``(rng, input_dim)`` to a ``(params, log_pdf, sample)`` triplet Returns: A function mapping ``(rng, input_dim)`` to a ``(params, log_pdf, sample)`` triplet. Examples: >>> import flows >>> input_dim, rng = 3, random.PRNGKey(0) >>> transformation = flows.Serial( ... flows.Reverse(), ... flows.Reverse() ... ) >>> init_fun = flows.Flow(transformation, Normal()) >>> params, log_pdf, sample = init_fun(rng, input_dim) """ def init_fun(rng, input_dim): transformation_rng, prior_rng = random.split(rng) params, direct_fun, inverse_fun = transformation(transformation_rng, input_dim) prior_params, prior_log_pdf, prior_sample = prior(prior_rng, input_dim) def log_pdf(params, inputs): u, log_det = direct_fun(params, inputs) log_probs = prior_log_pdf(prior_params, u) return log_probs + log_det def sample(rng, params, num_samples=1): prior_samples = prior_sample(rng, prior_params, num_samples) return inverse_fun(params, prior_samples)[0] return params, log_pdf, sample return init_fun