Source code for xcal.models

import numpy as np
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

from xcal.chem_consts._consts_from_table import get_mass_absp_c_vs_E
from xcal.chem_consts._periodictabledata import atom_weights, ptableinverse
from xcal.dict_gen import gen_fltr_res, gen_scint_cvt_func


def transpose_first_to_last(x):
    dims = list(range(x.ndimension()))
    dims.append(dims.pop(0))
    return x.permute(dims)

def transpose_last_to_first(x):
    dims = list(range(x.ndimension()))
    dims.insert(0, dims.pop(-1))
    return x.permute(dims)

def linear_interp(x, xp, fp):
    """
    Performs linear interpolation.

    Args:
        x (torch.Tensor): The x-coordinates at which to evaluate the interpolated values.
        xp (torch.Tensor): The x-coordinates of the data points.
        fp (torch.Tensor): The y-coordinates of the data points (same shape as xp).

    Returns:
        torch.Tensor: The interpolated values.
    """
    # Find the indices of the rightmost value less than or equal to x
    idx = torch.searchsorted(xp, x) - 1
    idx = idx.clamp(0, len(xp) - 2)  # Clamp values to range to avoid out of bounds

    # Compute the slope of the segments
    slope = (fp[idx + 1] - fp[idx]) / (xp[idx + 1] - xp[idx])

    # Evaluate the line segment at x
    return fp[idx] + slope * (x - xp[idx])

class Interp2D:
    def __init__(self, x, y, z):
        """
        Initialize the Interp2D class for performing bilinear interpolation on a 2D grid.

        Args:
            x (torch.Tensor): A 2-D tensor representing the x-coordinates of the grid points,
                              with shape (M, N), where M is the number of rows and N is the number
                              of columns. Assumes uniform x-coordinates across each row.
            y (torch.Tensor): A 2-D tensor representing the y-coordinates of the grid points,
                              with shape (M, N), where M is the number of rows and N is the number
                              of columns. Assumes uniform y-coordinates across each column.
            z (torch.Tensor): An N-D tensor of z-values corresponding to the grid points
                              defined by x and y coordinates, with the first two dimensions
                              matching the shape of x and y (M, N), and any additional dimensions
                              representing different variables or measurements at each grid point.
        """
        self.x = x
        self.y = y
        self.z = z

    def __call__(self, new_x, new_y):
        """
        Perform bilinear interpolation to find z-values at a single new (x, y) coordinate.

        Args:
            new_x (torch.Tensor): A scalar tensor representing the new x-coordinate where
                                  the z-value is to be interpolated.
            new_y (torch.Tensor): A scalar tensor representing the new y-coordinate where
                                  the z-value is to be interpolated.

        Returns:
            torch.Tensor: A tensor of the interpolated z-value at the specified new_x and new_y
                          coordinate. If z is an N-D tensor, the returned tensor will maintain
                          the additional dimensions of z beyond the first two.

        Raises:
            ValueError: If the new_x or new_y values are outside the range of the original
                        x or y grid coordinates.
        """
        if not (self.x.min() <= new_x <= self.x.max()) or not (self.y.min() <= new_y <= self.y.max()):
            raise ValueError("The new_x or new_y values are outside the range of x or y.")

        # Find indices for the closest points in x and y
        x_indices = torch.searchsorted(self.x[:, 0], new_x) - 1
        y_indices = torch.searchsorted(self.y[0, :], new_y) - 1

        # Ensure indices are within the bounds of the x and y arrays
        x_indices = torch.clamp(x_indices, 0, self.x.size(1) - 2)
        y_indices = torch.clamp(y_indices, 0, self.y.size(0) - 2)

        # Calculate the four corner points for bilinear interpolation
        x0 = self.x[x_indices, 0]
        x1 = self.x[x_indices + 1, 0]
        y0 = self.y[0, y_indices]
        y1 = self.y[0, y_indices + 1]

        # Extract the z-values at the corner points
        z00 = transpose_first_to_last(self.z[x_indices, y_indices])
        z01 = transpose_first_to_last(self.z[x_indices, y_indices + 1])
        z10 = transpose_first_to_last(self.z[x_indices + 1, y_indices])
        z11 = transpose_first_to_last(self.z[x_indices + 1, y_indices + 1])

        # Compute the weights for bilinear interpolation
        w00 = (x1 - new_x) * (y1 - new_y)
        w01 = (x1 - new_x) * (new_y - y0)
        w10 = (new_x - x0) * (y1 - new_y)
        w11 = (new_x - x0) * (new_y - y0)

        # Perform bilinear interpolation
        interpolated_z = (w00 * z00 + w01 * z01 + w10 * z10 + w11 * z11) / ((x1 - x0) * (y1 - y0))
        interpolated_z = torch.clamp(interpolated_z, min=0)  # Ensure non-negativity
        interpolated_z = transpose_last_to_first(interpolated_z)
        return interpolated_z


