import ctk
from utils import ui_helper_utils, measurements2D_utils
from utils.config import module_path
import time
import os
import traceback
import json
import slicer
import numpy as np
from utils.rano_utils import run_segmentation
from utils.config import debug
[docs]
class SegmentationMixin:
"""
Mixin class for the segmentation functionality of the RANO module.
This class handles the segmentation of the input volumes and the loading of the results into Slicer.
It also handles the progress bar and the cancellation of the segmentation process.
"""
def __init__(self):
self.start_time_t1 = None
"""Start time for the first segmentation process"""
self.start_time_t2 = None
"""Start time for the second segmentation process"""
[docs]
def onCliNodeStatusUpdate(self, cliNode, event, progressBar, task_dir, tmp_path_out, output_segmentation,
input_volume_list, timepoint, original_log_level=32):
"""
Callback function to handle the status update of the CLI node.
Args:
cliNode: The CLI node that is being observed.
event: The event that triggered the callback.
progressBar: The progress bar to update.
task_dir: The task directory for the segmentation.
tmp_path_out: The temporary output path for the segmentation.
output_segmentation: The output segmentation node.
input_volume_list: The list of input volume nodes.
timepoint: The timepoint for the segmentation ('timepoint1' or 'timepoint2').
original_log_level: The original log level for the Slicer application.
"""
if debug: print("Received an event '%s' from class '%s'" % (event, cliNode.GetClassName()))
# print(f"Status of CLI for {timepoint} is {cliNode.GetStatusString()}")
if cliNode.IsA('vtkMRMLCommandLineModuleNode') and not cliNode.GetStatus() == cliNode.Completed: # do this when the cli module sends a progress update (not when it is done)
if timepoint == 'timepoint1' and not self.start_time_t1:
self.start_time_t1 = time.time()
elif timepoint == 'timepoint2' and not self.start_time_t2:
self.start_time_t2 = time.time()
progressBar.setValue(min(cliNode.GetProgress(), 99.0))
# print("Status is %s" % cliNode.GetStatusString())
# print("cli update: progressBar value is %s" % cliNode.GetProgress(), flush=True)
if cliNode.GetProgress() == 100:
# disable stderr output from VTK etc that come from the segmentation CLI
logLevels = { n: getattr(ctk.ctkErrorLogLevel, n) for n in dir(ctk.ctkErrorLogLevel) if type(getattr(ctk.ctkErrorLogLevel, n)).__name__ == 'LogLevel'}
if "None" in logLevels:
logLevelValue = getattr(ctk.ctkErrorLogLevel, "None")
elif "Unknown" in logLevels:
logLevelValue = getattr(ctk.ctkErrorLogLevel, "Unknown")
else:
raise ValueError("Neither 'None' nor 'Unknown' logLevel was available in ctk.ctkErrorLogLevel")
if not debug: slicer.app.setPythonConsoleLogLevel(logLevelValue)
if cliNode.GetStatus() & cliNode.Completed: # do this when the cli module is done
# print("Status is %s" % cliNode.GetStatusString())
# if cliNode.GetStatus() & cliNode.ErrorsMask: # upon error
# outputText = cliNode.GetOutputText()
# errorText = cliNode.GetErrorText()
# print("CLI output text: \n" + outputText, flush=True)
# print("CLI execution failed: \n" + errorText, flush=True)
# else: # upon success
progressBar.setValue(100)
start_time = self.start_time_t1 if timepoint == 'timepoint1' else self.start_time_t2
print(f"Segmentation for {timepoint} completed after {int(time.time() - start_time)} seconds.")
if debug: print("CLI output text: \n" + cliNode.GetOutputText(), flush=True)
# print("CLI error text: \n" + cliNode.GetErrorText(), flush=True) # commented out here, because stderr is printed by VTK to the console already
self.onSegmentationCliNodeSuccess(input_volume_list, output_segmentation,
task_dir, timepoint, tmp_path_out)
slicer.app.setPythonConsoleLogLevel(original_log_level) # restore original log level
[docs]
def onSegmentationCliNodeSuccess(self, input_volume_list, output_segmentation, task_dir,
timepoint, tmp_path_out):
"""
Callback function to handle the success of the CLI node.
This function loads the output segmentation into Slicer and applies the transformation if required.
Args:
input_volume_list: The list of input volume nodes.
output_segmentation: The output segmentation node.
task_dir: The task directory for the segmentation.
timepoint: The timepoint for the segmentation ('timepoint1' or 'timepoint2').
tmp_path_out: The temporary output path for the segmentation.
"""
# get output files created by external inference process
# depending on whether registration is required or not, the output file is in a different location
# if no registration is required
affine_checkbox = self.ui.affineregCheckBox if timepoint == 'timepoint1' else self.ui.affineregCheckBox_t2
if not affine_checkbox.checked:
tmp_file_path_out = os.path.join(tmp_path_out, "output.nii.gz")
else: # want to load the segmentation in the segmentation input template space
tmp_file_path_out = os.path.join(tmp_path_out, "preprocessed", "registered", "output.nii.gz")
tmp_transform_file_img0 = os.path.join(tmp_path_out, "preprocessed", "registered", "image_0000",
"img_tmp_0000_ANTsregistered_0GenericAffine.mat")
# check if the output file exists
if not os.path.exists(tmp_file_path_out):
slicer.util.errorDisplay("Output segmentation file not found: " + tmp_file_path_out)
return
# load the transformation file
if not affine_checkbox.checked:
#slicer.util.errorDisplay("Transformation file not found: " + tmp_transform_file_img0)
print("Transformation file not found: " + tmp_transform_file_img0)
print("Creating identity transform instead")
# create a transform node with the identity transform
transformNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLTransformNode")
else:
# load the transform
transformNode = slicer.util.loadTransform(tmp_transform_file_img0)
# set the transform name
transformNode.SetName("Transform_" + timepoint + "_channel1")
# invert the transform
transformNode.Inverse()
transformNode.InverseName()
# get previous reference image for the output segmentation
referenceImageVolumeNode = output_segmentation.GetNodeReference('referenceImageGeometryRef')
# delete the previous reference image
if referenceImageVolumeNode:
slicer.mrmlScene.RemoveNode(referenceImageVolumeNode)
# load the file as a labelVolumeNode
loadedLabelVolumeNode = slicer.util.loadLabelVolume(tmp_file_path_out, properties={"show": False})
loadedLabelVolumeNode.SetHideFromEditors(1) # hide the volume from the subject hierarchy
segmentation = self.ImportLabelmapToSegmentationNodeWithBackgroundSegment(loadedLabelVolumeNode,
output_segmentation)
# set labelVolumeNode as the reference image for the output segmentation
output_segmentation.SetReferenceImageGeometryParameterFromVolumeNode(loadedLabelVolumeNode)
# rename the segments
# which structures were predicted by the neural network (including background (0))
predictedStructureVals = list(
set(np.unique(slicer.util.arrayFromVolume(loadedLabelVolumeNode))))
# load the label dictionary
label_names_path = os.path.join(task_dir, "config", "label_names.json")
with open(label_names_path) as jsonfile:
label_dict = json.load(jsonfile)
for seg_idx in range(segmentation.GetNumberOfSegments()):
slicer.app.processEvents() # to keep the GUI responsive
segment = segmentation.GetNthSegment(seg_idx)
orig_idx = int(predictedStructureVals[seg_idx])
segment.SetName(label_dict[str(orig_idx)])
# render in 3D
output_segmentation.CreateClosedSurfaceRepresentation()
# make some segments invisible
# make sure there is a displayNode
if not output_segmentation.GetDisplayNode():
output_segmentation.CreateDefaultDisplayNodes()
displayNode = output_segmentation.GetDisplayNode()
ids_to_make_invisible = ['0'] # ['1', '2', '3', '4']
for id in ids_to_make_invisible:
displayNode.SetSegmentVisibility(id, False)
# set opacity
displayNode.SetOpacity3D(0.5)
# show the first volume in the 2D views of the "timepoint" row
first_input_volume = input_volume_list[0]
ui_helper_utils.UIHelperMixin.setBackgroundVolumes(first_input_volume, timepoint)
measurements2D_utils.Measurements2DMixin.setViews(output_segmentation, timepoint)
# don't show the output volume in the 2D views, since the segmentation will be shown. The output volume
# is just kept as a reference image for the segmentation
slicer.util.setSliceViewerLayers(label=None)
# set flag to indicate that the segmentation was successfully computed and loaded
self._parameterNode.SetParameter(f"segmentation_loaded_{timepoint}", "true")
# apply the transform to the output segmentation
if self.ui.affineregCheckBox.checked:
output_segmentation.SetAndObserveTransformNodeID(transformNode.GetID())
self.setDefaultSegmentFor2DMeasurements("ETC")
# add an observer to the progress bar cancel button
[docs]
def onCancel(self, cliNode, progressBar):
cliNode.Cancel()
progressBar.close()
[docs]
@staticmethod
def ImportLabelmapToSegmentationNodeWithBackgroundSegment(loadedLabelVolumeNode, output_segmentation):
"""
Import the labelVolumeNode into the segmentation node. The labels in the labelVolumeNode are increased by 1
temporarily to include the background label (0) and then decreased by 1 again to have the original labels.
(This is because slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode ignores the background label)
Finally, the segmentIDs are reduced by 1 to have the correct segmentIDs in the segmentation node.
Args:
loadedLabelVolumeNode: The label volume node to import.
output_segmentation: The output segmentation node to import the label volume into.
"""
# increase values by 1
labelArray = slicer.util.arrayFromVolume(loadedLabelVolumeNode)
labelArray += 1
slicer.util.arrayFromVolumeModified(loadedLabelVolumeNode)
# convert the labels in the labelVolumeNode into segments of the output segmentation node
segmentation = output_segmentation.GetSegmentation()
segmentation.RemoveAllSegments()
slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(loadedLabelVolumeNode,
output_segmentation)
# decrease values by 1 again
labelArray -= 1
slicer.util.arrayFromVolumeModified(loadedLabelVolumeNode)
# reduce the segmentIDs by 1
for segID in segmentation.GetSegmentIDs():
old_segment = segmentation.GetSegment(segID)
old_segName = old_segment.GetName()
new_segID = str(int(old_segName) - 1)
# first change the name of the segment to the new ID
old_segment.SetName(new_segID)
# add segment with new ID (and new name)
segmentation.AddSegment(old_segment, new_segID)
# remove segment with old ID
segmentation.RemoveSegment(segID)
return segmentation
[docs]
def setDefaultSegmentFor2DMeasurements(self, defaultSegmentName="ETC"):
"""
Set the default segment for 2D measurements in the segment selector widget.
This function checks if the default segment exists in either of the segmentations
and sets it as the current segment in the segment selector widget.
Args:
defaultSegmentName: The name of the default segment to set.
"""
# check if the default segment exists in either of the segmentations
segNode1 = self._parameterNode.GetNodeReference("outputSegmentation")
segNode2 = self._parameterNode.GetNodeReference("outputSegmentation_t2")
for segNode in [segNode1, segNode2]:
if segNode:
segmentId = segNode.GetSegmentation().GetSegmentIdBySegmentName(defaultSegmentName)
if segmentId:
self.ui.SegmentSelectorWidget.setCurrentNode(segNode)
self.ui.SegmentSelectorWidget.setCurrentSegmentID(segmentId)
break
[docs]
@staticmethod
def get_task_dir(model_key, parameterNode):
"""
Get the task directory for the given model key.
Args:
model_key: The key of the model to get the task directory for.
parameterNode: The parameter node for the RANO module.
Returns:
The task directory for the given model key.
"""
if model_key == "":
return ""
else:
model_info = json.loads(parameterNode.GetParameter("ModelInfo"))
task_dir = model_info[str(model_key)]
return task_dir