分割指标计算全家桶

import glob
import os
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import SimpleITK as sitk
from scipy.spatial.distance import directed_hausdorff
from scipy.ndimage import distance_transform_edt
from collections import defaultdict

from tqdm import tqdm


def load_nifti_image(file_path):
    """
    Load a NIfTI image file.
    """
    return sitk.GetArrayFromImage(sitk.ReadImage(file_path))


def get_image_size(file_path):
    """
    Get the size of a NIfTI image.
    """
    img = sitk.ReadImage(file_path)
    return img.GetSize()


def compute_hd95(label1, label2):
    """
    Compute the 95th percentile of the Hausdorff Distance (HD95) between two binary label images.
    """
    if np.sum(label1) == 0 or np.sum(label2) == 0:
        return np.inf  # If one of the labels is empty, return infinity

    # Ensure arrays are C-contiguous
    label1 = np.ascontiguousarray(label1)
    label2 = np.ascontiguousarray(label2)

    # Get the coordinates of the foreground points
    points1 = np.array(np.nonzero(label1)).T
    points2 = np.array(np.nonzero(label2)).T

    # Compute directed Hausdorff distances
    d_forward = directed_hausdorff(points1, points2)[0]
    d_backward = directed_hausdorff(points2, points1)[0]

    # Combine distances
    hd95 = np.percentile([d_forward, d_backward], 95)

    return hd95


def compute_dice(label1, label2):
    """
    Compute the Dice coefficient between two binary label images.
    """
    intersection = np.sum(label1 * label2)
    sum_labels = np.sum(label1) + np.sum(label2)
    if sum_labels == 0:
        return np.nan  # Return NaN if both labels are empty
    return (2. * intersection) / sum_labels


def compute_asd(label1, label2):
    """
    Compute the Average Surface Distance (ASD) between two binary label images.
    """
    if np.sum(label1) == 0 or np.sum(label2) == 0:
        return np.inf, np.inf  # If one of the labels is empty, return infinity

    # Compute the distance transform
    dt1 = distance_transform_edt(1 - label1)
    dt2 = distance_transform_edt(1 - label2)

    # Compute the surface distances
    surface1 = np.logical_xor(label1, np.logical_and(label1, sitk.GetArrayFromImage(
        sitk.BinaryErode(sitk.GetImageFromArray(label1), [1, 1, 1]))))
    surface2 = np.logical_xor(label2, np.logical_and(label2, sitk.GetArrayFromImage(
        sitk.BinaryErode(sitk.GetImageFromArray(label2), [1, 1, 1]))))

    sds1 = dt2[surface1]
    sds2 = dt1[surface2]

    asd1 = np.mean(sds1)
    asd2 = np.mean(sds2)

    return (asd1 + asd2) / 2, np.std(np.concatenate([sds1, sds2]))


def compute_iou(label1, label2):
    """
    Compute the Intersection over Union (IoU) between two binary label images.
    """
    intersection = np.sum(label1 * label2)
    union = np.sum(label1) + np.sum(label2) - intersection
    if union == 0:
        return np.nan
    return intersection / union


def compute_tpr(label1, label2):
    """
    Compute the True Positive Rate (TPR) or Sensitivity.
    """
    tp = np.sum(label1 * label2)
    fn = np.sum(label1 * (1 - label2))
    if (tp + fn) == 0:
        return np.nan
    return tp / (tp + fn)


def compute_precision(label1, label2):
    """
    Compute the Precision.
    """
    tp = np.sum(label1 * label2)
    fp = np.sum((1 - label1) * label2)
    if (tp + fp) == 0:
        return np.nan
    return tp / (tp + fp)


def compute_f1(label1, label2):
    """
    Compute the F1-score.
    """
    precision = compute_precision(label1, label2)
    tpr = compute_tpr(label1, label2)
    if (precision + tpr) == 0:
        return np.nan
    return 2 * (precision * tpr) / (precision + tpr)


def compute_voe(label1, label2):
    """
    Compute the Volume Overlap Error (VOE).
    """
    return 1 - compute_iou(label1, label2)


def compute_metrics_per_label(nii1, nii2):
    """
    Compute all metrics for each label in the given NIfTI images.
    """
    metrics_results = {}
    img1 = load_nifti_image(nii1)
    img2 = load_nifti_image(nii2)
    img2[img2 == 255] = 1

    hd95 = compute_hd95(img1, img2)
    dice = compute_dice(img1, img2)
    asd_mean, asd_std = compute_asd(img1, img2)
    iou = compute_iou(img1, img2)
    tpr = compute_tpr(img1, img2)
    precision = compute_precision(img1, img2)
    f1 = compute_f1(img1, img2)
    voe = compute_voe(img1, img2)

    metrics_results["HD95"] = hd95
    metrics_results["Dice"] = dice
    metrics_results["ASD_mean"] = asd_mean
    metrics_results["ASD_std"] = asd_std
    metrics_results["IoU"] = iou
    metrics_results["TPR"] = tpr
    metrics_results["Precision"] = precision
    metrics_results["F1"] = f1
    metrics_results["VOE"] = voe

    return metrics_results


def match_by_size(label_dir, output_dir):
    label_files = glob.glob(os.path.join(label_dir, '*.nii.gz'))
    output_files = glob.glob(os.path.join(output_dir, '*.nii.gz'))

    label_sizes = {f: get_image_size(f) for f in label_files}
    output_sizes = {f: get_image_size(f) for f in output_files}

    matched_pairs = {}
    for label_file, label_size in label_sizes.items():
        best_match = None
        best_dice = -1.0
        for output_file, output_size in output_sizes.items():
            if label_size == output_size:
                dice = compute_dice(load_nifti_image(label_file), load_nifti_image(output_file))
                if dice > best_dice:
                    best_dice = dice
                    best_match = output_file
        if best_match:
            matched_pairs[label_file] = best_match

    return matched_pairs


label_dir = r'D:\label'
output_dir = r'D:\output'

matched_pairs = match_by_size(label_dir, output_dir)

all_metrics = defaultdict(list)

for label_file, output_file in matched_pairs.items():
    metrics = compute_metrics_per_label(label_file, output_file)
    label_name = os.path.basename(label_file)

    print(f"File: {label_name}")
    for metric_name, metric_value in metrics.items():
        print(f"{metric_name}: {metric_value}")
        all_metrics[metric_name].append(metric_value)
    print("-" * 20)

print("Average Metrics:")
for metric_name,metric_values in all_metrics.items():
    avg_metric = np.nanmean(metric_values)
    print(f"{metric_name}: {avg_metric}")

“分割指标计算全家桶”的一个回复

发表评论