Source code for ivis.readers.ms_casacore

# -*- coding: utf-8 -*-
import os, re, glob, contextlib, numpy as np
from dataclasses import dataclass
from typing import Iterator, Tuple, List, Union, Sequence, Optional
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import get_context
from pathlib import Path

from casacore.tables import table
from astropy.constants import c as c_light
from astropy.coordinates import Angle, SkyCoord
import astropy.units as u
from astropy.wcs import WCS
from astropy.io.fits import Header

from ivis.logger import logger
from ivis.readers.base import Reader
from ivis.types import VisIData

# -------------------- utils --------------------

@contextlib.contextmanager
def _quiet_tables():
    """Silence casacore 'Successful readonly open...' spam while opening tables."""
    with open(os.devnull, "w") as devnull:
        with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
            yield

def _list_block_dirs(ms_root: str) -> list[str]:
    """
    Return a list of block directories under ms_root. A block directory is any
    directory (ms_root itself or its immediate subdirs) that contains at least one *.ms.

    Accepts a single *.ms path (treated as a one-MS block).
    """
    # --- NEW: allow passing a single MS directly
    if ms_root.lower().endswith(".ms"):
        if not os.path.isdir(ms_root):
            raise FileNotFoundError(f"MS not found: {ms_root}")
        return [ms_root]

    blocks = []

    # root itself as a block?
    if glob.glob(os.path.join(ms_root, "*.ms")):
        blocks.append(ms_root)

    # immediate subdirs
    for name in sorted(os.listdir(ms_root)):
        sub = os.path.join(ms_root, name)
        if os.path.isdir(sub) and glob.glob(os.path.join(sub, "*.ms")):
            blocks.append(sub)

    if not blocks:
        raise FileNotFoundError(
            f"No *.ms found in {ms_root} or its immediate subdirectories."
        )
    return blocks


def _check_same_freq_grid(blocks: list["VisIData"]) -> None:
    """Ensure all blocks have identical frequency arrays (required for concat)."""
    ref = blocks[0].frequency
    for i, b in enumerate(blocks[1:], start=1):
        if b.frequency.shape != ref.shape or not np.allclose(b.frequency, ref, rtol=0, atol=1e-6):
            raise ValueError(
                f"Frequency grids differ between blocks (block 0 vs {i}). "
                "Use the same SPW/chan_sel or resample first."
            )

def _natkey(name: str):
    # Natural sort helper: "file2" < "file10"
    return [int(s) if s.isdigit() else s.lower() for s in re.split(r'(\d+)', name)]


def _beamkey_by_c10(path: str):
    """
    Prefer sorting by the number after 'C10_' in the filename, e.g. MW-C10_5_... -> 5.
    Fallback to a general natural sort by full filename.
    """
    name = Path(path).name
    m = re.search(r'[Cc]10[_-](\d+)', name)
    if m:
        return (0, int(m.group(1)), name.lower())
    # optional: also recognize 'BEAM_###'
    m2 = re.search(r'[Bb][Ee][Aa][Mm][_-]?(\d+)', name)
    if m2:
        return (0, int(m2.group(1)), name.lower())
    return (1, *_natkey(name))  # fallback: natural sort on the whole name


def _list_ms_sorted(ms_dir: str):
    """List .ms directories in the same order as raw `ls` (lexicographic).
    Accepts either a directory or a single .ms path.
    """
    # --- NEW: allow passing a single MS directly
    if ms_dir.lower().endswith(".ms"):
        if not os.path.isdir(ms_dir):
            raise FileNotFoundError(f"MS not found: {ms_dir}")
        return [ms_dir]

    # --- Original behavior (unchanged)
    items = []
    with os.scandir(ms_dir) as it:
        for de in it:
            if de.name.startswith('.'):
                continue
            if de.is_dir() and de.name.lower().endswith('.ms'):
                items.append(os.path.join(ms_dir, de.name))
    if not items:
        raise FileNotFoundError(f"No .ms found under {ms_dir}")
    return sorted(items)  # plain lexicographic, like `ls`


def _list_ms_sorted_natural(ms_dir: str):
    """List .ms directories and sort by beam number (C10_#), then natural name."""
    items = []
    with os.scandir(ms_dir) as it:
        for de in it:
            if de.name.startswith('.'):
                continue
            if de.is_dir() and de.name.lower().endswith('.ms'):
                items.append(os.path.join(ms_dir, de.name))
    if not items:
        raise FileNotFoundError(f"No .ms found under {ms_dir}")
    return sorted(items, key=_beamkey_by_c10)


def _list_block_dirs_sorted(ms_root: str):
    out = []
    with os.scandir(ms_root) as it:
        for de in it:
            if de.name.startswith('.'):
                continue
            if de.is_dir():
                out.append(os.path.join(ms_root, de.name))
    return sorted(out, key=lambda p: _natkey(Path(p).name))


def _phasecenter(ms_path: str) -> SkyCoord:
    with _quiet_tables():
        with table(f"{ms_path}/FIELD", readonly=True) as t:
            ra_rad, dec_rad = t.getcol("PHASE_DIR")[0, 0, :]
    ra_hms  = Angle(ra_rad, unit=u.rad).to_string(unit=u.hourangle, sep=":")
    dec_dms = Angle(dec_rad, unit=u.rad).to_string(unit=u.deg,       sep=":")
    return SkyCoord(ra_hms, dec_dms, unit=(u.hourangle, u.deg), frame="icrs")


def _freqs(ms_path: str, rest_freq: float | u.Quantity = 1.42040575177e9 * u.Hz):
    """
    Read frequencies from a MeasurementSet and compute corresponding velocities.

    Parameters
    ----------
    ms_path : str
        Path to the MeasurementSet.
    rest_freq : float or Quantity, optional
        Rest frequency of the spectral line [Hz]. Default: 1.42040575177 GHz (H I 21cm line).

    Returns
    -------
    freqs : ndarray (float64)
        Channel frequencies [Hz].
    vel   : ndarray (float64)
        Velocities relative to rest frequency [km/s].
    """
    with _quiet_tables():
        with table(f"{ms_path}/SPECTRAL_WINDOW", readonly=True) as t:
            freqs = np.atleast_1d(
                np.squeeze(t.getcol("CHAN_FREQ"))
            ).astype(np.float64)  # Hz

    # normalize rest_freq to Quantity
    rest_freq = rest_freq * u.Hz if np.isscalar(rest_freq) else rest_freq.to(u.Hz)

    vel_q = ((rest_freq - freqs * u.Hz) / rest_freq * c_light)   # m/s
    vel   = vel_q.to_value(u.km/u.s).astype(np.float64)          # km/s

    return freqs, vel

# -------- helpers --------

def _centers_equal(c0, c1, tol_deg: float) -> bool:
    """
    Compare 'center' objects with a tiny tolerance in degrees.
    Supports:
      - SkyCoord-like with .ra.deg/.dec.deg
      - (ra_deg, dec_deg) tuples/lists
      - Fallback to equality for strings/other identifiers.
    """
    try:
        if hasattr(c0, "ra") and hasattr(c0, "dec") and hasattr(c1, "ra") and hasattr(c1, "dec"):
            return (abs(float(c0.ra.deg)  - float(c1.ra.deg))  <= tol_deg and
                    abs(float(c0.dec.deg) - float(c1.dec.deg)) <= tol_deg)
        if isinstance(c0, (tuple, list)) and isinstance(c1, (tuple, list)) and len(c0) == 2 and len(c1) == 2:
            return (abs(float(c0[0]) - float(c1[0])) <= tol_deg and
                    abs(float(c0[1]) - float(c1[1])) <= tol_deg)
    except Exception:
        pass
    return c0 == c1

