Augmentations

adjust_brightness(image, delta)

Shifts the brightness of an RGB image by a given amount.

adjust_contrast(image, factor, *[, channel_axis])

Adjusts the contrast of an RGB image by a given multiplicative amount.

adjust_gamma(image, gamma, *[, gain, ...])

Adjusts the gamma of an RGB image.

adjust_hue(image, delta, *[, channel_axis])

Adjusts the hue of an RGB image by a given multiplicative amount.

adjust_saturation(image, factor, *[, ...])

Adjusts the saturation of an RGB image by a given multiplicative amount.

affine_transform(image, matrix, *[, offset, ...])

Applies an affine transformation given by matrix.

center_crop(image, height, width, *[, ...])

Crops an image to the given size keeping the same center of the original.

elastic_deformation(key, image, alpha, sigma, *)

Applies an elastic deformation to the given image.

flip_left_right(image, *[, channel_axis])

Flips an image along the horizontal axis.

flip_up_down(image, *[, channel_axis])

Flips an image along the vertical axis.

gaussian_blur(image, sigma, kernel_size, *)

Applies gaussian blur (convolution with a Gaussian kernel).

pad_to_size(image, target_height, ...[, ...])

Pads an image to the given size keeping the original image centered.

random_brightness(key, image, max_delta)

adjust_brightness(...) with random delta in [-max_delta, max_delta).

random_contrast(key, image, lower, upper, *)

adjust_contrast(...) with random factor in [lower, upper).

random_crop(key, image, crop_sizes)

Crop images randomly to specified sizes.

random_flip_left_right(key, image, *[, ...])

Applies flip_left_right with a given probability.

random_flip_up_down(key, image, *[, probability])

Applies flip_up_down with a given probability.

random_gamma(key, image, min_gamma, max_gamma, *)

