Source code for ivis.models.utils.tensor_ops
import torch
[docs]
def format_input_tensor(input_tensor: torch.Tensor) -> torch.Tensor:
"""
Format an input tensor for PyTorch's grid_sample.
Ensures shape is (N=1, C=1, H, W), as required by grid_sample.
Parameters
----------
input_tensor : torch.Tensor
A 2D, 3D, or already 4D tensor.
Returns
-------
formatted_tensor : torch.Tensor
Tensor reshaped for use with grid_sample.
"""
if input_tensor.dim() == 2:
return input_tensor.unsqueeze(0).unsqueeze(0)
elif input_tensor.dim() == 3:
return input_tensor.unsqueeze(0)
return input_tensor