import os
import numpy as np
import torch
from torch.fft import fft2 as tfft2
from astropy.constants import c as c_light
from ivis.logger import logger
from ivis.models.base import BaseModel
from ivis.models.operators import forward_beam, resolve_pb_grid_lists
from ivis.models.utils.gpu import print_gpu_memory
# LRSB: Low-Rank Spectral Basis
[docs]
class LRSB(BaseModel):
"""
Low-rank spectral basis model driven by a user-supplied basis matrix.
The basis must have shape (nbasis, nchan).
"""
def __init__(
self,
basis,
lambda_r=1.0,
lambda_pos=0.0,
conj_data=True,
assume_channel_invariant_operator=False,
reference_channel=0,
):
basis_arr = np.asarray(basis, dtype=np.float32)
if basis_arr.ndim != 2:
raise ValueError("basis must have shape (nbasis, nchan).")
self.lambda_r = float(lambda_r)
self.lambda_pos = float(lambda_pos)
self.conj_data = conj_data
self.assume_channel_invariant_operator = bool(assume_channel_invariant_operator)
self.reference_channel = int(reference_channel)
self._basis_np = basis_arr.astype(np.float32, copy=False)
self._basis_cache = {}
@property
def nbasis(self):
return int(self._basis_np.shape[0])
@property
def nchan(self):
return int(self._basis_np.shape[1])
def _get_basis(self, device, dtype):
cache_key = (str(torch.device(device)), str(dtype))
cached = self._basis_cache.get(cache_key)
if cached is not None:
return cached
basis = torch.from_numpy(self._basis_np).to(device=device, dtype=dtype)
self._basis_cache[cache_key] = basis
return basis
[docs]
def reconstruct_cube(self, x, device=None, return_numpy=False):
if torch.is_tensor(x):
coeffs = x
if device is not None:
coeffs = coeffs.to(device)
else:
target_device = torch.device(device) if device is not None else torch.device("cpu")
coeffs = torch.as_tensor(x, device=target_device).float()
if coeffs.shape[0] != self.nbasis:
raise ValueError(
f"Expected coeffs.shape[0] == nbasis={self.nbasis}, got {coeffs.shape[0]}"
)
basis = self._get_basis(device=coeffs.device, dtype=coeffs.dtype)
cube = torch.einsum("kc,khw->chw", basis, coeffs)
if return_numpy:
return cube.detach().cpu().numpy()
return cube
[docs]
def reconstruct_cube_from_coeffs(self, coeffs, device=None, return_numpy=True):
return self.reconstruct_cube(coeffs, device=device, return_numpy=return_numpy)
def _lambda_r_for_basis(self, basis_index):
return float(self.lambda_r)
def _weights_for_channel(self, basis, channel_index):
return basis[:, channel_index]
def _reference_operator_inputs(self, vis_data, beam_index):
if self.reference_channel < 0 or self.reference_channel >= self.nchan:
raise ValueError(
f"reference_channel={self.reference_channel} is outside [0, {self.nchan - 1}]"
)
nv = int(vis_data.nvis[beam_index])
scale = vis_data.frequency[self.reference_channel] / c_light.value
uu_ref = vis_data.uu[beam_index, :nv] * scale
vv_ref = vis_data.vv[beam_index, :nv] * scale
ww_ref = vis_data.ww[beam_index, :nv] * scale
return uu_ref, vv_ref, ww_ref, nv
def _build_invariant_beam_cache(
self,
x,
vis_data,
beam_index,
device,
primary_beam_list,
grid_list,
cell_size,
):
dev = torch.device(device)
uu_ref, vv_ref, ww_ref, nref = self._reference_operator_inputs(vis_data, beam_index)
hk_list = []
for k in range(self.nbasis):
hk_full = forward_beam(
x2d=x[k],
primary_beam=primary_beam_list[beam_index],
grid=grid_list[beam_index],
uu=uu_ref,
vv=vv_ref,
ww=ww_ref,
cell_size=cell_size,
device=dev,
)
if hk_full.numel() != nref:
raise ValueError(
f"Reference forward model size {hk_full.numel()} does not match beam visibility count {nref} for beam {beam_index}."
)
hk_list.append(hk_full)
return torch.stack(hk_list, dim=0)
def _prepare_channel_beam_blocks(self, vis_data, device):
dev = torch.device(device)
nchan, nbeam, _ = vis_data.data_I.shape
blocks = [[None for _ in range(nbeam)] for _ in range(nchan)]
for c in range(nchan):
for b in range(nbeam):
nv = int(vis_data.nvis[b])
good_np = ~np.asarray(vis_data.flag_I[c, b, :nv], dtype=bool)
if not np.any(good_np):
continue
I, sI, uu, vv, ww = vis_data.slice_chan_beam_I(c, b)
blocks[c][b] = {
"I": I,
"sI": sI,
"uu": uu,
"vv": vv,
"ww": ww,
"good_np": good_np,
"good_t": torch.as_tensor(good_np, device=dev, dtype=torch.bool),
}
return blocks
[docs]
def loss(self, x, shape, device, vis_data, **kwargs):
dev = torch.device(device)
u = x.reshape(shape)
u = torch.from_numpy(u).to(dev).requires_grad_(True)
L = self.objective(
x=u,
vis_data=vis_data,
device=device,
**kwargs,
)
grad = u.grad.detach().cpu().numpy().astype(x.dtype)
if dev.type == "cuda":
allocated = torch.cuda.memory_allocated(dev) / 1024**2
reserved = torch.cuda.memory_reserved(dev) / 1024**2
total = torch.cuda.get_device_properties(dev).total_memory / 1024**2
logger.info(
f"[PID {os.getpid()}] Total cost: {np.format_float_scientific(L.item(), precision=5)} | "
f"GPU: {allocated:.2f} MB allocated, {reserved:.2f} MB reserved, {total:.2f} MB total"
)
else:
logger.info(
f"[PID {os.getpid()}] Total cost: {np.format_float_scientific(L.item(), precision=5)}"
)
return L.item(), grad.ravel()
@torch.no_grad()
def forward(
self,
x,
vis_data,
device,
primary_beam_list=None,
primary_beam=None,
pb_list=None,
grid_list=None,
pb=None,
grid_array=None,
cell_size=None,
fill_flagged="zero",
):
dev = torch.device(device)
x = torch.as_tensor(x, device=dev).float()
primary_beam_list, grid_list = resolve_pb_grid_lists(
vis_data,
pb_list=primary_beam_list if primary_beam_list is not None else pb_list,
grid_list=grid_list,
pb=primary_beam if primary_beam is not None else pb,
grid_array=grid_array,
)
if not hasattr(vis_data, "data_I"):
raise ValueError("vis_data must have data_I to infer the output cube shape.")
nchan, nbeam, nvis = vis_data.data_I.shape
if nchan != self.nchan:
raise ValueError(f"Expected nchan={self.nchan} from basis, got {nchan}.")
if x.shape[0] != self.nbasis:
raise ValueError(f"Expected x.shape[0] == nbasis={self.nbasis}, got {x.shape[0]}")
basis = self._get_basis(device=dev, dtype=x.dtype)
out = np.zeros((nchan, nbeam, nvis), dtype=np.complex64)
has_flags = hasattr(vis_data, "flag_I")
blocks = self._prepare_channel_beam_blocks(vis_data=vis_data, device=dev)
if self.assume_channel_invariant_operator:
for b in range(nbeam):
hk_stack = self._build_invariant_beam_cache(
x=x,
vis_data=vis_data,
beam_index=b,
device=dev,
primary_beam_list=primary_beam_list,
grid_list=grid_list,
cell_size=cell_size,
)
for c in range(nchan):
block = blocks[c][b]
if block is None:
continue
weights = basis[:, c]
model_vis = torch.einsum(
"k,kn->n",
weights.to(hk_stack.dtype),
hk_stack[:, block["good_t"]],
)
good = block["good_np"]
model_vis = model_vis.detach().cpu().numpy().astype(np.complex64)
nv = int(vis_data.nvis[b])
if nv > nvis:
raise ValueError(f"vis_data.nvis[{b}]={nv} exceeds cube width {nvis}.")
if model_vis.size != int(np.count_nonzero(good)):
raise ValueError(
f"Model visibility count {model_vis.size} does not match "
f"unflagged slot count {int(np.count_nonzero(good))} for channel {c}, beam {b}."
)
out[c, b, :nv][good] = model_vis
if has_flags and fill_flagged == "zero":
fl = np.asarray(vis_data.flag_I[c, b], dtype=bool)
if fl.shape == out[c, b].shape:
out[c, b][fl] = 0.0
del hk_stack
else:
for c in range(nchan):
weights = basis[:, c]
x2d = torch.sum(x * weights[:, None, None], dim=0)
for b in range(nbeam):
block = blocks[c][b]
if block is None:
continue
model_vis = forward_beam(
x2d=x2d,
primary_beam=primary_beam_list[b],
grid=grid_list[b],
uu=block["uu"],
vv=block["vv"],
ww=block["ww"],
cell_size=cell_size,
device=dev,
)
good = block["good_np"]
model_vis = model_vis.detach().cpu().numpy().astype(np.complex64)
nv = int(vis_data.nvis[b])
if nv > nvis:
raise ValueError(f"vis_data.nvis[{b}]={nv} exceeds cube width {nvis}.")
if model_vis.size != int(np.count_nonzero(good)):
raise ValueError(
f"Model visibility count {model_vis.size} does not match "
f"unflagged slot count {int(np.count_nonzero(good))} for channel {c}, beam {b}."
)
out[c, b, :nv][good] = model_vis
if has_flags and fill_flagged == "zero":
fl = np.asarray(vis_data.flag_I[c, b], dtype=bool)
if fl.shape == out[c, b].shape:
out[c, b][fl] = 0.0
return out
[docs]
def objective(
self,
x,
vis_data,
device,
primary_beam_list=None,
primary_beam=None,
pb_list=None,
grid_list=None,
pb=None,
grid_array=None,
cell_size=None,
fftsd=None,
fftbeam=None,
tapper=None,
lambda_sd=0.0,
lambda_pos=None,
fftkernel=None,
beam_workers=4,
verbose=False,
**_,
):
dev = torch.device(device)
if x.ndim != 3:
raise ValueError(f"x must have shape (nbasis, H, W), got {tuple(x.shape)}")
if x.shape[0] != self.nbasis:
raise ValueError(f"Expected x.shape[0] == nbasis={self.nbasis}, got {x.shape[0]}")
dtype = x.dtype
primary_beam_list, grid_list = resolve_pb_grid_lists(
vis_data,
pb_list=primary_beam_list if primary_beam_list is not None else pb_list,
grid_list=grid_list,
pb=primary_beam if primary_beam is not None else pb,
grid_array=grid_array,
)
nchan = vis_data.frequency.shape[0]
nbeam = vis_data.uu.shape[0]
if nchan != self.nchan:
raise ValueError(f"Expected nchan={self.nchan} from basis, got {nchan}.")
loss = torch.tensor(0.0, device=dev, dtype=dtype)
lambda_pos = self.lambda_pos if lambda_pos is None else float(lambda_pos)
tapper_t = torch.from_numpy(tapper).to(dev, dtype=dtype) if tapper is not None else None
fftbeam_t = torch.from_numpy(fftbeam).to(dev) if fftbeam is not None else None
fftkernel_t = torch.from_numpy(fftkernel).to(dev) if fftkernel is not None else None
basis = self._get_basis(device=dev, dtype=dtype)
blocks = self._prepare_channel_beam_blocks(vis_data=vis_data, device=dev)
if self.assume_channel_invariant_operator:
need_sd_fft = lambda_sd > 0.0 and fftsd is not None and fftbeam is not None
channel_cache = []
for c in range(nchan):
weights = self._weights_for_channel(basis, c)
need_x2d = (lambda_pos > 0.0) or need_sd_fft
x2d = torch.sum(x * weights[:, None, None], dim=0) if need_x2d else None
xfft2 = tfft2(x2d * tapper_t) if need_sd_fft else None
channel_cache.append((weights, x2d, xfft2))
if lambda_pos > 0.0:
loss = loss + lambda_pos * torch.sum(torch.clamp(-x2d, min=0.0) ** 2)
if need_sd_fft:
fftsd_c = fftsd if fftsd.ndim == 2 else fftsd[c]
fftsd_t = torch.from_numpy(fftsd_c).to(dev)
model_sd = (cell_size**2) * xfft2 * fftbeam_t
Lsd = 0.5 * (
torch.nansum((model_sd.real - fftsd_t.real) ** 2)
+ torch.nansum((model_sd.imag - fftsd_t.imag) ** 2)
) * lambda_sd
loss = loss + Lsd
for b in range(nbeam):
hk_stack = self._build_invariant_beam_cache(
x=x,
vis_data=vis_data,
beam_index=b,
device=dev,
primary_beam_list=primary_beam_list,
grid_list=grid_list,
cell_size=cell_size,
)
for c in range(nchan):
block = blocks[c][b]
if block is None:
continue
weights, _, _ = channel_cache[c]
model_vis = torch.einsum(
"k,kn->n",
weights.to(hk_stack.dtype),
hk_stack[:, block["good_t"]],
)
I_use = block["I"].conj() if self.conj_data else block["I"]
vis_real = torch.from_numpy(I_use.real).to(dev)
vis_imag = torch.from_numpy(I_use.imag).to(dev)
sig = torch.from_numpy(block["sI"]).to(dev)
sig = torch.clamp(sig, min=1e-6)
residual_real = (model_vis.real - vis_real) / sig
residual_imag = (model_vis.imag - vis_imag) / sig
J = torch.sum(residual_real**2 + residual_imag**2)
loss = loss + 0.5 * J
if verbose:
print_gpu_memory(device)
del hk_stack
else:
for c in range(nchan):
weights = self._weights_for_channel(basis, c)
need_sd_fft = lambda_sd > 0.0 and fftsd is not None and fftbeam is not None
x2d = torch.sum(x * weights[:, None, None], dim=0)
if lambda_pos > 0.0:
loss = loss + lambda_pos * torch.sum(torch.clamp(-x2d, min=0.0) ** 2)
xfft2 = None
if need_sd_fft:
xfft2 = tfft2(x2d * tapper_t)
for b in range(nbeam):
block = blocks[c][b]
if block is None:
continue
model_vis = forward_beam(
x2d=x2d,
primary_beam=primary_beam_list[b],
grid=grid_list[b],
uu=block["uu"],
vv=block["vv"],
ww=block["ww"],
cell_size=cell_size,
device=dev,
)
I_use = block["I"].conj() if self.conj_data else block["I"]
vis_real = torch.from_numpy(I_use.real).to(dev)
vis_imag = torch.from_numpy(I_use.imag).to(dev)
sig = torch.from_numpy(block["sI"]).to(dev)
sig = torch.clamp(sig, min=1e-6)
residual_real = (model_vis.real - vis_real) / sig
residual_imag = (model_vis.imag - vis_imag) / sig
J = torch.sum(residual_real**2 + residual_imag**2)
loss = loss + 0.5 * J
if verbose:
print_gpu_memory(device)
if lambda_sd > 0.0 and fftsd is not None:
fftsd_c = fftsd if fftsd.ndim == 2 else fftsd[c]
fftsd_t = torch.from_numpy(fftsd_c).to(dev)
model_sd = (cell_size**2) * xfft2 * fftbeam_t
Lsd = 0.5 * (
torch.nansum((model_sd.real - fftsd_t.real) ** 2)
+ torch.nansum((model_sd.imag - fftsd_t.imag) ** 2)
) * lambda_sd
loss = loss + Lsd
if self.lambda_r > 0.0 and fftkernel is not None:
for k in range(self.nbasis):
lambda_r_k = self._lambda_r_for_basis(k)
if lambda_r_k <= 0.0:
continue
coeff_fft2 = tfft2(x[k] * tapper_t)
conv = (cell_size**2) * coeff_fft2 * fftkernel_t
Lr = 0.5 * torch.nansum(torch.abs(conv) ** 2) * lambda_r_k
loss = loss + Lr
loss.backward()
return loss
[docs]
class LRSB_C(LRSB):
"""
LRSB variant with explicit continuum basis functions.
This augments the learned line basis with fixed smooth spectral modes.
By default, it adds a single flat continuum mode psi_0(nu) = 1.
"""
def __init__(
self,
basis,
continuum_basis=None,
continuum_order=0,
frequency=None,
reference_frequency=None,
continuum_only_channels=None,
lambda_r_line_factor=1.0,
lambda_r_cont_factor=1.0,
**kwargs,
):
line_basis = np.asarray(basis, dtype=np.float32)
if line_basis.ndim != 2:
raise ValueError("basis must have shape (nbasis, nchan).")
continuum_arr = self._prepare_continuum_basis(
nchan=line_basis.shape[1],
continuum_basis=continuum_basis,
continuum_order=continuum_order,
frequency=frequency,
reference_frequency=reference_frequency,
line_basis=line_basis,
)
self._line_basis_np = line_basis.astype(np.float32, copy=False)
self._continuum_basis_np = continuum_arr.astype(np.float32, copy=False)
self._continuum_order = int(continuum_arr.shape[0] - 1)
self._reference_frequency = (
None if reference_frequency is None else float(reference_frequency)
)
self._continuum_only_channels = self._prepare_channel_mask(
nchan=line_basis.shape[1],
channels=continuum_only_channels,
)
self.lambda_r_line_factor = float(lambda_r_line_factor)
self.lambda_r_cont_factor = float(lambda_r_cont_factor)
hybrid_basis = np.concatenate((self._line_basis_np, self._continuum_basis_np), axis=0)
super().__init__(basis=hybrid_basis, **kwargs)
@staticmethod
def _prepare_continuum_basis(
nchan,
continuum_basis,
continuum_order,
frequency,
reference_frequency,
line_basis,
):
if continuum_basis is None:
order = int(continuum_order)
if order < 0:
raise ValueError("continuum_order must be >= 0.")
if order == 0:
continuum_arr = np.ones((1, nchan), dtype=np.float32)
else:
if frequency is None:
raise ValueError("frequency is required when continuum_order > 0.")
freq = np.asarray(frequency, dtype=np.float32)
if freq.ndim != 1 or freq.shape[0] != nchan:
raise ValueError(f"frequency must have shape ({nchan},), got {freq.shape}.")
nu_ref = float(np.mean(freq) if reference_frequency is None else reference_frequency)
if nu_ref == 0.0:
raise ValueError("reference_frequency must be non-zero.")
xnu = (freq - nu_ref) / nu_ref
continuum_arr = np.stack([xnu**m for m in range(order + 1)], axis=0).astype(
np.float32,
copy=False,
)
else:
continuum_arr = np.asarray(continuum_basis, dtype=np.float32)
if continuum_arr.ndim == 1:
continuum_arr = continuum_arr[None, :]
if continuum_arr.ndim != 2:
raise ValueError("continuum_basis must have shape (ncont, nchan).")
if continuum_arr.shape[1] != nchan:
raise ValueError(
f"continuum_basis must have nchan={nchan}, got {continuum_arr.shape[1]}."
)
return continuum_arr
@staticmethod
def _prepare_channel_mask(nchan, channels):
mask = np.zeros(nchan, dtype=bool)
if channels is None:
return mask
idx = np.asarray(channels)
if idx.dtype == bool:
if idx.shape != (nchan,):
raise ValueError(
f"Boolean continuum_only_channels mask must have shape ({nchan},), got {idx.shape}."
)
return idx.astype(bool, copy=True)
idx = np.asarray(idx, dtype=np.int64).ravel()
if idx.size == 0:
return mask
if np.any(idx < 0) or np.any(idx >= nchan):
raise ValueError(f"continuum_only_channels must be within [0, {nchan - 1}].")
mask[idx] = True
return mask
@property
def line_nbasis(self):
return int(self._line_basis_np.shape[0])
@property
def continuum_nbasis(self):
return int(self._continuum_basis_np.shape[0])
@property
def continuum_basis(self):
return self._continuum_basis_np
@property
def continuum_order(self):
return self._continuum_order
@property
def reference_frequency(self):
return self._reference_frequency
@property
def continuum_only_channels(self):
return self._continuum_only_channels.copy()
def _lambda_r_for_basis(self, basis_index):
if basis_index < self.line_nbasis:
return float(self.lambda_r) * self.lambda_r_line_factor
return float(self.lambda_r) * self.lambda_r_cont_factor
[docs]
def split_coeffs(self, x):
if torch.is_tensor(x):
return x[: self.line_nbasis], x[self.line_nbasis :]
coeffs = np.asarray(x)
return coeffs[: self.line_nbasis], coeffs[self.line_nbasis :]
[docs]
def reconstruct_line_cube(self, x, device=None, return_numpy=False):
line_coeffs, _ = self.split_coeffs(x)
line_model = LRSB(basis=self._line_basis_np)
return line_model.reconstruct_cube(line_coeffs, device=device, return_numpy=return_numpy)
[docs]
def reconstruct_continuum_cube(self, x, device=None, return_numpy=False):
_, continuum_coeffs = self.split_coeffs(x)
continuum_model = LRSB(basis=self._continuum_basis_np)
return continuum_model.reconstruct_cube(
continuum_coeffs, device=device, return_numpy=return_numpy
)
def _weights_for_channel(self, basis, channel_index):
weights = basis[:, channel_index]
if not self._continuum_only_channels[channel_index]:
return weights
line_zeros = torch.zeros(
self.line_nbasis,
device=weights.device,
dtype=weights.dtype,
)
return torch.cat((line_zeros, weights[self.line_nbasis :]), dim=0)
[docs]
class LRSBMemory(LRSB):
"""
Memory-streaming LRSB variant.
LRSB stores a smaller coefficient cube than Classic3D, but its objective
still accumulates one large autograd graph by default. This variant
backpropagates independent loss blocks as soon as they are computed.
"""
def _backward_loss(self, loss, loss_value):
loss.backward()
return loss_value + loss.detach()
def _channel_image(self, x, weights):
return torch.sum(x * weights[:, None, None], dim=0)
[docs]
def objective(
self,
x,
vis_data,
device,
primary_beam_list=None,
primary_beam=None,
pb_list=None,
grid_list=None,
pb=None,
grid_array=None,
cell_size=None,
fftsd=None,
fftbeam=None,
tapper=None,
lambda_sd=0.0,
lambda_pos=None,
fftkernel=None,
beam_workers=4,
verbose=False,
**_,
):
dev = torch.device(device)
if x.ndim != 3:
raise ValueError(f"x must have shape (nbasis, H, W), got {tuple(x.shape)}")
if x.shape[0] != self.nbasis:
raise ValueError(f"Expected x.shape[0] == nbasis={self.nbasis}, got {x.shape[0]}")
x.requires_grad_(True)
if x.is_leaf and x.grad is not None:
x.grad.zero_()
dtype = x.dtype
primary_beam_list, grid_list = resolve_pb_grid_lists(
vis_data,
pb_list=primary_beam_list if primary_beam_list is not None else pb_list,
grid_list=grid_list,
pb=primary_beam if primary_beam is not None else pb,
grid_array=grid_array,
)
nchan = vis_data.frequency.shape[0]
nbeam = vis_data.uu.shape[0]
if nchan != self.nchan:
raise ValueError(f"Expected nchan={self.nchan} from basis, got {nchan}.")
loss_value = torch.zeros((), device=dev, dtype=dtype)
lambda_pos = self.lambda_pos if lambda_pos is None else float(lambda_pos)
tapper_t = torch.from_numpy(tapper).to(dev, dtype=dtype) if tapper is not None else None
fftbeam_t = torch.from_numpy(fftbeam).to(dev) if fftbeam is not None else None
fftkernel_t = torch.from_numpy(fftkernel).to(dev) if fftkernel is not None else None
basis = self._get_basis(device=dev, dtype=dtype)
blocks = self._prepare_channel_beam_blocks(vis_data=vis_data, device=dev)
need_sd_fft = lambda_sd > 0.0 and fftsd is not None and fftbeam is not None
for c in range(nchan):
weights = self._weights_for_channel(basis, c)
if lambda_pos > 0.0:
x2d = self._channel_image(x, weights)
Lpos = lambda_pos * torch.sum(torch.clamp(-x2d, min=0.0) ** 2)
loss_value = self._backward_loss(Lpos, loss_value)
del x2d, Lpos
if need_sd_fft:
fftsd_c = fftsd if fftsd.ndim == 2 else fftsd[c]
fftsd_t = torch.from_numpy(fftsd_c).to(dev)
x2d = self._channel_image(x, weights)
xfft2 = tfft2(x2d * tapper_t)
model_sd = (cell_size**2) * xfft2 * fftbeam_t
Lsd = 0.5 * (
torch.nansum((model_sd.real - fftsd_t.real) ** 2)
+ torch.nansum((model_sd.imag - fftsd_t.imag) ** 2)
) * lambda_sd
loss_value = self._backward_loss(Lsd, loss_value)
del fftsd_t, x2d, xfft2, model_sd, Lsd
if self.assume_channel_invariant_operator:
for b in range(nbeam):
hk_stack = self._build_invariant_beam_cache(
x=x,
vis_data=vis_data,
beam_index=b,
device=dev,
primary_beam_list=primary_beam_list,
grid_list=grid_list,
cell_size=cell_size,
)
beam_loss = torch.zeros((), device=dev, dtype=dtype)
for c in range(nchan):
block = blocks[c][b]
if block is None:
continue
weights = self._weights_for_channel(basis, c)
model_vis = torch.einsum(
"k,kn->n",
weights.to(hk_stack.dtype),
hk_stack[:, block["good_t"]],
)
I_use = block["I"].conj() if self.conj_data else block["I"]
vis_real = torch.from_numpy(I_use.real).to(dev)
vis_imag = torch.from_numpy(I_use.imag).to(dev)
sig = torch.from_numpy(block["sI"]).to(dev)
sig = torch.clamp(sig, min=1e-6)
residual_real = (model_vis.real - vis_real) / sig
residual_imag = (model_vis.imag - vis_imag) / sig
J = torch.sum(residual_real**2 + residual_imag**2)
beam_loss = beam_loss + 0.5 * J
if verbose:
print_gpu_memory(device)
del model_vis, vis_real, vis_imag, sig, residual_real, residual_imag, J
if beam_loss.requires_grad:
loss_value = self._backward_loss(beam_loss, loss_value)
else:
loss_value = loss_value + beam_loss.detach()
del hk_stack, beam_loss
else:
for c in range(nchan):
weights = self._weights_for_channel(basis, c)
for b in range(nbeam):
block = blocks[c][b]
if block is None:
continue
x2d = self._channel_image(x, weights)
model_vis = forward_beam(
x2d=x2d,
primary_beam=primary_beam_list[b],
grid=grid_list[b],
uu=block["uu"],
vv=block["vv"],
ww=block["ww"],
cell_size=cell_size,
device=dev,
)
I_use = block["I"].conj() if self.conj_data else block["I"]
vis_real = torch.from_numpy(I_use.real).to(dev)
vis_imag = torch.from_numpy(I_use.imag).to(dev)
sig = torch.from_numpy(block["sI"]).to(dev)
sig = torch.clamp(sig, min=1e-6)
residual_real = (model_vis.real - vis_real) / sig
residual_imag = (model_vis.imag - vis_imag) / sig
J = torch.sum(residual_real**2 + residual_imag**2)
block_loss = 0.5 * J
loss_value = self._backward_loss(block_loss, loss_value)
if verbose:
print_gpu_memory(device)
del (
x2d,
model_vis,
vis_real,
vis_imag,
sig,
residual_real,
residual_imag,
J,
block_loss,
)
if self.lambda_r > 0.0 and fftkernel is not None:
for k in range(self.nbasis):
lambda_r_k = self._lambda_r_for_basis(k)
if lambda_r_k <= 0.0:
continue
coeff_fft2 = tfft2(x[k] * tapper_t)
conv = (cell_size**2) * coeff_fft2 * fftkernel_t
Lr = 0.5 * torch.nansum(torch.abs(conv) ** 2) * lambda_r_k
loss_value = self._backward_loss(Lr, loss_value)
del coeff_fft2, conv, Lr
return loss_value
[docs]
class LRSB_CMemory(LRSBMemory, LRSB_C):
"""
Memory-streaming LRSB_C variant.
This combines the hybrid line+continuum basis construction from LRSB_C
with the blockwise backward pass from LRSBMemory.
"""
def __init__(
self,
basis,
continuum_basis=None,
continuum_order=0,
frequency=None,
reference_frequency=None,
continuum_only_channels=None,
lambda_r_line_factor=1.0,
lambda_r_cont_factor=1.0,
**kwargs,
):
LRSB_C.__init__(
self,
basis=basis,
continuum_basis=continuum_basis,
continuum_order=continuum_order,
frequency=frequency,
reference_frequency=reference_frequency,
continuum_only_channels=continuum_only_channels,
lambda_r_line_factor=lambda_r_line_factor,
lambda_r_cont_factor=lambda_r_cont_factor,
**kwargs,
)