Source code for ivis.models.base
from abc import ABC, abstractmethod
import torch
import numpy as np
[docs]
class BaseModel(ABC):
"""
Abstract base class for IViS-compatible imaging models.
All models must implement loss() and forward().
"""
[docs]
@abstractmethod
def loss(self, x: np.ndarray, *args) -> tuple[float, np.ndarray]:
"""
Compute scalar loss and gradient for optimization.
Parameters
----------
x : np.ndarray
Flattened parameter vector.
Returns
-------
loss : float
Scalar loss.
grad : np.ndarray
Flattened gradient.
"""
pass
# @abstractmethod
# def forward(self, x: np.ndarray, *args) -> np.ndarray:
# """
# Simulate model visibilities from image parameters.
# Parameters
# ----------
# x : np.ndarray
# Sky model or parameters (flattened or shaped).
# Returns
# -------
# model_vis : np.ndarray
# Predicted complex visibilities.
# """
# pass