Source code for ivis.models.classic3D

import os
import numpy as np
import torch
from torch.fft import fft2 as tfft2

from ivis.logger import logger
from ivis.models.base import BaseModel
from ivis.models.operators import (
    backward_beam,
    forward_beam,
    resolve_pb_grid_lists,
)
from ivis.models.utils.gpu import print_gpu_memory


[docs] class Classic3D(BaseModel): def __init__(self, lambda_r=1, use_2pi=True, conj_data=True): self.lambda_r = lambda_r self.use_2pi = use_2pi self.conj_data = conj_data self._interp_cache = {}
[docs] def loss(self, x, shape, device, vis_data, **kwargs): dev = torch.device(device) x_cpu = x.reshape(shape) grad_cpu = np.zeros_like(x_cpu, dtype=x.dtype) loss_scalar = 0.0 for c, vis_data_c in vis_data.iter_single_channel(copy=False): u = torch.from_numpy(x_cpu[c:c+1]).to(dev).requires_grad_(True) L = self.objective( x=u, vis_data=vis_data_c, device=device, **kwargs, ) grad_cpu[c] = u.grad[0].detach().cpu().numpy().astype(x.dtype) loss_scalar += L.item() del u, L if dev.type == "cuda": torch.cuda.empty_cache() if dev.type == "cuda": allocated = torch.cuda.memory_allocated(dev) / 1024**2 reserved = torch.cuda.memory_reserved(dev) / 1024**2 total = torch.cuda.get_device_properties(dev).total_memory / 1024**2 logger.info( f"[PID {os.getpid()}] Total cost: {np.format_float_scientific(loss_scalar, precision=5)} | " f"GPU: {allocated:.2f} MB allocated, {reserved:.2f} MB reserved, {total:.2f} MB total" ) else: logger.info( f"[PID {os.getpid()}] Total cost: {np.format_float_scientific(loss_scalar, precision=5)}" ) return loss_scalar, grad_cpu.ravel()
@torch.no_grad() def forward( self, x, vis_data, device, primary_beam_list=None, primary_beam=None, pb_list=None, grid_list=None, pb=None, grid_array=None, cell_size=None, fill_flagged="zero", ): dev = torch.device(device) x = torch.as_tensor(x, device=dev).float() primary_beam_list, grid_list = resolve_pb_grid_lists( vis_data, pb_list=primary_beam_list if primary_beam_list is not None else pb_list, grid_list=grid_list, pb=primary_beam if primary_beam is not None else pb, grid_array=grid_array, ) if not hasattr(vis_data, "data_I"): raise ValueError("vis_data must have data_I to infer the output cube shape.") nchan, nbeam, nvis = vis_data.data_I.shape out = np.zeros((nchan, nbeam, nvis), dtype=np.complex64) has_flags = hasattr(vis_data, "flag_I") for c, b, Icb, sI, uu, vv, ww in vis_data.iter_chan_beam_I(): model_vis = forward_beam( x2d=x[c], primary_beam=primary_beam_list[b], grid=grid_list[b], uu=uu, vv=vv, ww=ww, cell_size=cell_size, device=dev, ).detach().cpu().numpy().astype(np.complex64) nv = int(vis_data.nvis[b]) if nv > nvis: raise ValueError(f"vis_data.nvis[{b}]={nv} exceeds cube width {nvis}.") good = ~np.asarray(vis_data.flag_I[c, b, :nv], dtype=bool) if model_vis.size != int(np.count_nonzero(good)): raise ValueError( f"Model visibility count {model_vis.size} does not match " f"unflagged slot count {int(np.count_nonzero(good))} for channel {c}, beam {b}." ) out[c, b, :nv][good] = model_vis if has_flags and fill_flagged == "zero": fl = np.asarray(vis_data.flag_I[c, b], dtype=bool) if fl.shape == out[c, b].shape: out[c, b][fl] = 0.0 return out @torch.no_grad() def backward( self, vis, vis_data, device, x_shape=None, primary_beam_list=None, primary_beam=None, pb_list=None, grid_list=None, pb=None, grid_array=None, cell_size=None, return_real=False, ): dev = torch.device(device) primary_beam_list, grid_list = resolve_pb_grid_lists( vis_data, pb_list=primary_beam_list if primary_beam_list is not None else pb_list, grid_list=grid_list, pb=primary_beam if primary_beam is not None else pb, grid_array=grid_array, ) if x_shape is None: if vis_data is None or not hasattr(vis_data, "data_I"): raise ValueError("Need x_shape or vis_data.data_I to infer image cube shape.") x_shape = (vis_data.data_I.shape[0],) + tuple(np.asarray(primary_beam_list[0]).shape) if len(x_shape) != 3: raise ValueError(f"x_shape must be (nchan, H, W), got {x_shape}") nchan, height, width = x_shape vis_in = vis_data.data_I if vis is None else vis use_cube = hasattr(vis_in, "shape") and tuple(vis_in.shape) == tuple(vis_data.data_I.shape) use_flat = not use_cube if use_flat: flat_vis = np.asarray(vis_in, dtype=np.complex64).reshape(-1) offset = 0 result = torch.zeros((nchan, height, width), dtype=torch.complex64, device=dev) for c, b, Icb, sI, uu, vv, ww in vis_data.iter_chan_beam_I(): if use_cube: nv = int(vis_data.nvis[b]) flg = np.asarray(vis_data.flag_I[c, b, :nv], dtype=bool) y = np.asarray(vis_in[c, b, :nv], dtype=np.complex64)[~flg] else: block_size = Icb.size y = flat_vis[offset:offset + block_size] if y.size != block_size: raise ValueError("Flat visibility vector is shorter than expected from vis_data.") offset += block_size result[c] = result[c] + backward_beam( y=y, primary_beam=primary_beam_list[b], grid=grid_list[b], uu=uu, vv=vv, ww=ww, cell_size=cell_size, image_shape=(height, width), device=dev, cache_store=self._interp_cache, ) if use_flat and offset != flat_vis.size: raise ValueError("Flat visibility vector is longer than expected from vis_data.") result = result.real if return_real else result return result.detach().cpu().numpy().astype(np.float32 if return_real else np.complex64)
[docs] def objective( self, x, vis_data, device, primary_beam_list=None, primary_beam=None, pb_list=None, grid_list=None, pb=None, grid_array=None, cell_size=None, fftsd=None, fftbeam=None, tapper=None, lambda_sd=0.0, lambda_pos=0.0, fftkernel=None, beam_workers=4, verbose=False, **_, ): x.requires_grad_(True) if x.is_leaf and x.grad is not None: x.grad.zero_() primary_beam_list, grid_list = resolve_pb_grid_lists( vis_data, pb_list=primary_beam_list if primary_beam_list is not None else pb_list, grid_list=grid_list, pb=primary_beam if primary_beam is not None else pb, grid_array=grid_array, ) loss_value = torch.zeros((), dtype=x.dtype, device=device) for c, b, I, sI, uu, vv, ww in vis_data.iter_chan_beam_I(): model_vis = forward_beam( x2d=x[c], primary_beam=primary_beam_list[b], grid=grid_list[b], uu=uu, vv=vv, ww=ww, cell_size=cell_size, device=device, ) I_use = I.conj() if self.conj_data else I vis_real = torch.from_numpy(I_use.real).to(device) vis_imag = torch.from_numpy(I_use.imag).to(device) sig = torch.from_numpy(sI).to(device) residual_real = (model_vis.real - vis_real) / sig residual_imag = (model_vis.imag - vis_imag) / sig J = torch.sum(residual_real**2 + residual_imag**2) block_loss = 0.5 * J block_loss.backward() loss_value = loss_value + block_loss.detach() if verbose: print_gpu_memory(device) del model_vis, vis_real, vis_imag, sig, residual_real, residual_imag, J, block_loss if lambda_sd > 0.0 and fftsd is not None: fftsd_t = torch.from_numpy(fftsd).to(device) fftbeam_t = torch.from_numpy(fftbeam).to(device) tapper_t = torch.from_numpy(tapper).to(device) for c in range(x.shape[0]): fftsd_c = fftsd_t[c] if fftsd_t.ndim == x.ndim else fftsd_t fftbeam_c = fftbeam_t[c] if fftbeam_t.ndim == x.ndim else fftbeam_t tapper_c = tapper_t[c] if tapper_t.ndim == x.ndim else tapper_t xfft2 = tfft2(x[c] * tapper_c) model_sd = (cell_size**2) * xfft2 * fftbeam_c Lsd = 0.5 * ( torch.nansum((model_sd.real - fftsd_c.real) ** 2) + torch.nansum((model_sd.imag - fftsd_c.imag) ** 2) ) * lambda_sd Lsd.backward() loss_value = loss_value + Lsd.detach() del fftsd_c, fftbeam_c, tapper_c, xfft2, model_sd, Lsd del fftsd_t, fftbeam_t, tapper_t if self.lambda_r > 0.0 and fftkernel is not None: tapper_t = torch.from_numpy(tapper).to(device) fftkernel_t = torch.from_numpy(fftkernel).to(device) for c in range(x.shape[0]): fftkernel_c = fftkernel_t[c] if fftkernel_t.ndim == x.ndim else fftkernel_t tapper_c = tapper_t[c] if tapper_t.ndim == x.ndim else tapper_t xfft2 = tfft2(x[c] * tapper_c) conv = (cell_size**2) * xfft2 * fftkernel_c Lr = 0.5 * torch.nansum(torch.abs(conv) ** 2) * self.lambda_r Lr.backward() loss_value = loss_value + Lr.detach() del fftkernel_c, tapper_c, xfft2, conv, Lr del tapper_t, fftkernel_t return loss_value
[docs] class Classic3DHighMemory(Classic3D): """ Classic3D variant that accumulates the full objective graph before backward. This can be faster for small problems, but peak memory grows with the number of beam blocks processed by each objective call. """
[docs] def objective( self, x, vis_data, device, primary_beam_list=None, primary_beam=None, pb_list=None, grid_list=None, pb=None, grid_array=None, cell_size=None, fftsd=None, fftbeam=None, tapper=None, lambda_sd=0.0, lambda_pos=0.0, fftkernel=None, beam_workers=4, verbose=False, **_, ): x.requires_grad_(True) if x.is_leaf and x.grad is not None: x.grad.zero_() primary_beam_list, grid_list = resolve_pb_grid_lists( vis_data, pb_list=primary_beam_list if primary_beam_list is not None else pb_list, grid_list=grid_list, pb=primary_beam if primary_beam is not None else pb, grid_array=grid_array, ) loss = 0.0 for c, b, I, sI, uu, vv, ww in vis_data.iter_chan_beam_I(): model_vis = forward_beam( x2d=x[c], primary_beam=primary_beam_list[b], grid=grid_list[b], uu=uu, vv=vv, ww=ww, cell_size=cell_size, device=device, ) I_use = I.conj() if self.conj_data else I vis_real = torch.from_numpy(I_use.real).to(device) vis_imag = torch.from_numpy(I_use.imag).to(device) sig = torch.from_numpy(sI).to(device) residual_real = (model_vis.real - vis_real) / sig residual_imag = (model_vis.imag - vis_imag) / sig J = torch.sum(residual_real**2 + residual_imag**2) loss = loss + 0.5 * J if verbose: print_gpu_memory(device) if lambda_sd > 0.0 and fftsd is not None: fftsd_t = torch.from_numpy(fftsd).to(device) fftbeam_t = torch.from_numpy(fftbeam).to(device) tapper_t = torch.from_numpy(tapper).to(device) xfft2 = tfft2(x * tapper_t) model_sd = (cell_size**2) * xfft2 * fftbeam_t Lsd = 0.5 * ( torch.nansum((model_sd.real - fftsd_t.real) ** 2) + torch.nansum((model_sd.imag - fftsd_t.imag) ** 2) ) * lambda_sd loss = loss + Lsd if self.lambda_r > 0.0 and fftkernel is not None: tapper_t = torch.from_numpy(tapper).to(device) fftkernel_t = torch.from_numpy(fftkernel).to(device) xfft2 = tfft2(x * tapper_t) conv = (cell_size**2) * xfft2 * fftkernel_t Lr = 0.5 * torch.nansum(torch.abs(conv) ** 2) * self.lambda_r loss = loss + Lr loss.backward() return loss
# import os # import numpy as np # import torch # from torch.fft import fft2 as tfft2 # from ivis.logger import logger # from ivis.models.base import BaseModel # from ivis.models.operators import ( # backward_beam, # forward_beam, # resolve_pb_grid_lists, # ) # from ivis.models.utils.gpu import print_gpu_memory # class Classic3D(BaseModel): # def __init__(self, lambda_r=1, Nw=None, use_2pi=True, conj_data=True): # self.lambda_r = lambda_r # self.Nw = None # self.use_2pi = use_2pi # self.conj_data = conj_data # self._interp_cache = {} # def loss(self, x, shape, device, vis_data, **kwargs): # u = x.reshape(shape) # u = torch.from_numpy(u).to(device).requires_grad_(True) # L = self.objective(x=u, vis_data=vis_data, device=device, **kwargs) # grad = u.grad.cpu().numpy().astype(x.dtype) # if torch.device(device).type == "cuda": # allocated = torch.cuda.memory_allocated(device) / 1024**2 # reserved = torch.cuda.memory_reserved(device) / 1024**2 # total = torch.cuda.get_device_properties(device).total_memory / 1024**2 # logger.info( # f"[PID {os.getpid()}] Total cost: {np.format_float_scientific(L.item(), precision=5)} | " # f"GPU: {allocated:.2f} MB allocated, {reserved:.2f} MB reserved, {total:.2f} MB total" # ) # else: # logger.info(f"[PID {os.getpid()}] Total cost: {np.format_float_scientific(L.item(), precision=5)}") # return L.item(), grad.ravel() # @torch.no_grad() # def forward( # self, # x, # vis_data, # device, # primary_beam_list=None, # primary_beam=None, # pb_list=None, # grid_list=None, # pb=None, # grid_array=None, # cell_size=None, # fill_flagged="zero", # ): # dev = torch.device(device) # x = torch.as_tensor(x, device=dev).float() # primary_beam_list, grid_list = resolve_pb_grid_lists( # vis_data, # pb_list=primary_beam_list if primary_beam_list is not None else pb_list, # grid_list=grid_list, # pb=primary_beam if primary_beam is not None else pb, # grid_array=grid_array, # ) # if not hasattr(vis_data, "data_I"): # raise ValueError("vis_data must have data_I to infer the output cube shape.") # nchan, nbeam, nvis = vis_data.data_I.shape # out = np.zeros((nchan, nbeam, nvis), dtype=np.complex64) # has_flags = hasattr(vis_data, "flag_I") # for c, b, Icb, sI, uu, vv, ww in vis_data.iter_chan_beam_I(): # model_vis = forward_beam( # x2d=x[c], # primary_beam=primary_beam_list[b], # grid=grid_list[b], # uu=uu, # vv=vv, # ww=ww, # cell_size=cell_size, # device=dev, # ).detach().cpu().numpy().astype(np.complex64) # nv = int(vis_data.nvis[b]) # if nv > nvis: # raise ValueError(f"vis_data.nvis[{b}]={nv} exceeds cube width {nvis}.") # good = ~np.asarray(vis_data.flag_I[c, b, :nv], dtype=bool) # if model_vis.size != int(np.count_nonzero(good)): # raise ValueError( # f"Model visibility count {model_vis.size} does not match " # f"unflagged slot count {int(np.count_nonzero(good))} for channel {c}, beam {b}." # ) # out[c, b, :nv][good] = model_vis # if has_flags and fill_flagged == "zero": # fl = np.asarray(vis_data.flag_I[c, b], dtype=bool) # if fl.shape == out[c, b].shape: # out[c, b][fl] = 0.0 # return out # @torch.no_grad() # def backward( # self, # vis, # vis_data, # device, # x_shape=None, # primary_beam_list=None, # primary_beam=None, # pb_list=None, # grid_list=None, # pb=None, # grid_array=None, # cell_size=None, # return_real=False, # ): # dev = torch.device(device) # primary_beam_list, grid_list = resolve_pb_grid_lists( # vis_data, # pb_list=primary_beam_list if primary_beam_list is not None else pb_list, # grid_list=grid_list, # pb=primary_beam if primary_beam is not None else pb, # grid_array=grid_array, # ) # if x_shape is None: # if vis_data is None or not hasattr(vis_data, "data_I"): # raise ValueError("Need x_shape or vis_data.data_I to infer image cube shape.") # x_shape = (vis_data.data_I.shape[0],) + tuple(np.asarray(primary_beam_list[0]).shape) # if len(x_shape) != 3: # raise ValueError(f"x_shape must be (nchan, H, W), got {x_shape}") # nchan, height, width = x_shape # vis_in = vis_data.data_I if vis is None else vis # use_cube = hasattr(vis_in, "shape") and tuple(vis_in.shape) == tuple(vis_data.data_I.shape) # use_flat = not use_cube # if use_flat: # flat_vis = np.asarray(vis_in, dtype=np.complex64).reshape(-1) # offset = 0 # result = torch.zeros((nchan, height, width), dtype=torch.complex64, device=dev) # for c, b, Icb, sI, uu, vv, ww in vis_data.iter_chan_beam_I(): # if use_cube: # nv = int(vis_data.nvis[b]) # flg = np.asarray(vis_data.flag_I[c, b, :nv], dtype=bool) # y = np.asarray(vis_in[c, b, :nv], dtype=np.complex64)[~flg] # else: # block_size = Icb.size # y = flat_vis[offset:offset + block_size] # if y.size != block_size: # raise ValueError("Flat visibility vector is shorter than expected from vis_data.") # offset += block_size # result[c] = result[c] + backward_beam( # y=y, # primary_beam=primary_beam_list[b], # grid=grid_list[b], # uu=uu, # vv=vv, # ww=ww, # cell_size=cell_size, # image_shape=(height, width), # device=dev, # cache_store=self._interp_cache, # ) # if use_flat and offset != flat_vis.size: # raise ValueError("Flat visibility vector is longer than expected from vis_data.") # result = result.real if return_real else result # return result.detach().cpu().numpy().astype(np.float32 if return_real else np.complex64) # def objective( # self, # x, # vis_data, # device, # primary_beam_list=None, # primary_beam=None, # pb_list=None, # grid_list=None, # pb=None, # grid_array=None, # cell_size=None, # fftsd=None, # fftbeam=None, # tapper=None, # lambda_sd=0.0, # lambda_pos=0.0, # fftkernel=None, # beam_workers=4, # verbose=False, # **_, # ): # x.requires_grad_(True) # if x.is_leaf and x.grad is not None: # x.grad.zero_() # primary_beam_list, grid_list = resolve_pb_grid_lists( # vis_data, # pb_list=primary_beam_list if primary_beam_list is not None else pb_list, # grid_list=grid_list, # pb=primary_beam if primary_beam is not None else pb, # grid_array=grid_array, # ) # loss_scalar = 0.0 # for c, b, I, sI, uu, vv, ww in vis_data.iter_chan_beam_I(): # model_vis = forward_beam( # x2d=x[c], # primary_beam=primary_beam_list[b], # grid=grid_list[b], # uu=uu, # vv=vv, # ww=ww, # cell_size=cell_size, # device=device, # ) # I_use = I.conj() if self.conj_data else I # vis_real = torch.from_numpy(I_use.real).to(device) # vis_imag = torch.from_numpy(I_use.imag).to(device) # sig = torch.from_numpy(sI).to(device) # residual_real = (model_vis.real - vis_real) / sig # residual_imag = (model_vis.imag - vis_imag) / sig # J = torch.sum(residual_real**2 + residual_imag**2) # L = 0.5 * J # L.backward(retain_graph=True) # loss_scalar += L.item() # if verbose: # print_gpu_memory(device) # if lambda_sd > 0.0 and fftsd is not None: # fftsd_t = torch.from_numpy(fftsd).to(device) # fftbeam_t = torch.from_numpy(fftbeam).to(device) # tapper_t = torch.from_numpy(tapper).to(device) # xfft2 = tfft2(x * tapper_t) # model_sd = (cell_size**2) * xfft2 * fftbeam_t # Lsd = 0.5 * (torch.nansum((model_sd.real - fftsd_t.real) ** 2) + # torch.nansum((model_sd.imag - fftsd_t.imag) ** 2)) * lambda_sd # Lsd.backward(retain_graph=True) # loss_scalar += Lsd.item() # if self.lambda_r > 0.0 and fftkernel is not None: # tapper_t = torch.from_numpy(tapper).to(device) # fftkernel_t = torch.from_numpy(fftkernel).to(device) # xfft2 = tfft2(x * tapper_t) # conv = (cell_size**2) * xfft2 * fftkernel_t # Lr = 0.5 * torch.nansum(torch.abs(conv) ** 2) * self.lambda_r # Lr.backward() # loss_scalar += Lr.item() # return torch.tensor(loss_scalar)