Source code for utils.measurements2D_utils

"""
This module contains the Measurements2DMixin class, which is used to perform 2D measurements on
lesions in medical images. The class provides methods for calculating, displaying, and managing
lesion measurements, as well as handling user interactions with the GUI.
"""
import functools
from collections import defaultdict

import numpy as np
from nibabel.orientations import aff2axcodes, io_orientation

import qt
import slicer
import vtk
from utils import response_classification_utils

from slicer.util import *
import utils.ui_helper_utils as ui_helper_utils
import utils.results_table_utils as results_table_utils

from utils.rano_utils import get_instance_segmentation_by_connected_component_analysis, \
    match_instance_segmentations_by_IoU, get_max_orthogonal_line_product_coords, \
    get_ijk_to_world_matrix, transform_world_to_ijk_coord, transform_ijk_to_world_coord, \
    point_closest_to_two_lines, circle_opening_on_slices_perpendicular_to_axis, sphere_opening, find_closest_plane

from utils.config import debug


[docs] class Measurements2DMixin: """ This mixin class provides methods for performing 2D measurements on lesions in medical images. It includes methods for calculating, displaying, and managing lesion measurements, as well as handling user interactions with the GUI. """ def __init__(self): self.instance_segmentations_matched = None """List of instance segmentations (numpy arrays) with matching labels across time points.""" self.resampledSegNodes = None """List of segmentation nodes (vtkMRMLSegmentationNode) containing the matched instance segmentations.""" self.resampledVolumeNodes = None """List of instance segmentations (vtkMRMLLabelMapVolumeNode) transformed and resampled to the reference input volume space.""" self.previous_timepoint_orientation = {} """Dictionary to store the orientation of the lesions in the previous timepoint to enable consistent 2D measurement orientation across timepoints.""" self.previous_timepoint_center = {} """Dictionary to store the center of the lesions in the previous timepoint to enable consistent 2D measurement slices across timepoints.""" self.observations = [] """Store observers for the line nodes to handle user interactions.""" self.observations2 = [] """Store observers for the line nodes to handle user interactions."""
[docs] def onCalc2DButton(self): """ This method is called when the "Calculate 2D" button is pressed. It performs the following steps: 1. Get the selected segmentations and instance segmentations. 2. Match the instance segmentations across timepoints. 3. Transform and resample the instance segmentations in the original reference image space. 4. For each timepoint and lesion, place the RANO lines. 5. Evaluate the 2D measurements and store the results in a dictionary. 6. Display the results in the UI. 7. Update the line pair UI list. 8. Calculate the results table. """ if debug: print("Calc 2D button pressed") # determine the method to use for 2D measurements method2DmeasComboBox = self.ui.method2DmeasComboBox.currentText def create_lineNodePairs(lesion_stats): """ From line coordinates stored in lesion_stats, create lineNodePairs that are used to display the lines in the UI views. Args: lesion_stats: dictionary containing the line coordinates for each lesion and timepoint Returns: lineNodePairs: LineNodePairList containing the line node pairs for each lesion and timepoint """ lineNodePairs = LineNodePairList() for les_idx in lesion_stats: for tp in lesion_stats[les_idx]: coords = lesion_stats[les_idx][tp]["coords"] if len(coords) == 0: continue lineNodePair = LineNodePair(lesion_idx=les_idx, timepoint=tp) lineNodePair.set_coords(coords) lineNodePairs.append(lineNodePair) return lineNodePairs def setLinePairViews(lineNodePairs): """ Set the views for the line node pairs in the UI, i.e. it will make sure that lines of timepoint1 are shown in the timepoint1 views and lines of timepoint2 are shown in the timepoint2 views. It will also center the views of each timepoint on the first available line node pair in the list. Args: lineNodePairs: LineNodePairList containing the line node pairs for each lesion and timepoint """ tp_view_set = [] # keep track of which timepoint views have already been set for pair in lineNodePairs: les_idx = pair.lesion_idx tp = pair.timepoint # make sure the line node pair is displayed in the correct view self.setViews(pair, tp) # center on the first available line node pair in the list if tp not in tp_view_set: # to set views for this timepoint only once tp_view_set.append(tp) self.centerTimepointViewsOnCenterPoint(pair, tp) def get_binary_semantic_segmentations(segNodes, segmentId): """ Extract the binary semantic segmentations of a segment from the segmentation nodes as numpy arrays. Args: segNodes: list of segmentation nodes for each time point segmentId: ID of the segment to extract Returns: binary_semantic_segmentations: list of binary semantic segmentations of the segment (numpy arrays) """ binary_semantic_segmentations = [] for node in segNodes: node.CreateBinaryLabelmapRepresentation() refVol = node.GetNodeReference(slicer.vtkMRMLSegmentationNode.GetReferenceImageGeometryReferenceRole()) if not segmentId in node.GetSegmentation().GetSegmentIDs(): # if the segmentId does not exist, we want an empty binary segmentation (all zeros) bin_sem_seg = np.zeros_like(arrayFromVolume(refVol)) else: bin_sem_seg = arrayFromVolume(refVol) == int(segmentId) binary_semantic_segmentations.append(bin_sem_seg) return binary_semantic_segmentations def get_instance_segmentations(binary_segmentations): """ Convert the binary segmentations to instance segmentations. Currently, this is done by connected component analysis. Args: binary_segmentations: list of binary segmentations (numpy arrays) Returns: instance_segmentations: list of instance segmentations (numpy arrays) for each time point. Note that at the moment, the instance segmentations are not matched across time points, i.e., the labels do not correspond to the same instance across time points. """ instance_segmentations = [] for seg in binary_segmentations: instance_seg = get_instance_segmentation_by_connected_component_analysis(seg) instance_segmentations.append(instance_seg) return instance_segmentations def matched_instances_across_timepoints(instance_segmentations): """ Match the instance segmentations across time points. For example, as an output, if the first timepoint has instances 1, 2, 3 and the second timepoint can have instance 1, 3, 4 then instance 2 is missing in the second timepoint (disappeared) and instance 4 is a new instance in the second timepoint. Args: instance_segmentations: list of instance segmentations (numpy arrays) for each time point Returns: instance_segmentations_matched: list of instance segmentations (numpy arrays) with matching labels across time points. """ instance_segmentations_matched = match_instance_segmentations_by_IoU(instance_segmentations) return instance_segmentations_matched def transform_to_and_resample_in_original_img_space(instance_segmentations_matched, segNodes): """ The instance_segmentations_matched need to be transformed and resampled in the reference spaces, i.e., seg1 to the reference space of timepoint1 and seg2 to the reference space of timepoint2. This is so that the RANO lines can be placed on slices of the original input volumes. Currently, the reference space are given by the channel1 input volumes, but this can be changed in the future. This means, the RANO lines are placed on slices of the channel1 input volumes (reference space). Steps: 1. Create a segmentation node for each instance segmentation and place instance segments in there. Then apply the transform to the segmentation node. This will give a segmentation node in the original space, but with spacing 1x1x1. 2. Resample the image to the original reference input volume space. Args: instance_segmentations_matched: list of instance segmentations (numpy arrays) with matching labels across time points segNodes: list of segmentation nodes for each time point Returns: resampledVolumeNodes: list of vtkMRMLLabelMapVolumeNode objects with the resampled instance segmentations """ assert(len(instance_segmentations_matched) == len(segNodes)), "Number of instance segmentations and segmentation nodes must be equal" num_timepoints = len(segNodes) # get the transforms and reference volumes for each timepoint (currently only channel1) transformNodes = [slicer.util.getNode(f"Transform_timepoint{i + 1}_channel1 (-)") for i in range(num_timepoints)] referenceVolumeNodes = [self._parameterNode.GetNodeReference(f"InputVolume_channel1_t{i + 1}") for i in range(num_timepoints)] resampledSegNodes = [] resampledVolumeNodes = [] for i, (seg, segNode, transformNode, referenceVolumeNode) in enumerate(zip(instance_segmentations_matched, segNodes, transformNodes, referenceVolumeNodes)): # new segmentation node to store the matched instance segmentations newSegNodeName = f"matched_instance_segmentation_t{i + 1}" # if the segmentation node already exists, remove it if slicer.mrmlScene.GetFirstNodeByName(newSegNodeName): slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName(newSegNodeName)) newSegNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode", f"matched_instance_segmentation_t{i + 1}") referenceImageVolumeNode = segNode.GetNodeReference('referenceImageGeometryRef') newSegNode.SetReferenceImageGeometryParameterFromVolumeNode(referenceImageVolumeNode) # add the segmentations to the new segmentation node for lab in np.unique(seg): if lab == 0: continue slicer.util.updateSegmentBinaryLabelmapFromArray(narray=np.array(seg == lab).astype(np.uint8), segmentationNode=newSegNode, segmentId=str(lab), referenceVolumeNode=referenceImageVolumeNode ) # set the name of the new segment as Les 1, Les 2, etc. newSegNode.GetSegmentation().GetSegment(str(lab)).SetName(f"Les {lab}") # render in 3D newSegNode.CreateClosedSurfaceRepresentation() # make sure there is a displayNode newSegNode.CreateDefaultDisplayNodes() # display the segmentations on the corresponding views only self.setViews(newSegNode, f"timepoint{i + 1}") # hide the old and new segmentation nodes since we want to show the resampled labelmap volumes instead segNode.GetDisplayNode().SetVisibility(False) newSegNode.GetDisplayNode().SetVisibility(False) # apply the transform to the new segmentation node if transformNode: newSegNode.SetAndObserveTransformNodeID(transformNode.GetID()) # 2. resample the segmentation to the original input volume space # get all segmentIds from the new segmentation node segmentIds = newSegNode.GetSegmentation().GetSegmentIDs() if segmentIds: # create new labelmap volume node newLabelMapVolumeName = f"matched_instance_segmentation_t{i + 1}_resampled" # if it already exists, remove it if slicer.mrmlScene.GetFirstNodeByName(newLabelMapVolumeName): slicer.mrmlScene.RemoveNode(slicer.mrmlScene.GetFirstNodeByName(newLabelMapVolumeName)) labelmapVolumeNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode", newLabelMapVolumeName) slicer.vtkSlicerSegmentationsModuleLogic.ExportSegmentsToLabelmapNode(newSegNode, segmentIds, labelmapVolumeNode, referenceVolumeNode) # set the views ui_helper_utils.UIHelperMixin.setLabelVolumes(labelmapVolumeNode, f"timepoint{i + 1}") else: print(f"No segments found in {newSegNode.GetName()}") labelmapVolumeNode = None resampledSegNodes.append(newSegNode) resampledVolumeNodes.append(labelmapVolumeNode) self.resampledSegNodes = resampledSegNodes return resampledVolumeNodes def open_instance_segmentation(instance_seg, radius=3): """ Open the instance segmentation using the method specified in the method2DmeasComboBox. Args: instance_seg: instance segmentation (numpy array) radius: radius of the opening operation Returns: seg_open: opened instance segmentation (numpy array) """ seg_open = np.zeros_like(instance_seg) for lab in np.unique(instance_seg): if lab == 0: continue if method2DmeasComboBox == "RANO_open2D": seg_open += lab * circle_opening_on_slices_perpendicular_to_axis(instance_seg, axes=[0, 1, 2], labels=[lab], radius=radius, ) elif method2DmeasComboBox == "RANO_open3D": seg_open += lab * sphere_opening(instance_seg, labels=[lab], radius=radius) else: raise ValueError( f"No implementation available for opening the segmentations with method {method2DmeasComboBox}") return seg_open def open_instance_segmentations(instance_segmentations, opening_radius): """ Loop over the instance segmentations and perform morphological opening on each segmentation. Args: instance_segmentations: list of instance segmentations (numpy arrays) opening_radius: radius of the opening operation Returns: opened_segmentations: list of opened instance segmentations (numpy arrays) """ # open the segmentations opened_segmentations = [] for seg in instance_segmentations: opened_seg = open_instance_segmentation(seg, radius=opening_radius) opened_segmentations.append(opened_seg.astype(np.uint8)) return opened_segmentations def evaluate_instance_segmentation(resampledVolumeNode): """ Evaluate the instance segmentation. Specifically, retrieve the coordinates of the 2D measurements considering the specified options (orientation, slice consistency, etc.) and the method selected in the UI. Also, calculate the volume of the lesions. Args: resampledVolumeNode: vtkMRMLLabelMapVolumeNode containing the matched instance segmentation. Returns: lesion_dict: dictionary containing the line pair coordinates and volume of each lesion for the current timepoint. """ lesion_dict = defaultdict(lambda: {"coords": [], "volume": np.nan}) if not resampledVolumeNode: print("resampledVolumeNode is None. Assuming no lesions for this timepoint.") return lesion_dict instance_segmentation = arrayFromVolume(resampledVolumeNode) for lab in np.unique(instance_segmentation): if lab == 0: # skip background label continue bin_seg = np.array(instance_segmentation == lab).astype(np.uint8) if method2DmeasComboBox in ["RANO", "RANO_open2D", "RANO_open3D"]: # store the orientation of the lesion in the current timepoint orientation_consistency_across_timepoints = self._parameterNode.GetParameter("orient_cons_tp") == "true" previous_timepoint_orientation_current_lesion = self.previous_timepoint_orientation.get(lab, None) if orientation_consistency_across_timepoints and previous_timepoint_orientation_current_lesion: valid_orientations = [previous_timepoint_orientation_current_lesion] else: valid_orientations = [orien for orien in ["sagittal", "coronal", "axial"] if self._parameterNode.GetParameter(orien) == "true"] # store the slice number of the lesion in the current timepoint force_close_slice_across_timepoints = self._parameterNode.GetParameter("same_slc_tp") == "true" center_IJK = None if force_close_slice_across_timepoints and previous_timepoint_orientation_current_lesion: previous_timepoint_center_current_lesion = self.previous_timepoint_center.get(lab, None) center_world = previous_timepoint_center_current_lesion # need to convert world center point to IJK point worldToIJK = get_ijk_to_world_matrix(resampledVolumeNode) worldToIJK.Invert() center_IJK = transform_world_to_ijk_coord(center_world, worldToIJK) def convert_to_IJK_axis(orientation, resampledVolumeNode): """ Convert the anatomical orientation to the corresponding IJK axis index using the IJK to RAS matrix. Args: orientation: anatomical orientation ("sagittal", "coronal", "axial") resampledVolumeNode: vtkMRMLLabelMapVolumeNode containing the matched instance segmentations Returns: ijk_axis_idx: IJK axis index corresponding to the anatomical orientation """ # mapping from anatomical orientations to anatomical axis indices in 3D slicer (RAS) orientation_to_world_axis_idx = {"sagittal": 0, "coronal": 1, "axial": 2} world_axis_idx = orientation_to_world_axis_idx[orientation] # get the current world axis index # get the RAS to IJK matrix ijkToWorld = get_ijk_to_world_matrix(resampledVolumeNode) ijkToWorld = slicer.util.arrayFromVTKMatrix(ijkToWorld) worldToIJK = np.linalg.inv(ijkToWorld) # convert from world axis index to IJK axis index ornt = io_orientation(worldToIJK) # mapping from RAS axes to IJK axes ijk_axis_idx = int(ornt[world_axis_idx][0]) # [0] picks the corresponding axis in the IJK space ([1] picks the direction) return ijk_axis_idx # need to convert world orientation to IJK axis valid_axes_IJK = [convert_to_IJK_axis(orien, resampledVolumeNode) for orien in valid_orientations] # need to go back to numpy order (KJI) ijk_axis_idx_to_kji = {0: 2, 1: 1, 2: 0} # convert 0 to 2 and 2 to 0 and leave 1 as is valid_axes_IJK = [ijk_axis_idx_to_kji[idx] for idx in valid_axes_IJK] # get the line pair coordinates for the current lesion coords_world = get_max_orthogonal_line_product_coords(bin_seg, valid_axes_IJK, center_IJK, ijkToWorld=get_ijk_to_world_matrix(resampledVolumeNode)) if len(coords_world) > 0: world_normal_axis_idx = find_closest_plane(coords_world) this_timepoint_orientation_curr_lesion = {0: 'sagittal', 1: 'coronal', 2: 'axial'}[world_normal_axis_idx] # get the center of the line pair in IJK space to allow for consistent slice selection across timepoints center_world = point_closest_to_two_lines(coords_world) this_timepoint_center_world_curr_lesion = center_world # store the orientation and center of the lesion in the current timepoint self.previous_timepoint_orientation[lab] = this_timepoint_orientation_curr_lesion self.previous_timepoint_center[lab] = this_timepoint_center_world_curr_lesion volume = np.sum(bin_seg) elif method2DmeasComboBox == "Random": coords_world = np.random.randint(0, 100, (2, 2, 3)) # 2 lines, 2 points, 3 coordinates volume = np.random.rand(2, 2, 3) * 100 else: raise ValueError(f"Method {method2DmeasComboBox} not recognized") lesion_dict[lab] = {"coords": coords_world, "volume": volume} return lesion_dict def display_opened_segmentations(opened_segmentations, segNodes): """ The opened segmentations (numpy arrays) are added to the segmentation nodes as new segments. Args: opened_segmentations: list of opened segmentations (numpy arrays) segNodes: list of segmentation nodes for each time point """ for i, (seg, segNode) in enumerate(zip(opened_segmentations, segNodes)): nextSegmentId = str(int(max(segNode.GetSegmentation().GetSegmentIDs())) + 1) # temporarily apply transform to the reference image of the segmentation node # this is required because segNode itself has been transformed to the original image space, so the reference image # needs to be transformed to the original image space as well transformNode = slicer.util.getNode(f"Transform_timepoint{i + 1}_channel1 (-)") if transformNode: referenceVolumeNode = segNode.GetNodeReference('referenceImageGeometryRef') referenceVolumeNode.SetAndObserveTransformNodeID(transformNode.GetID()) slicer.util.updateSegmentBinaryLabelmapFromArray(narray=seg, segmentationNode=segNode, segmentId=nextSegmentId, referenceVolumeNode=segNode.GetNodeReference( 'referenceImageGeometryRef')) # undo the transform of the reference image (we want to keep it in the segmentation model space to # avoid unnecessary resampling) if transformNode: referenceVolumeNode.SetAndObserveTransformNodeID(None) # set the name of the new segment segmentIdSelectedForOpening = self.ui.SegmentSelectorWidget.currentSegmentID() segmentNameSelectedForOpening = segNode.GetSegmentation().GetSegment(segmentIdSelectedForOpening).GetName() segNode.GetSegmentation().GetSegment(nextSegmentId).SetName(f"{segmentNameSelectedForOpening}_opened") def get_lesion_stats(resampledVolumeNodes): """ For each timepoint, evaluate the instance segmentation and store the line pair coordinates and volume of the lesions in a dictionary. Args: resampledVolumeNodes: list of vtkMRMLLabelMapVolumeNode containing the matched instance segmentations Returns: lesion_stats: dictionary of dictionaries containing the line pair coordinates and volume of each lesion (key=lesion_idx) for each timepoint (key=timepoint). """ lesion_stats = defaultdict(lambda: defaultdict(lambda: {})) # reset the previous timepoint orientation and center, so the previous run does not affect the current run self.previous_timepoint_orientation = {} self.previous_timepoint_center = {} for tp, resampledVolumeNode in enumerate(resampledVolumeNodes): tp = f"timepoint{tp + 1}" lesion_dict_tp = evaluate_instance_segmentation(resampledVolumeNode) for les_idx in lesion_dict_tp: lesion_stats[les_idx][tp] = lesion_dict_tp[les_idx] return lesion_stats segNodes = [self.ui.outputSelector.currentNode(), self.ui.outputSelector_t2.currentNode()] binary_semantic_segmentations = get_binary_semantic_segmentations(segNodes, segmentId=self.ui.SegmentSelectorWidget.currentSegmentID()) instance_segmentations = get_instance_segmentations(binary_semantic_segmentations) instance_segmentations_matched = matched_instances_across_timepoints(instance_segmentations) if method2DmeasComboBox in ["RANO_open2D", "RANO_open3D"]: opened_segmentations = open_instance_segmentations(instance_segmentations_matched, opening_radius=int(self.ui.radius_spinbox.value)) display_opened_segmentations(opened_segmentations, segNodes) instance_segmentations_matched = opened_segmentations self.instance_segmentations_matched = instance_segmentations_matched # transform the instance segmentations to the original reference image space self.resampledVolumeNodes = transform_to_and_resample_in_original_img_space(instance_segmentations_matched, segNodes) # hide the resampled labelmap volumes self.ui.toggleShowInstanceSegPushButton.setChecked(False) self.onToggleShowInstanceSegButton() lesion_stats = get_lesion_stats(self.resampledVolumeNodes) # remove the previous lines del self.lineNodePairs[:] self.lineNodePairs = create_lineNodePairs(lesion_stats) setLinePairViews(self.lineNodePairs) # center views on first volume (reference volume) self.onShowChannelButton(True, timepoint='timepoint1', inputSelector=self.ui.inputSelector_channel1_t1) self.onShowChannelButton(True, timepoint='timepoint2', inputSelector=self.ui.inputSelector_channel1_t2) self.update_linepair_ui_list() # algorithms to set enhancing, measurable, target lesions self.lineNodePairs.decide_enhancing() self.lineNodePairs.decide_measurable() self.lineNodePairs.decide_target() # calculate the results table results_table_utils.ResultsTableMixin.calculate_results_table(self.lineNodePairs)
[docs] def onToggleShowInstanceSegButton(self): """ This method is called when the "Show/Hide Lesions" button is pressed. It shows or hides the lesions in the slice views based on the resampled labelmap volumes. In the 3D views they are shown based on the resampled segmentation nodes (because the resampled labelmap volumes are not displayed smoothly in 3D, but voxelized). """ # show or hide the label volumes via the slicecontrollerwidget for timepoint in ["timepoint1", "timepoint2"]: if timepoint == 'timepoint1': viewnames = ["Red", "Yellow", "Green"] tp = 0 elif timepoint == 'timepoint2': viewnames = ["Red_2", "Yellow_2", "Green_2"] tp = 1 else: raise ValueError("timepoint must be 'timepoint1' or 'timepoint2'") checked = self.ui.toggleShowInstanceSegPushButton.checked for viewname in viewnames: compositeNode = slicer.app.layoutManager().sliceWidget(viewname).sliceLogic().GetSliceCompositeNode() controller = slicer.app.layoutManager().sliceWidget(viewname).sliceController() if compositeNode.GetLabelVolumeID() not in [node.GetID() for node in self.resampledVolumeNodes if node]: pass # controller.setLabelMapHidden(True) else: compositeNode.SetLabelOpacity(0.5) if checked: controller.setLabelMapHidden(False) else: controller.setLabelMapHidden(True) # show or hide the 3D view node = self.resampledSegNodes[tp] if self.resampledSegNodes else None if node: if checked: node.GetDisplayNode().SetVisibility(True) node.GetDisplayNode().SetVisibility2D(False) node.GetDisplayNode().SetVisibility3D(True) # set opacity node.GetDisplayNode().SetOpacity(0.5) else: node.GetDisplayNode().SetVisibility(False)
[docs] def onAddLinePairButton(self, timepoint): """ This method is called when the "Add Lines t1" or "Add Lines t2" button is pressed. It allows the user to add a new line pair for the selected timepoint by placing two lines in the slice views of the corresponding timepoint. The lines are added to the lineNodePairs list and displayed in the UI. Args: timepoint: the timepoint for which the line pair is added (e.g., "timepoint1" or "timepoint2") """ if debug: print("Add line pair button pressed") lesion_idx = int(self.ui.add_line_lesidx_spinBox.value) # check if the lesion index and timepoint already exist in the lineNodePairs for pair in self.lineNodePairs: if pair.lesion_idx == lesion_idx and pair.timepoint == timepoint: # show a message box msgBox = qt.QMessageBox() msgBox.setText(f"Line pair for Lesion Index {lesion_idx} and {timepoint} already exists") msgBox.exec() return newLineNodePair = LineNodePair(lesion_idx=lesion_idx, timepoint=timepoint) newLineNode1, newLineNode2 = newLineNodePair self.setViews(newLineNodePair, timepoint) # go into placement mode for the new line persistentPlaceMode = 1 # need to be in persistent mode to place the second line after first slicer.modules.markups.logic().StartPlaceMode(persistentPlaceMode) slicer.modules.markups.logic().SetActiveListID(newLineNode1) def place_line2(lineNode1, arg2): if lineNode1.GetNumberOfControlPoints() == 2: # go into placement mode for the new line slicer.modules.markups.logic().SetActiveListID(newLineNode2) def end_placement(lineNode2, arg2): if int(lineNode2.GetNumberOfControlPoints()) == 2: slicer.modules.markups.logic().StartPlaceMode(0) # exit placement mode def add_linePair(lineNode2): if lineNode2.GetNumberOfControlPoints() == 2: line_lenghts = newLineNodePair.get_line_lengths() newLineNodePair.measurable = True if all( [l > 10 for l in line_lenghts]) else False self.lineNodePairs.append(newLineNodePair) self.update_linepair_ui_list() # set the views again after both lines have been defined so they can be removed from the views in which they are not orthogonal self.setViews(newLineNodePair, timepoint) for observedNode, observer in self.observations2: observedNode.RemoveObserver(observer) add_linePair(lineNode2) # add callback to end placement mode AFTER the second line is placed self.observations2.append([newLineNode2, newLineNode2.AddObserver(newLineNode2.PointPositionDefinedEvent, end_placement)]) def onPointRemovedEvent(lineNode, event): if slicer.app.applicationLogic().GetInteractionNode().GetCurrentInteractionMode() == 2: # schedule removal the line pair if the user cancels the placement of either line # can't remove the lines immediately because one of the lines is the observer caller # remove observers for observedNode, observer in self.observations2: observedNode.RemoveObserver(observer) qt.QTimer.singleShot(0, lambda: slicer.mrmlScene.RemoveNode(newLineNode1)) qt.QTimer.singleShot(0, lambda: slicer.mrmlScene.RemoveNode(newLineNode2)) # add callback to place the second line when the first line is placed # first remove previous observers for observedNode, observer in self.observations2: observedNode.RemoveObserver(observer) self.observations2.append([newLineNode1, newLineNode1.AddObserver(newLineNode1.PointPositionDefinedEvent, place_line2)]) #self.observations2.append([newLineNode2, newLineNode2.AddObserver(newLineNode2.PointPositionDefinedEvent, add_linePair)]) # add callback to remove the line pair if the user cancels the placement of either line self.observations2.append([newLineNode1, newLineNode1.AddObserver(newLineNode1.PointRemovedEvent, onPointRemovedEvent)]) self.observations2.append([newLineNode2, newLineNode2.AddObserver(newLineNode2.PointRemovedEvent, onPointRemovedEvent)])
[docs] def update_linepair_ui_list(self): """ This method updates the UI list of line pairs by populating the table with the lesion index, timepoint, and whether the lesion is enhancing, measurable, and target. It also sets the background color of the rows such that rows of the same lesion index are grouped together for better readability. """ def onCellClicked(row, col): """Called when a cell is clicked.""" les_idx = int(self.ui.tableWidget.item(row, 0).text()) tp = self.ui.tableWidget.item(row, 1).text() # center timepoint's views on the clicked lesion for pair in self.lineNodePairs: if pair.lesion_idx == les_idx and pair.timepoint == tp: self.centerTimepointViewsOnCenterPoint(pair, tp) # update the menu table col_name_to_idx = {"Lesion Index": 0, "Timepoint": 1, "Enhancing": 2, "Measurable": 3, "Target": 4, " ": 5} tableWidget = self.ui.tableWidget tableWidget.setColumnCount(len(col_name_to_idx)) tableWidget.setHorizontalHeaderLabels(list(col_name_to_idx.keys())) tableWidget.setRowCount(0) tableWidget.clearContents() tableWidget.verticalHeader().setVisible(False) tableWidget.cellClicked.connect(onCellClicked) prev_les_idx = None color1 = qt.QColor(240, 240, 240) color2 = qt.QColor(255, 255, 255) prev_color = color1 # sort the lineNodePairs by lesion index and timepoint self.lineNodePairs = self.lineNodePairs.custom_sort(key=lambda x: (int(x.lesion_idx), x.timepoint)) for pair_idx, pair in enumerate(self.lineNodePairs): les_idx = int(pair.lesion_idx) tp = pair.timepoint # insert new row row_count = tableWidget.rowCount assert row_count == pair_idx, f"Row count: {row_count}, pair_idx: {pair_idx}, should be equal" tableWidget.insertRow(row_count) def setRowColor(rowIdx, color): """Set the background color of all cells in the row.""" for col in range(tableWidget.columnCount): item = tableWidget.item(rowIdx, col) if item: item.setBackground(qt.QBrush(color)) item.setTextAlignment(qt.Qt.AlignCenter) # define the contents of the columns tableWidget.setItem(row_count, col_name_to_idx["Lesion Index"], qt.QTableWidgetItem(str(les_idx))) tableWidget.setItem(row_count, col_name_to_idx["Timepoint"], qt.QTableWidgetItem(tp)) # place tick boxes in the "Enhancing", "Measurable", "Target" columns for col_name in ["Enhancing", "Measurable", "Target"]: checkbox = qt.QCheckBox() def is_valid_target_selection(checkbox, les_idx, tp, _): # make sure selected target lesion is measurable for pair in self.lineNodePairs: if pair.lesion_idx == les_idx and pair.timepoint == tp: if not pair.measurable and checkbox.isChecked(): msgBox = qt.QMessageBox() msgBox.setText("Target lesion must be measurable") msgBox.exec() checkbox.setChecked(False) return False # count the number of enhancing and non-enhancing target lesions for current timepoint counter_enhancing = 0 counter_non_enhancing = 0 for pair in self.lineNodePairs: if not pair.timepoint == tp: continue if pair.enhancing and pair.target: counter_enhancing += 1 elif not pair.enhancing and pair.target: counter_non_enhancing += 1 err = None if counter_enhancing > 0 and counter_non_enhancing > 0: if counter_enhancing > 2 or counter_non_enhancing > 2: err = (f"Only 2 enhancing and 2 non-enhancing target lesions are allowed, but tried to add " f"{counter_enhancing} enhancing and {counter_non_enhancing} non-enhancing target lesions " f"for timepoint {tp}") elif counter_enhancing > 3: err = (f"Only 3 enhancing target lesions are allowed, but tried to add {counter_enhancing} " f"for timepoint {tp}") elif counter_non_enhancing > 3: err = (f"Only 3 non-enhancing target lesions are allowed, but tried to add {counter_non_enhancing} " f"for timepoint {tp}") if err: msgBox = qt.QMessageBox() msgBox.setText(err) msgBox.exec() checkbox.setChecked(False) return False return True if col_name == "Target": checkbox.clicked.connect(functools.partial(is_valid_target_selection, checkbox, les_idx, tp)) # add callback to update the lineNodePair when the checkbox is clicked def onCheckboxToggled(checkbox, les_idx, tp, col_name, _): if debug: print(f"Checkbox toggled to {checkbox.checked} for line pair of lesion index {les_idx}, timepoint {tp}, and column {col_name}") for pair in self.lineNodePairs: if pair.lesion_idx == les_idx and pair.timepoint == tp: setattr(pair, col_name.lower(), checkbox.checked) response_classification_utils.ResponseClassificationMixin.update_response_assessment(self.ui, self.lineNodePairs) checkbox.toggled.connect(functools.partial(onCheckboxToggled, checkbox, les_idx, tp, col_name)) #tableWidget.setCellWidget(row_count, col_name_to_idx[col_name], checkbox) # make sure the checkbox is centered# make sure the checkbox is centered widget = qt.QWidget() layout = qt.QHBoxLayout(widget) layout.addWidget(checkbox) layout.setAlignment(qt.Qt.AlignCenter) layout.setContentsMargins(0, 0, 0, 0) tableWidget.setCellWidget(row_count, col_name_to_idx[col_name], widget) checkbox.setChecked(getattr(pair, col_name.lower())) # place push button in the "Delete" column deleteButton = qt.QPushButton() deleteButton.setIcon(qt.QIcon(self.resourcePath('Icons/trash.png'))) def deleteLinePair(index, _): if debug: print(f"Deleting line pair at index {index}") self.lineNodePairs.pop(index) self.update_linepair_ui_list() deleteButton.clicked.connect(functools.partial(deleteLinePair, pair_idx)) tableWidget.setCellWidget(row_count, col_name_to_idx[" "], deleteButton) # make the row narrow tableWidget.horizontalHeader().resizeSection(col_name_to_idx[" "], 10) other_color = color1 if prev_color == color2 else color2 selected_color = prev_color if (not prev_les_idx or les_idx == prev_les_idx) else other_color prev_color = selected_color prev_les_idx = les_idx setRowColor(row_count, selected_color) # make all rows narrow row_height = 20 for row in range(tableWidget.rowCount): tableWidget.setRowHeight(row, row_height) # if the text font is too large, the actual row height will be larger than the set row height row_height = tableWidget.rowHeight(0) # make the columns fit the content tableWidget.resizeColumnsToContents() self.ui.tableWidget.setFixedHeight((row_height + 1) * (tableWidget.rowCount+1) + 2) response_classification_utils.ResponseClassificationMixin.update_response_assessment(self.ui, self.lineNodePairs) response_classification_utils.ResponseClassificationMixin.update_overall_response_params(self.ui) response_classification_utils.ResponseClassificationMixin.update_overall_response_status(self.ui)
[docs] def coords_ijk_to_world(self, coords_ijk, node): """ Convert the coordinates from IJK to world coordinates. """ if len(coords_ijk) == 0: return [] ijkToWorld = get_ijk_to_world_matrix(node) l1p1, l1p2, l2p1, l2p2 = coords_ijk[0][0], coords_ijk[0][1], coords_ijk[1][0], coords_ijk[1][1] controlPointsLine1 = np.array([transform_ijk_to_world_coord(l1p1, ijkToWorld), transform_ijk_to_world_coord(l1p2, ijkToWorld)]) controlPointsLine2 = np.array([transform_ijk_to_world_coord(l2p1, ijkToWorld), transform_ijk_to_world_coord(l2p2, ijkToWorld)]) coords_world = np.array([controlPointsLine1, controlPointsLine2]) # 2 lines, 2 points, 3 coordinates return coords_world
[docs] @staticmethod def setViews(node, timepoint): """ Set the views for the line node pair such that the lines are only shown in the views corresponding to the timepoint and the line orientation (sagittal, coronal, axial) of the line pair. """ if timepoint == 'timepoint1': viewnames = ["Red", "Yellow", "Green"] viewname_3D = "view3d_1" elif timepoint == 'timepoint2': viewnames = ["Red_2", "Yellow_2", "Green_2"] viewname_3D = "view3d_2" else: raise ValueError(f"timepoint must be 'timepoint1' or 'timepoint2' but is {timepoint}") # if the node is a line node, show the line only in the views in which the line lies if isinstance(node, LineNodePair): # get the control point positions coords = node.get_coords() # check that there are no nan values in the coords if not np.isnan(coords).any(): axis = find_closest_plane(coords) if axis == 0: viewnames = [viewnames[1]] elif axis == 1: viewnames = [viewnames[2]] elif axis == 2: viewnames = [viewnames[0]] viewNodeIDs = [] for viewname in viewnames: viewNodeIDs.append( slicer.app.layoutManager().sliceWidget(viewname).sliceLogic().GetSliceNode().GetID()) # add the 3D view # loop over all 3D views to find the correct one for i in range(slicer.app.layoutManager().threeDViewCount): if slicer.app.layoutManager().threeDWidget( i).threeDView().mrmlViewNode().GetSingletonTag() == viewname_3D: threeDWidget_idx = i break else: raise ValueError(f"Could not find the 3D view with singleton tag {viewname_3D}") viewNodeIDs.append(slicer.app.layoutManager().threeDWidget( threeDWidget_idx).threeDView().mrmlViewNode().GetID()) if isinstance(node, LineNodePair): for lineNode in node: lineNode.GetDisplayNode().SetViewNodeIDs(viewNodeIDs) if node.fiducialNodeForText: node.fiducialNodeForText.GetDisplayNode().SetViewNodeIDs(viewNodeIDs) else: node.GetDisplayNode().SetViewNodeIDs(viewNodeIDs)
[docs] @staticmethod def centerTimepointViewsOnFirstMarkupPoint(lineNode, tp): """ Center the slice views and cameras on the first markup point of the line node. Args: lineNode: the line node to center on tp: the timepoint to center on (e.g., "timepoint1" or "timepoint2") """ # center view group on the first markup point # Center slice views and cameras on this position position = lineNode.GetNthControlPointPositionWorld(0) viewgroup_to_set = 0 if tp == 'timepoint1' else 1 for sliceNode in slicer.util.getNodesByClass('vtkMRMLSliceNode'): viewgroup = sliceNode.GetViewGroup() if viewgroup == viewgroup_to_set: sliceNode.JumpSliceByCentering(*position) for viewNode in slicer.util.getNodesByClass('vtkMRMLViewNode'): view_group = viewNode.GetViewGroup() if view_group == viewgroup_to_set: camera = slicer.modules.cameras.logic().GetViewActiveCameraNode(viewNode) camera.SetFocalPoint(*position)
[docs] @staticmethod def centerTimepointViewsOnCenterPoint(lineNodePair, tp): """ Center the slice views and cameras on the center point of the line node pair. Args: lineNodePair: the line node pair to center on tp: the timepoint to center on (e.g., "timepoint1" or "timepoint2") """ coords = lineNodePair.get_coords() center = point_closest_to_two_lines(coords) position = center viewgroup_to_set = 0 if tp == 'timepoint1' else 1 for sliceNode in slicer.util.getNodesByClass('vtkMRMLSliceNode'): viewgroup = sliceNode.GetViewGroup() if viewgroup == viewgroup_to_set: sliceNode.JumpSliceByCentering(*position) for viewNode in slicer.util.getNodesByClass('vtkMRMLViewNode'): view_group = viewNode.GetViewGroup() if view_group == viewgroup_to_set: camera = slicer.modules.cameras.logic().GetViewActiveCameraNode(viewNode) camera.SetFocalPoint(*position)
[docs] class LineNodePair(list): """ A class that represents a pair of line nodes for a lesion in a timepoint. The class inherits from list to allow easy access to the line nodes. The class also contains methods to create the line nodes, set and get their coordinates, and set the views for the line nodes. The class also contains methods to create a fiducial node for the text label of the line nodes and to handle events when the line nodes are modified. The class also contains methods to set the enhancing, measurable, and target properties of the line nodes. Args: lesion_idx: the index of the lesion timepoint: the timepoint for which the line nodes are created (e.g., "timepoint1" or "timepoint2") enhancing: whether the lesion is enhancing or not (default: True) measurable: whether the lesion is measurable or not (default: True) target: whether the lesion is a target lesion or not (default: False) """ def __init__(self, lesion_idx, timepoint, enhancing=True, measurable=True, target=False): lineNode1, lineNode2 = self.create_twoLineNodes(lesion_idx, timepoint) super().__init__([lineNode1, lineNode2]) self.lesion_idx = lesion_idx """Lesion index""" self.timepoint = timepoint """Timepoint for which the line node pair is created""" self.enhancing = enhancing """Whether the lesion is enhancing or not""" self.measurable = measurable """Whether the lesion is measurable or not""" self.target = target """Whether the lesion is a target lesion or not""" self.fiducialNodeForText = self.create_fiducialNodeFor_text() """Fiducial node for the text label of the line nodes""" self.observations = [] """List of observers for the line nodes""" # add the observers for the lines that trigger when the line nodes are modified self.observations.append([lineNode1, lineNode1.AddObserver(vtk.vtkCommand.ModifiedEvent, self.uponLineNodeModifiedEvent)]) self.observations.append([lineNode2, lineNode2.AddObserver(vtk.vtkCommand.ModifiedEvent, self.uponLineNodeModifiedEvent)])
[docs] def set_coords(self, coords): """ Set the coordinates of the line nodes to the given world coordinates. """ for lineNode, coord in zip(self, coords): # note: looping over self returns the line nodes because self is a list of line nodes if not isinstance(coord, np.ndarray): coord = np.array(coord) slicer.util.updateMarkupsControlPointsFromArray(lineNode, coord) self.uponLineNodeModifiedEvent(n=None, e=None)
[docs] def get_coords(self): """ Get the coordinates of the line nodes in world coordinates. """ coords = np.zeros((2, 2, 3)) * np.nan for i, lineNode in enumerate(self): # check if both control points exist controlpoint1_exists = lineNode.ControlPointExists(0) controlpoint2_exists = lineNode.ControlPointExists(1) if not controlpoint1_exists or not controlpoint2_exists: continue # leave the coords nan coords[i] = np.array([lineNode.GetNthControlPointPositionWorld(j) for j in range(2)]) return coords
[docs] def get_line_lengths(self): """ Get the lengths of the lines in world coordinates. """ len1 = self[0].GetLineLengthWorld() len2 = self[1].GetLineLengthWorld() return len1, len2
[docs] def get_line_length_product(self): """ Get the product of the lengths of the lines in world coordinates. """ len1, len2 = self.get_line_lengths() return len1 * len2
[docs] @staticmethod def create_twoLineNodes(les_idx, timepoint): """ Create two line nodes for the line pair and set their properties. Args: les_idx: the index of the lesion timepoint: the timepoint for which the line nodes are created (e.g., "timepoint1" or "timepoint2") """ # get the line nodes lineNode1Name = f"l1_les{int(les_idx)}_{timepoint.replace('timepoint', 't')}" lineNode2Name = f"l2_les{int(les_idx)}_{timepoint.replace('timepoint', 't')}" # check if the line already exists lineNode1 = slicer.mrmlScene.GetFirstNodeByName(lineNode1Name) lineNode2 = slicer.mrmlScene.GetFirstNodeByName(lineNode2Name) # remove the line nodes if they already exist and create new ones if lineNode1: slicer.mrmlScene.RemoveNode(lineNode1) if lineNode2: slicer.mrmlScene.RemoveNode(lineNode2) lineNode1 = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsLineNode", lineNode1Name) lineNode2 = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsLineNode", lineNode2Name) # make the line thicker lineNode1.GetDisplayNode().SetLineThickness(0.35) lineNode2.GetDisplayNode().SetLineThickness(0.35) # make the control points smaller lineNode1.GetDisplayNode().SetGlyphScale(1) lineNode2.GetDisplayNode().SetGlyphScale(1) # hide the text of the line nodes lineNode1.GetDisplayNode().SetPropertiesLabelVisibility(0) lineNode2.GetDisplayNode().SetPropertiesLabelVisibility(0) return lineNode1, lineNode2
[docs] def create_fiducialNodeFor_text(self): """ Create a fiducial node for the text label of the line nodes. """ # create an extra fiducial point that is only used to annotate the linePair (hiding the ctrlPoint point itself) fiducialNodeName = f"text_fiducial_les{int(self.lesion_idx)}_{self.timepoint.replace('timepoint', 't')}" fiducialNode = slicer.mrmlScene.GetFirstNodeByName(fiducialNodeName) if fiducialNode: slicer.mrmlScene.RemoveNode(fiducialNode) fiducialNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", fiducialNodeName) # hide the control point glyph fiducialNode.GetDisplayNode().SetGlyphScale(0) # used for relative size fiducialNode.GetDisplayNode().SetGlyphSize(0) # used for absolute size # add a ctrlPoint point to the scene fiducialNode.AddControlPoint(0, 0, 0) # for now, set the text to empty fiducialNode.SetNthControlPointLabel(0, "") # make text size smaller fiducialNode.GetDisplayNode().SetTextScale(3) # turn the text shadow off fiducialNode.GetDisplayNode().GetTextProperty().ShadowOff() return fiducialNode
[docs] @staticmethod def set_color_depending_on_orthogonality(n, e, lineNode1, lineNode2, fiducialNodeForText=None): """ Set the color of the lines depending on whether they are orthogonal or not. Args: n: the event name e: the event object lineNode1: the first line node lineNode2: the second line node fiducialNodeForText: the fiducial node for the text label of the line nodes """ # set the color of the lines depending on whether they are orthogonal # get the two lines if debug: print("Setting color depending on orthogonality") line1 = np.array( [lineNode1.GetNthControlPointPositionWorld(i) for i in range(lineNode1.GetNumberOfControlPoints())]) line2 = np.array( [lineNode2.GetNthControlPointPositionWorld(i) for i in range(lineNode2.GetNumberOfControlPoints())]) # get the direction vectors of the lines if not (len(line1) == 2 and len(line2) == 2): return dir1 = line1[-1] - line1[0] dir2 = line2[-1] - line2[0] # normalize the direction vectors dir1 /= np.linalg.norm(dir1) dir2 /= np.linalg.norm(dir2) # calculate the dot product of the direction vectors dot_product = np.dot(dir1, dir2) # set the color of the lines depending on the dot product tolerance_deg = 1 min_deg = 90 - tolerance_deg max_deg = 90 + tolerance_deg min_rad = np.deg2rad(min_deg) max_rad = np.deg2rad(max_deg) if min_rad < np.arccos(dot_product) < max_rad: color = (0, 1, 0) # green else: color = (1, 0, 0) # red lineNode1.GetDisplayNode().SetSelectedColor(color) lineNode2.GetDisplayNode().SetSelectedColor(color) if fiducialNodeForText: fiducialNodeForText.GetDisplayNode().SetSelectedColor(tuple([v / 1.2 for v in color]))
[docs] def annotate_with_text(self): """ Annotate the line nodes with the length of the lines in world coordinates. """ # set the location of the text to the middle of the line coords = self.get_coords() # check if coords has nans if not np.isnan(coords).any(): # get the intersection point of the two lines center_point = point_closest_to_two_lines(coords) elif not np.isnan(coords[0]).any(): # get the middle point of the first line center_point = coords[0].mean(axis=0) else: center_point = [0., 0., 0.] # set the location of the ctrlPoint point to the middle of the line self.fiducialNodeForText.SetNthControlPointPosition(0, center_point[0], center_point[1], center_point[2]) controlPoint_l1_0_exists = self[0].ControlPointExists(0) controlPoint_l1_1_exists = self[0].ControlPointExists(1) controlPoint_l2_0_exists = self[1].ControlPointExists(0) controlPoint_l2_1_exists = self[1].ControlPointExists(1) if not controlPoint_l1_0_exists or not controlPoint_l1_1_exists: line1LengthWorld = 0.0 else: line1LengthWorld = self[0].GetLineLengthWorld() if not controlPoint_l2_0_exists or not controlPoint_l2_1_exists: line2LengthWorld = 0.0 else: line2LengthWorld = self[1].GetLineLengthWorld() if line1LengthWorld or line2LengthWorld: self.fiducialNodeForText.SetNthControlPointLabel(0, f"Les {int(self.lesion_idx)}: {line1LengthWorld:.1f} x {line2LengthWorld:.1f}")
[docs] def uponLineNodeModifiedEvent(self, n, e): """ This method is called when the line nodes are modified. It sets the color of the lines depending on whether they are orthogonal or not and updates the text label of the line nodes with the length of the lines in world coordinates. Args: n: the event name e: the event object """ # print("LineNode modified event") self.set_color_depending_on_orthogonality(n, e, self[0], self[1], self.fiducialNodeForText) slicer.modules.RANOWidget.calculate_results_table(slicer.modules.RANOWidget.lineNodePairs) self.annotate_with_text()
[docs] def cleanup(self): """ Cleanup the line node pair by removing the observers and the fiducial node. """ # upon deletion of the object, remove the observers for observedNode, observer in self.observations: observedNode.RemoveObserver(observer) # remove the fiducial node slicer.mrmlScene.RemoveNode(self.fiducialNodeForText) # remove the line nodes slicer.mrmlScene.RemoveNode(self[0]) slicer.mrmlScene.RemoveNode(self[1])
[docs] def __repr__(self): return (f"LineNodePair(lesion_idx={self.lesion_idx}, timepoint={self.timepoint}, " f"lineNode1={self[0]}, lineNode2={self[1]})")
[docs] class LineNodePairList(list): """ A list of LineNodePair objects. This class inherits from list to allow easy access to the line node pairs. The class also contains methods to add, remove, and modify the line node pairs. The class also contains methods to update the UI and the response assessment based on the line node pairs. """
[docs] def __delitem__(self, index): """ Makes sure that line nodes contained in the LineNodePairList are removed from the scene when removed from the list """ if isinstance(index, slice): # Get all items that will be deleted items = self[index] for item in items: item.cleanup() else: self[index].cleanup() super().__delitem__(index) self.uponModified()
[docs] def pop(self, index): """ Makes sure that line nodes contained in the LineNodePairList are removed from the scene when popped from the list """ self[index].cleanup() out = super().pop(index) self.uponModified() return out
[docs] def uponModified(self): """ This method is called when the line node pairs are modified. It updates the UI and the response assessment based on the line node pairs. """ slicer.modules.RANOWidget.calculate_results_table(self) slicer.modules.RANOWidget.update_response_assessment(slicer.modules.RANOWidget.ui, self) slicer.modules.RANOWidget.update_linepair_ui_list()
# make sure that the list is returned as a LineNodePairList when sorted
[docs] def custom_sort(self, *args, **kwargs): """ Sort the list of LineNodePair objects and return a new LineNodePairList object. """ sorted_items = sorted(self, *args, **kwargs) return LineNodePairList(sorted_items)
[docs] def decide_enhancing(self): """ Logic to decide whether the lesion is enhancing or not. For now, all lesions are considered enhancing initially. """ for pair in self: pair.enhancing = True self.uponModified()
[docs] def decide_measurable(self): """ Decide whether the lesion is measurable or not based on the orthogonal lines. """ # for now if both lines are more than 10 pixels long, the lesion is measurable for pair in self: if pair[0].GetLineLengthWorld() > 10 and pair[1].GetLineLengthWorld() > 10: pair.measurable = True else: pair.measurable = False self.uponModified()
[docs] def decide_target(self, strategy="two_largest_enhancing"): """ Logic to decide whether the lesion is a target lesion or not. The strategy can be one of the following: - "two_largest_enhancing": select the two largest enhancing lesions from the baseline - "three_largest_enhancing": select the three largest enhancing lesions from the baseline - "two_largest_enhancing_and_two_largest_non_enhancing": select the two largest enhancing lesions and the two largest non-enhancing lesions from the baseline Args: strategy: the strategy to use for selecting the target lesions """ # set all target flags to False for pair in self: pair.target = False # sort the lesions by the product of the orthogonal lines sorted_list = self.custom_sort(key=lambda x: x[0].GetLineLengthWorld() * x[1].GetLineLengthWorld(), reverse=True) sorted_list_t1 = [pair for pair in sorted_list if pair.timepoint == "timepoint1"] sorted_list_t2 = [pair for pair in sorted_list if pair.timepoint == "timepoint2"] if strategy == "two_largest_enhancing" or strategy == "three_largest_enhancing": # sort the lesions by the product of the orthogonal lines counter_target_les = 0 for pair in sorted_list_t1: num_max = 2 if strategy == "two_largest_enhancing" else 3 if pair.measurable and counter_target_les < num_max: pair.target = True counter_target_les += 1 elif strategy == "two_largest_enhancing_and_two_largest_non_enhancing": counter_enhancing = 0 counter_non_enhancing = 0 for pair in sorted_list_t1: if pair.enhancing: if counter_enhancing < 2: pair.target = True counter_enhancing += 1 else: if counter_non_enhancing < 2: pair.target = True counter_non_enhancing += 1 # also set corresponding lesions in timepoint2 as targets for pair_t2 in sorted_list_t2: for pair in sorted_list_t1: if pair_t2.lesion_idx == pair.lesion_idx: pair_t2.target = pair.target self.uponModified()
[docs] def get_number_of_targets(self): """ Get the number of target lesions, but don't count the same lesion twice it is in the list twice (for both timepoints) """ target_les_idcs = [pair.lesion_idx for pair in self if pair.target] num_target_lesions = len(set(target_les_idcs)) # set to remove duplicates return num_target_lesions
[docs] def get_number_of_new_target_lesions(self): """ Get the number of new target lesions that were not target lesions at the first timepoint but appeared at the second """ target_les_idcs_tp1 = [pair.lesion_idx for pair in self if pair.target and pair.timepoint == 'timepoint1'] target_les_idcs_tp2 = [pair.lesion_idx for pair in self if pair.target and pair.timepoint == 'timepoint2'] num_new_target_lesions = len(set(target_les_idcs_tp2) - set(target_les_idcs_tp1)) return num_new_target_lesions
[docs] def get_number_of_disappeared_target_lesions(self): """ Get the number of target lesions that were target lesions at the previous timepoints but are not target lesions because they disappeared """ target_les_idcs_tp1 = [pair.lesion_idx for pair in self if pair.target and pair.timepoint == 'timepoint1'] target_les_idcs_tp2 = [pair.lesion_idx for pair in self if pair.target and pair.timepoint == 'timepoint2'] num_disappeared_target_lesions = len(set(target_les_idcs_tp1) - set(target_les_idcs_tp2)) return num_disappeared_target_lesions
[docs] def get_number_of_new_measurable_lesions(self): """ Get the number of new measurable lesions that were not measurable at the first timepoint but appeared at the second """ measurable_les_idcs_tp1 = [pair.lesion_idx for pair in self if pair.measurable and pair.timepoint == 'timepoint1'] measurable_les_idcs_tp2 = [pair.lesion_idx for pair in self if pair.measurable and pair.timepoint == 'timepoint2'] num_new_measurable_lesions = len(set(measurable_les_idcs_tp2) - set(measurable_les_idcs_tp1)) return num_new_measurable_lesions
[docs] def get_sum_of_bidimensional_products(self, timepoint): """ Given a list of line node pairs, this function returns the sum of bidimensional products of the orthogonal lines of all lesions at the given timepoint. """ sum_prod = 0 for pair in self: if pair.target and pair.timepoint == timepoint: sum_prod += pair.get_line_length_product() return sum_prod
[docs] def get_rel_area_change(self): """ Given a list of line node pairs, this function returns the relative change of the sum of bidimensional products of the orthogonal lines of all lesions at timepoint 2 relative to the sum of the bidimensional products of the orthogonal lines of all lesions at timepoint 1. """ sum_prod_t1 = self.get_sum_of_bidimensional_products("timepoint1") sum_prod_t2 = self.get_sum_of_bidimensional_products("timepoint2") return (sum_prod_t2 - sum_prod_t1) / sum_prod_t1