Source code for ivis.utils.fourier

# -*- coding: utf-8 -*-
"""
Power Spectrum and Image Processing Utilities
---------------------------------------------

This module provides tools to analyze the spatial frequency content of 2D images,
including power spectrum estimation, cross-spectra, autocorrelation functions,
and image preprocessing utilities such as apodization and padding.

Originally adapted from a package by J.-F. Robitaille.

Functions
---------

- ``powspec(image, reso=1, autocorr=False, **kwargs)``
    Computes the azimuthally averaged power spectrum of a 2D image, optionally returning
    the autocorrelation or cross-spectrum if a second image is passed via ``im2``.

- ``cross_spec(image, image2, reso=1)``
    Computes the cross-power spectrum between two 2D images.

- ``kgrid(image)``
    Returns a normalized radial spatial frequency grid (k) corresponding to the FFT bins
    for a given image shape.

- ``apodize(radius, shape)``
    Returns a 2D cosine taper mask to apodize the edges of an image.

- ``apodize_1d(radius, shape)``
    Returns a 1D cosine taper to apodize the edges of a 1D array (e.g. spectrum).

- ``padding(input, fact)``
    Pads a 2D image symmetrically by a given fractional amount (``fact``), centering
    the original image in a larger zero-filled array.

- ``depad(input, fact)``
    Crops a previously padded 2D image, removing the borders added by ``padding``.

Author: Adapted from JF Robitaille package by Antoine Marchal (mostly added docstrings for RTD). 
"""

import numpy as np
import torch

