disentanglement

view markdown


VAEs

Some good disentangled VAE implementations are here and more general VAE implementations are here. Tensorflow implementations available here

The goal is to obtain a nice latent representation $\mathbf z$ for our inputs $\mathbf x$. To do this, we learn parameters $\phi$ for the encoder $p_\phi( \mathbf z\vert \mathbf x)$ and $\theta$ for the decoder $q_{\mathbf \theta} ( \mathbf x\vert \mathbf z)$. We do this with the standard vae setup, whereby a code $z$ is sampled, using the output of the encoder (intro to VAEs here).

disentangled vae losses

reconstruction loss compactness prior loss total correlation loss
encourages accurate reconstruction of the input encourages points to be compactly placed in space encourages latent variables to be independent
  • summarizing the losses
    • reconstruction loss - measures the quality of the reconstruction, the form of the loss changes based on the assumed distribution of the likelihood of each pixel
      • binary cross entropy loss - corresopnds to bernoulli distr., most common - doesn’t penalize (0.1, 0.2) and (0.4, 0.5) the same way, which might be problematic
      • mse loss - gaussian distr. - tends to focus on a fex pixels that are very wrong
      • l1 loss - laplace distr.
    • compactness prior loss
      • doesn’t use the extra injected latent noise
      • tries to push all the points to the same place
      • emphasises smoothness of z, using as few dimensions of z as possible, and the main axes of z to capture most of the data variability
      • usually assume prior is standard normal, resulting in pushing the code means to 0 and code variance to 1
      • we can again split this term $\sum_i \underbrace{\text{KL} \left(p_\phi( \mathbf z_i\vert \mathbf x):\vert\vert:prior(\mathbf z_i) \right)}{\text{compactness prior loss}} = \underbrace{\sum_i I(x; z)}{\text{mutual info}} + \underbrace{\text{KL} \left(p_\phi( \mathbf z_i):\vert\vert:prior(\mathbf z_i) \right)}_{\text{factorial prior loss}}$
    • total correlation loss - encourages factors to be independent
      • measures dependence between marginals of the latent vars
      • intractable (requires pass through the whole dset)
      • instead sample $dec_\phi(\mathbf z\vert \mathbf x)$ and create $\prod_j dec_\phi( \mathbf z_i\vert \mathbf x) $ by permuting across the batch dimension
        • now, calculate the kl with the density-ratio trick - train a classifier to approximate the ratio from these terms

disentangled vae in code

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(x_reconstructed, x, mu, logvar, beta=1):
  '''
  Params
  ------
  x_reconstructed: torch.Tensor
		Reconstructed input, with values between 0-1
	x: torch.Tensor
		input, values unrestricted
  '''
  
  # reconstruction loss (assuming bernoulli distr.)
  # BCE = sum_i [x_rec_i * log(x_i) + (1 - x_rec_i) * log(1-x_i)]
	rec_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')

  # compactness prior loss
	# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
	KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  
  # total correlation loss (calculate tc-vae way)
	z_sample = mu + torch.randn_like(exp(0.5 * logvar))
	log_pz, log_qz, log_prod_qzi, log_q_zCx = func(z_sample, mu, logvar)
  # I[z;x] = KL[q(z,x)\vert\vertq(x)q(z)] = E_x[KL[q(z\vertx)\vert\vertq(z)]]
  mi_loss = (log_q_zCx - log_qz).mean()
  # TC[z] = KL[q(z)\vert\vert\prod_i z_i]
  tc_loss = (log_qz - log_prod_qzi).mean()
  # dw_kl_loss is KL[q(z)\vert\vertp(z)] instead of usual KL[q(z\vertx)\vert\vertp(z))]
  dw_kl_loss = (log_prod_qzi - log_pz).mean()
  
	return rec_loss + beta * KLD

vaes for interpretation

  • icam (bass et al. 2020) - learn disentangled repr using vae with adv loss to make repr class-relevant
    • va-gan (baumgartner et al. 2018) - interpret features in GAN space