adjust_gamma(...) with random gamma in [min_gamma, max_gamma)`.

random_hue(key, image, max_delta, *[, ...])

adjust_hue(...) with random delta in [-max_delta, max_delta).

random_saturation(key, image, lower, upper, *)

adjust_saturation(...) with random factor in [lower, upper).

resize_with_crop_or_pad(image, ...[, ...])

Crops and/or pads an image to a target width and height.

rotate(image, angle, *[, order, mode, cval])

Rotates an image around its center using interpolation.

rot90(image[, k, channel_axis])

Rotates an image counter-clockwise by 90 degrees.

solarize(image, threshold)

Applies solarization to an image.

adjust_brightness

dm_pix.adjust_brightness(image, delta)[source]

Shifts the brightness of an RGB image by a given amount.

This is equivalent to tf.image.adjust_brightness.

Parameters:
  • image (chex.Array) – an RGB image, given as a float tensor in [0, 1].

  • delta (chex.Numeric) – the (additive) amount to shift each channel by.

Return type:

chex.Array

Returns:

The brightness-adjusted image. May be outside of the [0, 1] range.

adjust_contrast

dm_pix.adjust_contrast(image, factor, *, channel_axis=-1)[source]

Adjusts the contrast of an RGB image by a given multiplicative amount.

This is equivalent to tf.image.adjust_contrast.

Parameters:
  • image (chex.Array) – an RGB image, given as a float tensor in [0, 1].

  • factor (chex.Numeric) – the (multiplicative) amount to adjust contrast by.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The contrast-adjusted image. May be outside of the [0, 1] range.

adjust_gamma

dm_pix.adjust_gamma(image, gamma, *, gain=1.0, assume_in_bounds=False)[source]

Adjusts the gamma of an RGB image.

This is equivalent to tf.image.adjust_gamma, i.e. returns gain * image ** gamma.

Parameters:
  • image (chex.Array) – an RGB image, given as a [0-1] float tensor.

  • gamma (chex.Numeric) – the exponent to apply.

  • gain (chex.Numeric) – the (multiplicative) gain to apply.

  • assume_in_bounds (bool) – whether the input image should be assumed to have all values within [0, 1]. If False (default), the inputs will be clipped to that range avoid NaNs.

Return type:

chex.Array

Returns:

The gamma-adjusted image.

adjust_hue

dm_pix.adjust_hue(image, delta, *, channel_axis=-1)[source]

Adjusts the hue of an RGB image by a given multiplicative amount.

This is equivalent to tf.image.adjust_hue when TF is running on GPU. When running on CPU, the results will be different if all RGB values for a pixel are outside of the [0, 1] range.

Parameters:
  • image (chex.Array) – an RGB image, given as a [0-1] float tensor.

  • delta (chex.Numeric) – the (additive) angle to shift hue by.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The saturation-adjusted image.

adjust_saturation

dm_pix.adjust_saturation(image, factor, *, channel_axis=-1)[source]

Adjusts the saturation of an RGB image by a given multiplicative amount.

This is equivalent to tf.image.adjust_saturation.

Parameters:
  • image (chex.Array) – an RGB image, given as a [0-1] float tensor.

  • factor (chex.Numeric) – the (multiplicative) amount to adjust saturation by.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The saturation-adjusted image.

affine_transform

dm_pix.affine_transform(image, matrix, *, offset=0.0, order=1, mode='nearest', cval=0.0)[source]

Applies an affine transformation given by matrix.

Given an output image pixel index vector o, the pixel value is determined from the input image at position jnp.dot(matrix, o) + offset.

This does ‘pull’ (or ‘backward’) resampling, transforming the output space to the input to locate data. Affine transformations are often described in the ‘push’ (or ‘forward’) direction, transforming input to output. If you have a matrix for the ‘push’ transformation, use its inverse (jax.numpy.linalg.inv) in this function.

Parameters:
  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either HWC or CHW.

  • matrix (chex.Array) –

    the inverse coordinate transformation matrix, mapping output coordinates to input coordinates. If ndim is the number of dimensions of input, the given matrix must have one of the following shapes:

    • (ndim, ndim): the linear transformation matrix for each output coordinate.

    • (ndim,): assume that the 2-D transformation matrix is diagonal, with the diagonal specified by the given value.

    • (ndim + 1, ndim + 1): assume that the transformation is specified using homogeneous coordinates [1]. In this case, any value passed to offset is ignored.

    • (ndim, ndim + 1): as above, but the bottom row of a homogeneous transformation matrix is always [0, 0, 0, 1], and may be omitted.

  • offset (Union[chex.Array, chex.Numeric]) – the offset into the array where the transform is applied. If a float, offset is the same for each axis. If an array, offset should contain one value for each axis.

  • order (int) – the order of the spline interpolation, default is 1. The order has to be in the range [0-1]. Note that PIX interpolation will only be used for order=1, for other values we use jax.scipy.ndimage.map_coordinates.

  • mode (str) – the mode parameter determines how the input array is extended beyond its boundaries. Default is ‘nearest’. Modes ‘nearest and ‘constant’ use PIX interpolation, which is very fast on accelerators (especially on TPUs). For all other modes, ‘wrap’, ‘mirror’ and ‘reflect’, we rely on jax.scipy.ndimage.map_coordinates, which however is slow on accelerators, so use it with care.

  • cval (float) – value to fill past edges of input if mode is ‘constant’. Default is 0.0.

Return type:

chex.Array

Returns:

The input image transformed by the given matrix.

Example transformations:

Rotation:

>>> angle = jnp.pi / 4
>>> matrix = jnp.array([
...    [jnp.cos(rotation), -jnp.sin(rotation), 0],
...    [jnp.sin(rotation), jnp.cos(rotation), 0],
...    [0, 0, 1],
... ])
>>> result = dm_pix.affine_transform(image=image, matrix=matrix)

Translation can be expressed through either the matrix itself or the offset parameter:

>>> matrix = jnp.array([
...   [1, 0, 0, 25],
...   [0, 1, 0, 25],
...   [0, 0, 1, 0],
... ])
>>> result = dm_pix.affine_transform(image=image, matrix=matrix)
>>> # Or with offset:
>>> matrix = jnp.array([
...   [1, 0, 0],
...   [0, 1, 0],
...   [0, 0, 1],
... ])
>>> offset = jnp.array([25, 25, 0])
>>> result = dm_pix.affine_transform(
        image=image, matrix=matrix, offset=offset)

Reflection:

>>> matrix = jnp.array([
...   [-1, 0, 0],
...   [0, 1, 0],
...   [0, 0, 1],
... ])
>>> result = dm_pix.affine_transform(image=image, matrix=matrix)

Scale:

>>> matrix = jnp.array([
...   [2, 0, 0],
...   [0, 1, 0],
...   [0, 0, 1],
... ])
>>> result = dm_pix.affine_transform(image=image, matrix=matrix)

Shear:

>>> matrix = jnp.array([
...   [1, 0.5, 0],
...   [0.5, 1, 0],
...   [0, 0, 1],
... ])
>>> result = dm_pix.affine_transform(image=image, matrix=matrix)

One can also combine different transformations matrices:

>>> matrix = rotation_matrix.dot(translation_matrix)

center_crop

dm_pix.center_crop(image, height, width, *, channel_axis=-1)[source]

Crops an image to the given size keeping the same center of the original.

Target height/width given can be greater than the current size of the image which results in being a no-op for that dimension.

In case of odd size along any dimension the bottom/right side gets the extra pixel.

Parameters:
  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either …HWC or …CHW.

  • height (chex.Numeric) – target height to crop the image to.

  • width (chex.Numeric) – target width to crop the image to.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The cropped image(s).

elastic_deformation

dm_pix.elastic_deformation(key, image, alpha, sigma, *, order=1, mode='nearest', cval=0.0, channel_axis=-1)[source]

Applies an elastic deformation to the given image.

Introduced by [Simard, 2003] and popularized by [Ronneberger, 2015]. Deforms images by moving pixels locally around using displacement fields.

Small sigma values (< 1.) give pixelated images while higher values result in water like results. Alpha should be in the between x5 and x10 the value given for sigma for sensible results.

Parameters:
  • key (chex.PRNGKey) – key: a JAX RNG key.

  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either HWC or CHW.

  • alpha (chex.Numeric) – strength of the distortion field. Higher values mean that pixels are moved further with respect to the distortion field’s direction.

  • sigma (chex.Numeric) – standard deviation of the gaussian kernel used to smooth the distortion fields.

  • order (int) – the order of the spline interpolation, default is 1. The order has to be in the range [0, 1]. Note that PIX interpolation will only be used for order=1, for other values we use jax.scipy.ndimage.map_coordinates.

  • mode (str) – the mode parameter determines how the input array is extended beyond its boundaries. Default is ‘nearest’. Modes ‘nearest and ‘constant’ use PIX interpolation, which is very fast on accelerators (especially on TPUs). For all other modes, ‘wrap’, ‘mirror’ and ‘reflect’, we rely on jax.scipy.ndimage.map_coordinates, which however is slow on accelerators, so use it with care.

  • cval (float) – value to fill past edges of input if mode is ‘constant’. Default is 0.0.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The transformed image.

flip_left_right

dm_pix.flip_left_right(image, *, channel_axis=-1)[source]

Flips an image along the horizontal axis.

Assumes that the image is either …HWC or …CHW and flips the W axis.

Parameters:
  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either …HWC or …CHW.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The flipped image.

flip_up_down

dm_pix.flip_up_down(image, *, channel_axis=-1)[source]

Flips an image along the vertical axis.

Assumes that the image is either …HWC or …CHW, and flips the H axis.

Parameters:
  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either …HWC or …CHW.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The flipped image.

gaussian_blur

dm_pix.gaussian_blur(image, sigma, kernel_size, *, padding='SAME', channel_axis=-1)[source]

Applies gaussian blur (convolution with a Gaussian kernel).

Parameters:
  • image (chex.Array) – the input image, as a [0-1] float tensor. Should have 3 or 4 dimensions with two spatial dimensions.

  • sigma (float) – the standard deviation (in pixels) of the gaussian kernel.

  • kernel_size (float) – the size (in pixels) of the square gaussian kernel. Will be “rounded” to the next odd integer.

  • padding (str) – either “SAME” or “VALID”, passed to the underlying convolution.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The blurred image.

pad_to_size

dm_pix.pad_to_size(image, target_height, target_width, *, mode='constant', pad_kwargs=None, channel_axis=-1)[source]

Pads an image to the given size keeping the original image centered.

For different padding methods and kwargs please see: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.pad.html

In case of odd size difference along any dimension the bottom/right side gets the extra padding pixel.

Target size can be smaller than original size which results in a no-op for such dimension.

Parameters:
  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either …HWC or …CHW.

  • target_height (int) – target height to pad the image to.

  • target_width (int) – target width to pad the image to.

  • mode (str) – Mode for padding the images, see jax.numpy.pad for details. Default is constant.

  • pad_kwargs (Optional[Any]) – Keyword arguments to pass jax.numpy.pad, see documentation for options.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The padded image(s).

random_brightness

dm_pix.random_brightness(key, image, max_delta)[source]

adjust_brightness(…) with random delta in [-max_delta, max_delta).

Return type:

chex.Array

random_contrast

dm_pix.random_contrast(key, image, lower, upper, *, channel_axis=-1)[source]

adjust_contrast(…) with random factor in [lower, upper).

Return type:

chex.Array

random_crop

dm_pix.random_crop(key, image, crop_sizes)[source]

Crop images randomly to specified sizes.

Given an input image, it crops the image to the specified crop_sizes. If crop_sizes are lesser than the image’s sizes, the offset for cropping is chosen at random. To deterministically crop an image, please use jax.lax.dynamic_slice and specify offsets and crop sizes.

Parameters:
  • key (chex.PRNGKey) – key for pseudo-random number generator.

  • image (chex.Array) – a JAX array which represents an image.

  • crop_sizes (Sequence[int]) – a sequence of integers, each of which sequentially specifies the crop size along the corresponding dimension of the image. Sequence length must be identical to the rank of the image and the crop size should not be greater than the corresponding image dimension.

Return type:

chex.Array

Returns:

A cropped image, a JAX array whose shape is same as crop_sizes.

random_flip_left_right

dm_pix.random_flip_left_right(key, image, *, probability=0.5)[source]

Applies flip_left_right with a given probability.

Parameters:
  • key (chex.PRNGKey) – a JAX RNG key.

  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either …HWC or …CHW.

  • probability (chex.Numeric) – the probability of applying flip_left_right transform. Must be a value in [0, 1].

Return type:

chex.Array

Returns:

A left-right flipped image if condition is met, otherwise original image.

random_flip_up_down

dm_pix.random_flip_up_down(key, image, *, probability=0.5)[source]

Applies flip_up_down with a given probability.

Parameters:
  • key (chex.PRNGKey) – a JAX RNG key.

  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either …HWC or …CHW.

  • probability (chex.Numeric) – the probability of applying flip_up_down transform. Must be a value in [0, 1].

Return type:

chex.Array

Returns:

An up-down flipped image if condition is met, otherwise original image.

random_gamma

dm_pix.random_gamma(key, image, min_gamma, max_gamma, *, gain=1, assume_in_bounds=False)[source]

adjust_gamma(…) with random gamma in [min_gamma, max_gamma)`.