class Interp1D:
    def __init__(self, x, y):
        """
        Initialize the Interp1d class for performing linear interpolation.

        Args:
            x (torch.Tensor): A 1-D tensor of x-coordinates, representing the positions
                              at which the y-values are known.
            y (torch.Tensor): An N-D tensor of y-coordinates corresponding to x. The first
                              dimension of y must match the length of x, and any additional
                              dimensions represent different sets of values to interpolate.
        """
        self.x = x
        self.y = y

    def __call__(self, new_x):
        """
        Perform linear interpolation to find y-values at new_x, a scalar or 1-D tensor
        of new x-coordinates.

        Args:
            new_x (torch.Tensor): A scalar or 1-D tensor of new x-coordinates where y-values
                                  are to be interpolated. This allows for interpolation at
                                  multiple points in a single call.

        Returns:
            torch.Tensor: An N-D tensor of interpolated y-values at new_x. The shape of the
                          output tensor maintains the additional dimensions of the input y,
                          with the first dimension size equal to the number of elements in new_x.

        Raises:
            ValueError: If any of the new_x values are outside the range of the original x
                        coordinates, indicating that interpolation cannot be performed at
                        those points.
        """
        # Handle the special case of a single data point
        if len(self.x) == 1:
            # Return y directly since there's no interpolation needed
            return self.y.repeat(len(new_x), *([1] * (self.y.dim() - 1)))

        if not torch.all(torch.logical_and(new_x >= self.x.min(), new_x <= self.x.max())):
            raise ValueError("Some values in new_x are outside the range of x.")

        # Find the indices of the nearest x-values in the original data
        indices = torch.searchsorted(self.x, new_x)
        indices = torch.clamp(indices, 1, len(self.x) - 1)  # Ensure indices are within range

        # Calculate the weights for interpolation
        x0 = self.x[indices - 1]
        x1 = self.x[indices]
        y0 = transpose_first_to_last(self.y[indices - 1])
        y1 = transpose_first_to_last(self.y[indices])
        alpha = (new_x - x0) / (x1 - x0)

        # Perform linear interpolation
        interpolated_y = y0 + alpha * (y1 - y0)
        interpolated_y = torch.clamp(interpolated_y, min=0)  # Ensure non-negativity
        interpolated_y = transpose_last_to_first(interpolated_y)
        return interpolated_y

class ClampFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        """
        """
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input, None, None


def clamp_with_grad(input, min, max):
    return ClampFunction.apply(input, min, max)

def normalize_tuple_as_parameter(tuple_value):
    """
    Normalize a tuple and represent it as a (normalized Pytorch Parameter, lower bound, and upper boun).

    Args:
        tuple_value (tuple): A tuple containing the initial value, lower bound, and upper bound.

    Returns:
        Parameter: A Parameter scalar representing the normalized value of the tuple.
    """
    # Unpack the tuple
    initial_value, lower_bound, upper_bound = tuple_value
    if initial_value is None:
        raise ValueError("initial_value tuple_value[0] cannot be None.")
    if lower_bound is None:
        lower_bound = initial_value
    if upper_bound is None:
        upper_bound = initial_value
    if lower_bound == upper_bound:
        return (torch.tensor(1, dtype=torch.float32), lower_bound, upper_bound)
    # Normalize the initial value
    normalized_value = (initial_value - lower_bound) / (upper_bound - lower_bound)

    # Create a Parameter scalar
    parameter_scalar = Parameter(torch.tensor(normalized_value, dtype=torch.float32))

    return (parameter_scalar, lower_bound, upper_bound)


