Source code for ivis.imager

# -*- coding: utf-8 -*-
"""
Imager module for joint deconvolution using GPU-accelerated optimization.

This module provides the `Imager` class which performs non-linear
optimization combining interferometric and single-dish data.

Author: Antoine Marchal
"""
import os
import glob
import sys
import numpy as np
from astropy import units as u
from astropy import wcs
from astropy.coordinates import SkyCoord
from astropy.io import fits
from scipy import optimize
from radio_beam import Beam
import torch
from tqdm import tqdm as tqdm
from numpy.fft import fft2
from torch.fft import fft2 as tfft2
from dataclasses import dataclass
from reproject import reproject_interp
from scipy.optimize import fmin_l_bfgs_b
import time

from ivis.logger import logger
from ivis.utils import dunits, dutils
from ivis.optim.solvers import (
    optimize_scipy_lbfgsb,
    optimize_torch_cg,
    optimize_torch_fista,
    optimize_torch_lbfgs,
)

#------------------------------------
#------------ Imager3D --------------
#------------------------------------
[docs] class Imager3D: """ GPU-accelerated imager for joint deconvolution of interferometric and single-dish data, using the new VisIData dataclass. """ def __init__(self, vis_data, pb, grid, sd, beam_sd, hdr, init_params, max_its, lambda_sd, positivity, cost_device="auto", optim_device="auto", beam_workers=0): self.vis_data = vis_data self.pb = pb self.grid = grid self.sd = sd self.beam_sd = beam_sd self.hdr = hdr self.init_params = init_params self.max_its = max_its self.lambda_sd = lambda_sd self.positivity = positivity self.beam_workers = beam_workers # moved to dutils self.cost_device = dutils.get_device(cost_device) self.optim_device = dutils.get_device(optim_device) logger.info("[Initialize Imager3D ]") logger.info(f"Number of iterations to be performed by the optimizer: {self.max_its}") if self.lambda_sd == 0: logger.warning("lambda_sd = 0 — No short-spacing correction.") if self.positivity: logger.info("Optimizer bounded - Positivity == True") else: logger.info("Optimizer not bounded - Positivity == False")
[docs] def forward_model(self, model, x=None): """ Compute model visibilities from the current image parameters using the given model's forward operator. """ if model is None: raise ValueError("Must pass a model instance to `forward_model()`.") cell_size = (self.hdr["CDELT2"] * u.deg).to(u.arcsec) pb_native = np.asarray(self.pb, dtype=np.float32) grid_native = np.asarray(self.grid, dtype=np.float32) x_model = self.init_params if x is None else x return model.forward( x=x_model, vis_data=self.vis_data, pb=pb_native, device=self.cost_device, cell_size=cell_size.value, grid_array=grid_native )
[docs] def adjoint_model(self, model, vis=None, return_real=False): """ Apply the backward/adjoint of the model forward operator to a visibility cube or flat visibility vector. If vis is None, uses self.vis_data.data_I. """ if model is None: raise ValueError("Must pass a model instance to `adjoint_model()`.") cell_size = (self.hdr["CDELT2"] * u.deg).to(u.arcsec) pb_native = np.asarray(self.pb, dtype=np.float32) grid_native = np.asarray(self.grid, dtype=np.float32) if hasattr(model, "backward"): return model.backward( vis=vis, vis_data=self.vis_data, pb=pb_native, device=self.cost_device, cell_size=cell_size.value, grid_array=grid_native, x_shape=self.init_params.shape, return_real=return_real, ) return model.adjoint( vis=vis, vis_data=self.vis_data, pb=pb_native, device=self.cost_device, cell_size=cell_size.value, grid_array=grid_native, x_shape=self.init_params.shape, return_real=return_real, )
# ------------------------------------------------------------------ # process(): SAME logic + SAME log strings as your original version # ------------------------------------------------------------------
[docs] def process(self, model=None, solver="LBFGS", units="Jy/arcsec^2", history_size=10, dtype=torch.float32): """ Devices ------- - optim_device: where PyTorch LBFGS params/optimizer live - cost_device : where model.objective() runs Rules #FIXME ------------ - positivity=True -> SciPy L-BFGS-B (CPU-only optimizer) - positivity=False -> PyTorch LBFGS on optim_device; cost on cost_device Notes ----- objective() is expected to call backward() internally (unchanged). """ if model is None: raise ValueError("Must pass a model instance to `process()`.") cost_dev = self.cost_device optim_dev = self.optim_device # --- Units logger (unchanged strings) --- if units == "Jy/arcsec^2": logger.info("Units of output: Jy/arcsec^2.") elif units == "K": logger.info("Units of output: K (using frequency in Hz).") elif units == "Jy/beam": logger.warning( "Units of output: Jy/beam. " "Note: this is not the preferred unit in IViS, as the effective beam " "depends on regularization. Unlike CLEAN, the model is not reconvolved " "with a Gaussian restoring beam. " "We recommend using 'K' units for diffuse extended emission." ) else: logger.error( f"Unknown unit '{units}'. Units must be 'Jy/arcsec^2', 'K', or (not recommended) 'Jy/beam'. " "Defaulting to Jy/arcsec^2." ) # --- Image/grid params (unchanged) --- cell_size = (self.hdr["CDELT2"] * u.deg).to(u.arcsec) shape = (self.hdr["NAXIS2"], self.hdr["NAXIS1"]) tapper = dutils.apodize(0.98, shape) # --- FFT beam for reg (unchanged) --- kernel_map = dutils.laplacian(shape) fftkernel = abs(fft2(kernel_map)) bmaj_pix = self.beam_sd.major.to(u.deg).value / cell_size.to(u.deg).value beam = dutils.gauss_beam(bmaj_pix, shape, FWHM=True) fftbeam = abs(fft2(beam)) # --- FFT single-dish (unchanged) --- fftsd = cell_size.value**2 * tfft2(torch.from_numpy(np.float32(self.sd))).cpu().numpy() # --- Common params (unchanged keys/values) --- params = dict( vis_data=self.vis_data, pb=np.asarray(self.pb, dtype=np.float32), fftbeam=np.asarray(fftbeam, dtype=np.float32), fftsd=np.asarray(fftsd, dtype=np.complex64), tapper=np.asarray(tapper, dtype=np.float32), lambda_sd=self.lambda_sd, fftkernel=np.asarray(fftkernel, dtype=np.float32), cell_size=cell_size.value, grid_array=np.asarray(self.grid, dtype=np.float32), beam_workers=self.beam_workers, ) param_shape = self.init_params.shape solver_name = str(solver).upper() if self.positivity == True: if solver_name == "FISTA": flat = optimize_torch_fista( model=model, x_init=self.init_params, dtype=dtype, max_its=self.max_its, cost_dev=cost_dev, optim_dev=optim_dev, params=params, positivity=True, ) result = flat.reshape(param_shape) else: x0 = self.init_params.ravel().astype(np.float64) raw_bounds = dutils.ROHSA_bounds(param_shape, lb_amp=0, ub_amp=np.inf) bounds64 = [(float(lo), float(hi)) for (lo, hi) in raw_bounds] result = optimize_scipy_lbfgsb( model=model, x0=x0, bounds64=bounds64, param_shape=param_shape, max_its=self.max_its, cost_dev=cost_dev, optim_dev=optim_dev, params=params ) else: if solver_name == "CG": flat = optimize_torch_cg( model=model, x_init=self.init_params, dtype=dtype, max_its=self.max_its, cost_dev=cost_dev, optim_dev=optim_dev, params=params, ) elif solver_name == "FISTA": flat = optimize_torch_fista( model=model, x_init=self.init_params, dtype=dtype, max_its=self.max_its, cost_dev=cost_dev, optim_dev=optim_dev, params=params, positivity=False, ) else: flat = optimize_torch_lbfgs( model=model, x_init=self.init_params, dtype=dtype, history_size=history_size, max_its=self.max_its, cost_dev=cost_dev, optim_dev=optim_dev, params=params, ) result = flat.reshape(param_shape) # logger.warning( # "If you are using ASKAP's convention I = XX + YY (with no 1/2 factor) then multiply the output by 2. " # "Here assuming I = 1/2 (XX + YY)." # ) # --- Unit conversion (unchanged) --- if units == "Jy/arcsec^2": output = result elif units == "Jy/beam": assumed_fwhm_pix = 3 # hard-coded value logger.warning( f"Converting to Jy/beam assuming a restoring beam of " f"{assumed_fwhm_pix} × cell_size = " f"{assumed_fwhm_pix * cell_size:.3f} FWHM." ) beam_r = Beam(assumed_fwhm_pix * cell_size, assumed_fwhm_pix * cell_size, 1.e-12 * u.deg) output = result * beam_r.sr.to(u.arcsec**2).value elif units == "K": nu_Hz = self.vis_data.frequency output = dunits.jy_per_arcsec2_to_K(result, nu_Hz) else: logger.warning("Unknown unit type. Returning result in Jy/arcsec^2.") output = result logger.info("Successful run. Please clap.") return output
# #------------------------------------ # #------------ Imager3D -------------- # #------------------------------------ # class Imager3D: # """ # GPU-accelerated imager for joint deconvolution of interferometric # and single-dish data, using the new VisIData dataclass. # """ # def __init__(self, vis_data, pb, grid, sd, beam_sd, hdr, # init_params, max_its, lambda_sd, positivity, cost_device="auto", # optim_device="auto", beam_workers=0): # self.vis_data = vis_data # self.pb = pb # self.grid = grid # self.sd = sd # self.beam_sd = beam_sd # self.hdr = hdr # self.init_params = init_params # self.max_its = max_its # self.lambda_sd = lambda_sd # self.positivity = positivity # self.beam_workers = beam_workers # self.cost_device = self.get_device(cost_device) # self.optim_device = self.get_device(optim_device) # logger.info("[Initialize Imager3D ]") # logger.info(f"Number of iterations to be performed by the optimizer: {self.max_its}") # if self.lambda_sd == 0: # logger.warning("lambda_sd = 0 — No short-spacing correction.") # if self.positivity == True: # logger.info('Optimizer bounded - Positivity == True') # # logger.warning('Optimizer bounded - Because there is noise in the data, it is generally not recommanded to add a positivity constaint.') # else: # logger.info('Optimizer not bounded - Positivity == False') # @staticmethod # def get_device(spec="auto") -> torch.device: # """ # Resolve a compute device from a flexible spec: # - "auto" -> cuda:0 if available; else mps (Apple Silicon); else cpu # - "cpu" -> cpu # - "cuda" -> cuda:0 (if available) # - "cuda:i" -> that specific GPU index (if available) # - "mps" -> Apple Metal Performance Shaders (if available) # - int i -> cuda:i if CUDA available; else cpu # - torch.device -> returned as-is # """ # # passthrough # if isinstance(spec, torch.device): # return spec # # int -> cuda:i if possible # if isinstance(spec, int): # if spec >= 0 and torch.cuda.is_available(): # idx = int(spec) # if idx < torch.cuda.device_count(): # logger.info(f"Using GPU cuda:{idx} ({torch.cuda.get_device_name(idx)})") # return torch.device(f"cuda:{idx}") # else: # logger.warning(f"Requested cuda:{idx} but only {torch.cuda.device_count()} device(s); using cuda:0") # return torch.device("cuda:0") # logger.info("CUDA unavailable or invalid index; using CPU.") # return torch.device("cpu") # # string path # if isinstance(spec, str): # s = spec.strip().lower() # if s in ("auto", ""): # if torch.cuda.is_available(): # logger.info(f"Using GPU (auto) cuda:0 ({torch.cuda.get_device_name(0)})") # return torch.device("cuda:0") # if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # logger.info("Using Apple MPS (auto).") # return torch.device("mps") # logger.info("Using CPU (auto).") # return torch.device("cpu") # if s == "cpu": # logger.info("Using CPU (user-specified).") # return torch.device("cpu") # if s == "mps": # if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # logger.info("Using Apple MPS (user-specified).") # return torch.device("mps") # logger.warning("MPS requested but not available; falling back to CPU.") # return torch.device("cpu") # if s.startswith("cuda"): # if not torch.cuda.is_available(): # logger.warning("CUDA requested but not available; falling back to CPU.") # return torch.device("cpu") # # allow "cuda" or "cuda:i" # idx = 0 # if ":" in s: # try: # idx = int(s.split(":", 1)[1]) # except Exception: # logger.warning(f"Could not parse device index from '{spec}', defaulting to cuda:0.") # idx = 0 # if idx < torch.cuda.device_count(): # logger.info(f"Using GPU cuda:{idx} ({torch.cuda.get_device_name(idx)})") # return torch.device(f"cuda:{idx}") # else: # logger.warning(f"Requested cuda:{idx} but only {torch.cuda.device_count()} device(s); using cuda:0") # return torch.device("cuda:0") # logger.warning(f"Unrecognized device spec '{spec}'; defaulting to CPU.") # return torch.device("cpu") # def forward_model(self, model): # """ # Compute model visibilities from the current image parameters # using the given model's forward operator. # """ # if model is None: # raise ValueError("Must pass a model instance to `forward_model()`.") # cell_size = (self.hdr["CDELT2"] * u.deg).to(u.arcsec) # pb_native = np.asarray(self.pb, dtype=np.float32) # grid_native = np.asarray(self.grid, dtype=np.float32) # return model.forward( # x=self.init_params, # vis_data=self.vis_data, # pb=pb_native, # device=self.cost_device, # cell_size=cell_size.value, # grid_array=grid_native # ) # def process(self, model=None, units="Jy/arcsec^2", # history_size=10, dtype=torch.float32): # """ # Devices # ------- # - ``optim_device``: where PyTorch LBFGS params/optimizer live # - ``cost_device`` : where ``model.objective()`` runs # Rules # ----- # - ``positivity=True`` (bounded): SciPy L-BFGS-B (CPU-only). # - ``positivity=False`` (unbounded): PyTorch LBFGS on optim_device; cost on cost_device. # Notes # ----- # ``objective()`` is expected to call ``backward()`` internally (unchanged). # """ # if model is None: # raise ValueError("Must pass a model instance to `process()`.") # cost_dev = self.cost_device # optim_dev = self.optim_device # # --- Units logger --- # if units == "Jy/arcsec^2": # logger.info("Units of output: Jy/arcsec^2.") # elif units == "K": # logger.info("Units of output: K (using frequency in Hz).") # elif units == "Jy/beam": # logger.warning( # "Units of output: Jy/beam. " # "Note: this is not the preferred unit in IViS, as the effective beam " # "depends on regularization. Unlike CLEAN, the model is not reconvolved " # "with a Gaussian restoring beam. " # "We recommend using 'K' units for diffuse extended emission." # ) # else: # logger.error( # f"Unknown unit '{units}'. Units must be 'Jy/arcsec^2', 'K', or (not recommended) 'Jy/beam'. " # "Defaulting to Jy/arcsec^2." # ) # # ---- local helper: format GPU mem for any CUDA device ---- # def _gpu_mem_str(dev: torch.device) -> str: # if dev.type != "cuda": # return "" # idx = dev.index if (dev.index is not None) else torch.cuda.current_device() # alloc = torch.cuda.memory_allocated(idx) / 1024**2 # reserved = torch.cuda.memory_reserved(idx) / 1024**2 # peak = torch.cuda.max_memory_allocated(idx) / 1024**2 # total = torch.cuda.get_device_properties(idx).total_memory / 1024**2 # return f"GPU[{idx}]: {alloc:.2f} MB alloc, {reserved:.2f} MB reserved, {peak:.2f} MB peak, {total:.2f} MB total" # # --- Image/grid params --- # cell_size = (self.hdr["CDELT2"] * u.deg).to(u.arcsec) # shape = (self.hdr["NAXIS2"], self.hdr["NAXIS1"]) # tapper = dutils.apodize(0.98, shape) # # --- FFT beam for reg --- # kernel_map = dutils.laplacian(shape) # fftkernel = abs(fft2(kernel_map)) # bmaj_pix = self.beam_sd.major.to(u.deg).value / cell_size.to(u.deg).value # beam = dutils.gauss_beam(bmaj_pix, shape, FWHM=True) # fftbeam = abs(fft2(beam)) # # --- FFT single-dish --- # fftsd = cell_size.value**2 * tfft2(torch.from_numpy(np.float32(self.sd))).cpu().numpy() # # --- Common params --- # params = dict( # vis_data=self.vis_data, # pb=np.asarray(self.pb, dtype=np.float32), # fftbeam=np.asarray(fftbeam, dtype=np.float32), # fftsd=np.asarray(fftsd, dtype=np.complex64), # tapper=np.asarray(tapper, dtype=np.float32), # lambda_sd=self.lambda_sd, # fftkernel=np.asarray(fftkernel, dtype=np.float32), # cell_size=cell_size.value, # grid_array=np.asarray(self.grid, dtype=np.float32), # beam_workers=self.beam_workers, # ) # param_shape = self.init_params.shape # # --- SciPy path (positivity -> bounded -> CPU only) # if getattr(self, "positivity", False): # from scipy.optimize import fmin_l_bfgs_b # if optim_dev.type == "cuda": # logger.info("positivity=True with optim_device on CUDA → falling back to CPU for SciPy L-BFGS-B.") # x0 = self.init_params.ravel().astype(np.float64) # raw_bounds = dutils.ROHSA_bounds(param_shape, lb_amp=0, ub_amp=np.inf) # bounds64 = [(float(lo), float(hi)) for (lo, hi) in raw_bounds] # def fun_and_grad(x): # # cost on cost_dev, but SciPy itself is CPU # f, g = model.loss( # x, shape=param_shape, device=cost_dev, # jac=True, **params # ) # return float(f), np.ascontiguousarray(g, dtype=np.float64) # logger.info(f"Starting optimisation: SciPy L-BFGS-B (CPU optimizer), cost on {cost_dev}") # x_opt, f_opt, info = fmin_l_bfgs_b( # fun_and_grad, x0, bounds=bounds64, # m=7, pgtol=1e-8, factr=1e7, maxls=20, # maxiter=int(self.max_its), iprint=25, # ) # result = x_opt.reshape(param_shape) # # --- PyTorch path (unbounded LBFGS on optim_device; cost on cost_device) # else: # # Reset peak stats on any CUDA devices we might use # for dev in [cost_dev, optim_dev]: # if dev.type == "cuda": # idx = dev.index if dev.index is not None else torch.cuda.current_device() # torch.cuda.reset_peak_memory_stats(idx) # logger.info( # f"Starting optimisation: PyTorch LBFGS on {optim_dev} (unconstrained); " # f"cost on {cost_dev}" # ) # x_param = torch.tensor(self.init_params, dtype=dtype, device=optim_dev, requires_grad=True) # opt = torch.optim.LBFGS( # [x_param], # lr=1.0, # max_iter=int(self.max_its), # history_size=history_size, # line_search_fn="strong_wolfe", # tolerance_grad=1e-8, # tolerance_change=0.0, # ) # def closure(): # opt.zero_grad(set_to_none=True) # if cost_dev == optim_dev: # # Evaluate directly on x_param; objective() will call backward() # loss = model.objective( # x_param, device=cost_dev, **params # ) # if x_param.grad is None: # raise RuntimeError("objective() did not produce gradients on x_param.") # # --- log now (graph already freed by backward, but param still live) --- # mem_bits = [] # if cost_dev.type == "cuda": # torch.cuda.synchronize(cost_dev) # mem_bits.append(_gpu_mem_str(cost_dev)) # if optim_dev.type == "cuda" and (optim_dev.index != cost_dev.index): # torch.cuda.synchronize(optim_dev) # mem_bits.append(_gpu_mem_str(optim_dev)) # mem_info = " | ".join(mem_bits) # else: # # Cross-device: leaf copy on cost_dev; DO NOT free before logging # x_for_cost = x_param.detach().to(cost_dev).requires_grad_(True) # loss = model.objective( # x_for_cost, device=cost_dev, **params # ) # if x_for_cost.grad is None: # raise RuntimeError("objective() did not produce gradients on x_for_cost.") # # --- log BEFORE copying grad back / freeing x_for_cost --- # mem_bits = [] # if cost_dev.type == "cuda": # torch.cuda.synchronize(cost_dev) # mem_bits.append(_gpu_mem_str(cost_dev)) # if optim_dev.type == "cuda" and (optim_dev.index != cost_dev.index): # torch.cuda.synchronize(optim_dev) # mem_bits.append(_gpu_mem_str(optim_dev)) # mem_info = " | ".join(mem_bits) # # Now move grad off GPU and free the leaf # x_param.grad = x_for_cost.grad.to(optim_dev) # del x_for_cost # release AFTER logging so allocated != 0 # logger.info( # f"[PID {os.getpid()}] Iter cost: {float(loss.detach().cpu()):.6e} " # f"(optim_dev={optim_dev}, cost_dev={cost_dev})" # + (f" | {mem_info}" if mem_info else "") # ) # return loss # objective already did backward() # import time # if cost_dev.type == "cuda": # torch.cuda.synchronize(cost_dev) # t0 = time.perf_counter() # final_loss = opt.step(closure) # if cost_dev.type == "cuda": # torch.cuda.synchronize(cost_dev) # elapsed = time.perf_counter() - t0 # if optim_dev == cost_dev: # end_mem_info = _gpu_mem_str(cost_dev) if cost_dev.type == "cuda" else "" # else: # end_mem_info = " | ".join( # [_gpu_mem_str(d) for d in (cost_dev, optim_dev) if d.type == "cuda"] # ) # logger.info( # f"[Timing] LBFGS (optim_dev={optim_dev}, cost_dev={cost_dev}) " # f"took {elapsed:.2f} s; final loss={float(final_loss):.6g}" # + (f" | {end_mem_info}" if end_mem_info else "") # ) # result = x_param.detach().cpu().numpy().reshape(param_shape) # logger.warning("If you are using ASKAP's convention I = XX + YY (with no 1/2 factor) then multiply the output by 2. Here assuming I = 1/2 (XX + YY).") # # --- Unit conversion --- # if units == "Jy/arcsec^2": # output = result # elif units == "Jy/beam": # assumed_fwhm_pix = 3 #hard-coded value # logger.warning( # f"Converting to Jy/beam assuming a restoring beam of " # f"{assumed_fwhm_pix} × cell_size = " # f"{assumed_fwhm_pix * cell_size:.3f} FWHM." # ) # beam_r = Beam(assumed_fwhm_pix * cell_size, # assumed_fwhm_pix * cell_size, 1.e-12 * u.deg) # output = result * beam_r.sr.to(u.arcsec**2).value # elif units == "K": # nu_Hz = self.vis_data.frequency # output = dunits.jy_per_arcsec2_to_K(result, nu_Hz) # else: # logger.warning("Unknown unit type. Returning result in Jy/arcsec^2.") # output = result # #Success logger # logger.info("Successful run. Please clap.") # return output #------------------------------------ #------------ Imager ---------------- #------------------------------------
[docs] class Imager: """ A GPU-accelerated imager for joint deconvolution of interferometric and single-dish data. Parameters ---------- vis_data : object Visibility data structure containing uvw coordinates, visibilities, and beam info. pb : ndarray Primary beam model array. grid : ndarray Grid array for SIN projection evaluation. sd : ndarray Single-dish map used for zero-spacing constraint. beam_sd : radio_beam.Beam Beam object for the single-dish map. hdr : dict FITS header containing WCS and shape information. init_params : ndarray Initial parameters (not flattened). max_its : int Maximum number of iterations for the optimizer. lambda_sd : float Regularization strength for the single-dish constraint. lambda_r : float Regularization strength for the spatial prior (e.g., Laplacian). positivity : bool Whether to enforce a positivity constraint during optimization. device : int or str Device to use: 0 for GPU, 'cpu' for CPU. beam_workers : int Number of workers for parallel beam convolution. """ def __init__(self, vis_data, pb, grid, sd, beam_sd, hdr, init_params, max_its, lambda_sd, positivity, device, beam_workers): super(Imager, self).__init__() self.vis_data = vis_data self.pb = pb self.grid = grid self.sd = sd self.beam_sd = beam_sd self.hdr = hdr self.init_params = init_params self.max_its = max_its self.lambda_sd = lambda_sd self.positivity = positivity self.beam_workers = beam_workers logger.info("[Initialize Imager ]") logger.info(f"Number of iterations to be performed by the optimizer: {self.max_its}") # Logger for hyper-parameters if self.lambda_sd == 0: logger.warning("lambda_sd = 0 - No short spacing correction (ignoring single dish data).") # Check if CUDA is found on the machine and fall back on CPU otherwise self.device = self.get_device(device) if self.device == 0: logger.info(f"Using GPU device: {torch.cuda.get_device_name(device)}") if self.device == "cpu": logger.info(f"Using {self.beam_workers} workers for beam parallelisation.")
[docs] @staticmethod def get_device(user_device): """ Selects the appropriate compute device (CPU or GPU) based on availability and user request. Parameters ---------- user_device : int or str 0 to request GPU, otherwise uses CPU. Returns ------- torch.device The selected torch device. """ if user_device == 0: # User requested GPU try: if torch.cuda.is_available(): device = torch.device("cuda:0") logger.info(f"Using GPU device: {torch.cuda.get_device_name(0)}") else: raise RuntimeError("CUDA not available.") except RuntimeError as e: logger.warning(f"{e} Falling back on CPU.") device = torch.device("cpu") else: device = torch.device("cpu") logger.info("Using CPU.") return device
[docs] def process_beam_positions(self): """ Determines the first and last indices for each beam in the visibility dataset. Returns ------- idmin : ndarray of int First occurrence index for each beam. idmax : ndarray of int Last occurrence index (exclusive upper bound) for each beam. """ # nb = len(self.vis_data.coords) # idmin = np.zeros(nb); idmax = np.zeros(nb) # for i in np.arange(nb): # idmin[i] = np.where(self.vis_data.beam == i)[0][0]; # idmax[i] = len(np.where(self.vis_data.beam == i)[0])#-1 nb = len(self.vis_data.coords) # Find unique beam indices and their first occurrence unique_beams, first_idx = np.unique(self.vis_data.beam, return_index=True) # Get counts of occurrences per beam beam_counts = np.bincount(self.vis_data.beam, minlength=nb) # Initialize arrays idmin = np.zeros(nb, dtype=int) idmax = np.zeros(nb, dtype=int) # Assign first index idmin[unique_beams] = first_idx # Assign (count - 1) for each beam idmax[unique_beams] = beam_counts[unique_beams] #- 1 return idmin, idmax
[docs] def forward_model(self, model, x=None): """ Compute model visibilities from an input image using the provided model's forward operator. Parameters ---------- model : object A model instance (e.g., ClassicIViS) that implements a `.forward(...)` method to simulate visibilities from image-domain parameters. x : np.ndarray or None, optional Image parameters to forward-project. If None, uses ``self.init_params``. Returns ------- model_vis : np.ndarray Complex model visibilities, one per (u,v) coordinate in the data. Raises ------ ValueError If no model is provided. Notes ----- - Converts spatial frequencies to units of radians per pixel based on image header. - Uses internal primary beam and interpolation grid arrays. - Forwards all necessary inputs to the model's `forward` method. """ if model is None: raise ValueError("You must pass a model instance to `forward_model()`.") # Image parameters cell_size = (self.hdr["CDELT2"] * u.deg).to(u.arcsec) shape = (self.hdr["NAXIS2"], self.hdr["NAXIS1"]) # Convert λ to radians per pixel uu_radpix = dunits._lambda_to_radpix(self.vis_data.uu, cell_size) vv_radpix = dunits._lambda_to_radpix(self.vis_data.vv, cell_size) ww_radpix = dunits._lambda_to_radpix(self.vis_data.ww, cell_size) # if needed # Get beam slice indices idmin, idmax = self.process_beam_positions() # Native arrays pb_native = np.asarray(self.pb, dtype=np.float32) grid_native = np.asarray(self.grid, dtype=np.float32) x_model = self.init_params if x is None else x return model.forward( x=x_model, data=self.vis_data.data, uu=uu_radpix, vv=vv_radpix, ww=self.vis_data.ww, pb=pb_native, idmina=idmin, idmaxa=idmax, device=self.device, cell_size=cell_size.value, grid_array=grid_native )
[docs] def process(self, model=None, units="Jy/arcsec^2", disk=False): """ Runs the imaging optimization pipeline and returns a restored image in the requested unit. Parameters ---------- model : object An imaging model instance implementing a `.loss(x, ...)` method compatible with scipy.optimize.minimize. units : str Output unit. Must be one of: 'Jy/arcsec^2', 'Jy/beam', or 'K'. disk : bool, optional If True, writes intermediate results to disk (currently unused). Returns ------- result : ndarray Restored image in the requested unit. """ #Image parameters cell_size = (self.hdr["CDELT2"] *u.deg).to(u.arcsec) shape = (self.hdr["NAXIS2"], self.hdr["NAXIS1"]) #tapper for apodization tapper = dutils.apodize(0.98, shape) #Convert lambda to radian per pixel uu_radpix = dunits._lambda_to_radpix(self.vis_data.uu, cell_size) vv_radpix = dunits._lambda_to_radpix(self.vis_data.vv, cell_size) ww_radpix = dunits._lambda_to_radpix(self.vis_data.ww, cell_size) #Build kernel for regularization kernel_map = dutils.laplacian(shape) fftkernel = abs(fft2(kernel_map)) #generate fftbeam bmaj = self.beam_sd.major.value cdelt2 = cell_size.to(u.deg).value bmaj_pix = bmaj / cdelt2 beam = dutils.gauss_beam(bmaj_pix, shape, FWHM=True) fftbeam = abs((fft2(beam))) #fft single-dish map fftsd = cell_size.value**2 * tfft2(torch.from_numpy(np.float32(self.sd))).numpy() #Get idx beams in array # logger.info("Processing beams position..") idmin, idmax = self.process_beam_positions() #define bounds for optimisation param_shape = self.init_params.shape # (H, W) or (2, H, W) if self.positivity == False: bounds = dutils.ROHSA_bounds(data_shape=param_shape, lb_amp=-np.inf, ub_amp=np.inf) else: bounds = dutils.ROHSA_bounds(data_shape=param_shape, lb_amp=0, ub_amp=np.inf) # Use gradient-descent to minimise cost logger.info('Starting optimisation (using LBFGS-B)') if self.positivity == True: logger.info('Optimizer bounded - Positivity == True') logger.warning('Optimizer bounded - Because there is noise in the data, it is generally not recommanded to add a positivity constaint.') else: logger.info('Optimizer not bounded - Positivity == False') # Precompute type conversions (done once) params_f32 = self.init_params.ravel().astype(np.float32) beam_f32 = np.asarray(self.vis_data.beam, dtype=np.float32) fftbeam_f32 = np.asarray(fftbeam, dtype=np.float32) data_c64 = np.asarray(self.vis_data.data, dtype=np.complex64) uu_f32 = np.asarray(uu_radpix, dtype=np.float32) vv_f32 = np.asarray(vv_radpix, dtype=np.float32) ww_f32 = np.asarray(self.vis_data.ww, dtype=np.float32) #Not radpix - original pb_f32 = np.asarray(self.pb, dtype=np.float32) idmin_i32 = np.asarray(idmin, dtype=np.int32) idmax_i32 = np.asarray(idmax, dtype=np.int32) sigma_f32 = np.asarray(self.vis_data.sigma, dtype=np.float32) fftsd_c64 = np.asarray(fftsd, dtype=np.complex64) tapper_f32 = np.asarray(tapper, dtype=np.float32) fftkernel_f32 = np.asarray(fftkernel, dtype=np.float32) grid_f32 = np.asarray(self.grid, dtype=np.float32) # ---- Precompute params ---- params = dict( beam=beam_f32, fftbeam=fftbeam_f32, data=data_c64, uu=uu_f32, vv=vv_f32, ww=ww_f32, pb=pb_f32, idmina=idmin_i32, idmaxa=idmax_i32, sigma=sigma_f32, fftsd=fftsd_c64, tapper=tapper_f32, lambda_sd=self.lambda_sd, fftkernel=fftkernel_f32, cell_size=cell_size.value, grid_array=grid_f32, beam_workers=self.beam_workers ) # ---- Define closure for optimization ---- def objective_flat(x): return model.loss(x, shape=shape, device=device, **params) shape = self.init_params.shape device = self.device options = { 'maxiter': self.max_its, 'maxfun': int(1e6), 'iprint': 25, } if model is None: logger.error("You must pass a model instance (e.g., ClassicIViS) to `process()`.") raise ValueError("Missing model input.") if not hasattr(model, 'loss'): logger.error("Provided model does not implement a `.loss(x, ...)` method compatible with scipy.optimize.minimize.") raise TypeError("Invalid model type.") # ---- Run optimizer ---- opt_output = optimize.minimize( objective_flat, params_f32, jac=True, tol=1.e-8, bounds=bounds, method='L-BFGS-B', options=options ) # logger.info(opt_output) result = np.reshape(opt_output.x, self.init_params.shape) #* 2 logger.warning("multiply by 2 for ASKAP.") #unit conversion if units == "Jy/arcsec^2": return result # if units == "Jy/beam": # logger.info("assuming a synthesized beam of 4.2857 x cell_size") # cell_size = (self.hdr["CDELT2"] *u.deg).to(u.arcsec) # beam_r = Beam(4.2857*cell_size, 4.2857*cell_size, 1.e-12*u.deg) # return result * (beam_r.sr).to(u.arcsec**2).value #Jy/arcsec^2 to Jy/beam elif units == "K": logger.info("assuming a synthesized beam of 3 x cell_size") cell_size = (self.hdr["CDELT2"] *u.deg).to(u.arcsec) nu = self.vis_data.frequency[0] *u.Hz beam_r = Beam(3*cell_size, 3*cell_size, 1.e-12*u.deg) result_Jy = result * (beam_r.sr).to(u.arcsec**2).value #Jy/arcsec^2 to Jy/beam return (result_Jy*u.Jy).to(u.K, u.brightness_temperature(nu, beam_r)).value else: logger.info("unit must be 'Jy/arcsec^2' or 'K'")