Return type:

chex.Array

random_hue

dm_pix.random_hue(key, image, max_delta, *, channel_axis=-1)[source]

adjust_hue(…) with random delta in [-max_delta, max_delta).

Return type:

chex.Array

random_saturation

dm_pix.random_saturation(key, image, lower, upper, *, channel_axis=-1)[source]

adjust_saturation(…) with random factor in [lower, upper).

Return type:

chex.Array

resize_with_crop_or_pad

dm_pix.resize_with_crop_or_pad(image, target_height, target_width, *, pad_mode='constant', pad_kwargs=None, channel_axis=-1)[source]

Crops and/or pads an image to a target width and height.

Equivalent in functionality to tf.image.resize_with_crop_or_pad but allows for different padding methods as well beyond zero padding.

Parameters:
  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either …HWC or …CHW.

  • target_height (chex.Numeric) – target height to crop or pad the image to.

  • target_width (chex.Numeric) – target width to crop or pad the image to.

  • pad_mode (str) – mode for padding the images, see jax.numpy.pad for details. Default is constant.

  • pad_kwargs (Optional[Any]) – keyword arguments to pass jax.numpy.pad, see documentation for options.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The image(s) resized by crop or pad to the desired target size.

rotate

dm_pix.rotate(image, angle, *, order=1, mode='nearest', cval=0.0)[source]