def denormalize_parameter_as_tuple(tuple_value):
    """
    Denormalize a normalized PyTorch Parameter scalar and represent it as a tuple.

    Args:
        tuple_value (tuple): A tuple containing the normalized Pytorch Parameter, lower bound, and upper bound.

    Returns:
        tuple: A tuple containing the denormalized value, lower bound, and upper bound.
    """
    # Unpack the tuple
    parameter_scalar, lower_bound, upper_bound = tuple_value
    if lower_bound is None or upper_bound is None:
        raise ValueError("lower_bound tuple_value[1] and upper_bound tuple_value[2] cannot be None.")
    if lower_bound == upper_bound:
        return (parameter_scalar*lower_bound, lower_bound, upper_bound)
    # Get the normalized value from the Parameter scalar
    normalized_value = parameter_scalar

    # Denormalize the value
    denormalized_value = normalized_value * (upper_bound - lower_bound) + lower_bound

    return (denormalized_value, lower_bound, upper_bound)

def merge_dicts(list1, list2):
    merged_list = []
    for dict1 in list1:
        for dict2 in list2:
            common_keys = set(dict1.keys()) & set(dict2.keys())
            if all(dict1[key] == dict2[key] for key in common_keys):
                merged_dict = {**dict1, **dict2}
                merged_list.append(merged_dict)
    return merged_list
def get_merged_params_list(lists):
    merged_params_list =[]
    for params_list in lists:
        if len(merged_params_list) == 0:
            merged_params_list = params_list
        else:
            merged_params_list = merge_dicts(merged_params_list, params_list)
    return merged_params_list

def pair_params_list(list1, list2):
    """
    Pair parameters lists with other instance of Base_Spec_Model or its child model.

    Args:
        other (Base_Spec_Model): Another instance of Base_Spec_Model.

    Returns:
        list: List of paired parameters lists.
    """
    paired_params_list = []
    for params1 in list1:
        for params2 in list2:
            # Merge dictionaries and add instance names as prefixes to keys
            merged_params_dict = {}
            for key, value in list(params1.items())+list(params2.items()):
                if isinstance(value, tuple):
                    merged_params_dict[key] = value
                else:
                    merged_params_dict[key] = value
            paired_params_list.append(merged_params_dict)
    return paired_params_list

def get_concatenated_params_list(lists):
    merged_params_list =[]
    for params_list in lists:
        if len(merged_params_list) == 0:
            merged_params_list = params_list
        else:
            merged_params_list = pair_params_list(merged_params_list, params_list)
    return merged_params_list

