# -*- coding: utf-8 -*-
"""
===================================
Utility Functions for IViS Imaging
===================================
This module provides a collection of utility functions used in IViS
for radio interferometric reconstruction, model fitting, and image-domain operations.
It includes tools for:
- WCS construction and reprojection using `astropy.wcs`
- Edge apodization (cosine taper) for windowing
- Construction of Laplacian kernels for regularization
- Synthetic Gaussian beam generation and injection
- Coordinate grid generation for PyTorch-based warping
- Elliptical Gaussian fitting for beam or source characterization
Functions
---------
- wcs2D: Build a 2D WCS object from a FITS header.
- apodize: Create a cosine taper for edge apodization.
- ROHSA_kernel: Return a discrete Laplacian kernel used in ROHSA.
- laplacian: Create a Laplacian kernel padded into a full-size map.
- ROHSA_bounds: Generate lower and upper parameter bounds.
- gauss_beam: Generate a normalized 2D Gaussian kernel.
- get_grid: Generate a sampling grid for torch.nn.functional.grid_sample.
- format_input_tensor: Reshape tensors for compatibility with PyTorch ops.
- insert_elliptical_gaussian_source: Inject an elliptical Gaussian model into an image.
- fit_elliptical_gaussian: Fit an elliptical Gaussian and visualize the result.
Dependencies
------------
- numpy
- torch
- matplotlib
- astropy
Author
------
Antoine Marchal, 2024–2025
"""
import numpy as np
import torch
import torch.nn.functional as F
from astropy.wcs.utils import pixel_to_pixel
from astropy import wcs
from astropy.modeling import models, fitting
import matplotlib.pyplot as plt
from ivis.logger import logger
[docs]
def gpu_mem_str(dev: torch.device) -> str:
"""
Return a formatted string describing CUDA memory usage for a device.
Returns empty string if dev is not a CUDA device.
"""
if dev.type != "cuda":
return ""
idx = dev.index if (dev.index is not None) else torch.cuda.current_device()
alloc = torch.cuda.memory_allocated(idx) / 1024**2
reserved = torch.cuda.memory_reserved(idx) / 1024**2
peak = torch.cuda.max_memory_allocated(idx) / 1024**2
total = torch.cuda.get_device_properties(idx).total_memory / 1024**2
return (
f"GPU[{idx}]: "
f"{alloc:.2f} MB alloc, "
f"{reserved:.2f} MB reserved, "
f"{peak:.2f} MB peak, "
f"{total:.2f} MB total"
)
[docs]
def get_device(spec="auto") -> torch.device:
"""
Resolve a compute device from a flexible spec:
- "auto" -> cuda:0 if available; else mps; else cpu
- "cpu" -> cpu
- "cuda" -> cuda:0 (if available)
- "cuda:i" -> cuda:i if available
- "mps" -> Apple MPS if available
- int i -> cuda:i if available; else cpu
- torch.device -> returned as-is
"""
# passthrough
if isinstance(spec, torch.device):
return spec
# int -> cuda:i if possible
if isinstance(spec, int):
if spec >= 0 and torch.cuda.is_available():
idx = int(spec)
if idx < torch.cuda.device_count():
logger.info(f"Using GPU cuda:{idx} ({torch.cuda.get_device_name(idx)})")
return torch.device(f"cuda:{idx}")
else:
logger.warning(
f"Requested cuda:{idx} but only {torch.cuda.device_count()} device(s); using cuda:0"
)
return torch.device("cuda:0")
logger.info("CUDA unavailable or invalid index; using CPU.")
return torch.device("cpu")
# string spec
if isinstance(spec, str):
s = spec.strip().lower()
if s in ("auto", ""):
if torch.cuda.is_available():
logger.info(f"Using GPU (auto) cuda:0 ({torch.cuda.get_device_name(0)})")
return torch.device("cuda:0")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
logger.info("Using Apple MPS (auto).")
return torch.device("mps")
logger.info("Using CPU (auto).")
return torch.device("cpu")
if s == "cpu":
logger.info("Using CPU (user-specified).")
return torch.device("cpu")
if s == "mps":
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
logger.info("Using Apple MPS (user-specified).")
return torch.device("mps")
logger.warning("MPS requested but not available; falling back to CPU.")
return torch.device("cpu")
if s.startswith("cuda"):
if not torch.cuda.is_available():
logger.warning("CUDA requested but not available; falling back to CPU.")
return torch.device("cpu")
idx = 0
if ":" in s:
try:
idx = int(s.split(":", 1)[1])
except Exception:
logger.warning(
f"Could not parse device index from '{spec}', defaulting to cuda:0."
)
idx = 0
if idx < torch.cuda.device_count():
logger.info(f"Using GPU cuda:{idx} ({torch.cuda.get_device_name(idx)})")
return torch.device(f"cuda:{idx}")
logger.warning(
f"Requested cuda:{idx} but only {torch.cuda.device_count()} device(s); using cuda:0"
)
return torch.device("cuda:0")
logger.warning(f"Unrecognized device spec '{spec}'; defaulting to CPU.")
return torch.device("cpu")
def _to_nchw(arr: np.ndarray, expect_complex: bool = False) -> torch.Tensor:
"""
Convert numpy array into 4D NCHW tensor for interpolation.
- If expect_complex=False: arr is (H,W), (B,H,W), or (N,B,H,W)
returns (N,1,H,W).
- If expect_complex=True: arr has trailing axis 2 for (real,imag)
e.g. (H,W,2), (B,H,W,2), or (N,1,H,W,2).
returns (N,2,H,W).
"""
arr = np.array(arr, dtype=np.float32, order="C", copy=False)
t = torch.from_numpy(arr)
if expect_complex:
if t.ndim == 3 and t.shape[-1] == 2:
# (H,W,2) → (1,2,H,W)
t = t.permute(2, 0, 1).unsqueeze(0)
elif t.ndim == 4 and t.shape[-1] == 2:
# (B,H,W,2) → (B,2,H,W)
t = t.permute(0, 3, 1, 2)
elif t.ndim == 5 and t.shape[1] == 1 and t.shape[-1] == 2:
# (N,1,H,W,2) → (N,2,H,W)
t = t.squeeze(1).permute(0, 3, 1, 2)
else:
raise ValueError(f"Unsupported complex shape {arr.shape}")
else:
if t.ndim == 2:
# (H,W) → (1,1,H,W)
t = t.unsqueeze(0).unsqueeze(0)
elif t.ndim == 3:
# (B,H,W) → (B,1,H,W)
t = t.unsqueeze(1)
elif t.ndim == 4 and t.shape[1] == 1:
# (N,1,H,W) already fine
pass
else:
raise ValueError(f"Unsupported real shape {arr.shape}")
return t.float()
[docs]
def downsample_pb(pb: np.ndarray, factor: int) -> np.ndarray:
"""Downsample primary beam array (real, 2D/3D) with bilinear interpolation."""
t = _to_nchw(pb, expect_complex=False)
small = F.interpolate(t, scale_factor=1.0/factor,
mode="bilinear", align_corners=False)
return small.squeeze(0).squeeze(0).cpu().numpy()
[docs]
def downsample_grid(grid: np.ndarray, factor: int) -> np.ndarray:
"""Downsample uv-grid array with trailing complex axis (...,2)."""
t = _to_nchw(grid, expect_complex=True)
small = F.interpolate(t, scale_factor=1.0/factor,
mode="bilinear", align_corners=False)
# Back to (N,H,W,2)
return small.permute(0, 2, 3, 1).cpu().numpy()
[docs]
def downsample_hdr(hdr, factor: int):
"""Return a copy of FITS-like header with updated NAXIS and CDELT."""
hdr = hdr.copy()
hdr["NAXIS1"] //= factor
hdr["NAXIS2"] //= factor
hdr["CDELT1"] *= factor
hdr["CDELT2"] *= factor
return hdr
[docs]
def wcs2D(hdr):
"""
Construct a 2D WCS object from a FITS header.
Parameters
----------
hdr : dict-like
FITS header containing WCS keywords.
Returns
-------
w : astropy.wcs.WCS
2D WCS object.
"""
w = wcs.WCS(naxis=2)
w.wcs.crpix = [hdr['CRPIX1'], hdr['CRPIX2']]
w.wcs.cdelt = np.array([hdr['CDELT1'], hdr['CDELT2']])
w.wcs.crval = [hdr['CRVAL1'], hdr['CRVAL2']]
w.wcs.ctype = [hdr['CTYPE1'], hdr['CTYPE2']]
return w
[docs]
def apodize(radius, shape):
"""
from JF Robitaille package.
Create edges apodization tapper
Parameters
----------
nx, ny : integers
size of the tapper
radius : float
radius must be lower than 1 and greater than 0.
Returns
-------
tapper : numpy array ready to multiply on your image
to apodize edges
"""
ny = shape[0]
nx = shape[1]
if (radius >= 1) or (radius <= 0.):
print('Error: radius must be lower than 1 and greater than 0.')
return
ni = np.fix(radius*nx)
dni = int(nx-ni)
nj = np.fix(radius*ny)
dnj = int(ny-nj)
tap1d_x = np.ones(nx)
tap1d_y = np.ones(ny)
tap1d_x[0:dni] = (np.cos(3. * np.pi/2. + np.pi/2.* (1.* np.arange(dni)/(dni-1)) ))
tap1d_x[nx-dni:] = (np.cos(0. + np.pi/2. * (1.* np.arange(dni)/(dni-1)) ))
tap1d_y[0:dnj] = (np.cos(3. * np.pi/2. + np.pi/2. * (1.* np.arange( dnj )/(dnj-1)) ))
tap1d_y[ny-dnj:] = (np.cos(0. + np.pi/2. * (1.* np.arange(dnj)/(dnj-1)) ))
tapper = np.zeros((ny, nx))
for i in range(nx):
tapper[:,i] = tap1d_y
for i in range(ny):
tapper[i,:] = tapper[i,:] * tap1d_x
return tapper
[docs]
def ROHSA_kernel():
"""
Return a Laplacian-like kernel used in ROHSA.
Returns
-------
kernel : ndarray
3x3 Laplacian kernel normalized by 1/4.
"""
return np.array([[0, -1, 0], [-1, 4, -1], [0, -1, 0]]) / 4.
[docs]
def laplacian(shape):
"""
Construct a 2D Laplacian kernel map for convolution in Fourier space.
Parameters
----------
shape : tuple
Shape of the output Laplacian map (ny, nx).
Returns
-------
kernel_map : ndarray
Laplacian kernel zero-padded in a map of the given shape.
"""
ny, nx = shape
X=np.arange(nx)
Y=np.arange(ny)
ymap,xmap=np.meshgrid(X,Y)
kernel = ROHSA_kernel()
kernel_map = np.zeros(shape)
center_x = nx // 2
center_y = ny // 2
# if (nx % 2) == 0:
kernel_map[center_y-1:center_y+2,center_x-1:center_x+2] = kernel
# else:
# kernel_map[center_y-1:center_y+2,center_x-1:center_x+2] = kernel
return kernel_map
[docs]
def ROHSA_bounds(data_shape, lb_amp, ub_amp):
"""
Create lower and upper bounds arrays for ROHSA optimization.
Parameters
----------
data_shape : tuple
Shape of the model parameter array.
lb_amp : float
Lower bound for amplitude.
ub_amp : float
Upper bound for amplitude.
Returns
-------
bounds : ndarray
Array of shape (N, 2), with N = np.prod(data_shape), where each row is (lower, upper).
"""
bounds_inf = np.full(data_shape, lb_amp)
bounds_sup = np.full(data_shape, ub_amp)
return np.column_stack((bounds_inf.ravel(), bounds_sup.ravel()))
[docs]
def gauss_beam(sigma, shape, FWHM=False):
"""
Generate a circular symmetric 2D Gaussian kernel.
Parameters
----------
sigma : float
Standard deviation of the Gaussian in pixels (or FWHM if `FWHM=True`).
shape : tuple
Shape of the output map (ny, nx).
FWHM : bool, optional
If True, `sigma` is interpreted as FWHM. Default is False.
Returns
-------
gauss : ndarray
2D normalized Gaussian array.
"""
ny, nx = shape
X=np.arange(nx)
Y=np.arange(ny)
ymap,xmap=np.meshgrid(X,Y)
if (nx % 2) == 0:
xmap = xmap - (nx)/2.
else:
xmap = xmap - (nx-1.)/2.
if (ny % 2) == 0:
ymap = ymap - (ny)/2.
else:
ymap = ymap - (ny-1.)/2.
map = np.sqrt(xmap**2.+ymap**2.)
if FWHM == True:
sigma = sigma / (2.*np.sqrt(2.*np.log(2.)))
gauss = np.exp(-0.5*(map)**2./sigma**2.)
gauss /= np.sum(gauss)
return gauss
[docs]
def get_grid(shape_input_tensor, wcs_in, wcs_out, shape_out):
"""
Create a normalized sampling grid for spatial reprojecting in PyTorch.
Parameters
----------
shape_input_tensor : tuple
Shape of the input tensor: (B, C, H_in, W_in).
wcs_in : astropy.wcs.WCS
WCS of the input image.
wcs_out : astropy.wcs.WCS
Target WCS for output grid.
shape_out : tuple
Shape of the output image (H_out, W_out).
Returns
-------
grid : torch.Tensor
Grid of shape (1, H_out, W_out, 2) normalized to [-1, 1].
"""
# Generate the output grid coordinates
x_out, y_out = torch.meshgrid(
torch.arange(shape_out[1], dtype=torch.float32), #FIXME
torch.arange(shape_out[0], dtype=torch.float32), #FIXME
indexing='xy'
)
# Convert output pixel coordinates to input pixel coordinates
# world_coords = wcs_out.pixel_to_world(x_out.numpy(), y_out.numpy())
# y_in, x_in = wcs_in.world_to_pixel(world_coords)
y_in, x_in = pixel_to_pixel(wcs_out, wcs_in, x_out.numpy(), y_out.numpy())
# Convert to PyTorch tensors
x_in_tensor = torch.tensor(x_in, dtype=torch.float32)
y_in_tensor = torch.tensor(y_in, dtype=torch.float32)
# Stack coordinates for grid sampling
grid = torch.stack((y_in_tensor, x_in_tensor), dim=-1) # Shape: [H, W, 2]
# Add batch dimension to the grid
grid = grid.unsqueeze(0) # Shape: [1, H, W, 2]
# Normalize grid coordinates to [-1, 1]
grid[..., 0] = (grid[..., 0] / (shape_input_tensor[2] - 1)) * 2 - 1 # Normalize y
grid[..., 1] = (grid[..., 1] / (shape_input_tensor[3] - 1)) * 2 - 1 # Normalize x
return grid
[docs]
def insert_elliptical_gaussian_source(shape, cell_size, flux_jy=1.0,
fwhm_maj_arcsec=15.0, fwhm_min_arcsec=7.5,
pa_deg=0.0, center=None):
"""
Generate a sky model with an elliptical Gaussian source.
Parameters
----------
shape : tuple
Output image shape (ny, nx).
cell_size : float
Pixel size in arcsec.
flux_jy : float, optional
Total integrated flux of the source in Jy. Default is 1.0.
fwhm_maj_arcsec : float, optional
FWHM of the major axis in arcsec. Default is 15.0.
fwhm_min_arcsec : float, optional
FWHM of the minor axis in arcsec. Default is 7.5.
pa_deg : float, optional
Position angle in degrees (counter-clockwise from x-axis). Default is 0.0.
center : tuple, optional
Pixel coordinates of the source center (y, x). Default is center of image.
Returns
-------
sky_model : ndarray
2D image array (float32) in units of Jy/pixel.
"""
ny, nx = shape
y, x = np.meshgrid(np.arange(ny), np.arange(nx), indexing='ij')
# Set center
if center is None:
cy, cx = ny // 2, nx // 2
else:
cy, cx = center
# Convert FWHM to sigma in pixels
sigma_maj_pix = (fwhm_maj_arcsec / 2.3548) / cell_size
sigma_min_pix = (fwhm_min_arcsec / 2.3548) / cell_size
theta_rad = np.deg2rad(pa_deg)
# Rotate coordinate grid
dx = x - cx
dy = y - cy
dx_rot = dx * np.cos(theta_rad) + dy * np.sin(theta_rad)
dy_rot = -dx * np.sin(theta_rad) + dy * np.cos(theta_rad)
# Elliptical 2D Gaussian (unit integral)
gaussian = np.exp(-0.5 * ((dx_rot / sigma_maj_pix)**2 + (dy_rot / sigma_min_pix)**2))
gaussian /= np.sum(gaussian) * cell_size**2 # now in Jy/arcsec²
# Scale to desired flux
sky_model = gaussian * flux_jy
return sky_model.astype(np.float32)
[docs]
def fit_elliptical_gaussian(cutout, pixel_scale_arcsec=1.0):
"""
Fit elliptical Gaussian to image cutout and return flux, Bmaj, Bmin in arcsec.
Parameters
----------
cutout : 2D ndarray
Image array containing a single source, in units of Jy/arcsec^2.
pixel_scale_arcsec : float
Pixel size in arcsec/pixel.
Returns
-------
flux : float
Integrated flux in Jy.
Bmaj : float
FWHM of major axis in arcsec.
Bmin : float
FWHM of minor axis in arcsec.
theta : float
Position angle in degrees (CCW from +x).
"""
y, x = np.mgrid[:cutout.shape[0], :cutout.shape[1]]
# Initial guess: symmetric Gaussian at center
amp_guess = np.max(cutout)
x0 = cutout.shape[1] / 2
y0 = cutout.shape[0] / 2
sigma_guess = 1.5
gauss_init = models.Gaussian2D(
amplitude=amp_guess,
x_mean=x0,
y_mean=y0,
x_stddev=sigma_guess,
y_stddev=sigma_guess,
theta=0.0
)
fitter = fitting.LevMarLSQFitter()
fitted = fitter(gauss_init, x, y, cutout)
# Extract fit parameters
sigma_x = fitted.x_stddev.value # in pixels
sigma_y = fitted.y_stddev.value # in pixels
theta = np.rad2deg(fitted.theta.value) # radians → degrees
# Convert sigma to arcsec
sigma_x_arcsec = sigma_x * pixel_scale_arcsec
sigma_y_arcsec = sigma_y * pixel_scale_arcsec
# Convert to FWHM (FWHM = 2.3548 * sigma)
FWHM_x = 2.3548 * sigma_x_arcsec
FWHM_y = 2.3548 * sigma_y_arcsec
# Integrated flux in Jy = amp (Jy/arcsec²) × 2πσxσy [in arcsec²]
flux = fitted.amplitude.value * 2 * np.pi * sigma_x_arcsec * sigma_y_arcsec
# Ensure Bmaj ≥ Bmin
if FWHM_x >= FWHM_y:
Bmaj, Bmin = FWHM_x, FWHM_y
else:
Bmaj, Bmin = FWHM_y, FWHM_x
theta += 90.0 # adjust PA if axes swapped
# Plot input and fitted model
model_data = fitted(x, y)
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(cutout, origin='lower', cmap='inferno')
plt.colorbar(label='Jy/arcsec$^{2}$')
plt.title('Input Cutout')
plt.subplot(1, 2, 2)
plt.imshow(cutout, origin='lower', cmap='inferno')
plt.contour(model_data, levels=5, colors='white', linewidths=1)
plt.colorbar(label='Jy/arcsec²')
plt.title('Fitted Gaussian Contours')
plt.suptitle(f"Flux = {flux:.3f} Jy Bmaj = {Bmaj:.2f}\" Bmin = {Bmin:.2f}\" PA = {theta:.1f}°")
plt.tight_layout()
plt.show()
return flux, Bmaj, Bmin, theta