Rotates an image around its center using interpolation.

Parameters:
  • image (chex.Array) – a JAX array representing an image. Assumes that the image is either HWC or CHW.

  • angle (float) – the counter-clockwise rotation angle in units of radians.

  • order (int) – the order of the spline interpolation, default is 1. The order has to be in the range [0,1]. See affine_transform for details.

  • mode (str) – the mode parameter determines how the input array is extended beyond its boundaries. Default is ‘nearest’. See affine_transform for details.

  • cval (float) – value to fill past edges of input if mode is ‘constant’. Default is 0.0.

Return type:

chex.Array

Returns:

The rotated image.

rot90

dm_pix.rot90(image, k=1, *, channel_axis=-1)[source]

Rotates an image counter-clockwise by 90 degrees.

This is equivalent to tf.image.rot90. Assumes that the image is either …HWC or …CHW.

Parameters:
  • image (chex.Array) – an RGB image, given as a float tensor in [0, 1].

  • k (int) – the number of times the rotation is applied.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The rotated image.

solarize

dm_pix.solarize(image, threshold)[source]

Applies solarization to an image.

All values above a given threshold will be inverted.

Parameters:
  • image (chex.Array) – an RGB image, given as a [0-1] float tensor.

  • threshold (chex.Numeric) – the threshold for inversion.