[docs] class Base_Spec_Model(Module): def __init__(self, params_list=[]): """Base class for all spectral components in xcal. Args: params_list (list): List of dictionaries containing possible discrete and continuous parameters combinations. All dictionaries share the same keywords. Each dictionary contains both discrete and continuous parameters. Continuous parameters should be specified as tuples with the format (initial value, lower bound, upper bound), while discrete parameters can be directly specified. """ super().__init__() if not hasattr(self.__class__, '_count'): self.__class__._count = 0 self.__class__._count += 1 self.prefix = f"{self.__class__.__name__}_{self.__class__._count}" # params_list contains all possible discrete parameters combinations and related continuous parameters. self._params_list = [] for params in params_list: new_params = {} for key, value in params.items(): if self.__class__.__name__ != 'Base_Spec_Model': modified_key = f"{self.prefix}_{key}" else: modified_key =f"{key}" if isinstance(value, tuple): new_params[f"{modified_key}"] = normalize_tuple_as_parameter(value) else: new_params[f"{modified_key}"] = value self._params_list.append(new_params) self._init_estimates()
[docs] def set_spectrum(self, energies, sp): """ Args: energies (numpy.array): A numpy array containing the X-ray energies of a poly-energetic source in units of keV. sp (numpy.array): Spectrum. Returns: """ self.ref_sp_energies = torch.tensor(energies) self.ref_sp = torch.tensor(sp)
[docs] def forward(self, energies): """ Placeholder forward method. Args: energies (numpy.array): A numpy array containing the X-ray energies of a poly-energetic source in units of keV. Returns: torch.Tensor: Output response. """ energies = torch.tensor(energies) # Check if ref_sp_energies and ref_sp attributes are set if hasattr(self, 'ref_sp_energies') and hasattr(self, 'ref_sp'): return linear_interp(energies, self.ref_sp_energies, self.ref_sp) else: # Handle the case where ref_sp is not set, e.g., return a placeholder or raise an error print("ref_sp_energies or ref_sp is not set.") return torch.ones(len(energies)) # or any other appropriate default action
def _init_estimates(self): """ Initialize estimates from the first dictionary in params_list. """ self.estimates = {} if len(self._params_list) == 0: return for key, value in self._params_list[0].items(): self.estimates[key] = value if isinstance(value, tuple): setattr(self, key, self.estimates[key][0]) else: setattr(self, key, value)
[docs] def set_params(self, params): """ Set estimates from a dictionary of parameters. Args: params (dict): Dictionary containing parameters. """ for key, value in params.items(): if key in self.estimates.keys(): if isinstance(value, tuple): if not isinstance(value[0], torch.Tensor): normalized_value = normalize_tuple_as_parameter(value) setattr(self, key, normalized_value[0]) self.estimates[key] = normalized_value else: setattr(self, key, value[0]) self.estimates[key] = value else: setattr(self, key, value) self.estimates[key] = value
[docs] def get_params(self): """ Read estimated parameters as a dictionary. Returns: dict: Dictionary containing estimated parameters. """ display_estimates = {} for key, value in self.estimates.items(): if isinstance(value, tuple): dv = denormalize_parameter_as_tuple(value) display_estimates[key] = clamp_with_grad(dv[0], dv[1], dv[2]) else: display_estimates[key] = value return display_estimates
def first_nonzero_from_right(arr): """ Finds the index of the first non-zero element from right to left in a 1D NumPy array. Args: arr (numpy.ndarray): A 1D NumPy array. Returns: int: The index of the first non-zero element from the right. If no non-zero element is found, returns -1. """ # Reverse the array and use np.argmax to find the first non-zero element from the right rev_index = np.argmax(arr[::-1] != 0) # Check if all elements are zero if arr[::-1][rev_index] == 0: return -1 else: # Convert the reverse index to the original index return len(arr) - 1 - rev_index def prepare_for_interpolation(src_spec_list, kV_index=None): """ Prepare the source spectral list for interpolation over voltage. Args: src_spec_list (list): List of source spectral responses for each voltage. kV_index (list): List of simulated voltages. Returns: list: Modified source spectral list ready for interpolation. """ modified_src_spec_list = src_spec_list.copy() kV_index = [first_nonzero_from_right(modified_src_spec_list[i]) for i in range(len(modified_src_spec_list))] for sid, m_src_spec in enumerate(modified_src_spec_list[:-1]): v0 = kV_index[sid] v1 = kV_index[sid+1] f1 = modified_src_spec_list[sid+1] for v in range(v0, v1): if v == kV_index[sid+1]: m_src_spec[v] = 0 else: r = (v - float(v0)) / (v1 - float(v0)) m_src_spec[v] = -r / (1 - r) * f1[v] return modified_src_spec_list def philibert_absorption_correction_factor(voltage, sin_psi, energies): Z = 74 # Tungsten target_material = ptableinverse[Z] PhilibertConstant = 4.0e5 PhilibertExponent = 1.65 # sin_psi = torch.sin(takeOffAngle * torch.pi / 180.0) h_local = 1.2 * atom_weights[target_material] / (Z ** 2) h_factor = h_local / (1.0 + h_local) kVp_e165 = voltage ** PhilibertExponent kappa = torch.zeros((energies.shape)) if not isinstance(energies, torch.Tensor): energies = torch.tensor(energies) kappa[:-1] = (PhilibertConstant / (kVp_e165 - energies ** PhilibertExponent)[:-1]) kappa[-1] = np.inf mu = torch.tensor(get_mass_absp_c_vs_E(ptableinverse[Z], energies)) # cm^-1 return (1 + mu / kappa / sin_psi) ** -1 * (1 + h_factor * mu / kappa / sin_psi) ** -1 def takeoff_angle_conversion_factor(voltage, sin_psi_cur, sin_psi_new, energies): # Assuming takeOffAngle_cur is already defined if not isinstance(sin_psi_cur, torch.Tensor): sin_psi_cur = torch.tensor(sin_psi_cur) if not isinstance(sin_psi_new, torch.Tensor): sin_psi_new = torch.tensor(sin_psi_new) return philibert_absorption_correction_factor(voltage, sin_psi_new, energies) / philibert_absorption_correction_factor(voltage, sin_psi_cur, energies) def angle_sin(psi, torch_mode=False): if torch_mode: return torch.sin(psi * torch.pi / 180.0) else: return np.sin(psi * np.pi / 180.0) class Reflection_Source_Analytical(Base_Spec_Model): def __init__(self, voltage, takeoff_angle, single_takeoff_angle=True): """ A template source model designed specifically for reflection sources, including all necessary methods. Args: voltage (tuple): (initial value, lower bound, upper bound) for the source voltage. These three values cannot be all None. It will not be optimized when lower == upper. takeoff_angle (tuple): (initial value, lower bound, upper bound) for the takeoff angle, in degrees. These three values cannot be all None. It will not be optimized when lower == upper. single_takeoff_angle (bool, optional): Determines whether the takeoff angle is same for all instances. If set to True (default), the same takeoff angle is applied to all instances of Reflection_Source. If set to False, each instance may have a distinct takeoff angle, with different prefix. """ params_list = [{'voltage': voltage, 'takeoff_angle': takeoff_angle}] super().__init__(params_list) self.single_takeoff_angle = single_takeoff_angle if self.single_takeoff_angle: for params in self._params_list: params[f"{self.__class__.__name__}_takeoff_angle"] = params.pop(f"{self.prefix}_takeoff_angle") self._init_estimates() def set_src_spec_list(self, src_spec_list, src_voltage_list, ref_takeoff_angle): """Set source spectra for interpolation, which will be used only by forward function. Args: src_spec_list (numpy.ndarray): This array contains the reference X-ray source spectra. Each spectrum in this array corresponds to a specific combination of the ref_takeoff_angle and one of the source voltages from src_voltage_list. src_voltage_list (numpy.ndarray): This is a sorted array containing the source voltages, each corresponding to a specific reference X-ray source spectrum. ref_takeoff_angle (float): This value represents the anode take-off angle, expressed in degrees, which is used in generating the reference X-ray spectra. """ self.src_spec_list = np.array(src_spec_list) self.src_voltage_list = np.array(src_voltage_list) modified_src_spec_list = prepare_for_interpolation(self.src_spec_list) self.src_spec_interp_func_over_v = Interp1D(torch.tensor(self.src_voltage_list, dtype=torch.float32), torch.tensor(modified_src_spec_list, dtype=torch.float32)) self.ref_takeoff_angle = ref_takeoff_angle def forward(self, energies): """ Takes X-ray energies and returns the source spectrum. Args: energies (torch.Tensor): A tensor containing the X-ray energies of a poly-energetic source in units of keV. Returns: torch.Tensor: The source response. """ voltage = self.get_params()[f"{self.prefix}_voltage"] src_spec = self.src_spec_interp_func_over_v(voltage) if self.single_takeoff_angle: takeoff_angle = self.get_params()[f"{self.__class__.__name__}_takeoff_angle"] else: takeoff_angle = self.get_params()[f"{self.prefix}_takeoff_angle"] # print('ID takeoff_angle:', id(takeoff_angle)) sin_psi_cur = angle_sin(self.ref_takeoff_angle, torch_mode=False) sin_psi_new = angle_sin(takeoff_angle, torch_mode=True) src_spec = src_spec * takeoff_angle_conversion_factor(voltage, sin_psi_cur, sin_psi_new, energies) return src_spec
[docs] class Reflection_Source(Base_Spec_Model): def __init__(self, voltage, takeoff_angle, single_takeoff_angle=True): """ A template source model designed specifically for reflection sources, including all necessary methods. Args: voltage (tuple): (initial value, lower bound, upper bound) for the source voltage. These three values cannot be all None. It will not be optimized when lower == upper. takeoff_angle (tuple): (initial value, lower bound, upper bound) for the takeoff angle, in degrees. These three values cannot be all None. It will not be optimized when lower == upper. single_takeoff_angle (bool, optional): Determines whether the takeoff angle is same for all instances. If set to True (default), the same takeoff angle is applied to all instances of Reflection_Source. If set to False, each instance may have a distinct takeoff angle, with different prefix. """ params_list = [{'voltage': voltage, 'takeoff_angle': takeoff_angle}] super().__init__(params_list) self.single_takeoff_angle = single_takeoff_angle if self.single_takeoff_angle: for params in self._params_list: params[f"{self.__class__.__name__}_takeoff_angle"] = params.pop(f"{self.prefix}_takeoff_angle") self._init_estimates()
[docs] def set_src_spec_list(self, energies, src_spec_list, voltages, takeoff_angles): """Set source spectra for interpolation, which will be used only by forward function. Args: energies (numpy.ndarray): Energies vector for lookup table. src_spec_list (numpy.ndarray): NVoltages * NAngles * NEnergies. This array contains the reference X-ray source spectra. Each spectrum in this array corresponds to one of the source voltages from src_voltage_list and one of takeoff angle from takeoff_angles. voltages (numpy.ndarray): This is a sorted array containing the source voltages, each corresponding to a specific reference X-ray source spectrum. takeoff_angles (numpy.ndarray): List of the anode take-off angles, expressed in degrees. """ self.energies = torch.tensor(energies, dtype=torch.float32) self.src_spec_list = np.array(src_spec_list) self.voltages = np.array(voltages) self.takeoff_angles = np.array(takeoff_angles) modified_src_spec_list = src_spec_list.copy() for tti, tt in enumerate(takeoff_angles): modified_src_spec_list[:, tti] = prepare_for_interpolation(modified_src_spec_list[:, tti]) # Generate 2D grids for x and y coordinates V, T = torch.meshgrid(torch.tensor(self.voltages, dtype=torch.float32), torch.tensor(self.takeoff_angles, dtype=torch.float32), indexing='ij') Z = torch.tensor(modified_src_spec_list, dtype=torch.float32) if V.shape[0] == 1 and T.shape[1] > 1: # Only one voltage → 1D interp along angle self.src_spec_interp_func = Interp1D(T[0], Z[0]) return if T.shape[1] == 1 and V.shape[0] > 1: # Only one takeoff angle → 1D interp along voltage self.src_spec_interp_func = Interp1D(V[:, 0], Z[:, 0]) return if T.shape[1] > 1 and V.shape[0] > 1: self.src_spec_interp_func = Interp2D(V, T, Z) return if T.shape[1] == 1 and V.shape[0] == 1: self.src_spec_interp_func = None return
[docs] def forward(self, energies): """ Takes X-ray energies and returns the source spectrum. Args: energies (torch.Tensor): A tensor containing the X-ray energies of a poly-energetic source in units of keV. Returns: torch.Tensor: The source response. """ voltage = self.get_params()[f"{self.prefix}_voltage"] if self.single_takeoff_angle: takeoff_angle = self.get_params()[f"{self.__class__.__name__}_takeoff_angle"] else: takeoff_angle = self.get_params()[f"{self.prefix}_takeoff_angle"] if isinstance(self.src_spec_interp_func, Interp1D): if len(self.voltages) == 1: # Only one voltage, interpolate over angle src_spec = self.src_spec_interp_func(takeoff_angle) else: # Only one takeoff angle, interpolate over voltage src_spec = self.src_spec_interp_func(voltage) elif isinstance(self.src_spec_interp_func, Interp2D): src_spec = self.src_spec_interp_func(voltage, takeoff_angle) else: src_spec = self.src_spec_list[0, 0] energies = torch.tensor(energies, dtype=torch.float32) if not isinstance(energies, torch.Tensor) else energies src_interp_E_func = Interp1D(self.energies, src_spec) return src_interp_E_func(energies)
[docs] class Transmission_Source(Base_Spec_Model): def __init__(self, voltage, target_thickness, single_target_thickness): """ A template source model designed specifically for reflection sources, including all necessary methods. Args: voltage (tuple): (initial value, lower bound, upper bound) for the source voltage. These three values cannot be all None. It will not be optimized when lower == upper. """ params_list = [{'voltage': voltage, 'target_thickness': target_thickness}] super().__init__(params_list) self.single_target_thickness = single_target_thickness if self.single_target_thickness: for params in self._params_list: params[f"{self.__class__.__name__}_target_thickness"] = params.pop(f"{self.prefix}_target_thickness") self._init_estimates()
[docs] def set_src_spec_list(self, energies, src_spec_list, voltages, target_thicknesses): """Set source spectra for interpolation, which will be used only by forward function. Args: src_spec_list (numpy.ndarray): This array contains the reference X-ray source spectra. Each spectrum in this array corresponds to a specific combination of the ref_takeoff_angle and one of the source voltages from src_voltage_list. src_voltage_list (numpy.ndarray): This is a sorted array containing the source voltages, each corresponding to a specific reference X-ray source spectrum. ref_takeoff_angle (float): This value represents the anode take-off angle, expressed in degrees, which is used in generating the reference X-ray spectra. """ self.energies = torch.tensor(energies, dtype=torch.float32) self.src_spec_list = np.array(src_spec_list) self.voltages = np.array(voltages) self.target_thicknesses = np.array(target_thicknesses) modified_src_spec_list = src_spec_list.copy() for tti, tt in enumerate(target_thicknesses): modified_src_spec_list[:, tti] = prepare_for_interpolation(modified_src_spec_list[:, tti]) # Generate 2D grids for x and y coordinates V, T = torch.meshgrid(torch.tensor(self.voltages, dtype=torch.float32), torch.tensor(self.target_thicknesses, dtype=torch.float32), indexing='ij') self.src_spec_interp_func = Interp2D(V, T, torch.tensor(modified_src_spec_list, dtype=torch.float32))
[docs] def forward(self, energies): """ Takes X-ray energies and returns the source spectrum. Args: energies (torch.Tensor): A tensor containing the X-ray energies of a poly-energetic source in units of keV. Returns: torch.Tensor: The source response. """ voltage = self.get_params()[f"{self.prefix}_voltage"] if self.single_target_thickness: target_thickness = self.get_params()[f"{self.__class__.__name__}_target_thickness"] else: target_thickness = self.get_params()[f"{self.prefix}_target_thickness"] src_spec = self.src_spec_interp_func(voltage, target_thickness) energies = torch.tensor(energies, dtype=torch.float32) if not isinstance(energies, torch.Tensor) else energies src_interp_E_func = Interp1D(self.energies, src_spec) return src_interp_E_func(energies)
[docs] class Filter(Base_Spec_Model): def __init__(self, materials, thickness): """ A template filter model based on Beer's Law and NIST mass attenuation coefficients, including all necessary methods. Args: materials (list): A list of possible materials for the filter, where each material should be an instance containing formula and density. thickness (tuple or list): If a tuple, it should be (initial value, lower bound, upper bound) for the filter thickness. If a list, it should have the same length as the materials list, specifying thickness for each material. These values cannot be all None. It will not be optimized when lower == upper. """ if isinstance(thickness, tuple): if all(t is None for t in thickness): raise ValueError("Thickness tuple cannot have all None values.") params_list = [{'material': mat, 'thickness': thickness} for mat in materials] elif isinstance(thickness, list): if len(thickness) != len(materials): raise ValueError("Length of thickness list must match length of materials list.") params_list = [{'material': mat, 'thickness': th} for mat, th in zip(materials, thickness)] else: raise TypeError("Thickness must be either a tuple or a list.") super().__init__(params_list)
[docs] def forward(self, energies): """ Takes X-ray energies and returns the filter response. Args: energies (torch.Tensor): A tensor containing the X-ray energies of a poly-energetic source in units of keV. Returns: torch.Tensor: The filter response as a function of input energies, selected material, and its thickness. """ mat = self.get_params()[f"{self.prefix}_material"] th = self.get_params()[f"{self.prefix}_thickness"] # print('ID filter th:', id(th)) energies = torch.tensor(energies, dtype=torch.float32) if not isinstance(energies, torch.Tensor) else energies return gen_fltr_res(energies, mat, th)
[docs] class Scintillator(Base_Spec_Model): def __init__(self, thickness, materials, device=None, dtype=None): """ A template scintillator model based on Beer's Law, NIST mass attenuation coefficients, and mass energy-absorption coefficients, including all necessary methods. Args: materials (list): A list of possible materials for the scintillator, where each material should be an instance containing formula and density. thickness (tuple): (initial value, lower bound, upper bound) for the scintillator thickness. These three values cannot be all None. It will not be optimized when lower == upper. """ params_list = [{'material': mat, 'thickness': thickness} for mat in materials] super().__init__(params_list)
[docs] def forward(self, energies): """ Takes X-ray energies and returns the scintillator response. Args: energies (torch.Tensor): A tensor containing the X-ray energies of a poly-energetic source in units of keV. Returns: torch.Tensor: The scintillator conversion function as a function of input energies, selected material, and its thickness. """ mat = self.get_params()[f"{self.prefix}_material"] th = self.get_params()[f"{self.prefix}_thickness"] # print('ID scintillator th:', id(th)) energies = torch.tensor(energies, dtype=torch.float32) if not isinstance(energies, torch.Tensor) else energies return gen_scint_cvt_func(energies, mat, th)
class Scintillator_MCNP(Base_Spec_Model): def __init__(self, thickness): """ Initializes a scintillator model for interpolation over scintillator thickness. Args: thickness (tuple): A tuple containing three elements: - initial value (float or None): The initial value of the scintillator thickness. - lower bound (float or None): The lower bound for the scintillator thickness. - upper bound (float or None): The upper bound for the scintillator thickness. At least one of these values cannot be None. If the lower bound equals the upper bound, the thickness will not be optimized. """ params_list = [{'thickness': thickness}] super().__init__(params_list) def set_scint_spec_list(self, scint_spec_list, thicknesses): """ Sets the lookup table for interpolation, which will be used in the forward function. Args: scint_spec_list (numpy.ndarray): A 2D array containing the reference scintillator spectra. Each row in this array corresponds to a specific scintillator thickness from the `thicknesses` array. thicknesses (numpy.ndarray): A sorted 1D array containing the scintillator thicknesses corresponding to each spectrum in `scint_spec_list`. The method also computes the logarithmic attenuation for each spectrum, which is used for interpolation over the thickness range. """ self.scint_spec_list = np.array(scint_spec_list) self.thicknesses = np.array(thicknesses) self.log_scint_spec_list = np.array([-np.log(1 - ss) for ss in scint_spec_list]) self.scint_spec_interp_func_over_th = Interp1D(torch.tensor(self.thicknesses, dtype=torch.float32), torch.tensor(self.log_scint_spec_list, dtype=torch.float32)) def forward(self, energies): """ Computes the scintillator response for given X-ray energies. Args: energies (torch.Tensor): A tensor containing X-ray energies of a poly-energetic response in keV. Returns: torch.Tensor: A tensor representing the interpolated scintillator response for energy integrating detector. """ energies = torch.tensor(energies, dtype=torch.float32) if not isinstance(energies, torch.Tensor) else energies thickness = self.get_params()[f"{self.prefix}_thickness"] src_spec = self.scint_spec_interp_func_over_th(thickness) src_spec = 1 - torch.exp(-src_spec) return src_spec * energies