Source code for xcal.estimate

import sys
import atexit
import warnings
import numpy as np
import copy

import torch
import torch.optim as optim
from torch.multiprocessing import Pool
import torch.multiprocessing as mp

mp.set_sharing_strategy('file_system')
mp.set_start_method("spawn", force=True)

import logging
from xcal.opt._pytorch_lbfgs.functions.LBFGS import FullBatchLBFGS as NNAT_LBFGS
from xcal._utils import *
from xcal.models import get_merged_params_list, get_concatenated_params_list, denormalize_parameter_as_tuple, clamp_with_grad
def weighted_mse_loss(input, target, weight):
    return 0.5 * torch.mean(weight * (input - target) ** 2)


[docs] def calc_forward_matrix(homogenous_vol_masks, lac_vs_energies, forward_projector, slices=None): """ Calculate the forward matrix for a combination of multiple solid objects using a given forward projector. Args: homogenous_vol_masks (list of numpy.ndarray): Each 3D array in the list represents a mask for a homogenous, pure object. lac_vs_energies (list of numpy.ndarray): Each 1D array contains the linear attenuation coefficient (LAC) curve and the corresponding energies for the materials represented in `homogenous_vol_masks`. forward_projector (object): An instance of a class that implements a forward projection method. This instance should have a method, forward(mask), that takes a 3D volume mask as input and computes the photon's line path length. slices (tuple of slice objects, optional): Slices to apply to the forward projection output to reduce memory usage. Each element in the tuple corresponds to a dimension of the 3D volume output of the forward_projector (views, rows, and columns), and specifies the portion of the data to include in the calculation. If not provided, the entire volume will be used. Returns: numpy.ndarray: The calculated forward matrix for spectral estimation. This matrix represents the exponential attenuation of photons through the combined materials, with dimensions corresponding to the input volumes and the energy levels specified in `lac_vs_energies`. """ linear_att_intg_list = [] for mask, lac_vs_energies in zip(homogenous_vol_masks, lac_vs_energies): linear_intg = forward_projector.forward(mask) if slices is None: linear_att_intg = linear_intg[np.newaxis, :] * lac_vs_energies[:, np.newaxis, np.newaxis, np.newaxis] else: linear_att_intg = linear_intg[(np.newaxis,) + slices] * lac_vs_energies[:, np.newaxis, np.newaxis, np.newaxis] linear_att_intg_list.append(linear_att_intg) tot_lai = np.sum(np.array(linear_att_intg_list), axis=0) forward_matrix = np.exp(- tot_lai.transpose((1, 2, 3, 0))) return forward_matrix
def fit_cell(energies, nrads, forward_matrices, spec_models, params, weights=None, learning_rate=0.02, max_iterations=5000, stop_threshold=1e-3, optimizer_type='NNAT_LBFGS', loss_type='transmission'): """Arguments are same as param_based_spec_estimate. """ logger = logging.getLogger(str(mp.current_process().pid)) def print(*args, **kwargs): message = ' '.join(map(str, args)) logger.info(message) def print_params(params): for key, value in sorted(params.items()): if isinstance(value, tuple): dv = denormalize_parameter_as_tuple(value) dd = torch.clamp(dv[0], dv[1], dv[2]) print(f"{key}: {dd.numpy()}") else: print(f"{key}: {value}") print() spec_models = [[copy.deepcopy(cm) for cm in component_models] for component_models in spec_models] params = copy.deepcopy(params) parameters = [] for component_models in spec_models: for cm in component_models: cm.set_params(params) parameters += list(cm.parameters()) parameters = list(set(parameters)) loss = torch.nn.MSELoss() if optimizer_type == 'Adam': ot = 'Adam' iter_prt = 50 optimizer = optim.Adam(parameters, lr=learning_rate) elif optimizer_type == 'NNAT_LBFGS': ot = 'NNAT_LBFGS' iter_prt = 5 optimizer = NNAT_LBFGS(parameters, lr=learning_rate) else: warnings.warn(f"The optimizer type {optimizer_type} is not supported.") sys.exit("Exiting the script due to unsupported optimizer type.") cost = np.inf print('Start Estimation.') for iter in range(1, max_iterations + 1): if iter % iter_prt == 0: print('Iteration:', iter) def closure(): if torch.is_grad_enabled(): optimizer.zero_grad() cost = 0 for yy, FF, ww, component_models in zip(nrads, forward_matrices, weights, spec_models): spec = component_models[0](energies) for cm in component_models[1:]: spec = spec*cm(energies) spec /= torch.trapz(spec, energies) trans_value = torch.trapz(FF * spec, energies, axis=-1).reshape((-1, 1)) if loss_type == 'transmission': sub_cost = weighted_mse_loss(trans_value, yy, ww) elif loss_type == 'attmse': sub_cost = 0.5 * loss(-torch.log(trans_value), -torch.log(yy)) elif loss_type == 'least_square': sub_cost = 0.5 * loss(trans_value, yy) else: raise ValueError('loss_type should be \'mse\' or \'wmse\' or \'attmse\'. ', 'Given', loss_type) cost += sub_cost if cost.requires_grad and ot != 'NNAT_LBFGS': cost.backward() return cost cost = closure() if torch.isnan(cost): print('Meet NaN!!') for component_models in spec_models: for cm in component_models: print(cm.get_params()) return iter, closure().item(), params if ot == 'NNAT_LBFGS': cost.backward() for component_models in spec_models: for cm in component_models: has_nan = check_gradients_for_nan(cm) if has_nan: return iter, closure().item(), params with (torch.no_grad()): if iter == 1: print('Initial cost: %e' % (closure().item())) # Before the update, clone the current parameters old_params = [parameter.data.clone() for parameter in parameters] if ot == 'Adam': optimizer.step() elif ot == 'NNAT_LBFGS': options = {'closure': closure, 'current_loss': cost, 'max_ls': 100, 'damping': False} cost, grad_new, _, _, closures_new, grads_new, desc_dir, fail = optimizer.step(options=options) with (torch.no_grad()): if iter % iter_prt == 0: print('Cost:', cost.item()) print_params(params) # After the update, check if the update is too small small_update = True for parameter,old_param in zip(parameters,old_params): if torch.norm(parameter.data.clamp(0, 1) - old_param.clamp(0, 1)) > stop_threshold: small_update = False break if small_update: print(f"Stopping at epoch {iter} because updates are too small.") print('Cost:', cost.item()) print_params(params) break return iter, cost.item(), params def init_logging(filename, num_processes): worker_id = mp.current_process().pid logger = logging.getLogger(str(worker_id)) logger.setLevel(logging.INFO) if filename is None: handler = logging.StreamHandler() else: handler = logging.FileHandler(f"{filename}_{worker_id % num_processes}.log") formatter = logging.Formatter('%(asctime)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) # Register a cleanup function to close the logger when the process exits atexit.register(close_logging, logger) def close_logging(logger): handlers = logger.handlers[:] for handler in handlers: handler.close() logger.removeHandler(handler)
[docs] class Estimate(): def __init__(self, energies): """The Estimate class provides a structured approach for parameter estimation by separating input arguments into data and optimization domains, thereby reducing duplicate input. The Estimate class provides estimation of both discrete and continuous parameters within a unified framework. Args: energies (numpy.ndarray): X-ray energies of a poly-energetic source in units of keV. """ self.energies = torch.tensor(energies, dtype=torch.float32) self.nrads = [] self.forward_matrices = [] self.spec_models = [] self.weights = []
[docs] def add_data(self, nrad, forward_matrix, component_models, weight=None): """Add data for parameter estimation, which allows adding multiple datasets scanned with different X-ray system setting. Args: nrad (numpy.ndarray): Normalized radiograph with dimensions [N_views, N_rows, N_cols]. forward_matrix (numpy.ndarray): Forward matricx corresponds to nrad with dimensions [N_views, N_rows, N_cols, N_energiy_bins]. We provide ``xcal.calc_forward_matrix.rst`` to calculate a forward matrix from a 3D mask for a homogenous object. component_models (object): An instance of Base_Spec_Model. weight (numpy.ndarray): Weight corresponds to the normalized radiograph. Returns: """ self.nrads.append(torch.tensor(nrad.reshape((-1, 1)), dtype=torch.float32)) self.num_sp_datasets = len(self.nrads) self.forward_matrices.append(torch.tensor(forward_matrix, dtype=torch.float32)) self.spec_models.append(component_models) if weight is None: weight = 1.0 / self.nrads[-1] else: weight = torch.tensor(weight.reshape((-1, 1)), dtype=torch.float32) self.weights.append(weight)
[docs] def fit(self, learning_rate=0.001, max_iterations=5000, stop_threshold=1e-4, optimizer_type='Adam', loss_type='transmission', logpath=None, num_processes=1): """Estimate both discrete and continuous parameters. Args: learning_rate (float, optional): [Default=0.001] Learning rate for the optimization process. max_iterations (int, optional): [Default=5000] Maximum number of iterations for the optimization. stop_threshold (float, optional): [Default=1e-4] Scalar valued stopping threshold in percent. If stop_threshold=0.0, then run max iterations. optimizer_type (str, optional): [Default='Adam'] Type of optimizer to use. If we do not have accurate initial guess use 'Adam', otherwise, 'NNAT_LBFGS' can provide a faster convergence. loss_type (str, optional): [Default='transmission'] Calculate loss function in 'transmission' or 'attenuation' space. logpath (optional): [Default=None] Path for logging, if required. num_processes (int, optional): [Default=1] Number of processes to use for parallel computation. Returns: """ # Calculate params_list concatenate_params_list = [get_concatenated_params_list([cm._params_list for cm in concatenate_models]) for concatenate_models in self.spec_models] merged_params_list = get_merged_params_list(concatenate_params_list) # Use multiprocessing pool to parallelize the optimization process with Pool(processes=num_processes, initializer=init_logging, initargs=(logpath, num_processes)) as pool: # Apply optimization function to each combination of model parameters result_objects = [ pool.apply_async( fit_cell, args=(self.energies, self.nrads, self.forward_matrices, self.spec_models, params, self.weights, learning_rate, max_iterations, stop_threshold, optimizer_type, loss_type) ) for params in merged_params_list ] # Gather results from all parallel optimizations print('Number of cases for different discrete parameters:', len(result_objects)) results = [r.get() for r in result_objects] # Retrieve results from async calls cost_list = [res[1] for res in results] optimal_cost_ind = np.argmin(cost_list) best_params = results[optimal_cost_ind][2] self.params = best_params self.results = results for component_models in self.spec_models: for cm in component_models: cm.set_params(best_params)
[docs] def get_spec_models(self): """ Obtain optimized spectral models. Returns: list: A list of compenent lists. Each compenent list contains all used components to scan the corresponding radiograph. """ return self.spec_models
[docs] def get_spectra(self): """ Obtain optimized system responses corresponding to list of added nrad. Returns: list: A list of system responses. """ spec_list = [] for sms in self.spec_models: est_sp = torch.ones(self.energies.shape) for sm in sms: est_sp*=sm(self.energies) spec_list.append(est_sp) return spec_list
[docs] def get_params(self): """ Read estimated parameters as a dictionary. Returns: dict: Dictionary containing estimated parameters. """ display_estimates = {} for key, value in self.params.items(): if isinstance(value, tuple): dv = denormalize_parameter_as_tuple(value) display_estimates[key] = clamp_with_grad(dv[0], dv[1], dv[2]) else: display_estimates[key] = value return display_estimates
[docs] def get_all_estimates(self): """ Generates a list of tuples, each containing a combination of discrete and continuous parameters. This function explores all possible combinations of given parameters to facilitate comprehensive analysis or optimization tasks. Each tuple in the list comprises three elements: 1. Stopped iterations: The number of iterations after which the evaluation stopped. 2. Cost value: The cost or objective function value associated with the parameter combination. 3. A dictionary of estimated parameters: Keys are parameter names, and values are the corresponding discrete or continuous values for that combination. Returns: List[Tuple[int, float, Dict[str, Union[int, float]]]]: A list of tuples, each representing a unique combination of parameters and their evaluation metrics. """ return self.results
def least_squares_estimation(energies, A_np, y_np, x_init_np, weights_np=None, num_iterations=1000, learning_rate=1e-3, smoothness_lambda=0.01, non_neg_lambda=0.01, change_lambda=0.01, change_scale=10000, change_threshold=0.001, stop_threshold=1e-5): """ Perform least squares estimation using Adam optimizer to solve y = Ax, ensuring that x is non-negative and sums to one (treated as a probability distribution), with optional weighted loss, smoothness regularization, and a stopping threshold when updates to x become very small. Prints loss every 50 iterations. Args: A_np (np.ndarray): The matrix coefficients of the linear model as a numpy array. Shape should be (m, n). y_np (np.ndarray): The output vector as a numpy array. Shape should be (m,). x_init_np (np.ndarray): Initial estimate of the parameter vector x as a numpy array. Shape should be (n,). weights_np (np.ndarray, optional): Weights for each observation, affecting their contribution to the loss. Shape should be (m,). If None, equal weighting is assumed. num_iterations (int): Number of iterations for the optimization. learning_rate (float): Learning rate for the optimization. smoothness_lambda (float): Regularization parameter for promoting smoothness in the solution. non_neg_lambda (float): Regularization parameter for enforcing non-negativity and sum-to-one constraint. change_lambda (float): Regularization parameter for limiting changes from the initial estimate. stop_threshold (float): Threshold for stopping the optimization when the change in x is small. Returns: np.ndarray: The estimated parameters x, non-negative and summing to one. Shape will be (n,). """ # Convert numpy arrays to PyTorch tensors energies = torch.tensor(energies, dtype=torch.float32) A = torch.tensor(A_np, dtype=torch.float32) y = torch.tensor(y_np, dtype=torch.float32) x_init = torch.tensor(x_init_np, dtype=torch.float32) x = torch.tensor(x_init_np, dtype=torch.float32, requires_grad=True) if weights_np is not None: weights = torch.tensor(weights_np, dtype=torch.float32) else: weights = torch.ones(y.shape[0], dtype=torch.float32) # Define the optimizer using Adam optimizer = torch.optim.Adam([x], lr=learning_rate) # Initialize x_old for the first iteration x_old = torch.zeros_like(x) # Optimization loop print('Start Estimation.') for iteration in range(num_iterations): optimizer.zero_grad() # Clear previous gradients y_pred = torch.trapz(A * x, energies, axis=-1).reshape((-1, 1)) loss = torch.mean(weights * (y_pred - y) ** 2) # Weighted mean squared error loss # Add smoothness regularization if required if smoothness_lambda > 0: smoothness_loss = torch.sum((x[:-1] - x[1:])**2) loss += smoothness_lambda * smoothness_loss # Add non-negativity and sum-to-one constraints non_neg_loss = torch.sum(torch.relu(-x) ** 2) + (torch.sum(x) - 1) ** 2 loss += non_neg_lambda * non_neg_loss # Add change regularization term change_penalty = torch.sum(torch.max(torch.abs(x - x_init) - change_scale * x_init, torch.tensor(change_threshold))- torch.tensor(change_threshold)) loss += change_lambda * change_penalty loss.backward() # Perform backpropagation optimizer.step() # Update the parameters # Apply non-negativity constraint and normalize to sum to one with torch.no_grad(): # Update without tracking gradient # Check if the update is smaller than the stop threshold if torch.norm(x - x_old) < stop_threshold: break x_old = x.clone() # Update x_old with the new values # Print loss every 100 iterations if (iteration + 1) % 100 == 0: print(f"Iteration {iteration + 1}: Loss = {loss.item()}") with torch.no_grad(): print(f"forward loss: {torch.mean(weights * (y_pred - y) ** 2)}; non-negative loss: {non_neg_lambda * non_neg_loss}; change penalty loss: {change_lambda * change_penalty}") # Return the estimated x as a numpy array return x.detach().numpy()