Return type:

chex.Array

Returns:

The solarized image.

Color conversions

hsl_to_rgb(image_hsl, *[, channel_axis])

Converts an image from HSL to RGB.

hsv_to_rgb(image_hsv, *[, channel_axis])

Converts an image from HSV to RGB.

rgb_to_hsl(image_rgb, *[, channel_axis])

Converts an image from RGB to HSL.

rgb_to_hsv(image_rgb, *[, channel_axis])

Converts an image from RGB to HSV.

rgb_to_grayscale(image, *[, keep_dims, ...])

Converts an image to a grayscale image using the luma value.

hsl_to_rgb

dm_pix.hsl_to_rgb(image_hsl, *, channel_axis=-1)[source]

Converts an image from HSL to RGB.

Parameters:
  • image_hsl (chex.Array) – an HSV image, with float values in range [0, 1]. Behavior outside of these bounds is not guaranteed.

  • channel_axis (int) – the channel axis. image_hsv should have 3 layers along this axis.

Return type:

chex.Array

Returns:

An RGB image, with float values in range [0, 1], stacked along channel_axis.

hsv_to_rgb

dm_pix.hsv_to_rgb(image_hsv, *, channel_axis=-1)[source]

Converts an image from HSV to RGB.

Parameters:
  • image_hsv (chex.Array) – an HSV image, with float values in range [0, 1]. Behavior outside of these bounds is not guaranteed.

  • channel_axis (int) – the channel axis. image_hsv should have 3 layers along this axis.

Return type:

chex.Array

Returns:

An RGB image, with float values in range [0, 1], stacked along channel_axis.

rgb_to_hsl

dm_pix.rgb_to_hsl(image_rgb, *, channel_axis=-1)[source]

Converts an image from RGB to HSL.

Parameters:
  • image_rgb (chex.Array) – an RGB image, with float values in range [0, 1]. Behavior outside of these bounds is not guaranteed.

  • channel_axis (int) – the channel axis. image_rgb should have 3 layers along this axis.

Return type:

chex.Array

Returns:

An HSL image, with float values in range [0, 1], stacked along channel_axis.

rgb_to_hsv

dm_pix.rgb_to_hsv(image_rgb, *, channel_axis=-1)[source]

Converts an image from RGB to HSV.

Parameters:
  • image_rgb (chex.Array) – an RGB image, with float values in range [0, 1]. Behavior outside of these bounds is not guaranteed.

  • channel_axis (int) – the channel axis. image_rgb should have 3 layers along this axis.

Return type:

chex.Array

Returns:

An HSV image, with float values in range [0, 1], stacked along channel_axis.

rgb_to_grayscale

dm_pix.rgb_to_grayscale(image, *, keep_dims=False, luma_standard='rec601', channel_axis=-1)[source]

Converts an image to a grayscale image using the luma value.

This is equivalent to tf.image.rgb_to_grayscale (when keep_channels=False).

Parameters:
  • image (chex.Array) – an RGB image, given as a float tensor in [0, 1].

  • keep_dims (bool) – if False (default), returns a tensor with a single channel. If True, will tile the resulting channel.

  • luma_standard – the luma standard to use, either “rec601”, “rec709” or “bt2001”. The default rec601 corresponds to TensorFlow’s.

  • channel_axis (int) – the index of the channel axis.

Return type:

chex.Array

Returns:

The grayscale image.

Depth and space transformations

depth_to_space(inputs, block_size)

Rearranges data from depth into blocks of spatial data.

space_to_depth(inputs, block_size)

Rearranges data from blocks of spatial data into depth.

depth_to_space

dm_pix.depth_to_space(inputs, block_size)[source]

Rearranges data from depth into blocks of spatial data.

Parameters:
  • inputs (chex.Array) – Array of shape [H, W, C] or [N, H, W, C]. The number of channels (depth dimension) must be divisible by block_size ** 2.

  • block_size (int) – Size of spatial blocks >= 2.