various vaes

  • vae (kingma & welling, 2013)
  • beta-vae (higgins et al. 2017) - add hyperparameter $\beta$ to weight the compactness prior term
  • beta-vae H (burgess et al. 2018) - add parameter $C$ to control the contribution of the compactness prior term
    • $\overbrace{\mathbb E_{p_\phi(\mathbf z\vert \mathbf x)}}^{\text{samples}} [ \underbrace{-\log q_{\mathbf \theta} ( \mathbf x\vert \mathbf z)}{\text{reconstruction loss}} ] + \textcolor{teal}{\beta}\; \vert\sum_i \underbrace{\text{KL} \left(p\phi( \mathbf z_i\vert \mathbf x):\vert\vert:prior(\mathbf z_i) \right)}_{\text{compactness prior loss}} -C\vert$
    • C is gradually increased from zero (allowing for a larger compactness prior loss) until good quality reconstruction is achieved
  • factor-vae (kim & minh 2018) - adds total correlation loss term
    • computes total correlation loss term using discriminator (can we discriminate between the samples when we shuffle over the batch dimension or not?)
    • beta-TC-VAE = beta-total-correlation VAE (chen et al. 2018) - same objective but computed without need for discriminator
      • use minibatch-weighted sampling to compute each of the 3 terms that make up the original VAE compactness prior loss
      • main idea is to better approximate $q(z)$ by weighting samples appropriately - biased, but easier to compute
    • Interpretable VAEs for nonlinear group factor analysis
  • Wasserstein Auto-Encoders (tolstikhin et al.) - removes the mutual info part of the loss
    • wasserstein distance = earth-movers distance, how far apart are 2 distrs
    • minimizes wasserstein distance + penalty which is similar to auto-encoding penalty, without the mutual info term
    • another intuition: rather than map each point to a ball (since VAE adds noise to each latent repr), we only constraint the overall distr of Z, potentially making reconstructions less blurry (but potentially making latent space less smooth)
    • wae
  • Adversarial Latent Autoencoder (pidhorskyi et al. 2020)
    • improve quality of generated VAE reconstructions by using a different setup which allows for using a GAN loss
    • alae
  • Variational Autoencoders Pursue PCA Directions (by Accident)
    • local orthogonality of the embedding transformation
    • prior $p(z)$ is standard normal, so encoder is assumed to be Gaussian with a certain mean, and diagonal covariance
    • disentanglement is sensitive to rotations of the latent embeddings but reconstruction err doesn’t care
    • for linear autoencoder w/ square-error as reconstruction loss, we recover PCA decomp
  • Disentangling Disentanglement in Variational Autoencoders (2019)
    • independence can be too simplistic, instead 2 things:
      • the latent encodings of data having an appropriate level of overlap
        • keeps encodings from just being a lookup table
        • when encoder is unimodal, $I(x; z)$ gives us a good handle on this
      • prior structure on the latents (e.g. independence, sparsity)
    • to trade these off, can penalize divergence between $q_\phi(z)$ and $p(z)$
    • nonisotropic priors - isotropic priors are only good up to rotation in the latent space
      • by chossing a nonisotropic prior (e.g. nonisotropic gaussian), can learn certain directions more easily
    • sparse prior - can help do clustering
  • VAE-SNE: a deep generative model for simultaneous dimensionality reduction and clustering (graving & couzin 2020) - reduce dims + cluster without specifying number of clusters
    • Screen Shot 2020-09-10 at 11.40.10 PM
    • stochastic neighbor regularizer that optimizes pairwise similarity kernels between original and latent distrs. to strengthen local neighborhood preservation
      • can use different neighbor kernels, e.g. t-SNE similarity (van der Maaten & Hinton, 2008) or Gaussian SNE kernel (Hinton & Roweis, 2003)
      • perplexity annealing technique (Kobak and Berens, 2019) - decay the size of local neighborhoods during training (helps to preserve structure at multiple scales)
    • Gaussian mixture prior for learning latent distr. (with very large number of clusters)
    • extensive evaluation - test several datasets / methods and evaluate how well the first 2 dimensions preserve the following:
      • global - correlation between pairwise distances in orig/latent spaces
      • local - both metric (distance- or radius-based) and topological (neighbor-based) neighborhoods which are 1% of total embedding size
      • fine-scale - neighborhoods which are <1% of total embedding size
      • temporal info (for time-series data only) - correlation between latent and original temporal derivatives
      • likelihood on out-of-sample data
    • further advancements
      • embed into polar coordinates (rather than Euclidean) helps a lot
      • convolutional VAE-SNE - extract features from images using some pre-trained net and then run VAE-SNE on these features
    • background: earlier works also used SNE objective for regularization - starts with van der Maaten 2009 (parametric t-SNE)
    • future work: density-preserving versions of t-SNE, modeling hierarchical structure in vae, conditional t-SNE kernel
  • A Survey of Inductive Biases for Factorial Representation-Learning (ridgeway 2016)
    • desiderata
      • compact
      • faithful - preserve info required for task
      • explicitly represent the attributes required for the task at hand
      • interpretable by humans
    • factorial representation - attributes are statistically independent and can provide a userful bias for learning
      • “compete” - factors are more orthogonal
      • “cooperate” - factors are more similar
      • bias on distribution of factors
        • PCA - minimize reconstruction err. subject to orthogonal weights
        • ICA - maximize non-Gaussianity (can also have sparse ICA)
      • bias on factors being invariant to certain types of changes
        • ISA (independent subspace analysis) - 2 layer model where first layer is linear, 2nd layer pools first layer (not maxpool, more like avgpool), sparsity at second layer
          • i.e. 1st layer cooperates, 2nd layer competes
        • VQ - vector quantizer - like ISA but first layer filters now compete and 2nd layer cooperates
        • SOM - encourages topographic map by enforcing nearby filters to be similar
      • bias in how factors are combined
        • linear combination - PCA/ICA
        • multilinear models - multiplicative interactions between factors (e.g on top of ISA)
        • functional parts - factor components are combined to construct the output
          • ex. NMF - parts can only add, not substract to total output
          • ex. have each pixel in the output be represented by only one factor in a VQ
        • hierarchical layers
          • ex. R-ICA - recursive ICA - run ICA on coefficients from previous layer (after some transformation)
    • supervision bias
      • constraints on some examples
        • e.g. some groups have same value for a factor
        • e.g. some examples have similar distances (basis for MDS = multidimensional scaling)
        • e.g. analogies between examples
        • can do all of these things with auto-encoders
  • more papers