def _norm_radius(radius) -> Angle:
    """Accept Angle, Quantity, or float (deg) and return Angle."""
    if isinstance(radius, Angle):
        return radius
    if isinstance(radius, u.Quantity):
        return Angle(radius)
    # assume float in degrees
    return Angle(radius, unit=u.deg)

def _within_radius(center: SkyCoord, radius: Angle, coord: SkyCoord) -> bool:
    return coord.icrs.separation(center.icrs) <= radius

def _normalize_target_header_from_header(hdr: Header):
    """
    Returns (celestial WCS, (ny, nx)) from a FITS Header that may include
    FREQ/STOKES axes. Uses only the celestial part.
    """
    # Use celestial slice: works even if header advertises 3–4 axes
    wcs_cel = WCS(hdr).celestial
    nx = int(hdr["NAXIS1"])
    ny = int(hdr["NAXIS2"])
    return wcs_cel, (ny, nx)

def _coord_in_image(wcs_cel: WCS, shape: tuple, coord: SkyCoord) -> bool:
    """
    Returns True if coord projects inside image footprint.
    shape = (ny, nx). Uses zero-based pixel coords.
    """
    ny, nx = shape
    c_icrs = coord.icrs
    x, y = wcs_cel.world_to_pixel(c_icrs)  # origin=0 convention
    return (0 <= x <= nx - 1) and (0 <= y <= ny - 1)


def _normalize_beam_sel(beam_sel, nbeam_total: int) -> np.ndarray:
    """
    Return a sorted, unique array of beam indices in [0, nbeam_total).
    Accepts: None, int, slice, Sequence[int].
    """
    if beam_sel is None:
        return np.arange(nbeam_total, dtype=int)

    if isinstance(beam_sel, int):
        idx = np.array([beam_sel], dtype=int)

    elif isinstance(beam_sel, slice):
        idx = np.arange(nbeam_total, dtype=int)[beam_sel]

    else:
        idx = np.asarray(list(beam_sel), dtype=int)

    if idx.size == 0:
        raise ValueError("beam_sel selects 0 beams")

    if np.any(idx < 0) or np.any(idx >= nbeam_total):
        raise IndexError(f"beam_sel has indices outside [0, {nbeam_total-1}]")

    # unique + sorted (keeps deterministic ordering)
    return np.unique(idx)


# ----------------------------- Public loader ------------------------------

def _read_one_ms(ms_path: str,
                 chan_idx: np.ndarray,
                 keep_autocorr: bool,
                 prefer_weight_spectrum: bool):
    """
    Worker: read a single .ms, return per-beam arrays (no UV filtering here).
    """
    logger.info(f"    [MS] Opening: {ms_path}")

    with _quiet_tables():
        with table(ms_path, readonly=True) as T:
            UVW = T.getcol("UVW")                   # (nrow, 3) meters
            A1  = T.getcol("ANTENNA1"); A2 = T.getcol("ANTENNA2")
            has_ws = ("WEIGHT_SPECTRUM" in T.colnames()) and prefer_weight_spectrum

            # Read only selected channels
            if chan_idx.size == 1 or np.all(np.diff(chan_idx) == 1):
                ch0 = int(chan_idx[0]); ch1 = int(chan_idx[-1])
                DATA = T.getcolslice("DATA",  blc=[ch0, 0], trc=[ch1, -1])   # (nrow, nchan, npol)
                FLAG = T.getcolslice("FLAG",  blc=[ch0, 0], trc=[ch1, -1])
                if has_ws:
                    W = T.getcolslice("WEIGHT_SPECTRUM", blc=[ch0, 0], trc=[ch1, -1])
                else:
                    SIGMA = T.getcol("SIGMA")                                 # (nrow, npol)
            else:
                d_blocks=[]; f_blocks=[]; w_blocks=[]
                for ch in chan_idx.tolist():
                    d_blocks.append(T.getcolslice("DATA", blc=[ch,0], trc=[ch,-1])[:,None,:])
                    f_blocks.append(T.getcolslice("FLAG", blc=[ch,0], trc=[ch,-1])[:,None,:])
                    if has_ws:
                        w_blocks.append(T.getcolslice("WEIGHT_SPECTRUM", blc=[ch,0], trc=[ch,-1])[:,None,:])
                DATA = np.concatenate(d_blocks, axis=1)
                FLAG = np.concatenate(f_blocks, axis=1)
                if has_ws:
                    W = np.concatenate(w_blocks, axis=1)
                else:
                    SIGMA = T.getcol("SIGMA")

    # Row mask: remove autocorr only (per-channel uv mask happens in parent)
    row_mask = np.ones(UVW.shape[0], dtype=bool)
    if not keep_autocorr:
        row_mask &= (A1 != A2)

    UVW  = UVW[row_mask]
    DATA = DATA[row_mask]
    FLAG = FLAG[row_mask]
    if has_ws:
        W = W[row_mask]
    else:
        SIGMA = SIGMA[row_mask]

    nrow2, nchan_chk, npol = DATA.shape
    # Stokes I + noise
    if npol == 1:
        I  = DATA[..., 0]                # (nrow2, nchan)
        fI = FLAG[..., 0]
        if has_ws:
            eps = 1e-12
            sI = 1.0 / np.sqrt(np.maximum(W[..., 0], eps))
        else:
            row_sig = SIGMA[:, 0]
            sI = np.repeat(row_sig[:, None], nchan_chk, 1)
    else:
        p0, p1 = 0, -1
        I  = 0.5 * (DATA[..., p0] + DATA[..., p1])
        fI = (FLAG[..., p0] | FLAG[..., p1])
        if has_ws:
            eps  = 1e-12
            sig0 = 1.0 / np.sqrt(np.maximum(W[..., p0], eps))
            sig1 = 1.0 / np.sqrt(np.maximum(W[..., p1], eps))
            sI = 0.5 * np.sqrt(sig0**2 + sig1**2)
        else:
            row_sig0 = SIGMA[:, p0]; row_sig1 = SIGMA[:, p1]
            row_sI   = 0.5 * np.sqrt(row_sig0**2 + row_sig1**2)
            sI = np.repeat(row_sI[:, None], nchan_chk, 1)

    center = _phasecenter(ms_path)

    # Return (channel-major is done in parent after UV mask)
    out = {
        "uu": UVW[:, 0].astype(np.float32),
        "vv": UVW[:, 1].astype(np.float32),
        "ww": UVW[:, 2].astype(np.float32),
        "I":  I.astype(np.complex64),        # (nrow2, nchan)
        "sI": sI.astype(np.float32),         # (nrow2, nchan)
        "fI": fI.astype(bool),               # (nrow2, nchan)
        "center": center,
        "nrow": UVW.shape[0],
    }
    logger.info(f"    [MS] Done: {os.path.basename(ms_path)}  rows={out['nrow']}")
    return out