[docs] def powspec_torch(image, pixel_size_arcsec=1.0, nbins=None, autocorr=False, im2=None): device = image.device H, W = image.shape if nbins is None: nbins = max(H, W) // 2 dx_arcmin = pixel_size_arcsec / 60.0 fx = torch.fft.fftshift(torch.fft.fftfreq(W, d=dx_arcmin)).to(device) fy = torch.fft.fftshift(torch.fft.fftfreq(H, d=dx_arcmin)).to(device) kx, ky = torch.meshgrid(fx, fy, indexing='xy') k_radius = torch.sqrt(kx**2 + ky**2) ft1 = torch.fft.fft2(image) if im2 is not None: ft2 = torch.fft.fft2(im2) ps2d = torch.real(ft1 * torch.conj(ft2)) / (H * W) else: ps2d = torch.real(ft1 * torch.conj(ft1)) / (H * W) if autocorr: ps2d = torch.fft.ifft2(ps2d).real ps_flat = torch.fft.fftshift(ps2d).flatten() k_flat = k_radius.flatten() kmin = k_flat[k_flat > 0].min() kmax = k_flat.max() bins = torch.linspace(kmin, kmax, nbins + 1, device=device) bin_centers = 0.5 * (bins[:-1] + bins[1:]) inds = torch.bucketize(k_flat, bins) - 1 power_1d = torch.zeros(nbins, device=device) for i in range(nbins): mask = inds == i if torch.any(mask): power_1d[i] = ps_flat[mask].mean() return bin_centers, power_1d
[docs] def powspec(image, reso=1, autocorr=False, **kwargs): """ Compute the azimuthally averaged 1D power spectrum or autocorrelation of a 2D image. Parameters ---------- image : ndarray Input 2D image. reso : float, optional Pixel size in spatial units (e.g. arcmin). autocorr : bool, optional If True, return the autocorrelation function instead of the power spectrum. im2 : ndarray, optional (passed via kwargs) Second image for computing the cross-spectrum. Returns ------- tab_k : ndarray Radial spatial frequency values. spec_k : ndarray Azimuthally averaged power spectrum. """ na = float(image.shape[1]) nb = float(image.shape[0]) nf = np.max(np.array([na, nb])) k_crit = nf / 2 bins = np.arange(k_crit + 1) imft = np.fft.fft2(image) if 'im2' in kwargs: im2 = kwargs.get('im2') im2ft = np.fft.fft2(im2) ps2D = imft * np.conj(im2ft) / (na * nb) else: ps2D = np.abs(imft) ** 2 / (na * nb) del imft if autocorr: ps2D = np.fft.ifft2(ps2D).real x, y = np.meshgrid(np.arange(na), np.arange(nb)) if na % 2 == 0: x = (x - (na / 2)) / na shiftx = na / 2 else: x = (x - (na - 1) / 2) / na shiftx = (na - 1) / 2 + 1 if nb % 2 == 0: y = (y - (nb / 2)) / nb shifty = nb / 2 else: y = (y - (nb - 1) / 2) / nb shifty = (nb - 1) / 2 + 1 k_mat = np.sqrt(x ** 2 + y ** 2) * nf k_mat = np.roll(k_mat, int(shiftx), axis=1) k_mat = np.roll(k_mat, int(shifty), axis=0) k_mod = np.round(k_mat, decimals=0) hval, _ = np.histogram(k_mod, bins=bins) kval = np.zeros(int(k_crit)) kpow = np.zeros(int(k_crit)) for j in range(int(k_crit)): kval[j] = np.sum(k_mod[k_mod == float(j)]) / hval[j] kpow[j] = np.sum(ps2D[k_mod == float(j)]) / hval[j] spec_k = kpow[1:np.size(hval) - 1] if not autocorr: tab_k = kval[1:np.size(hval) - 1] / (k_crit * 2. * reso) else: tab_k = kval[1:np.size(hval) - 1] * reso return tab_k, spec_k
[docs] def kgrid(image): """ Compute the 2D spatial frequency grid (radius only) for a given image. Parameters ---------- image : ndarray Input 2D image. Returns ------- kmat : ndarray 2D array of radial spatial frequencies (normalized). """ na = float(image.shape[1]) nb = float(image.shape[0]) x, y = np.meshgrid(np.arange(na), np.arange(nb)) if na % 2 == 0: x = (x - (na / 2)) / na else: x = (x - (na - 1) / 2) / na if nb % 2 == 0: y = (y - (nb / 2)) / nb else: y = (y - (nb - 1) / 2) / nb return np.sqrt(x ** 2 + y ** 2)
[docs] def cross_spec(image, image2, reso=1): """ Compute the 1D cross power spectrum between two 2D images. Parameters ---------- image : ndarray First 2D image. image2 : ndarray Second 2D image. reso : float, optional Pixel size in spatial units (e.g. arcmin). Returns ------- spec_k : ndarray Azimuthally averaged cross power spectrum. """ na = float(image.shape[1]) nb = float(image.shape[0]) nf = np.max(np.array([na, nb])) k_crit = nf / 2 bins = np.arange(k_crit + 1) imft = np.fft.fft2(image) im2ft = np.fft.fft2(image2) ps2D = imft * np.conj(im2ft) / (na * nb) del imft x, y = np.meshgrid(np.arange(na), np.arange(nb)) if na % 2 == 0: x = (x - (na / 2)) / na shiftx = na / 2 else: x = (x - (na - 1) / 2) / na shiftx = (na - 1) / 2 + 1 if nb % 2 == 0: y = (y - (nb / 2)) / nb shifty = nb / 2 else: y = (y - (nb - 1) / 2) / nb shifty = (nb - 1) / 2 + 1 k_mat = np.sqrt(x ** 2 + y ** 2) * nf k_mat = np.roll(k_mat, int(shiftx), axis=1) k_mat = np.roll(k_mat, int(shifty), axis=0) k_mod = np.round(k_mat, decimals=0) hval, _ = np.histogram(k_mod, bins=bins) kval = np.zeros(int(k_crit)) kpow = np.zeros(int(k_crit)) for j in range(int(k_crit)): kval[j] = np.sum(k_mod[k_mod == float(j)]) / hval[j] kpow[j] = np.sum(ps2D[k_mod == float(j)]) / hval[j] return kpow[1:np.size(hval) - 1]
[docs] def apodize_1d(radius, shape): """ Create a 1D cosine taper window for edge apodization. Parameters ---------- radius : float Fractional taper width (0 < radius < 1). shape : int Total length of the 1D array. Returns ------- tap1d_x : ndarray 1D tapering function. """ if not (0 < radius < 1): raise ValueError("radius must be between 0 and 1") nx = shape nj = np.fix(radius * nx) dnj = int(nx - nj) tap1d_x = np.ones(nx) tap1d_x[0:dnj] = np.cos(3 * np.pi / 2 + np.pi / 2 * np.arange(dnj) / (dnj - 1)) tap1d_x[nx - dnj:] = np.cos(np.pi / 2 * np.arange(dnj) / (dnj - 1)) return tap1d_x
[docs] def apodize(radius, shape): """ Create a 2D cosine taper window for edge apodization. Parameters ---------- radius : float Fractional taper width (0 < radius < 1). shape : tuple of int Shape of the 2D image (ny, nx). Returns ------- tapper : ndarray 2D tapering function. """ if not (0 < radius < 1): raise ValueError("radius must be between 0 and 1") ny, nx = shape dni = int(nx - np.fix(radius * nx)) dnj = int(ny - np.fix(radius * ny)) tap1d_x = np.ones(nx) tap1d_y = np.ones(ny) tap1d_x[0:dni] = np.cos(3 * np.pi / 2 + np.pi / 2 * np.arange(dni) / (dni - 1)) tap1d_x[nx - dni:] = np.cos(np.pi / 2 * np.arange(dni) / (dni - 1)) tap1d_y[0:dnj] = np.cos(3 * np.pi / 2 + np.pi / 2 * np.arange(dnj) / (dnj - 1)) tap1d_y[ny - dnj:] = np.cos(np.pi / 2 * np.arange(dnj) / (dnj - 1)) return np.outer(tap1d_y, tap1d_x)
[docs] def padding(input, fact): """ Pad a 2D image by a given fractional amount with zeros. Parameters ---------- input : ndarray Original 2D image. fact : float Padding fraction (e.g. 0.2 adds 20% on each side). Returns ------- output : ndarray Padded image. """ y, x = int(input.shape[0] * (1. + fact)), int(input.shape[1] * (1. + fact)) width = input.shape[1] height = input.shape[0] output = np.zeros((y, x)) xpos = int(x / 2 - width / 2) ypos = int(y / 2 - height / 2) output[ypos:height + ypos, xpos:width + xpos] = input return output
[docs] def depad(input, fact): """ Crop a padded 2D image to its original size. Parameters ---------- input : ndarray Padded 2D image. fact : float Padding fraction used during padding. Returns ------- output : ndarray Cropped image matching original dimensions. """ y, x = int(input.shape[0] / (1. + fact)) + 1, int(input.shape[1] / (1. + fact)) + 1 width = input.shape[1] height = input.shape[0] xpos = int(width / 2 - x / 2) ypos = int(height / 2 - y / 2) return input[ypos:y + ypos, xpos:x + xpos]