GANs

model-based (disentangle during training)

post-hoc (disentangle after training)

misc

  • Learning Diverse and Discriminative Representations via the Principle of Maximal Coding Rate Reduction (yu, …, & ma, 2020)
    • goal: learn low-dimensional structure from high-dim (labeled or unlabeled) data
    • approach: instead of cross-entropy loss, use maximal coding rate reduction = MCR loss function to learn linear feature space where:
      • inter-class discriminative - features of samples from different classes/clusters are uncorrelated + different low-dim linear subspaces
      • intra-class compressible - features of samples from same class/cluster are correlated (i.e. belong to low-dim linear subspace)
      • maximally diverse - dimension (or variance) of features for each class/cluster should be as large as possible as long as uncorrelated from other classes/clusters
    • related to nonlinear generalized PCA
    • given random variable $z$ and precision $\epsilon$, rate distortion $R(z, \epsilon)$ is minimal number of bits to encode $z$ such that expected decoding err is less than $\epsilon$
      • can compute from finite samples
      • can compute for each class (diagonal matrices represent class/cluster membership in loss function)
      • MCR maximizes (rate distortion for all features) - (rate distortion for all data separated into classes)
        • like a generalization of information gain
    • evaluation
      • with label corruption performs better
  • Learned Equivariant Rendering without Transformation Supervision - separate foreground / background using video

(semi)-supervised disentanglement

these papers use some form of supervision for the latent space when disentangling

evaluating disentanglement

  • Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations (locatello et al. 2019)
    • state of disentanglement is very poor…depends a lot on architecture/hyperparameters
    • good way to evaluate: make explicit inductive biases, investigate benefits of this disentanglement
    • defining disentanglement - compact, interpretable, independent, helpful for downstream tasks, causal inference
      • a change in one factor of variation should lead to a change in a single factor in the learned repr.
    • unsupervised learning of disentangled reprs. is impossible without inductive biases
  • note - vae’s come with reconstruction loss + compactness prior loss which can be looked at on their own
  • data

  • beta-vae disentanglement metric score = higgins metric - see if we can capture known disentangled repr. using pairs of things where only one thing changes
    • start with a known generative model that has an observed set of independent and interpretable factors (e.g. scale, color, etc.) that can be used to simulate data.
    • create a dataset comprised of pairs of generated data for which a single factor is held constant (e.g. a pair of images which have objects with the same color).
    • use the inference network to map each pair of images to a pair of latent variables.
    • train a linear classifier to predict which interpretable factor was held constant based on the latent representations. The accuracy of this predictor is the disentanglement metric score.
  • Evaluating Disentangled Representations (sepliarskaia et al. 2019)
    • defn 1 (Higgins et al., 2017; Kim and Mnih, 2018; Eastwood and Williams, 2018) = factorVAE metric: A disentangled representation is a representation where a change in one latent dimension corresponds to a change in one generative factor while being relatively invariant to changes in other generative factors.
    • defn 2 (Locatello et al., 2018; Kumar et al., 2017): A disentangled representation is a representation where a change in a single generative factor leads to a change in a single factor in the learned representation.
    • metrics
      • DCI: Eastwood and Williams (2018) - informativeness based on predicting gt factors using latent factors
      • SAP: Kumar et al. (2017) - how much does top latent factor match gt more than 2nd latent factor
      • mutual info gap MIG: Chen et al. 2018 - mutual info to compute the same thing
      • modularity (ridgeway & mozer, 2018) - if each dimension of r(x) depends on at most a factor of variation using their mutual info

non-deep methods

  • unifying vae and nonlinear ica (khemakhem et al. 2020)
    • ICA
      • maximize non-gaussianity of $z$ - use kurtosis, negentropy
      • minimize mutual info between components of $z$ - use KL, max entropyd