[docs] def read_ms_block_I( ms_dir: str, uvmin: float = 0.0, # in wavelengths (desired) uvmax: float = float("inf"), # in wavelengths (desired) chan_sel=None, # None | slice | list[int] | np.ndarray[int] rest_freq: float = 1.42040575177e9, # HI rest frequency as default value in unit of Hz keep_autocorr: bool = False, prefer_weight_spectrum: bool = True, n_workers: int = 0, # 0/1 = serial; >1 = parallel per-MS target_center: "SkyCoord | None" = None, target_radius: "Angle | u.Quantity | float | None" = None, # float ⇒ degrees beam_sel=None, # NEW: int | list[int] | slice | None ) -> "VisIData": """ Load a directory of .ms files (one per beam) into an I-only, channel-major VisIData. Primary-beam order is preserved **even if** some beams are skipped: we do not remove beam slots; we simply avoid reading DATA for out-of-radius beams and leave their slots empty (nvis=0, flag=True). """ # ---------------- 1) Deterministic beam list ---------------- ms_list = _list_ms_sorted(ms_dir) # or _list_ms_sorted_natural(ms_dir) nbeam_total = len(ms_list) sel = _normalize_beam_sel(beam_sel, nbeam_total) ms_list = [ms_list[i] for i in sel] logger.info( f"[BLOCK] Loading {len(ms_list)}/{nbeam_total} beam(s) from: {ms_dir} " # f"(beam_sel={sel.tolist()})" ) # # ---------------- 1) Deterministic beam list (do NOT filter) ---------------- # ms_list = _list_ms_sorted(ms_dir) # logger.info(f"[BLOCK] Loading {len(ms_list)} beam(s) from: {ms_dir}") # ---------------- 2) Channel selection (unchanged) ---------------- all_freq, all_vel = _freqs(ms_list[0], rest_freq) nchan_total = all_freq.size if chan_sel is None: chan_idx = np.arange(nchan_total, dtype=int) elif isinstance(chan_sel, slice): chan_idx = np.arange(nchan_total, dtype=int)[chan_sel] else: chan_idx = np.asarray(chan_sel, dtype=int) if chan_idx.size == 0: raise ValueError(f"chan_sel selects 0 channels; available nchan={nchan_total}") frequency = all_freq[chan_idx] # (nchan,) velocity = all_vel [chan_idx] # (nchan,) nchan = int(frequency.size) # ---------------- 3) Decide which beams to actually read ---------------- def _norm_radius(R): from astropy.coordinates import Angle import astropy.units as u if R is None: return None if isinstance(R, Angle): return R if hasattr(R, "unit"): return Angle(R) # Quantity return Angle(R, unit=u.deg) # float -> deg def _within_radius(cen, R, coord): return coord.icrs.separation(cen.icrs) <= R nbeam = len(ms_list) centers = [None] * nbeam keep_mask = [True] * nbeam # default: keep R = _norm_radius(target_radius) if (target_center is not None and target_radius is not None) else None for i, ms in enumerate(ms_list): c = _phasecenter(ms) # cheap FIELD read centers[i] = c if R is not None and not _within_radius(target_center, R, c): keep_mask[i] = False logger.info(f"[SKIP-READ] {os.path.basename(ms)} outside {R.to_string()} of " f"{target_center.to_string('hmsdms')}") # ---------------- 4) Read only the kept beams (preserve positions) ---------------- uu_list = [None] * nbeam vv_list = [None] * nbeam ww_list = [None] * nbeam I_list = [None] * nbeam sI_list = [None] * nbeam fI_list = [None] * nbeam if n_workers and n_workers > 1: logger.info(f"[BLOCK] Parallel read with {n_workers} workers (order-preserving; selective)") ctx = get_context("fork") # 'spawn' on Windows with ProcessPoolExecutor(max_workers=n_workers, mp_context=ctx) as ex: futs = {ex.submit(_read_one_ms, ms_list[i], chan_idx, keep_autocorr, prefer_weight_spectrum): i for i in range(nbeam) if keep_mask[i]} for fut in as_completed(futs): i = futs[fut] out = fut.result() uu_list[i] = out["uu"]; vv_list[i] = out["vv"]; ww_list[i] = out["ww"] I_list[i] = out["I"]; sI_list[i] = out["sI"]; fI_list[i] = out["fI"] else: logger.info("[BLOCK] Serial read (order-preserving; selective)") for i in range(nbeam): if not keep_mask[i]: continue # leave None -> empty slot out = _read_one_ms(ms_list[i], chan_idx, keep_autocorr, prefer_weight_spectrum) uu_list[i] = out["uu"]; vv_list[i] = out["vv"]; ww_list[i] = out["ww"] I_list[i] = out["I"]; sI_list[i] = out["sI"]; fI_list[i] = out["fI"] # ---------------- 5) Pack to dense (nchan, nbeam, nvis_max) ---------------- nvis = np.zeros((nbeam,), dtype=np.int32) for i in range(nbeam): nvis[i] = 0 if uu_list[i] is None else int(uu_list[i].shape[0]) nvis_max = int(nvis.max()) # can be 0 if all beams skipped data_I = np.zeros((nchan, nbeam, nvis_max), dtype=np.complex64) sigma_I = np.zeros((nchan, nbeam, nvis_max), dtype=np.float32) flag_I = np.ones( (nchan, nbeam, nvis_max), dtype=bool) uu = np.zeros((nbeam, nvis_max), dtype=np.float32) vv = np.zeros_like(uu); ww = np.zeros_like(uu) # Per-channel UV mask and channel-major transpose for b in range(nbeam): if nvis[b] == 0: continue UVW0 = uu_list[b]; UVW1 = vv_list[b]; UVW2 = ww_list[b] I = I_list[b] sI = sI_list[b] fI = fI_list[b] # (nrow_b, nchan) bl_m = np.sqrt(UVW0**2 + UVW1**2 + UVW2**2) # (nrow_b,) in_row = (bl_m >= uvmin) & (bl_m <= uvmax) # (nrow_b,) fI[~in_row, :] = True sI[~in_row, :] = np.inf # for b in range(nbeam): # if nvis[b] == 0: # continue # empty slot; already all flags=True # UVW0 = uu_list[b]; UVW1 = vv_list[b]; UVW2 = ww_list[b] # I = I_list[b]; sI = sI_list[b]; fI = fI_list[b] # (nrow_b, nchan) # # baseline length in meters for each row # bl_m = np.sqrt((UVW0**2 + UVW1**2 + UVW2**2))[:, None] # (nrow_b, 1) # # convert to wavelengths per channel # bl_lam = bl_m * (frequency[None, :] / c_light.value) # (nrow_b, nchan) # in_rng = (bl_m >= uvmin) & (bl_m <= uvmax) # fI |= ~in_rng # sI[~in_rng] = np.inf # to channel-major for this beam I_cb = I.transpose(1, 0).astype(np.complex64) # (nchan, nvis_b) sI_cb = sI.transpose(1, 0).astype(np.float32) # (nchan, nvis_b) fI_cb = fI.transpose(1, 0).astype(bool) # (nchan, nvis_b) nv = nvis[b] uu[b, :nv] = UVW0 vv[b, :nv] = UVW1 ww[b, :nv] = UVW2 data_I [:, b, :nv] = I_cb sigma_I[:, b, :nv] = sI_cb flag_I [:, b, :nv] = fI_cb centers = np.asarray(centers, dtype=object) kept_count = int(np.count_nonzero(keep_mask)) if R is not None else nbeam logger.info(f"[BLOCK] Done: nchan={nchan}, nbeam={nbeam}, nvis_max={nvis_max} " f"(read {kept_count}/{nbeam} beams)") return VisIData( frequency=frequency, velocity=velocity, centers=centers, # natural beam order by C10_# nvis=nvis, uu=uu, vv=vv, ww=ww, # meters (scaled to λ on demand) data_I=data_I, sigma_I=sigma_I, flag_I=flag_I, )
[docs] def read_ms_blocks_I( ms_root: str, uvmin: float = 0.0, uvmax: float = float("inf"), chan_sel=None, # None | slice | list[int] | np.ndarray[int] rest_freq: float = 1.42040575177e9, # HI rest frequency as default value in unit of Hz keep_autocorr: bool = False, prefer_weight_spectrum: bool = True, mode: str = "merge", # "merge" | "stack" | "separate" n_workers: int = 0, # 0/1 = serial; >1 = parallel per-MS target_center: "SkyCoord | None" = None, target_radius: "Angle | u.Quantity | float | None" = None, center_tol_deg: float = 1e-12, # tolerance for center equality when mode="merge" beam_sel=None, ) -> "VisIData | List[VisIData]": """ Load multiple ``blocks`` of observations located under ``ms_root``, then either: - ``merge``: concatenate vis per beam across blocks, assuming same beam order and centers - ``stack``: stack beams (Nblock × beams) - ``separate``: return a list of ``VisIData``, one per block Block discovery policy ---------------------- - If ``ms_root`` itself contains ``*.ms`` directly, it's treated as a single block. - Any immediate subdirectory of ``ms_root`` that contains ``*.ms`` is also a block. Returns ------- VisIData | list[VisIData] - ``merge``: beams equal to the number of unique centers (per order in block 0) - ``stack``: beams equal to the sum of beams across blocks - ``separate``: list of ``VisIData`` objects """ # ---------- discover and read blocks ---------- block_dirs = _list_block_dirs(ms_root) if not block_dirs: raise FileNotFoundError(f"No blocks found under {ms_root}") blocks: List[VisIData] = [] for bdir in block_dirs: logger.info(f"[BLOCK] Loading block from: {bdir}") vi = read_ms_block_I( bdir, uvmin=uvmin, uvmax=uvmax, chan_sel=chan_sel, rest_freq=rest_freq, keep_autocorr=keep_autocorr, prefer_weight_spectrum=prefer_weight_spectrum, n_workers=n_workers, target_center=target_center, target_radius=target_radius, beam_sel=beam_sel ) blocks.append(vi) if mode == "separate": return blocks if mode not in ("merge", "stack"): raise ValueError('mode must be one of: "merge", "stack", "separate"') # ---------- common sanity: same spectral grid ---------- _check_same_freq_grid(blocks) nchan = blocks[0].frequency.size # ---------- STACK: old behavior (Nblock × beams) ---------- if mode == "stack": total_beams = sum(b.uu.shape[0] for b in blocks) global_nvis_max = max(int(b.nvis.max()) for b in blocks) data_I = np.zeros((nchan, total_beams, global_nvis_max), dtype=np.complex64) sigma_I = np.zeros((nchan, total_beams, global_nvis_max), dtype=np.float32) flag_I = np.ones( (nchan, total_beams, global_nvis_max), dtype=bool) uu = np.zeros((total_beams, global_nvis_max), dtype=np.float32) vv = np.zeros_like(uu) ww = np.zeros_like(uu) nvis = np.zeros((total_beams,), dtype=np.int32) centers = np.empty((total_beams,), dtype=object) b_off = 0 for blk in blocks: nb = blk.uu.shape[0] for j in range(nb): nv = int(blk.nvis[j]) dst = b_off + j nvis[dst] = nv centers[dst] = blk.centers[j] uu[dst, :nv] = blk.uu[j, :nv] vv[dst, :nv] = blk.vv[j, :nv] ww[dst, :nv] = blk.ww[j, :nv] data_I [:, dst, :nv] = blk.data_I[:, j, :nv] sigma_I[:, dst, :nv] = blk.sigma_I[:, j, :nv] flag_I [:, dst, :nv] = blk.flag_I[:, j, :nv] b_off += nb return VisIData( frequency=blocks[0].frequency, velocity=blocks[0].velocity, centers=centers, nvis=nvis, uu=uu, vv=vv, ww=ww, data_I=data_I, sigma_I=sigma_I, flag_I=flag_I, ) # ---------- MERGE: concatenate per-beam across blocks ---------- # Sanity: same number of beams and same centers (order) across blocks nbeams = blocks[0].uu.shape[0] for bi, blk in enumerate(blocks[1:], start=1): if blk.uu.shape[0] != nbeams: raise ValueError(f"Block {bi} has {blk.uu.shape[0]} beams; expected {nbeams}.") for j in range(nbeams): c0, c1 = blocks[0].centers[j], blk.centers[j] if not _centers_equal(c0, c1, tol_deg=center_tol_deg): raise ValueError(f"Beam center mismatch at beam {j}: " f"block0={c0} vs block{bi}={c1}") # Merged nvis per beam = sum over blocks merged_nvis = np.zeros((nbeams,), dtype=np.int32) for j in range(nbeams): merged_nvis[j] = sum(int(blk.nvis[j]) for blk in blocks) global_nvis_max = int(merged_nvis.max()) # Allocate outputs data_I = np.zeros((nchan, nbeams, global_nvis_max), dtype=np.complex64) sigma_I = np.zeros((nchan, nbeams, global_nvis_max), dtype=np.float32) flag_I = np.ones( (nchan, nbeams, global_nvis_max), dtype=bool) uu = np.zeros((nbeams, global_nvis_max), dtype=np.float32) vv = np.zeros_like(uu) ww = np.zeros_like(uu) nvis = merged_nvis.copy() centers = np.array(blocks[0].centers, dtype=object) # preserve order # Concatenate per beam across blocks in read order (block0, block1, ...) for j in range(nbeams): w = 0 for blk in blocks: nv = int(blk.nvis[j]) if nv <= 0: continue uu[j, w:w+nv] = blk.uu[j, :nv] vv[j, w:w+nv] = blk.vv[j, :nv] ww[j, w:w+nv] = blk.ww[j, :nv] data_I [:, j, w:w+nv] = blk.data_I[:, j, :nv] sigma_I[:, j, w:w+nv] = blk.sigma_I[:, j, :nv] flag_I [:, j, w:w+nv] = blk.flag_I[:, j, :nv] w += nv # padded tail remains flag=True return VisIData( frequency=blocks[0].frequency, velocity=blocks[0].velocity, centers=centers, nvis=nvis, uu=uu, vv=vv, ww=ww, data_I=data_I, sigma_I=sigma_I, flag_I=flag_I, )
# ------------------------- channel-slab iterator --------------------------
[docs] def iter_channel_slabs( ms_dir: str, uvmin: float = 0.0, uvmax: float = float("inf"), chan_sel=None, # None | slice | list[int] | np.ndarray[int] rest_freq: float = 1.42040575177e9, # HI rest frequency as default value in unit of Hz slab: int = 64, # max channels per slab keep_autocorr: bool = False, prefer_weight_spectrum: bool = True, n_workers: int = 0, # 0/1 = serial; >1 = parallel per-MS beam_sel=None ) -> Iterator[Tuple[int, int, VisIData]]: """ Yield contiguous channel slabs so you can stream a big cube with low RAM. Yields: (start, stop, visI) where start/stop are absolute channel indices into the SPW (Python slice semantics), and visI is a VisIData with shape (stop-start, nbeam, nvis_max). """ ms_list = _list_ms_sorted(ms_dir) all_freq, _ = _freqs(ms_list[0], rest_freq) all_idx = np.arange(all_freq.size, dtype=int) # Normalize chan_sel -> explicit indices if chan_sel is None: sel_idx = all_idx elif isinstance(chan_sel, slice): sel_idx = all_idx[chan_sel] else: sel_idx = np.asarray(chan_sel, dtype=int) if sel_idx.size == 0: return # nothing to do i = 0 n = sel_idx.size while i < n: # grow a contiguous run up to 'slab' channels j = i + 1 while j < n and sel_idx[j] == sel_idx[j-1] + 1 and (j - i) < slab: j += 1 start = int(sel_idx[i]) stop = int(sel_idx[j-1] + 1) # slice end (exclusive) visI = read_ms_block_I( ms_dir, uvmin=uvmin, uvmax=uvmax, chan_sel=slice(start, stop), rest_freq=rest_freq, keep_autocorr=keep_autocorr, prefer_weight_spectrum=prefer_weight_spectrum, n_workers=n_workers, beam_sel=beam_sel ) yield start, stop, visI i = j
[docs] def iter_blocks_chan_beam_I( ms_root: str, uvmin: float = 0.0, uvmax: float = float("inf"), chan_sel=None, rest_freq: float = 1.42040575177e9, # HI rest frequency as default value in unit of Hz keep_autocorr: bool = False, prefer_weight_spectrum: bool = True, n_workers: int = 0, # 0/1 = serial; >1 = parallel per-MS beam_sel = None ): """ Stream over (block_index, c, b, I, sI, uu, vv, ww) without concatenating. """ block_dirs = _list_block_dirs(ms_root) for bi, bdir in enumerate(block_dirs): vis = read_ms_block_I( bdir, uvmin=uvmin, uvmax=uvmax, chan_sel=chan_sel, rest_freq=rest_freq, keep_autocorr=keep_autocorr, prefer_weight_spectrum=prefer_weight_spectrum, n_workers=n_workers, beam_sel=beam_sel ) for c, b, I, sI, uu, vv, ww in vis.iter_chan_beam_I(): yield bi, c, b, I, sI, uu, vv, ww del vis
[docs] def iter_blocks_channel_slabs( ms_root: str, uvmin: float = 0.0, uvmax: float = float("inf"), chan_sel=None, slab: int = 64, keep_autocorr: bool = False, prefer_weight_spectrum: bool = True, n_workers: int = 0, concat: bool = False, # NEW: if True, concat slabs across blocks before yielding beam_sel = None ): """ Yield slabs for each block, or concatenated slabs if concat=True. If concat=False: Yields (bi, block_dir, c0, c1, visI) for each block slab. If concat=True: Yields (c0, c1, visI_concat) where visI_concat has all beams from all blocks for that channel range (like mode="concat" but streaming). """ block_dirs = _list_block_dirs(ms_root) if not concat: for bi, bdir in enumerate(block_dirs): for c0, c1, visI in iter_channel_slabs( bdir, uvmin=uvmin, uvmax=uvmax, chan_sel=chan_sel, rest_freq=rest_freq, slab=slab, keep_autocorr=keep_autocorr, prefer_weight_spectrum=prefer_weight_spectrum, n_workers=n_workers, beam_sel=beam_sel ): yield bi, bdir, c0, c1, visI # caller should del visI when done else: # Accumulate slabs from each block for the same channel range from collections import defaultdict slab_accum = defaultdict(list) n_blocks = len(block_dirs) for bi, bdir in enumerate(block_dirs): for c0, c1, visI in iter_channel_slabs( bdir, uvmin=uvmin, uvmax=uvmax, chan_sel=chan_sel, rest_freq=rest_freq, slab=slab, keep_autocorr=keep_autocorr, prefer_weight_spectrum=prefer_weight_spectrum, n_workers=n_workers, beam_sel=beam_sel ): slab_accum[(c0, c1)].append(visI) # Once we have all blocks for this slab range -> concat & yield if len(slab_accum[(c0, c1)]) == n_blocks: vis_concat = concat_visidata_slabs(slab_accum[(c0, c1)]) yield c0, c1, vis_concat del slab_accum[(c0, c1)]
[docs] def concat_visidata_slabs(slabs: list[VisIData]) -> VisIData: """ Concatenate a list of VisIData slabs (same channels, different beams) into one VisIData with all beams. """ if not slabs: raise ValueError("No slabs to concatenate") # Frequency/velocity arrays are identical in all slabs freq = slabs[0].frequency vel = slabs[0].velocity nchan = freq.shape[0] total_beams = sum(s.uu.shape[0] for s in slabs) global_nvis_max = max(int(s.nvis.max()) for s in slabs) data_I = np.zeros((nchan, total_beams, global_nvis_max), dtype=np.complex64) sigma_I = np.zeros((nchan, total_beams, global_nvis_max), dtype=np.float32) flag_I = np.ones( (nchan, total_beams, global_nvis_max), dtype=bool) uu = np.zeros((total_beams, global_nvis_max), dtype=np.float32) vv = np.zeros_like(uu) ww = np.zeros_like(uu) nvis = np.zeros((total_beams,), dtype=np.int32) centers = np.empty((total_beams,), dtype=object) b_off = 0 for slab in slabs: nb = slab.uu.shape[0] for j in range(nb): nv = int(slab.nvis[j]) dst = b_off + j nvis[dst] = nv centers[dst] = slab.centers[j] uu[dst, :nv] = slab.uu[j, :nv] vv[dst, :nv] = slab.vv[j, :nv] ww[dst, :nv] = slab.ww[j, :nv] data_I [:, dst, :nv] = slab.data_I[:, j, :nv] sigma_I[:, dst, :nv] = slab.sigma_I[:, j, :nv] flag_I [:, dst, :nv] = slab.flag_I[:, j, :nv] b_off += nb return VisIData( frequency=freq, velocity=vel, centers=centers, nvis=nvis, uu=uu, vv=vv, ww=ww, data_I=data_I, sigma_I=sigma_I, flag_I=flag_I, )
[docs] def iter_blocks_chan_beam_via_slabs( ms_root: str, uvmin: float = 0.0, uvmax: float = float("inf"), chan_sel=None, slab: int = 64, keep_autocorr: bool = False, prefer_weight_spectrum: bool = True, n_workers: int = 0, # 0/1 = serial; >1 = parallel per-MS beam_sel = None ): """ Yield (bi, c_abs, b, I, sI, uu, vv, ww), streaming through slabs per block. c_abs is the absolute channel index in the SPW. """ block_dirs = _list_block_dirs(ms_root) for bi, bdir in enumerate(block_dirs): for c0, c1, visI in iter_channel_slabs( bdir, uvmin=uvmin, uvmax=uvmax, chan_sel=chan_sel, rest_freq=rest_freq, slab=slab, keep_autocorr=keep_autocorr, prefer_weight_spectrum=prefer_weight_spectrum, beam_sel=beam_sel ): # iterate tiny chunks from the slab for c_rel in range(visI.data_I.shape[0]): c_abs = c0 + c_rel for b in range(visI.uu.shape[0]): I, sI, uu, vv, ww = visI.slice_chan_beam_I(c_rel, b) if I.size: yield bi, c_abs, b, I, sI, uu, vv, ww del visI # free slab memory
[docs] class CasacoreReader: """ Concrete Reader backed by casacore. Delegates to the module-level functions defined above. """ def __init__(self, *, prefer_weight_spectrum: bool = True, keep_autocorr: bool = False, n_workers: int = 0): self.prefer_weight_spectrum = prefer_weight_spectrum self.keep_autocorr = keep_autocorr self.n_workers = n_workers # --- optional helpers (not required by the Protocol) ---
[docs] def list_ms(self, ms_dir: str) -> List[str]: return _list_ms_sorted(ms_dir)
[docs] def freq_grid(self, ms_dir: str, rest_freq: float): msl = self.list_ms(ms_dir) if not msl: raise FileNotFoundError(f"No .ms found in {ms_dir}") freqs, _vel = _freqs(msl[0], rest_freq) return freqs
# --- Protocol methods --- # def read_block_I(self, ms_dir: str, **kwargs) -> VisIData: # return read_ms_block_I( # ms_dir, # uvmin=kwargs.get("uvmin", 0.0), # uvmax=kwargs.get("uvmax", float("inf")), # chan_sel=kwargs.get("chan_sel"), # keep_autocorr=kwargs.get("keep_autocorr", self.keep_autocorr), # prefer_weight_spectrum=kwargs.get("prefer_weight_spectrum", self.prefer_weight_spectrum), # n_workers=kwargs.get("n_workers", self.n_workers), # target_center=kwargs.get("target_center"), # target_radius=kwargs.get("target_radius"), # )
[docs] def read_blocks_I(self, ms_root: str, **kwargs) -> Union["VisIData", List["VisIData"]]: return read_ms_blocks_I( ms_root, uvmin=kwargs.get("uvmin", 0.0), uvmax=kwargs.get("uvmax", float("inf")), chan_sel=kwargs.get("chan_sel"), rest_freq=kwargs.get("rest_freq", 1.42040575177e9), keep_autocorr=kwargs.get("keep_autocorr", self.keep_autocorr), prefer_weight_spectrum=kwargs.get("prefer_weight_spectrum", self.prefer_weight_spectrum), mode=kwargs.get("mode", "merge"), n_workers=kwargs.get("n_workers", self.n_workers), center_tol_deg=kwargs.get("center_tol_deg", 1e-12), target_center=kwargs.get("target_center"), target_radius=kwargs.get("target_radius"), beam_sel=kwargs.get("beam_sel"), )
[docs] def iter_channel_slabs(self, ms_dir: str, **kwargs) -> Iterator[Tuple[int, int, VisIData]]: return iter_channel_slabs( ms_dir, uvmin=kwargs.get("uvmin", 0.0), uvmax=kwargs.get("uvmax", float("inf")), chan_sel=kwargs.get("chan_sel"), rest_freq=kwargs.get("rest_freq", 1.42040575177e9), slab=kwargs.get("slab", 64), keep_autocorr=kwargs.get("keep_autocorr", self.keep_autocorr), prefer_weight_spectrum=kwargs.get("prefer_weight_spectrum", self.prefer_weight_spectrum), n_workers=kwargs.get("n_workers", self.n_workers), beam_sel=kwargs.get("beam_sel"), )
# ------------------------------- demo ------------------------------------- if __name__ == "__main__": # Example usage — adjust path + channels # ms_dir = "/Users/antoine/Desktop/Synthesis/ivis/docs/tutorials/data_tutorials/ivis_data/msl_mw/" ms_dir = "/Users/antoine/Desktop/Synthesis/ivis/docs/tutorials/data_tutorials/msdir2" # # single shot load (channels 0..99) # visI = read_ms_block_I( # ms_dir, # uvmin=20.0, # uvmax=5000.0, # chan_sel=slice(0, 100), # keep_autocorr=False, # prefer_weight_spectrum=True, # ) # print("VisIData loaded:") # print(f" nchan = {visI.frequency.size}") # print(f" nbeam = {visI.uu.shape[0]}") # print(f" nvis_max= {visI.uu.shape[1]}") # print(f" fmin/max= {visI.frequency.min():.3f} – {visI.frequency.max():.3f} Hz") # I, sI, uu, vv, ww = visI.slice_chan_beam_I(c=0, b=0) # print(f"Example slice (chan=0, beam=0): I={I.shape}, σ={sI.shape}, uu={uu.shape}") # # Streaming slabs example # for c0, c1, slab_vis in iter_channel_slabs(ms_dir, uvmin=20, uvmax=5000, chan_sel=slice(0, 128), slab=32): # print(f"Slab [{c0}:{c1}) → {slab_vis.data_I.shape}") # del slab_vis # concat or merge print("Test #Concat all") I: VisIData = read_ms_blocks_I( ms_root=ms_dir, uvmin=20.0, uvmax=5000.0, chan_sel=slice(0, 4), rest_freq=1.42040575177e9, #HI rest frequency in Hz keep_autocorr=False, prefer_weight_spectrum=True, mode="merge", n_workers=4, ) for c, b, Ib, sI, uu, vv, ww in I.iter_chan_beam_I(): pass stop # # keep separate # vis_blocks = read_ms_blocks_I( # ms_root=ms_dir, # uvmin=0.0, uvmax=np.inf, # chan_sel=slice(0, 64), # mode="separate", # ) # for bi, vis in enumerate(vis_blocks): # print("block", bi, vis.data_I.shape) # Option A — Stream slabs per block (moderate RAM, simple) print("Test # Option A — Stream slabs per block (moderate RAM, simple)") for c0, c1, visI in iter_blocks_channel_slabs( ms_root=ms_dir, uvmin=20, uvmax=5000, chan_sel=slice(0,128), slab=16, concat=True, n_workers=4, ): # Count all unflagged visibilities in this slab total_vis = np.count_nonzero(~visI.flag_I) logger.info( f"Concat slab [{c0}:{c1}) -> {visI.data_I.shape}, total unflagged vis={total_vis}" ) for rel_c in range(visI.frequency.shape[0]): logger.info( f"Get single channel {rel_c} from slab [{c0}:{c1}) -> {visI.data_I.shape}" ) chan_vis = visI.single_channel(rel_c, copy=False) del visI stop for bi, bdir, c0, c1, visI in iter_blocks_channel_slabs( ms_root=ms_dir, uvmin=20, uvmax=5000, chan_sel=slice(0,128), slab=16, # concat=True, ): logger.info(f"[block {bi}] {os.path.basename(bdir)} slab [{c0}:{c1}) -> {visI.data_I.shape}") for rel_c in range(visI.frequency.shape[0]): logger.info(f"Get single channel {rel_c} from slab [{c0}:{c1}) -> {visI.data_I.shape}") chan_vis = visI.single_channel(rel_c, copy=False) del visI # # Option B — Stream (block, channel, beam) inside slabs (lowest RAM) # print("Test # Option B — Stream (block, channel, beam) inside slabs (lowest RAM)") # for bi, c, b, I, sI, uu, vv, ww in iter_blocks_chan_beam_via_slabs( # ms_root=ms_dir, uvmin=20, uvmax=5000, chan_sel=slice(0,8), slab=4 # ): # # NUFFT/predict/imaging for just this (block,chan,beam) slice # pass # def read_ms_block_I_no_beam_selection_from_center_and_radius( # ms_dir: str, # uvmin: float = 0.0, # in wavelengths (desired) # uvmax: float = float("inf"), # in wavelengths (desired) # chan_sel=None, # None | slice | list[int] | np.ndarray[int] # keep_autocorr: bool = False, # prefer_weight_spectrum: bool = True, # n_workers: int = 0, # 0/1 = serial; >1 = parallel per-MS # target_center: "SkyCoord | None" = None, # target_radius: "Angle | u.Quantity | float | None" = None, # float ⇒ degrees # ) -> "VisIData": # """ # Load a directory of .ms files (one per beam) into an I-only, channel-major VisIData. # The beams are packed in a **deterministic, natural order by beam index** inferred # from filenames like '...-C10_5_...'. This matches what humans expect from a # sorted listing (1,2,3,4,5,...), and is preserved under parallel reads. # """ # # 1) Deterministic, human-expected order # ms_list = _list_ms_sorted(ms_dir) # logger.info(f"[BLOCK] Loading {len(ms_list)} beam(s) from: {ms_dir}") # # --- NEW: center+radius filter --- # if (target_center is not None) and (target_radius is not None): # R = _norm_radius(target_radius) # kept, skipped = [], [] # for ms in ms_list: # c = _phasecenter(ms) # cheap FIELD read # if _within_radius(target_center, R, c): # kept.append(ms) # else: # skipped.append(ms) # for s in skipped: # logger.info(f"[SKIP] {os.path.basename(s)}: outside {R.to_string()} of center") # ms_list = kept # if not ms_list: # raise ValueError(f"No beams within {R.to_string()} of {target_center.to_string('hmsdms')}") # # 2) Channel selection # all_freq, all_vel = _freqs(ms_list[0]) # nchan_total = all_freq.size # if chan_sel is None: # chan_idx = np.arange(nchan_total, dtype=int) # elif isinstance(chan_sel, slice): # chan_idx = np.arange(nchan_total, dtype=int)[chan_sel] # else: # chan_idx = np.asarray(chan_sel, dtype=int) # if chan_idx.size == 0: # raise ValueError(f"chan_sel selects 0 channels; available nchan={nchan_total}") # frequency = all_freq[chan_idx] # (nchan,) # velocity = all_vel [chan_idx] # (nchan,) # nchan = int(frequency.size) # # 3) Read per-MS into fixed slots (preserve the chosen order) # nbeam = len(ms_list) # centers = [None] * nbeam # uu_list = [None] * nbeam # vv_list = [None] * nbeam # ww_list = [None] * nbeam # I_list = [None] * nbeam # sI_list = [None] * nbeam # fI_list = [None] * nbeam # if n_workers and n_workers > 1: # logger.info(f"[BLOCK] Parallel read with {n_workers} workers (order-preserving)") # ctx = get_context("fork") # 'spawn' on Windows # with ProcessPoolExecutor(max_workers=n_workers, mp_context=ctx) as ex: # futs = {ex.submit(_read_one_ms, ms, chan_idx, keep_autocorr, prefer_weight_spectrum): i # for i, ms in enumerate(ms_list)} # for fut in as_completed(futs): # i = futs[fut] # out = fut.result() # centers[i] = out["center"] # uu_list[i] = out["uu"]; vv_list[i] = out["vv"]; ww_list[i] = out["ww"] # I_list[i] = out["I"]; sI_list[i] = out["sI"]; fI_list[i] = out["fI"] # else: # logger.info("[BLOCK] Serial read (order-preserving)") # for i, ms in enumerate(ms_list): # out = _read_one_ms(ms, chan_idx, keep_autocorr, prefer_weight_spectrum) # centers[i] = out["center"] # uu_list[i] = out["uu"]; vv_list[i] = out["vv"]; ww_list[i] = out["ww"] # I_list[i] = out["I"]; sI_list[i] = out["sI"]; fI_list[i] = out["fI"] # # 4) Pack to dense (nchan, nbeam, nvis_max) # nvis = np.array([u.shape[0] for u in uu_list], dtype=np.int32) # nvis_max = int(nvis.max()) # data_I = np.zeros((nchan, nbeam, nvis_max), dtype=np.complex64) # sigma_I = np.zeros((nchan, nbeam, nvis_max), dtype=np.float32) # flag_I = np.ones( (nchan, nbeam, nvis_max), dtype=bool) # uu = np.zeros((nbeam, nvis_max), dtype=np.float32) # vv = np.zeros_like(uu); ww = np.zeros_like(uu) # # Per-channel UV mask and channel-major transpose # for b in range(nbeam): # UVW0 = uu_list[b]; UVW1 = vv_list[b]; UVW2 = ww_list[b] # I = I_list[b]; sI = sI_list[b]; fI = fI_list[b] # (nrow_b, nchan) # # baseline length in meters for each row # bl_m = np.sqrt((UVW0**2 + UVW1**2 + UVW2**2))[:, None] # (nrow_b, 1) # # convert to wavelengths per channel # bl_lam = bl_m * (frequency[None, :] / c_light.value) # (nrow_b, nchan) # in_rng = (bl_lam >= uvmin) & (bl_lam <= uvmax) # fI |= ~in_rng # sI[~in_rng] = np.inf # # to channel-major for this beam # I_cb = I.transpose(1, 0).astype(np.complex64) # (nchan, nvis_b) # sI_cb = sI.transpose(1, 0).astype(np.float32) # (nchan, nvis_b) # fI_cb = fI.transpose(1, 0).astype(bool) # (nchan, nvis_b) # nv = int(UVW0.shape[0]) # uu[b, :nv] = UVW0 # vv[b, :nv] = UVW1 # ww[b, :nv] = UVW2 # data_I [:, b, :nv] = I_cb # sigma_I[:, b, :nv] = sI_cb # flag_I [:, b, :nv] = fI_cb # centers = np.asarray(centers, dtype=object) # logger.info(f"[BLOCK] Done: nchan={nchan}, nbeam={nbeam}, nvis_max={nvis_max}") # return VisIData( # frequency=frequency, # velocity=velocity, # centers=centers, # natural beam order by C10_# # nvis=nvis, # uu=uu, vv=vv, ww=ww, # meters (scaled to λ on demand) # data_I=data_I, # sigma_I=sigma_I, # flag_I=flag_I, # ) # def read_ms_block_I_no_parrallel( # ms_dir: str, # uvmin: float = 0.0, # in wavelengths (desired) # uvmax: float = float("inf"), # in wavelengths (desired) # chan_sel=None, # None | slice | list[int] | np.ndarray[int] # keep_autocorr: bool = False, # prefer_weight_spectrum: bool = True, # ) -> "VisIData": # """ # Load a directory of .ms files (one per beam) into an I-only, channel-major VisIData. # Only the requested channels are read from disk (getcolslice). # UV filtering is done in METERS using bounds derived from the selected channel range, # so we don't throw away visibilities that are in-range for some channel. # """ # ms_list = _list_ms_sorted(ms_dir) # # ---- normalize channel selection -> explicit indices ---- # all_freq, all_vel = _freqs(ms_list[0]) # nchan_total = all_freq.size # if chan_sel is None: # chan_idx = np.arange(nchan_total, dtype=int) # elif isinstance(chan_sel, slice): # chan_idx = np.arange(nchan_total, dtype=int)[chan_sel] # else: # chan_idx = np.asarray(chan_sel, dtype=int) # if chan_idx.size == 0: # raise ValueError(f"chan_sel selects 0 channels; available nchan={nchan_total}") # frequency = all_freq[chan_idx] # velocity = all_vel [chan_idx] # nchan = int(frequency.size) # # --- derive conservative METER bounds from desired wavelength bounds across selected chans # f_min = float(frequency.min()) # f_max = float(frequency.max()) # uvmin_m = 0.0 if not np.isfinite(uvmin) else (uvmin * c_light.value / f_max) # keep anything that could be >= uvmin # uvmax_m = np.inf if not np.isfinite(uvmax) else (uvmax * c_light.value / f_min) # keep anything that could be <= uvmax # centers = [] # uu_list=[]; vv_list=[]; ww_list=[] # I_list=[]; sI_list=[]; fI_list=[] # for ms in ms_list: # logger.info(f" [MS] Reading: {ms}") # centers.append(_phasecenter(ms)) # with _quiet_tables(): # with table(ms, readonly=True) as T: # UVW = T.getcol("UVW") # (nrow, 3) [meters] # A1 = T.getcol("ANTENNA1"); A2 = T.getcol("ANTENNA2") # has_ws = ("WEIGHT_SPECTRUM" in T.colnames()) and prefer_weight_spectrum # # ---- Read only selected channels ---- # if chan_idx.size == 1 or np.all(np.diff(chan_idx) == 1): # # contiguous block # ch0 = int(chan_idx[0]); ch1 = int(chan_idx[-1]) # DATA = T.getcolslice("DATA", blc=[ch0, 0], trc=[ch1, -1]) # (nrow, nchan, npol) # FLAG = T.getcolslice("FLAG", blc=[ch0, 0], trc=[ch1, -1]) # if has_ws: # W = T.getcolslice("WEIGHT_SPECTRUM", blc=[ch0, 0], trc=[ch1, -1]) # else: # SIGMA = T.getcol("SIGMA") # (nrow, npol) # else: # # non-contiguous: gather per-channel # d_blocks=[]; f_blocks=[]; w_blocks=[] # for ch in chan_idx.tolist(): # d_blocks.append(T.getcolslice("DATA", blc=[ch,0], trc=[ch,-1])[:,None,:]) # f_blocks.append(T.getcolslice("FLAG", blc=[ch,0], trc=[ch,-1])[:,None,:]) # if has_ws: # w_blocks.append(T.getcolslice("WEIGHT_SPECTRUM", blc=[ch,0], trc=[ch,-1])[:,None,:]) # DATA = np.concatenate(d_blocks, axis=1) # FLAG = np.concatenate(f_blocks, axis=1) # if has_ws: # W = np.concatenate(w_blocks, axis=1) # else: # SIGMA = T.getcol("SIGMA") # # ---- row mask: only autocorr removal; no uvmin/uvmax here ---- # row_mask = np.ones(UVW.shape[0], dtype=bool) # if not keep_autocorr: # row_mask &= (A1 != A2) # UVW = UVW[row_mask] # DATA = DATA[row_mask] # FLAG = FLAG[row_mask] # if has_ws: # W = W[row_mask] # else: # SIGMA = SIGMA[row_mask] # nrow2, nchan_chk, npol = DATA.shape # if nchan_chk != nchan: # raise RuntimeError("Channel selection mismatch after masking.") # # ---- Stokes I combine ---- # if npol == 1: # I = DATA[..., 0] # (nrow2, nchan) # fI = FLAG[..., 0] # if has_ws: # eps = 1e-12 # sI = 1.0 / np.sqrt(np.maximum(W[..., 0], eps)) # else: # row_sig = SIGMA[:, 0] # (nrow2,) # sI = np.repeat(row_sig[:, None], nchan, 1) # (nrow2, nchan) # else: # p0, p1 = 0, -1 # I = 0.5 * (DATA[..., p0] + DATA[..., p1]) # fI = (FLAG[..., p0] | FLAG[..., p1]) # if has_ws: # eps = 1e-12 # sig0 = 1.0 / np.sqrt(np.maximum(W[..., p0], eps)) # sig1 = 1.0 / np.sqrt(np.maximum(W[..., p1], eps)) # sI = 0.5 * np.sqrt(sig0**2 + sig1**2) # else: # row_sig0 = SIGMA[:, p0]; row_sig1 = SIGMA[:, p1] # row_sI = 0.5 * np.sqrt(row_sig0**2 + row_sig1**2) # sI = np.repeat(row_sI[:, None], nchan, 1) # # --- per-channel UV mask in wavelengths (matches old behavior) --- # # baseline length in meters for each (remaining) row # bl_m = np.sqrt((UVW**2).sum(axis=1)) # (nrow2,) # # convert to wavelengths per channel # bl_lam = bl_m[:, None] * (frequency[None, :] / c_light.value) # (nrow2, nchan) # in_range = (bl_lam >= uvmin) & (bl_lam <= uvmax) # (nrow2, nchan) # # flag vis that are out-of-range for that channel # fI |= ~in_range # # (optional but nice): make those samples weightless # sI[~in_range] = np.inf # # (optional): zero their vis (won’t be used once flagged anyway) # # I[~in_range] = 0 # # ---- to channel-major for this beam ---- # I_cb = I.transpose(1, 0).astype(np.complex64) # (nchan, nvis_b) # sI_cb = sI.transpose(1, 0).astype(np.float32) # (nchan, nvis_b) # fI_cb = fI.transpose(1, 0).astype(bool) # (nchan, nvis_b) # # ---- coords in METERS (no channel duplication) ---- # uu_list.append(UVW[:, 0].astype(np.float32)) # vv_list.append(UVW[:, 1].astype(np.float32)) # ww_list.append(UVW[:, 2].astype(np.float32)) # I_list.append(I_cb); sI_list.append(sI_cb); fI_list.append(fI_cb) # # ---- pack to dense (nchan, nbeam, nvis_max) ---- # nbeam = len(ms_list) # nvis = np.array([u.shape[0] for u in uu_list], dtype=np.int32) # nvis_max = int(nvis.max()) # data_I = np.zeros((nchan, nbeam, nvis_max), dtype=np.complex64) # sigma_I = np.zeros((nchan, nbeam, nvis_max), dtype=np.float32) # flag_I = np.ones( (nchan, nbeam, nvis_max), dtype=bool) # uu = np.zeros((nbeam, nvis_max), dtype=np.float32) # vv = np.zeros_like(uu); ww = np.zeros_like(uu) # for b in range(nbeam): # nv = int(nvis[b]) # uu[b, :nv] = uu_list[b] # vv[b, :nv] = vv_list[b] # ww[b, :nv] = ww_list[b] # data_I [:, b, :nv] = I_list [b] # sigma_I[:, b, :nv] = sI_list[b] # flag_I [:, b, :nv] = fI_list[b] # padded tails remain flagged=True # centers = np.asarray(centers, dtype=object) # return VisIData( # frequency=frequency, # velocity=velocity, # centers=centers, # nvis=nvis, # uu=uu, vv=vv, ww=ww, # stored in meters # data_I=data_I, # sigma_I=sigma_I, # flag_I=flag_I, # ) # def read_ms_blocks_I( # ms_root: str, # uvmin: float = 0.0, # uvmax: float = float("inf"), # chan_sel=None, # None | slice | list[int] | np.ndarray[int] # keep_autocorr: bool = False, # prefer_weight_spectrum: bool = True, # mode: str = "concat", # "concat" | "separate" # n_workers: int = 0, # 0/1 = serial; >1 = parallel per-MS # ) -> "VisIData | list[VisIData]": # """ # Load multiple blocks of observations located under ms_root. # - If ms_root contains *.ms directly, it's treated as a single block. # - Any immediate subdirectory of ms_root that contains *.ms is also a block. # mode="concat": returns one VisIData with beams from all blocks concatenated # mode="separate": returns a list[VisIData], one per block # """ # block_dirs = _list_block_dirs(ms_root) # # Load each block independently with your optimized reader # blocks: list[VisIData] = [] # for bdir in block_dirs: # logger.info(f"[BLOCK] Loading block from: {bdir}") # vi = read_ms_block_I( # bdir, # uvmin=uvmin, # uvmax=uvmax, # chan_sel=chan_sel, # keep_autocorr=keep_autocorr, # prefer_weight_spectrum=prefer_weight_spectrum, # n_workers=n_workers, # ) # blocks.append(vi) # if mode == "separate": # return blocks # if mode != "concat": # raise ValueError('mode must be "concat" or "separate"') # # ---------------- concat path ---------------- # _check_same_freq_grid(blocks) # ensure same channels across blocks # # Concatenate beams; keep per-beam nvis; pad to global nvis_max # nchan = blocks[0].frequency.size # total_beams = sum(b.uu.shape[0] for b in blocks) # global_nvis_max = max(int(b.nvis.max()) for b in blocks) # # Allocate output # data_I = np.zeros((nchan, total_beams, global_nvis_max), dtype=np.complex64) # sigma_I = np.zeros((nchan, total_beams, global_nvis_max), dtype=np.float32) # flag_I = np.ones( (nchan, total_beams, global_nvis_max), dtype=bool) # uu = np.zeros((total_beams, global_nvis_max), dtype=np.float32) # vv = np.zeros_like(uu); ww = np.zeros_like(uu) # nvis = np.zeros((total_beams,), dtype=np.int32) # centers = np.empty((total_beams,), dtype=object) # # Copy block by block # b_off = 0 # for blk in blocks: # nb = blk.uu.shape[0] # for j in range(nb): # nv = int(blk.nvis[j]) # dst = b_off + j # nvis[dst] = nv # centers[dst] = blk.centers[j] # uu[dst, :nv] = blk.uu[j, :nv] # vv[dst, :nv] = blk.vv[j, :nv] # ww[dst, :nv] = blk.ww[j, :nv] # data_I [:, dst, :nv] = blk.data_I[:, j, :nv] # sigma_I[:, dst, :nv] = blk.sigma_I[:, j, :nv] # flag_I [:, dst, :nv] = blk.flag_I[:, j, :nv] # # padded tails remain flagged=True # b_off += nb # return VisIData( # frequency=blocks[0].frequency, # velocity=blocks[0].velocity, # centers=centers, # nvis=nvis, # uu=uu, vv=vv, ww=ww, # data_I=data_I, # sigma_I=sigma_I, # flag_I=flag_I, # )