Return type:

chex.Array

Returns:

For inputs of shape [H, W, C] the output is a reshaped array of shape [H * B, W * B, C / (B ** 2)], where B is block_size. If there’s a leading batch dimension, it stays unchanged.

space_to_depth

dm_pix.space_to_depth(inputs, block_size)[source]

Rearranges data from blocks of spatial data into depth.

This is the reverse of depth_to_space.

Parameters:
  • inputs (chex.Array) – Array of shape [H, W, C] or [N, H, W, C]. The height and width must each be divisible by block_size.

  • block_size (int) – Size of spatial blocks >= 2.

Return type:

chex.Array

Returns:

For inputs of shape [H, W, C] the output is a reshaped array of shape [H / B, W / B, C * (B ** 2)], where B is block_size. If there’s a leading batch dimension, it stays unchanged.

Interpolation functions

flat_nd_linear_interpolate(volume, ...[, ...])

Maps the input ND volume to coordinates by linear interpolation.

flat_nd_linear_interpolate_constant(volume, ...)

Maps volume by interpolation and returns a constant outside boundaries.

flat_nd_linear_interpolate

dm_pix.flat_nd_linear_interpolate(volume, coordinates, *, unflattened_vol_shape=None)[source]

Maps the input ND volume to coordinates by linear interpolation.

Parameters:
  • volume (chex.Array) – A volume (flat if unflattened_vol_shape is provided) where to query coordinates.

  • coordinates (chex.Array) – An array of shape (N, M_coordinates). Where M_coordinates can be M-dimensional. If M_coordinates == 1, then coordinates.shape can simply be (N,), e.g. if N=3 and M_coordinates=1, this has the form (z, y, x).

  • unflattened_vol_shape (Optional[Sequence[int]]) – The shape of the volume before flattening. If provided, then volume must be pre-flattened.

Return type:

chex.Array

Returns:

The resulting mapped coordinates. The shape of the output is M_coordinates (derived from coordinates by dropping the first axis).

flat_nd_linear_interpolate_constant

dm_pix.flat_nd_linear_interpolate_constant(volume, coordinates, *, cval=0.0, unflattened_vol_shape=None)[source]

Maps volume by interpolation and returns a constant outside boundaries.

Maps the input ND volume to coordinates by linear interpolation, but returns a constant value if the coordinates fall outside the volume boundary.

Parameters:
  • volume (chex.Array) – A volume (flat if unflattened_vol_shape is provided) where to query coordinates.

  • coordinates (chex.Array) – An array of shape (N, M_coordinates). Where M_coordinates can be M-dimensional. If M_coordinates == 1, then coordinates.shape can simply be (N,), e.g. if N=3 and M_coordinates=1, this has the form (z, y, x).

  • cval (Optional[float]) – A constant value to map to for coordinates that fall outside the volume boundaries.

  • unflattened_vol_shape (Optional[Sequence[int]]) – The shape of the volume before flattening. If provided, then volume must be pre-flattened.

Return type:

chex.Array

Returns:

The resulting mapped coordinates. The shape of the output is M_coordinates (derived from coordinates by dropping the first axis).

Metrics

mae(a, b)

Returns the Mean Absolute Error between a and b.

mse(a, b)

Returns the Mean Squared Error between a and b.

psnr(a, b)

Returns the Peak Signal-to-Noise Ratio between a and b.

rmse(a, b)

Returns the Root Mean Squared Error between a and b.

simse(a, b)

Returns the Scale-Invariant Mean Squared Error between a and b.

ssim(a, b, *[, max_val, filter_size, ...])

Computes the structural similarity index (SSIM) between image pairs.

mae

dm_pix.mae(a, b)[source]

Returns the Mean Absolute Error between a and b.

Parameters:
  • a (chex.Array) – First image (or set of images).

  • b (chex.Array) – Second image (or set of images).

Return type:

chex.Numeric

Returns:

MAE between a and b.

mse

dm_pix.mse(a, b)[source]

Returns the Mean Squared Error between a and b.

Parameters:
  • a (chex.Array) – First image (or set of images).

  • b (chex.Array) – Second image (or set of images).

Return type:

chex.Numeric

Returns:

MSE between a and b.

psnr

dm_pix.psnr(a, b)[source]

Returns the Peak Signal-to-Noise Ratio between a and b.

Assumes that the dynamic range of the images (the difference between the maximum and the minimum allowed values) is 1.0.

Parameters:
  • a (chex.Array) – First image (or set of images).

  • b (chex.Array) – Second image (or set of images).

Return type:

chex.Numeric

Returns:

PSNR in decibels between a and b.

rmse

dm_pix.rmse(a, b)[source]

Returns the Root Mean Squared Error between a and b.

Parameters:
  • a (chex.Array) – First image (or set of images).

  • b (chex.Array) – Second image (or set of images).

Return type:

chex.Numeric

Returns:

RMSE between a and b.

simse

dm_pix.simse(a, b)[source]

Returns the Scale-Invariant Mean Squared Error between a and b.

For each image pair, a scaling factor for b is computed as the solution to the following problem:

min_alpha || vec(a) - alpha * vec(b) ||_2^2

where a and b are flattened, i.e., vec(x) = np.flatten(x). The MSE between the optimally scaled b and a is returned: mse(a, alpha*b).

This is a scale-invariant metric, so for example: simse(x, y) == sims(x, y*5).

This metric was used in “Shape, Illumination, and Reflectance from Shading” by Barron and Malik, TPAMI, ‘15.

Parameters:
  • a (chex.Array) – First image (or set of images).

  • b (chex.Array) – Second image (or set of images).

Return type:

chex.Numeric

Returns:

SIMSE between a and b.

ssim

dm_pix.ssim(a, b, *, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, return_map=False, precision=Precision.HIGHEST, filter_fn=None)[source]

Computes the structural similarity index (SSIM) between image pairs.

This function is based on the standard SSIM implementation from: Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, “Image quality assessment: from error visibility to structural similarity”, in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, 2004.

This function was modeled after tf.image.ssim, and should produce comparable output.

Note: the true SSIM is only defined on grayscale. This function does not perform any colorspace transform. If the input is in a color space, then it will compute the average SSIM.

Parameters:
  • a (chex.Array) – First image (or set of images).

  • b (chex.Array) – Second image (or set of images).

  • max_val (float) – The maximum magnitude that a or b can have.

  • filter_size (int) – Window size (>= 1). Image dims must be at least this small.

  • filter_sigma (float) – The bandwidth of the Gaussian used for filtering (> 0.).

  • k1 (float) – One of the SSIM dampening parameters (> 0.).

  • k2 (float) – One of the SSIM dampening parameters (> 0.).

  • return_map (bool) – If True, will cause the per-pixel SSIM “map” to be returned.

  • precision – The numerical precision to use when performing convolution.

  • filter_fn (Optional[Callable[[chex.Array], chex.Array]]) – An optional argument for overriding the filter function used by SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size and filter_sigma.

Return type:

chex.Numeric

Returns:

Each image’s mean SSIM, or a tensor of individual values if return_map.

Patch extraction functions

extract_patches(images, sizes, strides, rates, *)

Extract patches from images.

extract_patches

dm_pix.extract_patches(images, sizes, strides, rates, *, padding='VALID')[source]

Extract patches from images.

This function is a wrapper for jax.lax.conv_general_dilated_patches to conform to the same interface as tf.image.extract_patches, except for this function supports arbitrary-dimensional images, not only 4D as in tf.image.extract_patches.

The function extracts patches of shape sizes from images in the same manner as a convolution with kernel of shape sizes, stride equal to strides, and the given padding scheme. The patches are stacked in the channel dimension.

Parameters:
  • images (chex.Array) – input batch of images of shape [B, H, W, …, C].

  • sizes (Sequence[int]) – size of the extracted patches. Must be [1, size_rows, size_cols, …, 1].

  • strides (Sequence[int]) – how far the centers of two consecutive patches are in the images. Must be [1, stride_rows, stride_cols, …, 1].

  • rates (Sequence[int]) – sampling rate. Must be [1, rate_rows, rate_cols, …, 1]. This is the input stride, specifying how far two consecutive patch samples are in the input. Equivalent to extracting patches with patch_sizes_eff = patch_sizes + (patch_sizes - 1) * (rates - 1), followed by subsampling them spatially by a factor of rates. This is equivalent to rate in dilated (a.k.a. Atrous) convolutions.

  • padding (str) – the type of padding algorithm to use.

Return type:

jnp.ndarray

Returns:

Tensor of shape [B, patch_rows, patch_cols, …, size_rows * size_cols * … * C].